-
Manos Koukoutos authoredManos Koukoutos authored
Extractors.scala 14.35 KiB
/* Copyright 2009-2015 EPFL, Lausanne */
package leon
package purescala
import Expressions._
import Common._
import Types._
import Constructors._
import ExprOps._
import Definitions.Program
object Extractors {
object Operator {
def unapply(expr: Expr): Option[(Seq[Expr], (Seq[Expr]) => Expr)] = expr match {
/* Unary operators */
case Not(t) =>
Some((Seq(t), (es: Seq[Expr]) => not(es.head)))
case Choose(expr) =>
Some((Seq(expr), (es: Seq[Expr]) => Choose(es.head)))
case UMinus(t) =>
Some((Seq(t), (es: Seq[Expr]) => UMinus(es.head)))
case BVUMinus(t) =>
Some((Seq(t), (es: Seq[Expr]) => BVUMinus(es.head)))
case RealUMinus(t) =>
Some((Seq(t), (es: Seq[Expr]) => RealUMinus(es.head)))
case BVNot(t) =>
Some((Seq(t), (es: Seq[Expr]) => BVNot(es.head)))
case SetCardinality(t) =>
Some((Seq(t), (es: Seq[Expr]) => SetCardinality(es.head)))
case CaseClassSelector(cd, e, sel) =>
Some((Seq(e), (es: Seq[Expr]) => caseClassSelector(cd, es.head, sel)))
case IsInstanceOf(e, ct) =>
Some((Seq(e), (es: Seq[Expr]) => IsInstanceOf(es.head, ct)))
case AsInstanceOf(e, ct) =>
Some((Seq(e), (es: Seq[Expr]) => asInstOf(es.head, ct)))
case TupleSelect(t, i) =>
Some((Seq(t), (es: Seq[Expr]) => TupleSelect(es.head, i)))
case ArrayLength(a) =>
Some((Seq(a), (es: Seq[Expr]) => ArrayLength(es.head)))
case Lambda(args, body) =>
Some((Seq(body), (es: Seq[Expr]) => Lambda(args, es.head)))
case PartialLambda(mapping, tpe) =>
val sze = tpe.from.size + 1
val subArgs = mapping.flatMap { case (args, v) => args :+ v }
val builder = (as: Seq[Expr]) => {
def rec(kvs: Seq[Expr]): Seq[(Seq[Expr], Expr)] = kvs match {
case seq if seq.size >= sze =>
val ((args :+ res), rest) = seq.splitAt(sze)
(args -> res) +: rec(rest)
case Seq() => Seq.empty
case _ => sys.error("unexpected number of key/value expressions")
}
PartialLambda(rec(as), tpe)
}
Some((subArgs, builder))
case Forall(args, body) =>
Some((Seq(body), (es: Seq[Expr]) => Forall(args, es.head)))
/* Binary operators */
case LetDef(fd, body) => Some((
Seq(fd.fullBody, body),
(es: Seq[Expr]) => {
fd.fullBody = es(0)
LetDef(fd, es(1))
}
))
case Equals(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => equality(es(0), es(1)))
case Implies(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => implies(es(0), es(1)))
case Plus(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => plus(es(0), es(1)))
case Minus(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => minus(es(0), es(1)))
case Times(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => times(es(0), es(1)))
case Division(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => Division(es(0), es(1)))
case Remainder(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => Remainder(es(0), es(1)))
case Modulo(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => Modulo(es(0), es(1)))
case LessThan(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => LessThan(es(0), es(1)))
case GreaterThan(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => GreaterThan(es(0), es(1)))
case LessEquals(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => LessEquals(es(0), es(1)))
case GreaterEquals(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => GreaterEquals(es(0), es(1)))
case BVPlus(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => plus(es(0), es(1)))
case BVMinus(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => minus(es(0), es(1)))
case BVTimes(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => times(es(0), es(1)))
case BVDivision(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => BVDivision(es(0), es(1)))
case BVRemainder(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => BVRemainder(es(0), es(1)))
case BVAnd(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => BVAnd(es(0), es(1)))
case BVOr(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => BVOr(es(0), es(1)))
case BVXOr(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => BVXOr(es(0), es(1)))
case BVShiftLeft(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => BVShiftLeft(es(0), es(1)))
case BVAShiftRight(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => BVAShiftRight(es(0), es(1)))
case BVLShiftRight(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => BVLShiftRight(es(0), es(1)))
case RealPlus(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => plus(es(0), es(1)))
case RealMinus(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => minus(es(0), es(1)))
case RealTimes(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => times(es(0), es(1)))
case RealDivision(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => RealDivision(es(0), es(1)))
case ElementOfSet(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => ElementOfSet(es(0), es(1)))
case SubsetOf(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => SubsetOf(es(0), es(1)))
case SetIntersection(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => SetIntersection(es(0), es(1)))
case SetUnion(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => SetUnion(es(0), es(1)))
case SetDifference(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => SetDifference(es(0), es(1)))
case mg@MapApply(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => MapApply(es(0), es(1)))
case MapUnion(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => MapUnion(es(0), es(1)))
case MapDifference(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => MapDifference(es(0), es(1)))
case MapIsDefinedAt(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => MapIsDefinedAt(es(0), es(1)))
case ArraySelect(t1, t2) =>
Some(Seq(t1, t2), (es: Seq[Expr]) => ArraySelect(es(0), es(1)))
case Let(binder, e, body) =>
Some(Seq(e, body), (es: Seq[Expr]) => Let(binder, es(0), es(1)))
case Require(pre, body) =>
Some(Seq(pre, body), (es: Seq[Expr]) => Require(es(0), es(1)))
case Ensuring(body, post) =>
Some(Seq(body, post), (es: Seq[Expr]) => Ensuring(es(0), es(1)))
case Assert(const, oerr, body) =>
Some(Seq(const, body), (es: Seq[Expr]) => Assert(es(0), oerr, es(1)))
/* Other operators */
case fi@FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _)))
case mi@MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, as => MethodInvocation(as.head, cd, tfd, as.tail)))
case fa@Application(caller, args) => Some(caller +: args, as => application(as.head, as.tail))
case CaseClass(cd, args) => Some((args, CaseClass(cd, _)))
case And(args) => Some((args, and))
case Or(args) => Some((args, or))
case FiniteSet(els, base) =>
Some((els.toSeq, els => FiniteSet(els.toSet, base)))
case FiniteMap(args, f, t) => {
val subArgs = args.flatMap { case (k, v) => Seq(k, v) }.toSeq
val builder = (as: Seq[Expr]) => {
def rec(kvs: Seq[Expr]): Map[Expr, Expr] = kvs match {
case Seq(k, v, t@_*) =>
Map(k -> v) ++ rec(t)
case Seq() => Map()
case _ => sys.error("odd number of key/value expressions")
}
FiniteMap(rec(as), f, t)
}
Some((subArgs, builder))
}
case ArrayUpdated(t1, t2, t3) => Some((
Seq(t1, t2, t3),
(as: Seq[Expr]) => ArrayUpdated(as(0), as(1), as(2))
))
case NonemptyArray(elems, Some((default, length))) =>
val all = elems.values.toSeq :+ default :+ length
Some((all, as => {
val l = as.length
nonemptyArray(as.take(l - 2), Some((as(l - 2), as(l - 1))))
}))
case NonemptyArray(elems, None) =>
val all = elems.values.toSeq
Some((all, finiteArray))
case Tuple(args) => Some((args, tupleWrap))
case IfExpr(cond, thenn, elze) => Some((
Seq(cond, thenn, elze),
{ case Seq(c, t, e) => IfExpr(c, t, e) }
))
case MatchExpr(scrut, cases) => Some((
scrut +: cases.flatMap {
case SimpleCase(_, e) => Seq(e)
case GuardedCase(_, e1, e2) => Seq(e1, e2)
},
(es: Seq[Expr]) => {
var i = 1
val newcases = for (caze <- cases) yield caze match {
case SimpleCase(b, _) => i += 1; SimpleCase(b, es(i - 1))
case GuardedCase(b, _, _) => i += 2; GuardedCase(b, es(i - 2), es(i - 1))
}
matchExpr(es.head, newcases)
}
))
case Passes(in, out, cases) => Some((
in +: out +: cases.flatMap {
_.expressions
}, {
case Seq(in, out, es@_*) => {
var i = 0
val newcases = for (caze <- cases) yield caze match {
case SimpleCase(b, _) => i += 1; SimpleCase(b, es(i - 1))
case GuardedCase(b, _, _) => i += 2; GuardedCase(b, es(i - 1), es(i - 2))
}
passes(in, out, newcases)
}
}
))
/* Terminals */
case t: Terminal => Some(Seq[Expr](), (_:Seq[Expr]) => t)
/* Expr's not handled here should implement this trait */
case e: Extractable =>
e.extract
case _ =>
None
}
}
trait Extractable {
def extract: Option[(Seq[Expr], Seq[Expr] => Expr)]
}
object StringLiteral {
def unapply(e: Expr)(implicit pgm: Program): Option[String] = e match {
case CaseClass(cct, args) =>
for {
libS <- pgm.library.String
if cct.classDef == libS
(_, chars) <- isListLiteral(args.head)
if chars.forall(_.isInstanceOf[CharLiteral])
} yield {
chars.collect{ case CharLiteral(c) => c }.mkString
}
case _ =>
None
}
}
object TopLevelOrs { // expr1 AND (expr2 AND (expr3 AND ..)) => List(expr1, expr2, expr3)
def unapply(e: Expr): Option[Seq[Expr]] = e match {
case Let(i, e, TopLevelOrs(bs)) => Some(bs map (let(i,e,_)))
case Or(exprs) =>
Some(exprs.flatMap(unapply).flatten)
case e =>
Some(Seq(e))
}
}
object TopLevelAnds { // expr1 AND (expr2 AND (expr3 AND ..)) => List(expr1, expr2, expr3)
def unapply(e: Expr): Option[Seq[Expr]] = e match {
case Let(i, e, TopLevelAnds(bs)) => Some(bs map (let(i,e,_)))
case And(exprs) =>
Some(exprs.flatMap(unapply).flatten)
case e =>
Some(Seq(e))
}
}
object IsTyped {
def unapply[T <: Typed](e: T): Option[(T, TypeTree)] = Some((e, e.getType))
}
object FiniteArray {
def unapply(e: Expr): Option[(Map[Int, Expr], Option[Expr], Expr)] = e match {
case EmptyArray(_) =>
Some((Map(), None, IntLiteral(0)))
case NonemptyArray(els, Some((default, length))) =>
Some((els, Some(default), length))
case NonemptyArray(els, None) =>
Some((els, None, IntLiteral(els.size)))
case _ =>
None
}
}
object SimpleCase {
def apply(p : Pattern, rhs : Expr) = MatchCase(p, None, rhs)
def unapply(c : MatchCase) = c match {
case MatchCase(p, None, rhs) => Some((p, rhs))
case _ => None
}
}
object GuardedCase {
def apply(p : Pattern, g: Expr, rhs : Expr) = MatchCase(p, Some(g), rhs)
def unapply(c : MatchCase) = c match {
case MatchCase(p, Some(g), rhs) => Some((p, g, rhs))
case _ => None
}
}
object Pattern {
def unapply(p : Pattern) : Option[(
Option[Identifier],
Seq[Pattern],
(Option[Identifier], Seq[Pattern]) => Pattern
)] = Option(p) map {
case InstanceOfPattern(b, ct) => (b, Seq(), (b, _) => InstanceOfPattern(b,ct))
case WildcardPattern(b) => (b, Seq(), (b, _) => WildcardPattern(b))
case CaseClassPattern(b, ct, subs) => (b, subs, (b, sp) => CaseClassPattern(b, ct, sp))
case TuplePattern(b,subs) => (b, subs, (b, sp) => TuplePattern(b, sp))
case LiteralPattern(b, l) => (b, Seq(), (b, _) => LiteralPattern(b, l))
case UnapplyPattern(b, fd, subs) => (b, subs, (b, sp) => UnapplyPattern(b, fd, sp))
}
}
def unwrapTuple(e: Expr, isTuple: Boolean): Seq[Expr] = e.getType match {
case TupleType(subs) if isTuple =>
for (ind <- 1 to subs.size) yield { tupleSelect(e, ind, isTuple) }
case _ if !isTuple => Seq(e)
case tp => sys.error(s"Calling unwrapTuple on non-tuple $e of type $tp")
}
def unwrapTuple(e: Expr, expectedSize: Int): Seq[Expr] = unwrapTuple(e, expectedSize > 1)
def unwrapTupleType(tp: TypeTree, isTuple: Boolean): Seq[TypeTree] = tp match {
case TupleType(subs) if isTuple => subs
case tp if !isTuple => Seq(tp)
case tp => sys.error(s"Calling unwrapTupleType on $tp")
}
def unwrapTupleType(tp: TypeTree, expectedSize: Int): Seq[TypeTree] =
unwrapTupleType(tp, expectedSize > 1)
def unwrapTuplePattern(p: Pattern, isTuple: Boolean): Seq[Pattern] = p match {
case TuplePattern(_, subs) if isTuple => subs
case tp if !isTuple => Seq(tp)
case tp => sys.error(s"Calling unwrapTuplePattern on $p")
}
def unwrapTuplePattern(p: Pattern, expectedSize: Int): Seq[Pattern] =
unwrapTuplePattern(p, expectedSize > 1)
object LetPattern {
def apply(patt : Pattern, value: Expr, body: Expr) : Expr = {
patt match {
case WildcardPattern(Some(binder)) => Let(binder, value, body)
case _ => MatchExpr(value, List(SimpleCase(patt, body)))
}
}
def unapply(me : MatchExpr) : Option[(Pattern, Expr, Expr)] = {
if (me eq null) None else { me match {
case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, variablesOf(scrut)) =>
Some(( pattern, scrut, body ))
case _ => None
}}
}
}
object LetTuple {
def unapply(me : MatchExpr) : Option[(Seq[Identifier], Expr, Expr)] = {
if (me eq null) None else { me match {
case LetPattern(TuplePattern(None,subPatts), value, body) if
subPatts forall { case WildcardPattern(Some(_)) => true; case _ => false } =>
Some((subPatts map { _.binder.get }, value, body ))
case _ => None
}}
}
}
}