diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 493efc40ed9c89149587e45048b58541ba25c860..21291bd033a1eb1689ce8f17d7c3061b8c551d22 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -168,14 +168,14 @@ trait CodeGeneration { val body = funDef.body.getOrElse(throw CompilationException("Can't compile a FunDef without body: "+funDef.id.name)) val bodyWithPre = if(funDef.hasPrecondition && params.checkContracts) { - IfExpr(funDef.precondition.get, body, Error("Precondition failed")) + IfExpr(funDef.precondition.get, body, Error(body.getType, "Precondition failed")) } else { body } val bodyWithPost = if(funDef.hasPostcondition && params.checkContracts) { val Some((id, post)) = funDef.postcondition - Let(id, bodyWithPre, IfExpr(post, Variable(id), Error("Postcondition failed")) ) + Let(id, bodyWithPre, IfExpr(post, Variable(id), Error(id.getType, "Postcondition failed")) ) } else { bodyWithPre } @@ -207,7 +207,7 @@ trait CodeGeneration { load(id, ch) case Assert(cond, oerr, body) => - mkExpr(IfExpr(Not(cond), Error(oerr.getOrElse("Assertion failed @"+e.getPos)), body), ch) + mkExpr(IfExpr(Not(cond), Error(body.getType, oerr.getOrElse("Assertion failed @"+e.getPos)), body), ch) case Ensuring(body, id, post) => mkExpr(Let(id, body, Assert(post, Some("Ensuring failed"), Variable(id))), ch) @@ -664,7 +664,7 @@ trait CodeGeneration { } // Misc and boolean tests - case Error(desc) => + case Error(tpe, desc) => ch << New(ErrorClass) << DUP ch << Ldc(desc) ch << InvokeSpecial(ErrorClass, constructorName, "(Ljava/lang/String;)V") diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index e55f57777a4fb6a58c29cac0b602b157fae9ef7e..a29175fe11618832b18127cf9c0d2e21f9e0c762 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -166,7 +166,7 @@ class CompilationUnit(val ctx: LeonContext, case f @ FiniteLambda(dflt, els) => val l = new leon.codegen.runtime.FiniteLambda(exprToJVM(dflt)) for ((k,v) <- els) { - val jvmK = if (f.fixedType.from.size == 1) { + val jvmK = if (f.getType.from.size == 1) { exprToJVM(Tuple(Seq(k))) } else { exprToJVM(k) diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index 0173e39dd4a628e7eae5f1bd88e96469556112d7..6adab6929b9bac530248f97f769fb53c8e53e4a4 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -62,7 +62,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case tt @ TupleType(parts) => constructors.getOrElse(tt, { - val cs = List(Constructor[Expr, TypeTree](parts, tt, s => Tuple(s).setType(tt), tt.toString)) + val cs = List(Constructor[Expr, TypeTree](parts, tt, s => Tuple(s), tt.toString)) constructors += tt -> cs cs }) diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala index 67515e4c0c5da1c5da10bbc1efd0e633c3fdac93..be8f151fed01e0b135e03dc01ace0009d4fbae32 100644 --- a/src/main/scala/leon/evaluators/DualEvaluator.scala +++ b/src/main/scala/leon/evaluators/DualEvaluator.scala @@ -6,6 +6,7 @@ package evaluators import purescala.Common._ import purescala.Trees._ import purescala.Definitions._ +import purescala.TypeTrees.MutableTyped import codegen._ @@ -27,7 +28,7 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte def withVars(news: Map[Identifier, Expr]) = copy(news) } - case class RawObject(o: AnyRef) extends Expr + case class RawObject(o: AnyRef) extends Expr with MutableTyped def call(tfd: TypedFunDef, args: Seq[AnyRef]): Expr = { diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index dbe7d651b8a13de82626f27b52de7751bd016646..2a374ff19ff67274d82f33fb9130c324b9f22926 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -104,12 +104,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int e(b)(rctx.withNewVar(i, first), gctx) case Assert(cond, oerr, body) => - e(IfExpr(Not(cond), Error(oerr.getOrElse("Assertion failed @"+expr.getPos)), body)) + e(IfExpr(Not(cond), Error(expr.getType, oerr.getOrElse("Assertion failed @"+expr.getPos)), body)) case Ensuring(body, id, post) => e(Let(id, body, Assert(post, Some("Ensuring failed"), Variable(id)))) - case Error(desc) => + case Error(tpe, desc) => throw RuntimeError("Error reached in evaluation: " + desc) case IfExpr(cond, thenn, elze) => diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index e99cd8269b8d6f08777fd7034342bfb98f31939c..4647f3a036fd6ad448b1ba223fe67fe56a8e7862 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -98,13 +98,13 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex } catch { case ee @ EvalError(e) => if (rctx.tracingFrames > 0) { - gctx.values ::= (expr -> Error(e)) + gctx.values ::= (expr -> Error(expr.getType, e)) } throw ee; case re @ RuntimeError(e) => if (rctx.tracingFrames > 0) { - gctx.values ::= (expr -> Error(e)) + gctx.values ::= (expr -> Error(expr.getType, e)) } throw re; } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index cbaa9f400f7296721c536fdaf4b34ae73b467f36..a993e00d6c3d847520933c3cd2a13add1e876e62 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1006,8 +1006,7 @@ trait CodeExtraction extends ASTExtractors { case passes @ ExPasses(sel, cses) => val rs = extractTree(sel) val rc = cses.map(extractMatchCase(_)) - val rt: LeonType = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_).get) - Passes(rs, rc).setType(rt) + Passes(rs, rc) case ExArrayLiteral(tpe, args) => @@ -1027,7 +1026,7 @@ trait CodeExtraction extends ASTExtractors { Tuple(tupleExprs) case ExErrorExpression(str, tpt) => - Error(str).setType(extractType(tpt)) + Error(extractType(tpt), str) case ExTupleExtract(tuple, index) => val tupleExpr = extractTree(tuple) @@ -1324,8 +1323,7 @@ trait CodeExtraction extends ASTExtractors { FiniteSet(args.map(extractTree(_)).toSet).setType(SetType(underlying)) case ExFiniteMultiset(tt, args) => - val underlying = extractType(tt) - FiniteMultiset(args.map(extractTree(_))).setType(MultisetType(underlying)) + FiniteMultiset(args.map(extractTree(_))) case ExEmptySet(tt) => val underlying = extractType(tt) @@ -1333,7 +1331,7 @@ trait CodeExtraction extends ASTExtractors { case ExEmptyMultiset(tt) => val underlying = extractType(tt) - EmptyMultiset(underlying).setType(MultisetType(underlying)) + EmptyMultiset(underlying) case ExEmptyMap(ft, tt) => val fromUnderlying = extractType(ft) @@ -1362,7 +1360,7 @@ trait CodeExtraction extends ASTExtractors { val underlying = extractType(baseType) val lengthRec = extractTree(length) val defaultValueRec = extractTree(defaultValue) - ArrayFill(lengthRec, defaultValueRec).setType(ArrayType(underlying)) + ArrayFill(lengthRec, defaultValueRec) case ExIfThenElse(t1,t2,t3) => val r1 = extractTree(t1) @@ -1374,7 +1372,7 @@ trait CodeExtraction extends ASTExtractors { val lub = leastUpperBound(r2.getType, r3.getType) lub match { case Some(lub) => - IfExpr(r1, r2, r3).setType(lub) + IfExpr(r1, r2, r3) case None => outOfSubsetError(tr, "Both branches of ifthenelse have incompatible types ("+r2.getType.asString(ctx)+" and "+r3.getType.asString(ctx)+")") @@ -1409,8 +1407,7 @@ trait CodeExtraction extends ASTExtractors { case pm @ ExPatternMatching(sel, cses) => val rs = extractTree(sel) val rc = cses.map(extractMatchCase(_)) - val rt: LeonType = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_).get) - MatchExpr(rs, rc).setType(rt) + MatchExpr(rs, rc) case t: This => extractType(t) match { @@ -1471,7 +1468,7 @@ trait CodeExtraction extends ASTExtractors { val newTps = tps.map(t => extractType(t)) - FunctionInvocation(fd.typed(newTps), args).setType(fd.returnType) + FunctionInvocation(fd.typed(newTps), args) case (IsTyped(rec, ct: ClassType), _, args) if isMethod(sym) => val fd = getFunDef(sym, c.pos) @@ -1533,22 +1530,22 @@ trait CodeExtraction extends ASTExtractors { // Set methods case (IsTyped(a1, SetType(b1)), "min", Nil) => - SetMin(a1).setType(b1) + SetMin(a1) case (IsTyped(a1, SetType(b1)), "max", Nil) => - SetMax(a1).setType(b1) + SetMax(a1) case (IsTyped(a1, SetType(b1)), "++", List(IsTyped(a2, SetType(b2)))) if b1 == b2 => - SetUnion(a1, a2).setType(SetType(b1)) + SetUnion(a1, a2) case (IsTyped(a1, SetType(b1)), "&", List(IsTyped(a2, SetType(b2)))) if b1 == b2 => - SetIntersection(a1, a2).setType(SetType(b1)) + SetIntersection(a1, a2) case (IsTyped(a1, SetType(b1)), "subsetOf", List(IsTyped(a2, SetType(b2)))) if b1 == b2 => SubsetOf(a1, a2) case (IsTyped(a1, SetType(b1)), "--", List(IsTyped(a2, SetType(b2)))) if b1 == b2 => - SetDifference(a1, a2).setType(SetType(b1)) + SetDifference(a1, a2) case (IsTyped(a1, SetType(b1)), "contains", List(a2)) => ElementOfSet(a2, a1) @@ -1556,37 +1553,37 @@ trait CodeExtraction extends ASTExtractors { // Multiset methods case (IsTyped(a1, MultisetType(b1)), "++", List(IsTyped(a2, MultisetType(b2)))) if b1 == b2 => - MultisetUnion(a1, a2).setType(MultisetType(b1)) + MultisetUnion(a1, a2) case (IsTyped(a1, MultisetType(b1)), "&", List(IsTyped(a2, MultisetType(b2)))) if b1 == b2 => - MultisetIntersection(a1, a2).setType(MultisetType(b1)) + MultisetIntersection(a1, a2) case (IsTyped(a1, MultisetType(b1)), "--", List(IsTyped(a2, MultisetType(b2)))) if b1 == b2 => - MultisetDifference(a1, a2).setType(MultisetType(b1)) + MultisetDifference(a1, a2) case (IsTyped(a1, MultisetType(b1)), "+++", List(IsTyped(a2, MultisetType(b2)))) if b1 == b2 => - MultisetPlus(a1, a2).setType(MultisetType(b1)) + MultisetPlus(a1, a2) case (IsTyped(_, MultisetType(b1)), "toSet", Nil) => - MultisetToSet(rrec).setType(b1) + MultisetToSet(rrec) // Array methods case (IsTyped(a1, ArrayType(vt)), "apply", List(a2)) => - ArraySelect(a1, a2).setType(vt) + ArraySelect(a1, a2) case (IsTyped(a1, at: ArrayType), "length", Nil) => ArrayLength(a1) case (IsTyped(a1, at: ArrayType), "clone", Nil) => - ArrayClone(a1).setType(at) + ArrayClone(a1) case (IsTyped(a1, at: ArrayType), "updated", List(k, v)) => - ArrayUpdated(a1, k, v).setType(at) + ArrayUpdated(a1, k, v) // Map methods case (IsTyped(a1, MapType(_, vt)), "apply", List(a2)) => - MapGet(a1, a2).setType(vt) + MapGet(a1, a2) case (IsTyped(a1, mt: MapType), "isDefinedAt", List(a2)) => MapIsDefinedAt(a1, a2) @@ -1595,10 +1592,10 @@ trait CodeExtraction extends ASTExtractors { MapIsDefinedAt(a1, a2) case (IsTyped(a1, mt: MapType), "updated", List(k, v)) => - MapUnion(a1, FiniteMap(Seq((k, v))).setType(mt)).setType(mt) + MapUnion(a1, FiniteMap(Seq((k, v))).setType(mt)) case (IsTyped(a1, mt1: MapType), "++", List(IsTyped(a2, mt2: MapType))) if mt1 == mt2 => - MapUnion(a1, a2).setType(mt1) + MapUnion(a1, a2) case (_, name, _) => outOfSubsetError(tr, "Unknown call to "+name) diff --git a/src/main/scala/leon/purescala/Common.scala b/src/main/scala/leon/purescala/Common.scala index e9501ab8aaa400396c78f9b318837ec25a371c38..a34958333e1649dedca8a451b38feb9c92c31211 100644 --- a/src/main/scala/leon/purescala/Common.scala +++ b/src/main/scala/leon/purescala/Common.scala @@ -8,14 +8,14 @@ import Definitions.Definition object Common { import Trees.Variable - import TypeTrees.Typed + import TypeTrees.{MutableTyped,Typed} abstract class Tree extends Positioned with Serializable { def copiedFrom(o: Tree): this.type = { setPos(o) (this, o) match { // do not force if already set - case (t1: Typed, t2: Typed) if !t1.isTyped => + case (t1: MutableTyped, t2: Typed) if !t1.isTyped => t1.setType(t2.getType) case _ => } @@ -30,7 +30,7 @@ object Common { } // the type is left blank (Untyped) for Identifiers that are not variables - class Identifier private[Common](val name: String, val globalId: Int, val id: Int, alwaysShowUniqueID: Boolean = false) extends Tree with Typed { + class Identifier private[Common](val name: String, val globalId: Int, val id: Int, alwaysShowUniqueID: Boolean = false) extends Tree with MutableTyped { self : Serializable => override def equals(other: Any): Boolean = { diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index e89752798631dff69797584e2be86e680904bf9b..f6d1089cbaa1a4976e9b3c997918bd17208f30af 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -45,19 +45,15 @@ object Definitions { } /** A ValDef declares a new identifier to be of a certain type. */ - case class ValDef(id: Identifier, tpe: TypeTree) extends Definition with FixedType { + case class ValDef(id: Identifier, tpe: TypeTree) extends Definition with Typed { self: Serializable => - val fixedType = tpe + val getType = tpe def subDefinitions = Seq() - override def hashCode : Int = id.hashCode - override def equals(that : Any) : Boolean = that match { - case t : ValDef => t.id == this.id - case _ => false - } def toVariable : Variable = Variable(id).setType(tpe) + setSubDefOwners() } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 501f9955174a8e6841de2f9331248ea2059831e4..4a83deef688a527c72996dcb76ddb85a118d4b39 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -59,7 +59,6 @@ object Extractors { case SetUnion(t1,t2) => Some((t1,t2,SetUnion)) case SetDifference(t1,t2) => Some((t1,t2,SetDifference)) case Multiplicity(t1,t2) => Some((t1,t2,Multiplicity)) - case SubmultisetOf(t1,t2) => Some((t1,t2,SubmultisetOf)) case MultisetIntersection(t1,t2) => Some((t1,t2,MultisetIntersection)) case MultisetUnion(t1,t2) => Some((t1,t2,MultisetUnion)) case MultisetPlus(t1,t2) => Some((t1,t2,MultisetPlus)) @@ -123,7 +122,10 @@ object Extractors { } case FiniteMultiset(args) => Some((args, FiniteMultiset)) case ArrayUpdated(t1, t2, t3) => Some((Seq(t1,t2,t3), (as: Seq[Expr]) => ArrayUpdated(as(0), as(1), as(2)))) - case FiniteArray(args) => Some((args, FiniteArray)) + case FiniteArray(args) => Some((args, { (as: Seq[Expr]) => + val tpe = leastUpperBound(as.map(_.getType)).map(ArrayType(_)).getOrElse(expr.getType) + FiniteArray(as).setType(tpe) + })) case Distinct(args) => Some((args, Distinct)) case Tuple(args) => Some((args, Tuple)) case IfExpr(cond, thenn, elze) => Some((Seq(cond, thenn, elze), (as: Seq[Expr]) => IfExpr(as(0), as(1), as(2)))) diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index 20058923c84bf8795eebd5745097d7dec35ed9d9..bb9e1ae10ff1f02ad639ca03928ffa1a9b5d17d8 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -162,7 +162,7 @@ object FunctionClosure extends TransformationPhase { } } val tpe = csesRec.head.rhs.getType - MatchExpr(scrutRec, csesRec).copiedFrom(m).setType(tpe) + MatchExpr(scrutRec, csesRec).copiedFrom(m) } case v @ Variable(id) => id2freshId.get(id) match { case None => v diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index f09e63815207be5ab5a6a7df2e90952c7e8f097d..429b8e1b9d9b54b7a3ea432e8371c351df2c3878 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -285,7 +285,7 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case Tuple(exprs) => p"($exprs)" case TupleSelect(t, i) => p"${t}._$i" case Choose(vars, pred) => p"choose(($vars) => $pred)" - case e @ Error(err) => p"""error[${e.getType}]("$err")""" + case e @ Error(tpe, err) => p"""error[$tpe]("$err")""" case CaseClassInstanceOf(cct, e) => if (cct.classDef.isCaseObject) { p"($e == $cct)" diff --git a/src/main/scala/leon/purescala/SimplifierWithPaths.scala b/src/main/scala/leon/purescala/SimplifierWithPaths.scala index 6353c6d9069a859dc1c863ad5abb3f47fa130f80..f40946b8ff44b4cc8799f9c851a0346b307e0fb7 100644 --- a/src/main/scala/leon/purescala/SimplifierWithPaths.scala +++ b/src/main/scala/leon/purescala/SimplifierWithPaths.scala @@ -112,8 +112,8 @@ class SimplifierWithPaths(sf: SolverFactory[Solver]) extends TransformerWithPC { } } newCases match { - case List() => Error("Unreachable code").copiedFrom(e) - case List(theCase) if !scrut.getType.isInstanceOf[AbstractClassType] => + case List() => Error(e.getType, "Unreachable code").copiedFrom(e) + case List(theCase) if !scrut.getType.isInstanceOf[AbstractClassType] => // Avoid AbstractClassType as it may lead to invalid field accesses replaceFromIDs(mapForPattern(scrut, theCase.pattern), theCase.rhs) case _ => MatchExpr(rs, newCases).copiedFrom(e) diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 0ba220f7d257d1bb54868e74df7a127870802ea6..25e096264010da9143ba23f0931a3496d50a49d3 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -397,10 +397,10 @@ object TreeOps { case LessEquals(e1,e2) => GreaterThan(e1,e2) case GreaterThan(e1,e2) => LessEquals(e1,e2) case GreaterEquals(e1,e2) => LessThan(e1,e2) - case i @ IfExpr(c,e1,e2) => IfExpr(c, negate(e1), negate(e2)).setType(i.getType) + case i @ IfExpr(c,e1,e2) => IfExpr(c, negate(e1), negate(e2)) case BooleanLiteral(b) => BooleanLiteral(!b) case _ => Not(expr) - }).setType(expr.getType).setPos(expr) + }).setPos(expr) // rewrites pattern-matching expressions to use fresh variables for the binders def freshenLocals(expr: Expr) : Expr = { @@ -606,8 +606,8 @@ object TreeOps { def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match { case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s) case l @ Let(i,e,b) => rec(b, s + (i -> rec(e, s))) - case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)).setType(i.getType) - case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut, s), cses.map(inCase(_, s))).setType(m.getType).setPos(m) + case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)) + case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut, s), cses.map(inCase(_, s))).setPos(m) case n @ NAryOperator(args, recons) => { var change = false val rargs = args.map(a => { @@ -620,7 +620,7 @@ object TreeOps { } }) if(change) - recons(rargs).setType(n.getType) + recons(rargs) else n } @@ -628,14 +628,14 @@ object TreeOps { val r1 = rec(t1, s) val r2 = rec(t2, s) if(r1 != t1 || r2 != t2) - recons(r1,r2).setType(b.getType) + recons(r1,r2) else b } case u @ UnaryOperator(t,recons) => { val r = rec(t, s) if(r != t) - recons(r).setType(u.getType) + recons(r) else u } @@ -778,7 +778,7 @@ object TreeOps { case TuplePattern(ob, subps) => { val TupleType(tpes) = in.getType assert(tpes.size == subps.size) - val subTests = subps.zipWithIndex.map{case (p, i) => rec(TupleSelect(in, i+1).setType(tpes(i)), p)} + val subTests = subps.zipWithIndex.map{case (p, i) => rec(TupleSelect(in, i+1), p)} And(bind(ob, in) +: subTests) } case LiteralPattern(ob,lit) => And(Equals(in,lit), bind(ob,in)) @@ -807,7 +807,7 @@ object TreeOps { val TupleType(tpes) = in.getType assert(tpes.size == subps.size) - val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)} + val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(TupleSelect(in, i+1), p)} val map = maps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _) b match { case Some(id) => map + (id -> in) @@ -835,11 +835,11 @@ object TreeOps { (realCond, newRhs) } - val bigIte = condsAndRhs.foldRight[Expr](Error("Match is non-exhaustive").copiedFrom(m))((p1, ex) => { + val bigIte = condsAndRhs.foldRight[Expr](Error(m.getType, "Match is non-exhaustive").copiedFrom(m))((p1, ex) => { if(p1._1 == BooleanLiteral(true)) { p1._2 } else { - IfExpr(p1._1, p1._2, ex).setType(m.getType) + IfExpr(p1._1, p1._2, ex) } }) @@ -884,7 +884,7 @@ object TreeOps { val r = postMap({ case mg @ MapGet(m,k) => val ida = MapIsDefinedAt(m, k) - Some(IfExpr(ida, mg, Error("key not found for map access").copiedFrom(mg)).copiedFrom(mg)) + Some(IfExpr(ida, mg, Error(mg.getType, "Key not found for map access").copiedFrom(mg)).copiedFrom(mg)) case _=> None @@ -966,7 +966,7 @@ object TreeOps { Some(IfExpr(c, op(beforeIte ++ Seq(t) ++ afterIte).copiedFrom(nop), op(beforeIte ++ Seq(e) ++ afterIte).copiedFrom(nop) - ).setType(nop.getType)) + )) } } case _ => None @@ -1234,7 +1234,7 @@ object TreeOps { case Tuple(Seq()) => UnitLiteral() case Variable(id) if idMap contains id => Variable(idMap(id)) - case Error(err) => Error(err).setType(mapType(e.getType).getOrElse(e.getType)).copiedFrom(e) + case Error(tpe, err) => Error(mapType(tpe).getOrElse(e.getType), err).copiedFrom(e) case Tuple(Seq(s)) => pre(s) case ts @ TupleSelect(t, 1) => t.getType match { @@ -1916,9 +1916,9 @@ object TreeOps { case l @ Lambda(args, body) => val newBody = rec(body, true) extract(Lambda(args, newBody), build) - case NAryOperator(es, recons) => recons(es.map(rec(_, build))).setType(expr.getType) - case BinaryOperator(e1, e2, recons) => recons(rec(e1, build), rec(e2, build)).setType(expr.getType) - case UnaryOperator(e, recons) => recons(rec(e, build)).setType(expr.getType) + case NAryOperator(es, recons) => recons(es.map(rec(_, build))) + case BinaryOperator(e1, e2, recons) => recons(rec(e1, build), rec(e2, build)) + case UnaryOperator(e, recons) => recons(rec(e, build)) case t: Terminal => t } @@ -2268,10 +2268,10 @@ object TreeOps { Seq(c) }} - var finalMatch = MatchExpr(scrutinee, List(newCases.head)).setType(e.getType) + var finalMatch = MatchExpr(scrutinee, List(newCases.head)) for (toAdd <- newCases.tail if !isMatchExhaustive(finalMatch)) { - finalMatch = MatchExpr(scrutinee, finalMatch.cases :+ toAdd).setType(e.getType) + finalMatch = MatchExpr(scrutinee, finalMatch.cases :+ toAdd) } finalMatch diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index ec3879c1fc912be52eaa5b40b57bd82cf9b61d60..4cc550547f117397fd95ecf0cb80c3a9fdd33b5e 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -21,96 +21,70 @@ object Trees { self: Expr => } - case class NoTree(tpe: TypeTree) extends Expr with Terminal with FixedType { - val fixedType = tpe + case class NoTree(tpe: TypeTree) extends Expr with Terminal with Typed { + val getType = tpe } /* This describes computational errors (unmatched case, taking min of an * empty set, division by zero, etc.). It should always be typed according to * the expected type. */ - case class Error(description: String) extends Expr with Terminal - - case class Require(pred: Expr, body: Expr) extends Expr with FixedType { - val fixedType = body.getType + case class Error(tpe: TypeTree, description: String) extends Expr with Terminal { + val getType = tpe } - case class Ensuring(body: Expr, id: Identifier, pred: Expr) extends Expr with FixedType { - val fixedType = body.getType + case class Require(pred: Expr, body: Expr) extends Expr with Typed { + def getType = body.getType } - case class Assert(pred: Expr, error: Option[String], body: Expr) extends Expr with FixedType { - val fixedType = body.getType + case class Ensuring(body: Expr, id: Identifier, pred: Expr) extends Expr { + def getType = body.getType } - case class Passes(scrut: Expr, tests : List[MatchCase]) extends Expr with FixedType { - val fixedType = leastUpperBound(tests.map(_.rhs.getType)).getOrElse{ - Untyped - } + case class Assert(pred: Expr, error: Option[String], body: Expr) extends Expr { + def getType = body.getType } - case class Choose(vars: List[Identifier], pred: Expr) extends Expr with FixedType with UnaryExtractable { + case class Passes(scrut: Expr, tests : List[MatchCase]) extends Expr { + def getType = leastUpperBound(tests.map(_.rhs.getType)).getOrElse(Untyped) + } + case class Choose(vars: List[Identifier], pred: Expr) extends Expr with UnaryExtractable { assert(!vars.isEmpty) - val fixedType = if (vars.size > 1) TupleType(vars.map(_.getType)) else vars.head.getType + def getType = if (vars.size > 1) TupleType(vars.map(_.getType)) else vars.head.getType def extract = { Some((pred, (e: Expr) => Choose(vars, e).setPos(this))) } } - // Provide an oracle (synthesizable, all-seeing choose) - case class WithOracle(oracles: List[Identifier], body: Expr) extends Expr with FixedType with UnaryExtractable { - assert(!oracles.isEmpty) - - val fixedType = body.getType - - def extract = { - Some((body, (e: Expr) => WithOracle(oracles, e).setPos(this))) - } - } - - case class Hole(fixedType: TypeTree, alts: Seq[Expr]) extends Expr with FixedType with NAryExtractable { - - def extract = { - Some((alts, (es: Seq[Expr]) => Hole(fixedType, es).setPos(this))) - } - } - - case class RepairHole(fixedType: TypeTree, components: Seq[Expr]) extends Expr with FixedType with NAryExtractable { - - def extract = { - Some((components, (es: Seq[Expr]) => RepairHole(fixedType, es).setPos(this))) - } - } - - /* Like vals */ - case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr with FixedType { - val fixedType = body.getType + case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr { + def getType = body.getType } - case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr with FixedType { + case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr { assert(value.getType.isInstanceOf[TupleType], "The definition value in LetTuple must be of some tuple type; yet we got [%s]. In expr: \n%s".format(value.getType, this)) - val fixedType = body.getType + def getType = body.getType } case class LetDef(fd: FunDef, body: Expr) extends Expr { - val et = body.getType - if(et != Untyped) - setType(et) - + def getType = body.getType } - - /* Control flow */ - case class FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) extends Expr with FixedType { - val fixedType = tfd.returnType + case class FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) extends Expr { + def getType = tfd.returnType } - case class MethodInvocation(rec: Expr, cd: ClassDef, tfd: TypedFunDef, args: Seq[Expr]) extends Expr with FixedType { - val fixedType = { + /** + * OO Trees + * + * Both MethodInvocation and This get removed by phase MethodLifting. Methods become functions, + * This becomes first argument, and MethodInvocation become FunctionInvocation. + */ + case class MethodInvocation(rec: Expr, cd: ClassDef, tfd: TypedFunDef, args: Seq[Expr]) extends Expr { + def getType = { // We need ot instanciate the type based on the type of the function as well as receiver val fdret = tfd.returnType val extraMap: Map[TypeParameterDef, TypeTree] = rec.getType match { @@ -124,13 +98,13 @@ object Trees { } } - case class Application(caller: Expr, args: Seq[Expr]) extends Expr with FixedType { + case class Application(caller: Expr, args: Seq[Expr]) extends Expr { assert(caller.getType.isInstanceOf[FunctionType]) - val fixedType = caller.getType.asInstanceOf[FunctionType].to + def getType = caller.getType.asInstanceOf[FunctionType].to } - case class Lambda(args: Seq[ValDef], body: Expr) extends Expr with FixedType { - val fixedType = FunctionType(args.map(_.tpe), body.getType) + case class Lambda(args: Seq[ValDef], body: Expr) extends Expr { + def getType = FunctionType(args.map(_.tpe), body.getType) } object FiniteLambda { @@ -179,25 +153,24 @@ object Trees { } } - case class Forall(args: Seq[ValDef], body: Expr) extends Expr with FixedType { + case class Forall(args: Seq[ValDef], body: Expr) extends Expr { assert(body.getType == BooleanType) - val fixedType = BooleanType + def getType = BooleanType } - case class This(ct: ClassType) extends Expr with FixedType with Terminal { - val fixedType = ct + case class This(ct: ClassType) extends Expr with Terminal { + def getType = ct } - case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr with FixedType { - val fixedType = leastUpperBound(thenn.getType, elze.getType).getOrElse{ - Untyped - } + case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr { + def getType = leastUpperBound(thenn.getType, elze.getType).getOrElse(Untyped) } - case class Tuple(exprs: Seq[Expr]) extends Expr with FixedType { - val fixedType = TupleType(exprs.map(_.getType)) + case class Tuple(exprs: Seq[Expr]) extends Expr { + def getType = TupleType(exprs.map(_.getType)) } + // TODO: ship this simplification to constructors object TupleSelect { def apply(tuple: Expr, index: Int): Expr = { tuple match { @@ -212,12 +185,12 @@ object Trees { } // This must be 1-indexed ! (So are methods of Scala Tuples) - class TupleSelect(val tuple: Expr, val index: Int) extends Expr with FixedType { + class TupleSelect(val tuple: Expr, val index: Int) extends Expr { assert(index >= 1) assert(tuple.getType.isInstanceOf[TupleType], "Applying TupleSelect on a non-tuple tree [%s] of type [%s].".format(tuple, tuple.getType)) - val fixedType : TypeTree = tuple.getType match { + def getType = tuple.getType match { case TupleType(ts) => assert(index <= ts.size) ts(index - 1) @@ -244,19 +217,17 @@ object Trees { })) case _: TupleType | Int32Type | BooleanType | UnitType => new MatchExpr(scrutinee, cases) - case _ => scala.sys.error("Constructing match expression on non-supported type.") + case t => scala.sys.error("Constructing match expression on non-supported type: "+t) } } def unapply(me: MatchExpr) : Option[(Expr,Seq[MatchCase])] = if (me == null) None else Some((me.scrutinee, me.cases)) } - class MatchExpr(val scrutinee: Expr, val cases: Seq[MatchCase]) extends Expr with FixedType { + class MatchExpr(val scrutinee: Expr, val cases: Seq[MatchCase]) extends Expr { assert(cases.nonEmpty) - val fixedType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse{ - Untyped - } + def getType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse(Untyped) def scrutineeClassType: ClassType = scrutinee.getType.asInstanceOf[ClassType] @@ -355,8 +326,8 @@ object Trees { if(and == null) None else Some(and.exprs) } - class And private (val exprs: Seq[Expr]) extends Expr with FixedType { - val fixedType = BooleanType + class And private (val exprs: Seq[Expr]) extends Expr { + def getType = BooleanType assert(exprs.size >= 2) @@ -398,8 +369,8 @@ object Trees { if(or == null) None else Some(or.exprs) } - class Or private[Trees] (val exprs: Seq[Expr]) extends Expr with FixedType { - val fixedType = BooleanType + class Or private[Trees] (val exprs: Seq[Expr]) extends Expr { + def getType = BooleanType assert(exprs.size >= 2) @@ -425,8 +396,8 @@ object Trees { } } - class Iff private[Trees] (val left: Expr, val right: Expr) extends Expr with FixedType { - val fixedType = BooleanType + class Iff private[Trees] (val left: Expr, val right: Expr) extends Expr { + def getType = BooleanType override def equals(that: Any): Boolean = (that != null) && (that match { case t: Iff => t.left == left && t.right == right @@ -449,8 +420,8 @@ object Trees { if(imp == null) None else Some(imp.left, imp.right) } - class Implies private[Trees] (val left: Expr, val right: Expr) extends Expr with FixedType { - val fixedType = BooleanType + class Implies private[Trees] (val left: Expr, val right: Expr) extends Expr { + def getType = BooleanType override def equals(that: Any): Boolean = (that != null) && (that match { case t: Implies => t.left == left && t.right == right @@ -472,8 +443,8 @@ object Trees { } } - class Not private[Trees] (val expr: Expr) extends Expr with FixedType { - val fixedType = BooleanType + class Not private[Trees] (val expr: Expr) extends Expr { + val getType = BooleanType override def equals(that: Any) : Boolean = (that != null) && (that match { case n : Not => n.expr == expr @@ -507,8 +478,8 @@ object Trees { } } - class Equals private[Trees] (val left: Expr, val right: Expr) extends Expr with FixedType { - val fixedType = BooleanType + class Equals private[Trees] (val left: Expr, val right: Expr) extends Expr { + val getType = BooleanType override def equals(that: Any): Boolean = (that != null) && (that match { case t: Equals => t.left == left && t.right == right @@ -519,8 +490,14 @@ object Trees { } case class Variable(id: Identifier) extends Expr with Terminal { - override def getType = id.getType - override def setType(tt: TypeTree) = { id.setType(tt); this } + private var _tpe = id.getType + + def setType(tpe: TypeTree): this.type = { + _tpe = tpe + this + } + + def getType = _tpe } /* Literals */ @@ -528,34 +505,35 @@ object Trees { val value: T } - case class GenericValue(tp: TypeParameter, id: Int) extends Expr with Terminal with FixedType { - val fixedType = tp + case class GenericValue(tp: TypeParameter, id: Int) extends Expr with Terminal { + val getType = tp } - case class CharLiteral(value: Char) extends Literal[Char] with FixedType { - val fixedType = CharType + case class CharLiteral(value: Char) extends Literal[Char] { + val getType = CharType } - case class IntLiteral(value: Int) extends Literal[Int] with FixedType { - val fixedType = Int32Type + case class IntLiteral(value: Int) extends Literal[Int] { + val getType = Int32Type } - case class BooleanLiteral(value: Boolean) extends Literal[Boolean] with FixedType { - val fixedType = BooleanType + case class BooleanLiteral(value: Boolean) extends Literal[Boolean] { + val getType = BooleanType } - case class StringLiteral(value: String) extends Literal[String] - case class UnitLiteral() extends Literal[Unit] with FixedType { - val fixedType = UnitType + case class StringLiteral(value: String) extends Literal[String] with MutableTyped + + case class UnitLiteral() extends Literal[Unit] { + val getType = UnitType val value = () } - case class CaseClass(ct: CaseClassType, args: Seq[Expr]) extends Expr with FixedType { - val fixedType = ct + case class CaseClass(ct: CaseClassType, args: Seq[Expr]) extends Expr { + val getType = ct } - case class CaseClassInstanceOf(classType: CaseClassType, expr: Expr) extends Expr with FixedType { - val fixedType = BooleanType + case class CaseClassInstanceOf(classType: CaseClassType, expr: Expr) extends Expr { + val getType = BooleanType } object CaseClassSelector { @@ -576,9 +554,9 @@ object Trees { } } - class CaseClassSelector(val classType: CaseClassType, val caseClass: Expr, val selector: Identifier) extends Expr with FixedType { + class CaseClassSelector(val classType: CaseClassType, val caseClass: Expr, val selector: Identifier) extends Expr { val selectorIndex = classType.classDef.selectorID2Index(selector) - val fixedType = classType.fieldsTypes(selectorIndex) + def getType = classType.fieldsTypes(selectorIndex) override def equals(that: Any): Boolean = (that != null) && (that match { case t: CaseClassSelector => (t.classType, t.caseClass, t.selector) == (classType, caseClass, selector) @@ -589,123 +567,144 @@ object Trees { } /* Arithmetic */ - case class Plus(lhs: Expr, rhs: Expr) extends Expr with FixedType { - val fixedType = Int32Type + case class Plus(lhs: Expr, rhs: Expr) extends Expr { + val getType = Int32Type } - case class Minus(lhs: Expr, rhs: Expr) extends Expr with FixedType { - val fixedType = Int32Type + case class Minus(lhs: Expr, rhs: Expr) extends Expr { + val getType = Int32Type } - case class UMinus(expr: Expr) extends Expr with FixedType { - val fixedType = Int32Type + case class UMinus(expr: Expr) extends Expr { + val getType = Int32Type } - case class Times(lhs: Expr, rhs: Expr) extends Expr with FixedType { - val fixedType = Int32Type + case class Times(lhs: Expr, rhs: Expr) extends Expr { + val getType = Int32Type } - case class Division(lhs: Expr, rhs: Expr) extends Expr with FixedType { - val fixedType = Int32Type + case class Division(lhs: Expr, rhs: Expr) extends Expr { + val getType = Int32Type } - case class Modulo(lhs: Expr, rhs: Expr) extends Expr with FixedType { - val fixedType = Int32Type + case class Modulo(lhs: Expr, rhs: Expr) extends Expr { + val getType = Int32Type } - case class LessThan(lhs: Expr, rhs: Expr) extends Expr with FixedType { - val fixedType = BooleanType + case class LessThan(lhs: Expr, rhs: Expr) extends Expr { + val getType = BooleanType } - case class GreaterThan(lhs: Expr, rhs: Expr) extends Expr with FixedType { - val fixedType = BooleanType + case class GreaterThan(lhs: Expr, rhs: Expr) extends Expr { + val getType = BooleanType } - case class LessEquals(lhs: Expr, rhs: Expr) extends Expr with FixedType { - val fixedType = BooleanType + case class LessEquals(lhs: Expr, rhs: Expr) extends Expr { + val getType = BooleanType } - case class GreaterEquals(lhs: Expr, rhs: Expr) extends Expr with FixedType { - val fixedType = BooleanType + case class GreaterEquals(lhs: Expr, rhs: Expr) extends Expr { + val getType = BooleanType } /* Set expressions */ - case class FiniteSet(elements: Set[Expr]) extends Expr { + case class FiniteSet(elements: Set[Expr]) extends Expr with MutableTyped { val tpe = if (elements.isEmpty) None else leastUpperBound(elements.toSeq.map(_.getType)) - tpe.foreach(t => setType(SetType(t))) + tpe.filter(_ != Untyped).foreach(t => setType(SetType(t))) } - // TODO : Figure out what evaluation order is, for this. - // Perhaps then rewrite as "contains". - case class ElementOfSet(element: Expr, set: Expr) extends Expr with FixedType { - val fixedType = BooleanType + + case class ElementOfSet(element: Expr, set: Expr) extends Expr { + val getType = BooleanType } - case class SetCardinality(set: Expr) extends Expr with FixedType { - val fixedType = Int32Type + case class SetCardinality(set: Expr) extends Expr { + val getType = Int32Type } - case class SubsetOf(set1: Expr, set2: Expr) extends Expr with FixedType { - val fixedType = BooleanType + case class SubsetOf(set1: Expr, set2: Expr) extends Expr { + val getType = BooleanType } case class SetIntersection(set1: Expr, set2: Expr) extends Expr { - leastUpperBound(Seq(set1, set2).map(_.getType)).foreach(setType _) + def getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped) } case class SetUnion(set1: Expr, set2: Expr) extends Expr { - leastUpperBound(Seq(set1, set2).map(_.getType)).foreach(setType _) + def getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped) } case class SetDifference(set1: Expr, set2: Expr) extends Expr { - leastUpperBound(Seq(set1, set2).map(_.getType)).foreach(setType _) + def getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped) } - case class SetMin(set: Expr) extends Expr with FixedType { - val fixedType = Int32Type + + @deprecated("SetMin is not supported by any solver", "2.3") + case class SetMin(set: Expr) extends Expr { + val getType = Int32Type } - case class SetMax(set: Expr) extends Expr with FixedType { - val fixedType = Int32Type + + @deprecated("SetMax is not supported by any solver", "2.3") + case class SetMax(set: Expr) extends Expr { + val getType = Int32Type } - /* Multiset expressions */ - case class EmptyMultiset(baseType: TypeTree) extends Expr with Terminal - case class FiniteMultiset(elements: Seq[Expr]) extends Expr - case class Multiplicity(element: Expr, multiset: Expr) extends Expr - case class MultisetCardinality(multiset: Expr) extends Expr with FixedType { - val fixedType = Int32Type + /* Multiset expressions !!! UNSUPPORTED / DEPRECATED !!! */ + case class EmptyMultiset(baseType: TypeTree) extends Expr with Terminal { + val getType = MultisetType(baseType) + } + case class FiniteMultiset(elements: Seq[Expr]) extends Expr { + assert(elements.size > 0) + def getType = MultisetType(leastUpperBound(elements.map(_.getType)).getOrElse(Untyped)) + } + case class Multiplicity(element: Expr, multiset: Expr) extends Expr { + val getType = Int32Type + } + case class MultisetCardinality(multiset: Expr) extends Expr { + val getType = Int32Type + } + case class MultisetIntersection(multiset1: Expr, multiset2: Expr) extends Expr { + def getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) + } + case class MultisetUnion(multiset1: Expr, multiset2: Expr) extends Expr { + def getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) + } + case class MultisetPlus(multiset1: Expr, multiset2: Expr) extends Expr { // disjoint union + def getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) + } + case class MultisetDifference(multiset1: Expr, multiset2: Expr) extends Expr { + def getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) + } + case class MultisetToSet(multiset: Expr) extends Expr { + def getType = multiset.getType match { + case MultisetType(base) => SetType(base) + case _ => Untyped + } } - case class SubmultisetOf(multiset1: Expr, multiset2: Expr) extends Expr - case class MultisetIntersection(multiset1: Expr, multiset2: Expr) extends Expr - case class MultisetUnion(multiset1: Expr, multiset2: Expr) extends Expr - case class MultisetPlus(multiset1: Expr, multiset2: Expr) extends Expr // disjoint union - case class MultisetDifference(multiset1: Expr, multiset2: Expr) extends Expr - case class MultisetToSet(multiset: Expr) extends Expr /* Map operations. */ - case class FiniteMap(singletons: Seq[(Expr, Expr)]) extends Expr - - case class MapGet(map: Expr, key: Expr) extends Expr - case class MapUnion(map1: Expr, map2: Expr) extends Expr - case class MapDifference(map: Expr, keys: Expr) extends Expr - case class MapIsDefinedAt(map: Expr, key: Expr) extends Expr with FixedType { - val fixedType = BooleanType + case class FiniteMap(singletons: Seq[(Expr, Expr)]) extends Expr with MutableTyped + case class MapGet(map: Expr, key: Expr) extends Expr { + def getType = map.getType match { + case MapType(from, to) => to + case _ => Untyped + } + } + case class MapUnion(map1: Expr, map2: Expr) extends Expr { + def getType = leastUpperBound(Seq(map1, map2).map(_.getType)).getOrElse(Untyped) + } + case class MapDifference(map: Expr, keys: Expr) extends Expr with MutableTyped + case class MapIsDefinedAt(map: Expr, key: Expr) extends Expr { + val getType = BooleanType } /* Array operations */ @deprecated("Unsupported Array operation with most solvers", "Leon 2.3") - case class ArrayFill(length: Expr, defaultValue: Expr) extends Expr with FixedType { - val fixedType = ArrayType(defaultValue.getType) + case class ArrayFill(length: Expr, defaultValue: Expr) extends Expr { + def getType = ArrayType(defaultValue.getType) } @deprecated("Unsupported Array operation with most solvers", "Leon 2.3") - case class ArrayMake(defaultValue: Expr) extends Expr with FixedType { - val fixedType = ArrayType(defaultValue.getType) + case class ArrayMake(defaultValue: Expr) extends Expr { + def getType = ArrayType(defaultValue.getType) } - case class ArraySelect(array: Expr, index: Expr) extends Expr with FixedType { - assert(array.getType.isInstanceOf[ArrayType], - "The array value in ArraySelect must of of array type; yet we got [%s]. In expr: \n%s".format(array.getType, array)) - - val fixedType = array.getType match { + case class ArraySelect(array: Expr, index: Expr) extends Expr { + def getType = array.getType match { case ArrayType(base) => base case _ => Untyped } - } - case class ArrayUpdated(array: Expr, index: Expr, newValue: Expr) extends Expr with FixedType { - assert(array.getType.isInstanceOf[ArrayType], - "The array value in ArrayUpdated must of of array type; yet we got [%s]. In expr: \n%s".format(array.getType, array)) - - val fixedType = array.getType match { + case class ArrayUpdated(array: Expr, index: Expr, newValue: Expr) extends Expr { + def getType = array.getType match { case ArrayType(base) => leastUpperBound(base, newValue.getType).map(ArrayType(_)).getOrElse(Untyped) case _ => @@ -713,20 +712,50 @@ object Trees { } } - case class ArrayLength(array: Expr) extends Expr with FixedType { - val fixedType = Int32Type + case class ArrayLength(array: Expr) extends Expr { + val getType = Int32Type } - case class FiniteArray(exprs: Seq[Expr]) extends Expr + + case class FiniteArray(exprs: Seq[Expr]) extends Expr with MutableTyped @deprecated("Unsupported Array operation with most solvers", "Leon 2.3") case class ArrayClone(array: Expr) extends Expr { - if(array.getType != Untyped) - setType(array.getType) + def getType = array.getType + } + + case class Distinct(exprs: Seq[Expr]) extends Expr { + val getType = BooleanType + } + + /* Special trees */ + + // Provide an oracle (synthesizable, all-seeing choose) + case class WithOracle(oracles: List[Identifier], body: Expr) extends Expr with UnaryExtractable { + assert(!oracles.isEmpty) + + def getType = body.getType + + def extract = { + Some((body, (e: Expr) => WithOracle(oracles, e).setPos(this))) + } + } + + case class Hole(tpe: TypeTree, alts: Seq[Expr]) extends Expr with NAryExtractable { + val getType = tpe + + def extract = { + Some((alts, (es: Seq[Expr]) => Hole(tpe, es).setPos(this))) + } } - /* Constraint programming */ - case class Distinct(exprs: Seq[Expr]) extends Expr with FixedType { - val fixedType = BooleanType + case class RepairHole(tpe: TypeTree, components: Seq[Expr]) extends Expr with NAryExtractable { + val getType = tpe + + def extract = { + Some((components, (es: Seq[Expr]) => RepairHole(tpe, es).setPos(this))) + } } + + } diff --git a/src/main/scala/leon/purescala/TypeTreeOps.scala b/src/main/scala/leon/purescala/TypeTreeOps.scala index 5f88073bc61e9a6d8be2306deb838c7a85e30d36..8e26404b4eb199a15e8d9f68a92753c955ead6b0 100644 --- a/src/main/scala/leon/purescala/TypeTreeOps.scala +++ b/src/main/scala/leon/purescala/TypeTreeOps.scala @@ -282,8 +282,8 @@ object TypeTreeOps { MatchExpr(srec(e), cases.map(trCase)).copiedFrom(m) - case Error(desc) => - Error(desc).setType(tpeSub(e.getType)).copiedFrom(e) + case Error(tpe, desc) => + Error(tpeSub(tpe), desc).copiedFrom(e) case s @ FiniteSet(elements) if elements.isEmpty => FiniteSet(Set()).setType(tpeSub(s.getType)).copiedFrom(s) diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index 6b283b8efc6d6386d688e632ebc77259c733e5b4..38074b978feb50eff0ffc43d9edaf4bcd1fa7981 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -3,13 +3,25 @@ package leon package purescala +import scala.language.implicitConversions + object TypeTrees { import Common._ import Trees._ import Definitions._ import TypeTreeOps._ - trait Typed extends Serializable { + /** + * HasType indicates that structure is typed + * + * setType not necessarily defined though + */ + trait Typed { + def getType: TypeTree + def isTyped : Boolean = (getType != Untyped) + } + + trait MutableTyped extends Typed { self => private var _type: Option[TypeTree] = None @@ -24,8 +36,6 @@ object TypeTrees { case Some(o) if o != tt => scala.sys.error("Resetting type information! Type [" + o + "] is modified to [" + tt) case _ => this } - - def isTyped : Boolean = (getType != Untyped) } class TypeErrorException(msg: String) extends Exception(msg) @@ -40,16 +50,9 @@ object TypeTrees { } } - trait FixedType extends Typed { - self => - - val fixedType: TypeTree - override def getType: TypeTree = fixedType - override def setType(tt2: TypeTree) : self.type = this + abstract class TypeTree extends Tree with Typed { + def getType = this } - - - abstract class TypeTree extends Tree case object Untyped extends TypeTree case object BooleanType extends TypeTree diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 7901e428fe605751a7162132face803f8ec9c805..850472a8056d9760411c3729e000850d19b12bd3 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -67,7 +67,9 @@ trait SMTLIBTarget { // Corresponds to a raw array value, which is coerced to a Leon expr depending on target type (set/array) // Should NEVER escape past SMT-world - case class RawArrayValue(keyTpe: TypeTree, elems: Map[Expr, Expr], default: Expr) extends Expr + case class RawArrayValue(keyTpe: TypeTree, elems: Map[Expr, Expr], default: Expr) extends Expr { + def getType = RawArrayType(keyTpe, default.getType) + } def fromRawArray(r: RawArrayValue, tpe: TypeTree): Expr = tpe match { case SetType(base) => @@ -343,8 +345,8 @@ trait SMTLIBTarget { newBody ) - case er @ Error(_) => - val s = declareVariable(FreshIdentifier("error_value").setType(er.getType)) + case er @ Error(tpe, _) => + val s = declareVariable(FreshIdentifier("error_value").setType(tpe)) s case s @ CaseClassSelector(cct, e, id) => diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index efd68f68069a927035919d0e56b670ea7f9b6bb1..a6f1af7b73cb891beca4c36e22228c63e2a05665 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -261,11 +261,6 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { } } - case h @ RepairHole(_, _) => - val hid = FreshIdentifier("hole", true).setType(h.getType) - exprVars += hid - Variable(hid) - case c @ Choose(ids, cond) => val cid = FreshIdentifier("choose", true).setType(c.getType) storeExpr(cid) @@ -297,9 +292,9 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { Variable(lid) - case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, a))).setType(n.getType) - case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, a1), rec(pathVar, a2)).setType(b.getType) - case u @ UnaryOperator(a, r) => r(rec(pathVar, a)).setType(u.getType) + case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, a))) + case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, a1), rec(pathVar, a2)) + case u @ UnaryOperator(a, r) => r(rec(pathVar, a)) case t : Terminal => t } } diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 623b124457497dd53f1f2e5ab99c2d035e1ce408..1ab6287b79a079f3186a98ef157298bb2b96fff0 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -532,8 +532,7 @@ trait AbstractZ3Solver rb } case Waypoint(_, e) => rec(e) - case e @ Error(_) => { - val tpe = e.getType + case e @ Error(tpe, _) => { val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) // Might introduce dupplicates (e), but no worries here variables += (e -> newAST) @@ -672,8 +671,8 @@ trait AbstractZ3Solver case arr @ FiniteArray(exprs) => { val ArrayType(innerType) = arr.getType val arrayType = arr.getType - val a: Expr = ArrayFill(IntLiteral(exprs.length), simplestValue(innerType)).setType(arrayType) - val u = exprs.zipWithIndex.foldLeft(a)((array, expI) => ArrayUpdated(array, IntLiteral(expI._2), expI._1).setType(arrayType)) + val a: Expr = ArrayFill(IntLiteral(exprs.length), simplestValue(innerType)) + val u = exprs.zipWithIndex.foldLeft(a)((array, expI) => ArrayUpdated(array, IntLiteral(expI._2), expI._1)) rec(u) } case Distinct(exs) => z3.mkDistinct(exs.map(rec(_)): _*) diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 94b80d36b3d595020da0b4d5c7401c40b928b1c3..a6a6b29890ff76ff3b869ca743182facbfacdfb8 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -23,9 +23,9 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr, val isTrust if (pre == BooleanLiteral(true)) { term } else if (pre == BooleanLiteral(false)) { - Error("Impossible program").setType(term.getType) + Error(term.getType, "Impossible program") } else { - IfExpr(pre, term, Error("Precondition failed").setType(term.getType)) + IfExpr(pre, term, Error(term.getType, "Precondition failed")) } } @@ -67,7 +67,7 @@ object Solution { def unapply(s: Solution): Option[(Expr, Set[FunDef], Expr)] = if (s eq null) None else Some((s.pre, s.defs, s.term)) def choose(p: Problem): Solution = { - new Solution(BooleanLiteral(true), Set(), Choose(p.xs, p.phi).setType(TupleType(p.xs.map(_.getType)))) + new Solution(BooleanLiteral(true), Set(), Choose(p.xs, p.phi)) } // Generate the simplest, wrongest solution, used for complexity lowerbound @@ -80,6 +80,6 @@ object Solution { } def failed(p: Problem): Solution = { - new Solution(BooleanLiteral(true), Set(), Error("Failed").setType(TupleType(p.xs.map(_.getType)))) + new Solution(BooleanLiteral(true), Set(), Error(TupleType(p.xs.map(_.getType)), "Failed")) } } diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 17a586dbee98a2f8704b0a9c2b6a9e019c330c94..520a9c3bfce975ed8964baa9d6adc84d539121a0 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -27,13 +27,15 @@ import codegen.CodeGenParams import utils._ -case object CEGIS extends CEGISLike("CEGIS") { +case object CEGIS extends CEGISLike[TypeTree]("CEGIS") { def getGrammar(sctx: SynthesisContext, p: Problem) = { ExpressionGrammars.default(sctx, p) } + + def getGrammarLabel(id: Identifier): TypeTree = id.getType } -case object CEGLESS extends CEGISLike("CEGLESS") { +case object CEGLESS extends CEGISLike[TypeTree]("CEGLESS") { override val maxUnfoldings = 3; def getGrammar(sctx: SynthesisContext, p: Problem) = { @@ -49,16 +51,20 @@ case object CEGLESS extends CEGISLike("CEGLESS") { val inputs = p.as.map(_.toVariable) - val guidedGrammar = guides.map(SimilarTo(_, inputs.toSet, Set(sctx.functionContext))).foldLeft[ExpressionGrammar](Empty)(_ || _) + val guidedGrammar = guides.map(SimilarTo(_, inputs.toSet, Set(sctx.functionContext))).foldLeft[ExpressionGrammar[TypeTree]](Empty)(_ || _) guidedGrammar || OneOf(inputs) || SafeRecCalls(sctx.program, p.pc) } + + def getGrammarLabel(id: Identifier): TypeTree = id.getType } -abstract class CEGISLike(name: String) extends Rule(name) { +abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { + + def getGrammar(sctx: SynthesisContext, p: Problem): ExpressionGrammar[T] - def getGrammar(sctx: SynthesisContext, p: Problem): ExpressionGrammar + def getGrammarLabel(id: Identifier): T val maxUnfoldings = 3 @@ -92,6 +98,8 @@ abstract class CEGISLike(name: String) extends Rule(name) { // b -> Set(c1, c2) means c1 and c2 are uninterpreted behind b, requires b to be closed private var guardedTerms: Map[Identifier, Set[Identifier]] = Map(initGuard -> p.xs.toSet) + private var labels: Map[Identifier, T] = Map() ++ p.xs.map(x => x -> getGrammarLabel(x)) + def isBClosed(b: Identifier) = guardedTerms.contains(b) /** @@ -192,7 +200,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { // We compute the IF expression corresponding to each c val ifExpr = if (cases.isEmpty) { // This can happen with ADTs with only cases with arguments - Error("No valid clause available").setType(c.getType) + Error(c.getType, "No valid clause available") } else { cases.tail.foldLeft(cases.head._2) { case (elze, (b, thenn)) => IfExpr(Variable(b), thenn, elze) @@ -216,7 +224,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { case Some(value) => res = Let(c, cToExprs(c), res) case None => - res = Let(c, Error("No value available").setType(c.getType), res) + res = Let(c, Error(c.getType, "No value available"), res) } for (dep <- cChildren(c) if !unreachableCs(dep)) { @@ -235,7 +243,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { val ba = FreshIdentifier("bssArray").setType(ArrayType(BooleanType)) val bav = Variable(ba) val substMap : Map[Expr,Expr] = (bssOrdered.zipWithIndex.map { - case (b,i) => Variable(b) -> ArraySelect(bav, IntLiteral(i)).setType(BooleanType) + case (b,i) => Variable(b) -> ArraySelect(bav, IntLiteral(i)) }).toMap val forArray = replace(substMap, simplerRes) @@ -277,7 +285,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { substAll(map.toMap, cClauses(c)) } - Tuple(p.xs.map(c => getCValue(c))).setType(TupleType(p.xs.map(_.getType))) + Tuple(p.xs.map(c => getCValue(c))) } @@ -354,7 +362,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { for ((recId, parentGuards) <- cGroups) { - var alts = grammar.getProductions(recId.getType) + var alts = grammar.getProductions(labels(recId)) if (finalUnfolding) { alts = alts.filter(_.subTrees.isEmpty) } @@ -382,7 +390,10 @@ abstract class CEGISLike(name: String) extends Rule(name) { })(index) val cases = for((bid, gen) <- altsWithBranches.toList) yield { // b1 => E(gen1, gen2) [b1 -> {gen1, gen2}] - val rec = for ((t, i) <- gen.subTrees.zipWithIndex) yield { getC(t, i) } + val newLabels = for ((t, i) <- gen.subTrees.zipWithIndex) yield { getC(t.getType, i) -> t } + labels ++= newLabels + + val rec = newLabels.map(_._1) val ex = gen.builder(rec.map(_.toVariable)) if (!rec.isEmpty) { diff --git a/src/main/scala/leon/synthesis/rules/Ground.scala b/src/main/scala/leon/synthesis/rules/Ground.scala index c2bc50f51d3910b0f93eefd070c9b0721f7d36b6..9de60847a0b4f2b3c429d9c69c524ba916e02077 100644 --- a/src/main/scala/leon/synthesis/rules/Ground.scala +++ b/src/main/scala/leon/synthesis/rules/Ground.scala @@ -21,10 +21,10 @@ case object Ground extends Rule("Ground") { val result = solver.solveSAT(p.phi) match { case (Some(true), model) => - val sol = Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(model))).setType(tpe)) + val sol = Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(model)))) RuleClosed(sol) case (Some(false), model) => - val sol = Solution(BooleanLiteral(false), Set(), Error(p.phi+" is UNSAT!").setType(tpe)) + val sol = Solution(BooleanLiteral(false), Set(), Error(tpe, p.phi+" is UNSAT!")) RuleClosed(sol) case _ => RuleFailed() diff --git a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala index 182f296762c3c9c0664e71d997b96ca3360bc999..0b69b8fd22135be7ebe8c5321f1ebe5efd242735 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala @@ -196,7 +196,7 @@ case object IntegerInequalities extends Rule("Integer Inequalities") { val funDef = new FunDef(FreshIdentifier("rec", true), Nil, returnType, Seq(ValDef(loopCounter.id, Int32Type)),DefType.MethodDef) val funBody = expandAndSimplifyArithmetic(IfExpr( LessThan(loopCounter, IntLiteral(0)), - Error("No solution exists"), + Error(returnType, "No solution exists"), IfExpr( concretePre, LetTuple(subProblemxs, concreteTerm, diff --git a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala index d9f15a2a3307247abcf2698537eafb9f5fe75f8d..5a776437b85a2fd10ef7f797c6b3b04277c4824e 100644 --- a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala +++ b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala @@ -47,7 +47,7 @@ case object OptimisticGround extends Rule("Optimistic Ground") { predicates = valuateWithModelIn(phi, ass, invalidModel) +: predicates case (Some(false), _) => - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe)))) + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(satModel)))))) case _ => continue = false @@ -56,7 +56,7 @@ case object OptimisticGround extends Rule("Optimistic Ground") { case (Some(false), _) => if (predicates.isEmpty) { - result = Some(RuleClosed(Solution(BooleanLiteral(false), Set(), Error(p.phi+" is UNSAT!").setType(tpe)))) + result = Some(RuleClosed(Solution(BooleanLiteral(false), Set(), Error(tpe, p.phi+" is UNSAT!")))) } else { continue = false result = None diff --git a/src/main/scala/leon/synthesis/rules/Tegis.scala b/src/main/scala/leon/synthesis/rules/Tegis.scala index e04dc0c718909ead413d1a024e61dfefed9e8938..c436b6ed0d3aced39e181c40ee29d6549a0574cd 100644 --- a/src/main/scala/leon/synthesis/rules/Tegis.scala +++ b/src/main/scala/leon/synthesis/rules/Tegis.scala @@ -34,6 +34,10 @@ case object TEGIS extends Rule("TEGIS") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { // check if the formula contains passes: + if (!sctx.program.library.passes.isDefined) { + return Nil; + } + val passes = sctx.program.library.passes.get val mayHaveTests = exists({ diff --git a/src/main/scala/leon/synthesis/rules/Unification.scala b/src/main/scala/leon/synthesis/rules/Unification.scala index 87597e4a020b632f376d94ba3decbad1241adf8f..54da1d47654b94d6a3df13ba6cf7da26df43227a 100644 --- a/src/main/scala/leon/synthesis/rules/Unification.scala +++ b/src/main/scala/leon/synthesis/rules/Unification.scala @@ -54,7 +54,7 @@ object Unification { if (isImpossible) { val tpe = TupleType(p.xs.map(_.getType)) - List(RuleInstantiation.immediateSuccess(p, this, Solution(BooleanLiteral(false), Set(), Error(p.phi+" is UNSAT!").setType(tpe)))) + List(RuleInstantiation.immediateSuccess(p, this, Solution(BooleanLiteral(false), Set(), Error(tpe, p.phi+" is UNSAT!")))) } else { Nil } diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index 5ed5990c91e3ebc9ee818ac3cc92e370c3548761..ea0003e5a9cd8a52c5ff6b1b137aa5a4c2b16cfa 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -18,12 +18,12 @@ import purescala.ScalaPrinter import scala.collection.mutable.{HashMap => MutableMap} -abstract class ExpressionGrammar { - type Gen = Generator[TypeTree, Expr] +abstract class ExpressionGrammar[T <% Typed] { + type Gen = Generator[T, Expr] - private[this] val cache = new MutableMap[TypeTree, Seq[Gen]]() + private[this] val cache = new MutableMap[T, Seq[Gen]]() - def getProductions(t: TypeTree): Seq[Gen] = { + def getProductions(t: T): Seq[Gen] = { cache.getOrElse(t, { val res = computeProductions(t) cache += t -> res @@ -31,15 +31,16 @@ abstract class ExpressionGrammar { }) } - def computeProductions(t: TypeTree): Seq[Gen] + def computeProductions(t: T): Seq[Gen] - final def ||(that: ExpressionGrammar): ExpressionGrammar = { + final def ||(that: ExpressionGrammar[T]): ExpressionGrammar[T] = { ExpressionGrammars.Or(Seq(this, that)) } + final def printProductions(printer: String => Unit) { for ((t, gs) <- cache; g <- gs) { - val subs = g.subTrees.map { tpe => FreshIdentifier(tpe.toString).setType(tpe).toVariable } + val subs = g.subTrees.map { t => FreshIdentifier(t.toString).setType(t.getType).toVariable} val gen = g.builder(subs) printer(f"$t%30s ::= "+gen) @@ -49,21 +50,21 @@ abstract class ExpressionGrammar { object ExpressionGrammars { - case class Or(gs: Seq[ExpressionGrammar]) extends ExpressionGrammar { - val subGrammars: Seq[ExpressionGrammar] = gs.flatMap { - case o: Or => o.subGrammars + case class Or[T <% Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGrammar[T] { + val subGrammars: Seq[ExpressionGrammar[T]] = gs.flatMap { + case o: Or[T] => o.subGrammars case g => Seq(g) } - def computeProductions(t: TypeTree): Seq[Gen] = + def computeProductions(t: T): Seq[Gen] = subGrammars.flatMap(_.getProductions(t)) } - case object Empty extends ExpressionGrammar { + case object Empty extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = Nil } - case object BaseGrammar extends ExpressionGrammar { + case object BaseGrammar extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = t match { case BooleanType => List( @@ -104,7 +105,7 @@ object ExpressionGrammars { } } - case class OneOf(inputs: Seq[Expr]) extends ExpressionGrammar { + case class OneOf(inputs: Seq[Expr]) extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = { inputs.collect { case i if isSubtypeOf(i.getType, t) => Generator[TypeTree, Expr](Nil, { _ => i }) @@ -112,7 +113,7 @@ object ExpressionGrammars { } } - case class SimilarTo(e: Expr, excludeExpr: Set[Expr] = Set(), excludeFCalls: Set[FunDef] = Set()) extends ExpressionGrammar { + case class SimilarTo(e: Expr, excludeExpr: Set[Expr] = Set(), excludeFCalls: Set[FunDef] = Set()) extends ExpressionGrammar[TypeTree] { lazy val allSimilar = computeSimilar(e).groupBy(_._1).mapValues(_.map(_._2)) def computeProductions(t: TypeTree): Seq[Gen] = { @@ -162,7 +163,7 @@ object ExpressionGrammars { } } - case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[TypeTree]) extends ExpressionGrammar { + case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[TypeTree]) extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = { def getCandidates(fd: FunDef): Seq[TypedFunDef] = { @@ -228,7 +229,7 @@ object ExpressionGrammars { } } - case class SafeRecCalls(prog: Program, pc: Expr) extends ExpressionGrammar { + case class SafeRecCalls(prog: Program, pc: Expr) extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = { val calls = terminatingCalls(prog, t, pc) @@ -242,14 +243,14 @@ object ExpressionGrammars { } } - def default(prog: Program, inputs: Seq[Expr], currentFunction: FunDef, pc: Expr): ExpressionGrammar = { + def default(prog: Program, inputs: Seq[Expr], currentFunction: FunDef, pc: Expr): ExpressionGrammar[TypeTree] = { BaseGrammar || OneOf(inputs) || FunctionCalls(prog, currentFunction, inputs.map(_.getType)) || SafeRecCalls(prog, pc) } - def default(sctx: SynthesisContext, p: Problem): ExpressionGrammar = { + def default(sctx: SynthesisContext, p: Problem): ExpressionGrammar[TypeTree] = { default(sctx.program, p.as.map(_.toVariable), sctx.functionContext, p.pc) } } diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index 5cdbdc3d582919924f2aa2081f868cffca6ac7b1..acb36b0725cd0d130ac8f3f11cb2af70b1f9366b 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -71,7 +71,7 @@ object UnitElimination extends TransformationPhase { case t@Tuple(args) => { val TupleType(tpes) = t.getType val (newTpes, newArgs) = tpes.zip(args).filterNot{ case (UnitType, _) => true case _ => false }.unzip - Tuple(newArgs.map(removeUnit)).setType(TupleType(newTpes)) + Tuple(newArgs.map(removeUnit)) } case ts@TupleSelect(t, index) => { val TupleType(tpes) = t.getType @@ -80,7 +80,7 @@ object UnitElimination extends TransformationPhase { case ((nbUnit, newIndex), (tpe, i)) => if(i == index-1) (nbUnit, index - nbUnit) else (if(tpe == UnitType) nbUnit + 1 else nbUnit, newIndex) } - TupleSelect(removeUnit(t), newIndex).setType(selectionType) + TupleSelect(removeUnit(t), newIndex) } case Let(id, e, b) => { if(id.getType == UnitType) @@ -126,16 +126,16 @@ object UnitElimination extends TransformationPhase { case ite@IfExpr(cond, tExpr, eExpr) => { val thenRec = removeUnit(tExpr) val elseRec = removeUnit(eExpr) - IfExpr(removeUnit(cond), thenRec, elseRec).setType(thenRec.getType) + IfExpr(removeUnit(cond), thenRec, elseRec) } case n @ NAryOperator(args, recons) => { - recons(args.map(removeUnit(_))).setType(n.getType) + recons(args.map(removeUnit(_))) } case b @ BinaryOperator(a1, a2, recons) => { - recons(removeUnit(a1), removeUnit(a2)).setType(b.getType) + recons(removeUnit(a1), removeUnit(a2)) } case u @ UnaryOperator(a, recons) => { - recons(removeUnit(a)).setType(u.getType) + recons(removeUnit(a)) } case v @ Variable(id) => if(id2FreshId.isDefinedAt(id)) Variable(id2FreshId(id)) else v case (t: Terminal) => t @@ -146,7 +146,7 @@ object UnitElimination extends TransformationPhase { case GuardedCase(pat, guard, rhs) => GuardedCase(pat, removeUnit(guard), removeUnit(rhs)) } val tpe = csesRec.head.rhs.getType - MatchExpr(scrutRec, csesRec).setType(tpe).setPos(m) + MatchExpr(scrutRec, csesRec).setPos(m) } case _ => sys.error("not supported: " + expr) } diff --git a/src/main/scala/leon/verification/DefaultTactic.scala b/src/main/scala/leon/verification/DefaultTactic.scala index f3c607f6e707a59cb1b36855c2c99bc8b2de1615..6031de1fa3fcc5b5ad646cf17178052bd0ddf0d2 100644 --- a/src/main/scala/leon/verification/DefaultTactic.scala +++ b/src/main/scala/leon/verification/DefaultTactic.scala @@ -51,10 +51,10 @@ class DefaultTactic(vctx: VerificationContext) extends Tactic(vctx) { fd.body match { case Some(body) => val calls = collectWithPC { - case e @ Error("Match is non-exhaustive") => + case e @ Error(_, "Match is non-exhaustive") => (e, VCExhaustiveMatch, BooleanLiteral(false)) - case e @ Error(_) => + case e @ Error(_, _) => (e, VCAssert, BooleanLiteral(false)) case a @ Assert(cond, Some(err), _) => diff --git a/src/main/scala/leon/xlang/ArrayTransformation.scala b/src/main/scala/leon/xlang/ArrayTransformation.scala index 0625936a969810a96045e4b9d13f9a67e50e5170..a73366f3c187dda931ff220acb67fea352fca953 100644 --- a/src/main/scala/leon/xlang/ArrayTransformation.scala +++ b/src/main/scala/leon/xlang/ArrayTransformation.scala @@ -38,7 +38,7 @@ object ArrayTransformation extends TransformationPhase { val ri = transform(i) val rv = transform(v) val Variable(id) = ra - Assignment(id, ArrayUpdated(ra, ri, rv).setType(ra.getType).setPos(up)) + Assignment(id, ArrayUpdated(ra, ri, rv).setPos(up)) } case ArrayClone(a) => { val ra = transform(a) @@ -75,7 +75,7 @@ object ArrayTransformation extends TransformationPhase { val rc = transform(c) val rt = transform(t) val re = transform(e) - IfExpr(rc, rt, re).setType(rt.getType) + IfExpr(rc, rt, re) } case m @ MatchExpr(scrut, cses) => { @@ -85,7 +85,7 @@ object ArrayTransformation extends TransformationPhase { case GuardedCase(pat, guard, rhs) => GuardedCase(pat, transform(guard), transform(rhs)) } val tpe = csesRec.head.rhs.getType - MatchExpr(scrutRec, csesRec).setType(tpe).setPos(m) + MatchExpr(scrutRec, csesRec).setPos(m) } case LetDef(fd, b) => { fd.precondition = fd.precondition.map(transform) @@ -94,9 +94,9 @@ object ArrayTransformation extends TransformationPhase { val rb = transform(b) LetDef(fd, rb) } - case n @ NAryOperator(args, recons) => recons(args.map(transform)).setType(n.getType) - case b @ BinaryOperator(a1, a2, recons) => recons(transform(a1), transform(a2)).setType(b.getType) - case u @ UnaryOperator(a, recons) => recons(transform(a)).setType(u.getType) + case n @ NAryOperator(args, recons) => recons(args.map(transform)) + case b @ BinaryOperator(a1, a2, recons) => recons(transform(a1), transform(a2)) + case u @ UnaryOperator(a, recons) => recons(transform(a)) case (t: Terminal) => t case unhandled => scala.sys.error("Non-terminal case should be handled in ArrayTransformation: " + unhandled) diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index 39334e495cb8d2ab54839a9fb93fd4c4f8299aeb..7eae711b869ea6dc87b7fae4881d24ed42e08f89 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -75,14 +75,14 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val thenVal = if(modifiedVars.isEmpty) tRes else Tuple(tRes +: modifiedVars.map(vId => tFun.get(vId) match { case Some(newId) => newId.toVariable case None => vId.toVariable - })).setType(iteType) + })) val elseVal = if(modifiedVars.isEmpty) eRes else Tuple(eRes +: modifiedVars.map(vId => eFun.get(vId) match { case Some(newId) => newId.toVariable case None => vId.toVariable - })).setType(iteType) + })) - val iteExpr = IfExpr(cRes, replaceNames(cFun, tScope(thenVal)), replaceNames(cFun, eScope(elseVal))).setType(iteType).copiedFrom(ite) + val iteExpr = IfExpr(cRes, replaceNames(cFun, tScope(thenVal)), replaceNames(cFun, eScope(elseVal))).copiedFrom(ite) val scope = ((body: Expr) => { val tupleId = FreshIdentifier("t").setType(iteType) @@ -91,10 +91,10 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef if(freshIds.isEmpty) Let(resId, tupleId.toVariable, body) else - Let(resId, TupleSelect(tupleId.toVariable, 1).setType(iteRType), + Let(resId, TupleSelect(tupleId.toVariable, 1), freshIds.zipWithIndex.foldLeft(body)((b, id) => Let(id._1, - TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), + TupleSelect(tupleId.toVariable, id._2 + 2), b)))).copiedFrom(expr)) }) @@ -115,7 +115,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef case (cRes, cFun) => if(modifiedVars.isEmpty) cRes else Tuple(cRes +: modifiedVars.map(vId => cFun.get(vId) match { case Some(newId) => newId.toVariable case None => vId.toVariable - })).setType(matchType) + })) } val newRhs = csesVals.zip(csesScope).map{ @@ -124,7 +124,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val matchExpr = MatchExpr(scrutRes, cses.zip(newRhs).map{ case (sc @ SimpleCase(pat, _), newRhs) => SimpleCase(pat, newRhs).setPos(sc) case (gc @ GuardedCase(pat, guard, _), newRhs) => GuardedCase(pat, replaceNames(scrutFun, guard), newRhs).setPos(gc) - }).setType(matchType).setPos(m) + }).setPos(m) val scope = ((body: Expr) => { val tupleId = FreshIdentifier("t").setType(matchType) @@ -136,7 +136,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef Let(resId, TupleSelect(tupleId.toVariable, 1), freshIds.zipWithIndex.foldLeft(body)((b, id) => Let(id._1, - TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), + TupleSelect(tupleId.toVariable, id._2 + 2), b))))) }) @@ -167,9 +167,9 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef condFun.get(modifiedVars.head).getOrElse(whileFunVars.head).toVariable else Tuple(modifiedVars.map(id => condFun.get(id).getOrElse(modifiedVars2WhileFunVars(id)).toVariable)) - ).setType(whileFunReturnType) + ) val whileFunBody = replaceNames(modifiedVars2WhileFunVars, - condScope(IfExpr(whileFunCond, whileFunRecursiveCall, whileFunBaseCase).setType(whileFunReturnType))) + condScope(IfExpr(whileFunCond, whileFunRecursiveCall, whileFunBaseCase))) whileFunDef.body = Some(whileFunBody) val resVar = Variable(FreshIdentifier("res").setType(whileFunReturnType)) @@ -177,7 +177,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef if(whileFunVars.size == 1) Map(whileFunVars.head.toVariable -> resVar) else - whileFunVars.zipWithIndex.map{ case (v, i) => (v.toVariable, TupleSelect(resVar, i+1).setType(v.getType)) }.toMap + whileFunVars.zipWithIndex.map{ case (v, i) => (v.toVariable, TupleSelect(resVar, i+1)) }.toMap val modifiedVars2ResultVars: Map[Expr, Expr] = modifiedVars.map(id => (id.toVariable, whileFunVars2ResultVars(modifiedVars2WhileFunVars(id).toVariable))).toMap @@ -206,7 +206,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef else finalVars.zipWithIndex.foldLeft(body)((b, id) => Let(id._1, - TupleSelect(tupleId.toVariable, id._2 + 1).setType(id._1.getType), + TupleSelect(tupleId.toVariable, id._2 + 1), b)))) }) diff --git a/src/main/scala/leon/xlang/TreeOps.scala b/src/main/scala/leon/xlang/TreeOps.scala index dd94605b99d533dded04e1402e4f589399b7ad61..40c23677e51a9eea40c40368c5409869b65df12f 100644 --- a/src/main/scala/leon/xlang/TreeOps.scala +++ b/src/main/scala/leon/xlang/TreeOps.scala @@ -46,7 +46,7 @@ object TreeOps { Some(nexprs match { case Seq() => UnitLiteral() case Seq(e) => e - case es => Block(es.init, es.last).setType(es.last.getType) + case es => Block(es.init, es.last) }) case _ => None diff --git a/src/main/scala/leon/xlang/Trees.scala b/src/main/scala/leon/xlang/Trees.scala index 6da2d30f7d87b032de6e8ef1ed16849302baca3d..11883a7f86b5829d4c2989717990aa5b1347329c 100644 --- a/src/main/scala/leon/xlang/Trees.scala +++ b/src/main/scala/leon/xlang/Trees.scala @@ -19,7 +19,7 @@ object Trees { sb } - case class Block(exprs: Seq[Expr], last: Expr) extends Expr with NAryExtractable with PrettyPrintable with FixedType { + case class Block(exprs: Seq[Expr], last: Expr) extends Expr with NAryExtractable with PrettyPrintable { def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = { val Block(args, rest) = this Some((args :+ rest, exprs => Block(exprs.init, exprs.last))) @@ -31,11 +31,11 @@ object Trees { |}""" } - val fixedType = last.getType + def getType = last.getType } - case class Assignment(varId: Identifier, expr: Expr) extends Expr with FixedType with UnaryExtractable with PrettyPrintable { - val fixedType = UnitType + case class Assignment(varId: Identifier, expr: Expr) extends Expr with UnaryExtractable with PrettyPrintable { + val getType = UnitType def extract: Option[(Expr, (Expr)=>Expr)] = { Some((expr, Assignment(varId, _))) @@ -46,8 +46,8 @@ object Trees { } } - case class While(cond: Expr, body: Expr) extends Expr with FixedType with BinaryExtractable with PrettyPrintable { - val fixedType = UnitType + case class While(cond: Expr, body: Expr) extends Expr with BinaryExtractable with PrettyPrintable { + val getType = UnitType var invariant: Option[Expr] = None def getInvariant: Expr = invariant.get @@ -72,7 +72,7 @@ object Trees { } } - case class Epsilon(pred: Expr) extends Expr with UnaryExtractable with PrettyPrintable { + case class Epsilon(pred: Expr) extends Expr with UnaryExtractable with PrettyPrintable with MutableTyped { def extract: Option[(Expr, (Expr)=>Expr)] = { Some((pred, (expr: Expr) => Epsilon(expr).setType(this.getType).setPos(this))) } @@ -82,7 +82,7 @@ object Trees { } } - case class EpsilonVariable(pos: Position) extends Expr with Terminal with PrettyPrintable{ + case class EpsilonVariable(pos: Position) extends Expr with Terminal with PrettyPrintable with MutableTyped { def printWith(implicit pctx: PrinterContext) { p"x${pos.line}_${pos.col}" @@ -91,9 +91,7 @@ object Trees { //same as let, buf for mutable variable declaration case class LetVar(binder: Identifier, value: Expr, body: Expr) extends Expr with BinaryExtractable with PrettyPrintable { - val et = body.getType - if(et != Untyped) - setType(et) + def getType = body.getType def extract: Option[(Expr, Expr, (Expr, Expr)=>Expr)] = { val LetVar(binders, expr, body) = this @@ -108,7 +106,7 @@ object Trees { } } - case class Waypoint(i: Int, expr: Expr) extends Expr with UnaryExtractable with PrettyPrintable { + case class Waypoint(i: Int, expr: Expr) extends Expr with UnaryExtractable with PrettyPrintable with MutableTyped { def extract: Option[(Expr, (Expr)=>Expr)] = { Some((expr, (e: Expr) => Waypoint(i, e))) } @@ -118,8 +116,8 @@ object Trees { } } - case class ArrayUpdate(array: Expr, index: Expr, newValue: Expr) extends Expr with FixedType with NAryExtractable with PrettyPrintable { - val fixedType = UnitType + case class ArrayUpdate(array: Expr, index: Expr, newValue: Expr) extends Expr with NAryExtractable with PrettyPrintable { + val getType = UnitType def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = { val ArrayUpdate(t1, t2, t3) = this diff --git a/src/test/scala/leon/test/codegen/CodeGenTests.scala b/src/test/scala/leon/test/codegen/CodeGenTests.scala index d4221906bb84c9d0c4dafdfac81e7c16d972a752..e3b3b6968362f3c4030318967cd78ebf82cadf87 100644 --- a/src/test/scala/leon/test/codegen/CodeGenTests.scala +++ b/src/test/scala/leon/test/codegen/CodeGenTests.scala @@ -6,6 +6,7 @@ import leon._ import leon.codegen._ import leon.purescala.Definitions._ import leon.purescala.Trees._ +import leon.purescala.TypeTrees._ import leon.evaluators.{CodeGenEvaluator,EvaluationResults} import EvaluationResults._ @@ -298,7 +299,7 @@ class CodeGenTests extends test.LeonTestSuite { case class Conc() extends Ab { } def test = Conc().x }""", - Error("Looping") + Error(Untyped, "Looping") ), TestCase("Lazier" , """