From 83e59c6f940afc047aa6e9f1cee98e12be4c8b6e Mon Sep 17 00:00:00 2001 From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch> Date: Tue, 24 Feb 2015 13:15:58 +0100 Subject: [PATCH] Eliminate MutableTyped from Expr's. Improve how respective Expr's are handled. Eliminate MutableTyped from Expr's. Expr.getType is now a val. Variables don't have a mutable type. Separate representation of empty and nonempty Sets, Maps, Multisets, and Arrays. Introduce more generic constructors/ extractors for these types. Simplify Map builder in NAryOperator. Deprecate some deprecated Expr's. Represent String literals as Lists. Make some tests consistent with typing limitations in Leon. --- .../scala/leon/codegen/CodeGeneration.scala | 1 + .../scala/leon/codegen/CompilationUnit.scala | 6 +- .../scala/leon/datagen/VanuatooDataGen.scala | 19 ++- .../scala/leon/evaluators/DualEvaluator.scala | 18 +- .../leon/evaluators/RecursiveEvaluator.scala | 37 ++-- .../frontends/scalac/CodeExtraction.scala | 42 ++--- .../scala/leon/purescala/Constructors.scala | 29 ++++ .../scala/leon/purescala/Definitions.scala | 2 +- .../scala/leon/purescala/Extractors.scala | 122 ++++++++++---- .../scala/leon/purescala/PrettyPrinter.scala | 33 ++-- .../scala/leon/purescala/ScalaPrinter.scala | 1 + src/main/scala/leon/purescala/TreeOps.scala | 36 +--- src/main/scala/leon/purescala/Trees.scala | 159 +++++++++++------- .../scala/leon/purescala/TypeTreeOps.scala | 3 +- src/main/scala/leon/purescala/TypeTrees.scala | 15 +- .../leon/repair/rules/GuidedDecomp.scala | 2 +- .../solvers/smtlib/SMTLIBCVC4Target.scala | 10 +- .../leon/solvers/smtlib/SMTLIBTarget.scala | 12 +- .../leon/solvers/z3/AbstractZ3Solver.scala | 10 +- src/main/scala/leon/synthesis/Witnesses.scala | 4 +- .../leon/synthesis/rules/CegisLike.scala | 6 +- .../synthesis/utils/ExpressionGrammar.scala | 8 +- .../scala/leon/xlang/EpsilonElimination.scala | 4 +- src/main/scala/leon/xlang/TreeOps.scala | 2 +- src/main/scala/leon/xlang/Trees.scala | 20 ++- .../purescala/valid/LiteralMaps.scala | 2 +- .../evaluators/DefaultEvaluatorTests.scala | 21 ++- .../test/evaluators/EvaluatorsTests.scala | 14 +- 28 files changed, 389 insertions(+), 249 deletions(-) diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index fd704fd28..24dadb724 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -10,6 +10,7 @@ import purescala.TreeOps.{simplestValue, matchToIfThenElse} import purescala.TypeTrees._ import purescala.Constructors._ import purescala.TypeTreeOps.instantiateType +import purescala.Extractors._ import utils._ import cafebabe._ diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 52509ab71..c522be854 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -7,6 +7,8 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ import purescala.TypeTrees._ +import purescala.Extractors._ +import purescala.Constructors._ import codegen.runtime.LeonCodeGenRuntimeMonitor import cafebabe._ @@ -224,7 +226,7 @@ class CompilationUnit(val ctx: LeonContext, else GenericValue(tp, id).copiedFrom(gv) case (set : runtime.Set, SetType(b)) => - FiniteSet(set.getElements().asScala.map(jvmToExpr(_, b)).toSet).setType(SetType(b)) + finiteSet(set.getElements().asScala.map(jvmToExpr(_, b)).toSet, b) case (map : runtime.Map, MapType(from, to)) => val pairs = map.getElements().asScala.map { entry => @@ -232,7 +234,7 @@ class CompilationUnit(val ctx: LeonContext, val v = jvmToExpr(entry.getValue(), to) (k, v) } - FiniteMap(pairs.toSeq) + finiteMap(pairs.toSeq, from, to) case _ => throw CompilationException("Unsupported return value : " + e.getClass +" while expecting "+tpe) diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index 2ffc66cc0..368b56808 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -8,7 +8,8 @@ import purescala.Definitions._ import purescala.TreeOps._ import purescala.Trees._ import purescala.TypeTrees._ -import purescala.Extractors.TopLevelAnds +import purescala.Extractors._ +import purescala.Constructors._ import codegen.CompilationUnit import codegen.runtime.LeonCodeGenRuntimeMonitor @@ -60,7 +61,12 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case at @ ArrayType(sub) => constructors.getOrElse(at, { val cs = for (size <- List(0, 1, 2, 5)) yield { - Constructor[Expr, TypeTree]((1 to size).map(i => sub).toList, at, s => FiniteArray(s).setType(at), at.toString+"@"+size) + Constructor[Expr, TypeTree]( + (1 to size).map(i => sub).toList, + at, + s => finiteArray(s, None, sub), + at.toString+"@"+size + ) } constructors += at -> cs cs @@ -69,7 +75,12 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case st @ SetType(sub) => constructors.getOrElse(st, { val cs = for (size <- List(0, 1, 2, 5)) yield { - Constructor[Expr, TypeTree]((1 to size).map(i => sub).toList, st, s => FiniteSet(s.toSet).setType(st), st.toString+"@"+size) + Constructor[Expr, TypeTree]( + (1 to size).map(i => sub).toList, + st, + s => finiteSet(s.toSet, sub), + st.toString+"@"+size + ) } constructors += st -> cs cs @@ -87,7 +98,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { val cs = for (size <- List(0, 1, 2, 5)) yield { val subs = (1 to size).flatMap(i => List(from, to)).toList - Constructor[Expr, TypeTree](subs, mt, s => FiniteMap(s.grouped(2).map(t => (t(0), t(1))).toSeq).setType(mt), mt.toString+"@"+size) + Constructor[Expr, TypeTree](subs, mt, s => finiteMap(s.grouped(2).map(t => (t(0), t(1))).toSeq, from, to), mt.toString+"@"+size) } constructors += mt -> cs cs diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala index 851731bb9..7903f8c5e 100644 --- a/src/main/scala/leon/evaluators/DualEvaluator.scala +++ b/src/main/scala/leon/evaluators/DualEvaluator.scala @@ -6,7 +6,7 @@ package evaluators import purescala.Common._ import purescala.Trees._ import purescala.Definitions._ -import purescala.TypeTrees.MutableTyped +import purescala.TypeTrees._ import codegen._ @@ -29,7 +29,9 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte def withVars(news: Map[Identifier, Expr]) = copy(news) } - case class RawObject(o: AnyRef) extends Expr with MutableTyped + case class RawObject(o: AnyRef, tpe: TypeTree) extends Expr { + val getType = tpe + } def call(tfd: TypedFunDef, args: Seq[AnyRef]): Expr = { @@ -50,7 +52,7 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte meth.invoke(null, allArgs : _*) } - RawObject(res).setType(tfd.returnType) + RawObject(res, tfd.returnType) } catch { case e: java.lang.reflect.InvocationTargetException => throw new RuntimeError(e.getCause.getMessage) @@ -74,7 +76,7 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte val res = field.get(null) - RawObject(res).setType(fd.returnType) + RawObject(res, fd.returnType) } catch { case e: java.lang.reflect.InvocationTargetException => throw new RuntimeError(e.getCause.getMessage) @@ -93,7 +95,7 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte if (!tfd.fd.canBeStrictField) { val rargs = args.map( e(_)(rctx.copy(needJVMRef = true), gctx) match { - case RawObject(obj) => obj + case RawObject(obj, _) => obj case _ => throw new EvalError("Failed to get JVM ref when requested") } ) @@ -111,12 +113,12 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte def jvmBarrier(e: Expr, returnJVMRef: Boolean): Expr = { e match { - case RawObject(obj) if returnJVMRef => + case RawObject(obj, _) if returnJVMRef => e - case RawObject(obj) if !returnJVMRef => + case RawObject(obj, _) if !returnJVMRef => unit.jvmToExpr(obj, e.getType) case e if returnJVMRef => - RawObject(unit.exprToJVM(e)(monitor)).setType(e.getType) + RawObject(unit.exprToJVM(e)(monitor), e.getType) case e if !returnJVMRef => e } diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 7db4c90ea..a741602dd 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -386,7 +386,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case SetUnion(s1,s2) => (e(s1), e(s2)) match { - case (f@FiniteSet(els1),FiniteSet(els2)) => FiniteSet(els1 ++ els2).setType(f.getType) + case (f@FiniteSet(els1),FiniteSet(els2)) => + val SetType(tpe) = f.getType + finiteSet(els1 ++ els2, tpe) case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) } @@ -394,8 +396,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (e(s1), e(s2)) match { case (f @ FiniteSet(els1), FiniteSet(els2)) => { val newElems = (els1 intersect els2) - val baseType = f.getType.asInstanceOf[SetType].base - FiniteSet(newElems).setType(f.getType) + val SetType(tpe) = f.getType + finiteSet(newElems, tpe) } case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) } @@ -403,9 +405,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case SetDifference(s1,s2) => (e(s1), e(s2)) match { case (f @ FiniteSet(els1),FiniteSet(els2)) => { + val SetType(tpe) = f.getType val newElems = els1 -- els2 - val baseType = f.getType.asInstanceOf[SetType].base - FiniteSet(newElems).setType(f.getType) + finiteSet(newElems, tpe) } case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) } @@ -425,7 +427,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case _ => throw EvalError(typeErrorMsg(sr, SetType(Untyped))) } - case f @ FiniteSet(els) => FiniteSet(els.map(e(_))).setType(f.getType) + case f @ FiniteSet(els) => + val SetType(tp) = f.getType + finiteSet(els.map(e(_)), tp) case i @ IntLiteral(_) => i case i @ InfiniteIntegerLiteral(_) => i case b @ BooleanLiteral(_) => b @@ -443,7 +447,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val IntLiteral(index) = ri val FiniteArray(elems, default, length) = ra - FiniteArray(elems.updated(index, rv), default, length).setType(ra.getType) + val ArrayType(tp) = ra.getType + finiteArray(elems.updated(index, rv), default map { (_, length) }, tp) case ArraySelect(a, i) => val IntLiteral(index) = e(i) @@ -454,14 +459,17 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case e : IndexOutOfBoundsException => throw RuntimeError(e.getMessage) } - case FiniteArray(elems, default, length) => - FiniteArray( + case f @ FiniteArray(elems, default, length) => + val ArrayType(tp) = f.getType + finiteArray( elems.map(el => (el._1, e(el._2))), - default.map(e), - e(length) - ).setType(expr.getType) + default.map{ d => (e(d), e(length)) }, + tp + ) - case f @ FiniteMap(ss) => FiniteMap(ss.map{ case (k, v) => (e(k), e(v)) }.distinct).setType(f.getType) + case f @ FiniteMap(ss) => + val MapType(kT, vT) = f.getType + finiteMap(ss.map{ case (k, v) => (e(k), e(v)) }.distinct, kT, vT) case g @ MapGet(m,k) => (e(m), e(k)) match { case (FiniteMap(ss), e) => ss.find(_._1 == e) match { case Some((_, v0)) => v0 @@ -473,7 +481,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case (f1@FiniteMap(ss1), FiniteMap(ss2)) => { val filtered1 = ss1.filterNot(s1 => ss2.exists(s2 => s2._1 == s1._1)) val newSs = filtered1 ++ ss2 - FiniteMap(newSs).setType(f1.getType) + val MapType(kT, vT) = u.getType + finiteMap(newSs, kT, vT) } case (l, r) => throw EvalError(typeErrorMsg(l, m1.getType)) } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 86d7247ef..9a298342a 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -921,7 +921,15 @@ trait CodeExtraction extends ASTExtractors { case ExInt32Literal(i) => (LiteralPattern(binder, IntLiteral(i)), dctx) case ExBooleanLiteral(b) => (LiteralPattern(binder, BooleanLiteral(b)), dctx) case ExUnitLiteral() => (LiteralPattern(binder, UnitLiteral()), dctx) - case ExStringLiteral(s) => (LiteralPattern(binder, StringLiteral(s)), dctx) + case sLit@ExStringLiteral(s) => + val consClass = libraryCaseClass(sLit.pos, "leon.collection.Cons") + val nilClass = libraryCaseClass(sLit.pos, "leon.collection.Nil") + val nil = CaseClassPattern(None, CaseClassType(nilClass, Seq(CharType)), Seq()) + val consType = CaseClassType(consClass, Seq(CharType)) + def mkCons(hd: Pattern, tl: Pattern) = CaseClassPattern(None, consType, Seq(hd,tl)) + val chars = s.toCharArray()//.asInstanceOf[Seq[Char]] + def charPat(ch : Char) = LiteralPattern(None, CharLiteral(ch)) + (chars.foldRight(nil)( (ch: Char, p : Pattern) => mkCons( charPat(ch), p)), dctx) case _ => outOfSubsetError(p, "Unsupported pattern: "+p.getClass) @@ -1014,7 +1022,7 @@ trait CodeExtraction extends ASTExtractors { val UnwrapTuple(ines) = ine ines foreach { - case v : Variable if currentFunDef.params.map{ _.toVariable } contains v => + case v @ Variable(_) if currentFunDef.params.map{ _.toVariable } contains v => case LeonThis(_) => case other => ctx.reporter.fatalError(other.getPos, "Only i/o variables are allowed in i/o examples") } @@ -1030,7 +1038,7 @@ trait CodeExtraction extends ASTExtractors { gives(rs, rc) case ExArrayLiteral(tpe, args) => - FiniteArray(args.map(extractTree)).setType(ArrayType(extractType(tpe)(dctx, current.pos))) + finiteArray(args.map(extractTree), None, extractType(tpe)(dctx, current.pos)) case ExCaseObject(sym) => getClassDef(sym, current.pos) match { @@ -1168,17 +1176,17 @@ trait CodeExtraction extends ASTExtractors { case epsi @ ExEpsilonExpression(tpt, varSym, predBody) => val pstpe = extractType(tpt) - val nctx = dctx.withNewVar(varSym -> (() => EpsilonVariable(epsi.pos).setType(pstpe))) + val nctx = dctx.withNewVar(varSym -> (() => EpsilonVariable(epsi.pos, pstpe))) val c1 = extractTree(predBody)(nctx) if(containsEpsilon(c1)) { outOfSubsetError(epsi, "Usage of nested epsilon is not allowed") } - Epsilon(c1).setType(pstpe) + Epsilon(c1, pstpe) case ExWaypointExpression(tpt, i, tree) => val pstpe = extractType(tpt) val IntLiteral(ri) = extractTree(i) - Waypoint(ri, extractTree(tree)).setType(pstpe) + Waypoint(ri, extractTree(tree), pstpe) case update @ ExUpdate(lhs, index, newValue) => val lhsRec = extractTree(lhs) @@ -1371,14 +1379,14 @@ trait CodeExtraction extends ASTExtractors { case ExFiniteSet(tt, args) => val underlying = extractType(tt) - FiniteSet(args.map(extractTree(_)).toSet).setType(SetType(underlying)) + finiteSet(args.map(extractTree(_)).toSet, underlying) + case ExEmptySet(tt) => + val underlying = extractType(tt) + EmptySet(underlying) case ExFiniteMultiset(tt, args) => - FiniteMultiset(args.map(extractTree(_))) - - case ExEmptySet(tt) => val underlying = extractType(tt) - FiniteSet(Set()).setType(SetType(underlying)) + finiteMultiset(args.map(extractTree(_)),underlying) case ExEmptyMultiset(tt) => val underlying = extractType(tt) @@ -1387,14 +1395,11 @@ trait CodeExtraction extends ASTExtractors { case ExEmptyMap(ft, tt) => val fromUnderlying = extractType(ft) val toUnderlying = extractType(tt) - val tpe = MapType(fromUnderlying, toUnderlying) - - FiniteMap(Seq()).setType(tpe) + EmptyMap(fromUnderlying, toUnderlying) case ExLiteralMap(ft, tt, elems) => val fromUnderlying = extractType(ft) val toUnderlying = extractType(tt) - val tpe = MapType(fromUnderlying, toUnderlying) val singletons: Seq[(LeonExpr, LeonExpr)] = elems.collect { case ExTuple(tpes, trees) if (trees.size == 2) => @@ -1405,13 +1410,12 @@ trait CodeExtraction extends ASTExtractors { outOfSubsetError(tr, "Some map elements could not be extracted as Tuple2") } - FiniteMap(singletons).setType(tpe) + finiteMap(singletons, fromUnderlying, toUnderlying) case ExArrayFill(baseType, length, defaultValue) => - val underlying = extractType(baseType) val lengthRec = extractTree(length) val defaultValueRec = extractTree(defaultValue) - FiniteArray(Map(), Some(defaultValueRec), lengthRec).setType(ArrayType(underlying)) + NonemptyArray(Map(), Some(defaultValueRec, lengthRec)) case ExIfThenElse(t1,t2,t3) => val r1 = extractTree(t1) @@ -1661,7 +1665,7 @@ trait CodeExtraction extends ASTExtractors { MapIsDefinedAt(a1, a2) case (IsTyped(a1, mt: MapType), "updated", List(k, v)) => - MapUnion(a1, FiniteMap(Seq((k, v))).setType(mt)) + MapUnion(a1, NonemptyMap(Seq((k, v)))) case (IsTyped(a1, mt1: MapType), "++", List(IsTyped(a2, mt2: MapType))) if mt1 == mt2 => MapUnion(a1, a2) diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 51e3e8da4..3c74cc29c 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -189,6 +189,35 @@ object Constructors { case _ => Implies(lhs, rhs) } + def finiteSet(els: Set[Expr], tpe: TypeTree) = { + if (els.isEmpty) EmptySet(tpe) + else NonemptySet(els) + } + + def finiteMultiset(els: Seq[Expr], tpe: TypeTree) = { + if (els.isEmpty) EmptyMultiset(tpe) + else NonemptyMultiset(els) + } + + def finiteMap(els: Seq[(Expr, Expr)], keyType: TypeTree, valueType: TypeTree) = { + if (els.isEmpty) EmptyMap(keyType, valueType) + else NonemptyMap(els.distinct) + } + + def finiteArray(els: Seq[Expr]): Expr = { + require(!els.isEmpty) + finiteArray(els, None, Untyped) // Untyped is not correct, but will not be used anyway + } + + def finiteArray(els: Seq[Expr], defaultLength: Option[(Expr, Expr)], tpe: TypeTree): Expr = { + finiteArray(els.zipWithIndex.map{ _.swap }.toMap, defaultLength, tpe) + } + + def finiteArray(els: Map[Int, Expr], defaultLength: Option[(Expr, Expr)], tpe: TypeTree): Expr = { + if (els.isEmpty && defaultLength.isEmpty) EmptyArray(tpe) + else NonemptyArray(els, defaultLength) + } + def finiteLambda(dflt: Expr, els: Seq[(Expr, Expr)], tpe: FunctionType): Lambda = { val args = tpe.from.zipWithIndex.map { case (tpe, idx) => ValDef(FreshIdentifier(s"x${idx + 1}").setType(tpe), tpe) diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 88cd305ba..2c459bc5f 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -56,7 +56,7 @@ object Definitions { def subDefinitions = Seq() - def toVariable : Variable = Variable(id).setType(tpe) + def toVariable : Variable = Variable(id, Some(tpe)) setSubDefOwners() } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 52b434db6..291aa8487 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -101,51 +101,41 @@ object Extractors { case And(args) => Some((args, and)) case Or(args) => Some((args, or)) case FiniteSet(args) => - Some((args.toSeq, - { newargs => - if (newargs.isEmpty) { - FiniteSet(Set()).setType(expr.getType) - } else { - FiniteSet(newargs.toSet) - } - } - )) + val SetType(tpe) = expr.getType + Some((args.toSeq, els => finiteSet(els.toSet, tpe))) case FiniteMap(args) => { val subArgs = args.flatMap{case (k, v) => Seq(k, v)} val builder: (Seq[Expr]) => Expr = (as: Seq[Expr]) => { - val (keys, values, isKey) = as.foldLeft[(List[Expr], List[Expr], Boolean)]((Nil, Nil, true)){ - case ((keys, values, isKey), rExpr) => if(isKey) (rExpr::keys, values, false) else (keys, rExpr::values, true) + def rec(kvs: Seq[Expr]) : Seq[(Expr, Expr)] = kvs match { + case Seq(k, v, t@_*) => + (k,v) +: rec(t) + case Seq() => Seq() + case _ => sys.error("odd number of key/value expressions") } - assert(isKey) - val tpe = (keys, values) match { - case (Seq(), Seq()) => expr.getType - case _ => - MapType( - bestRealType(leastUpperBound(keys.map (_.getType)).get), - bestRealType(leastUpperBound(values.map(_.getType)).get) - ) - } - FiniteMap(keys.zip(values)).setType(tpe) + val MapType(keyType, valueType) = expr.getType + finiteMap(rec(as), keyType, valueType) } Some((subArgs, builder)) } - case FiniteMultiset(args) => Some((args, FiniteMultiset)) - case ArrayUpdated(t1, t2, t3) => Some((Seq(t1,t2,t3), (as: Seq[Expr]) => ArrayUpdated(as(0), as(1), as(2)))) + case FiniteMultiset(args) => + val MultisetType(tpe) = expr.getType + Some((args, finiteMultiset(_, tpe))) + case ArrayUpdated(t1, t2, t3) => Some((Seq(t1,t2,t3), (as: Seq[Expr]) => + ArrayUpdated(as(0), as(1), as(2)))) case FiniteArray(elems, default, length) => { val fixedElems: Seq[(Int, Expr)] = elems.toSeq val all: Seq[Expr] = fixedElems.map(_._2) ++ default ++ Seq(length) Some((all, (as: Seq[Expr]) => { - val tpe = leastUpperBound(as.map(_.getType)) - .map(ArrayType(_)) - .getOrElse(expr.getType) + val ArrayType(tpe) = expr.getType val (newElems, newDefault, newSize) = default match { case None => (as.init, None, as.last) case Some(_) => (as.init.init, Some(as.init.last), as.last) } - FiniteArray( + finiteArray( fixedElems.zip(newElems).map(p => (p._1._1, p._2)).toMap, - newDefault, - newSize).setType(tpe) + newDefault map ((_, newSize)), + tpe + ) })) } @@ -241,6 +231,39 @@ object Extractors { def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)]; } + object StringLiteral { + def unapply(e: Expr): Option[String] = e match { + case CaseClass(cct, args) => + DefOps.programOf(cct.classDef) flatMap { p => + val lib = p.library + + if (Some(cct.classDef) == lib.String) { + isListLiteral(args(0)) match { + case Some((_, chars)) => + val str = chars.map { + case CharLiteral(c) => Some(c) + case _ => None + } + + if (str.forall(_.isDefined)) { + Some(str.flatten.mkString) + } else { + None + } + case _ => + None + + } + } else { + None + } + } + case _ => + None + } + } + + object TopLevelOrs { // expr1 AND (expr2 AND (expr3 AND ..)) => List(expr1, expr2, expr3) def unapply(e: Expr): Option[Seq[Expr]] = e match { case Or(exprs) => @@ -269,7 +292,7 @@ object Extractors { def rec(body: Expr): Option[(Expr, Seq[(Expr, Expr)])] = body match { case _ : IntLiteral | _ : UMinus | _ : BooleanLiteral | _ : GenericValue | _ : Tuple | - _ : CaseClass | _ : FiniteArray | _ : FiniteSet | _ : FiniteMap | _ : Lambda => + _ : CaseClass | FiniteArray(_, _, _) | FiniteSet(_) | FiniteMap(_) | _ : Lambda => Some(body -> Seq.empty) case IfExpr(Equals(tpArgs, key), expr, elze) if tpArgs == argsTuple => rec(elze).map { case (dflt, mapping) => dflt -> ((key -> expr) +: mapping) } @@ -280,6 +303,43 @@ object Extractors { } } + object FiniteSet { + def unapply(e: Expr): Option[Set[Expr]] = e match { + case EmptySet(_) => Some(Set()) + case NonemptySet(els) => Some(els) + case _ => None + } + } + + object FiniteMultiset { + def unapply(e: Expr): Option[Seq[Expr]] = e match { + case EmptyMultiset(_) => Some(Seq()) + case NonemptyMultiset(els) => Some(els) + case _ => None + } + } + + object FiniteMap { + def unapply(e: Expr): Option[Seq[(Expr, Expr)]] = e match { + case EmptyMap(_, _) => Some(Seq()) + case NonemptyMap(pairs) => Some(pairs) + case _ => None + } + } + + object FiniteArray { + def unapply(e: Expr): Option[(Map[Int, Expr], Option[Expr], Expr)] = e match { + case EmptyArray(_) => + Some((Map(), None, IntLiteral(0))) + case NonemptyArray(els, Some((default, length))) => + Some((els, Some(default), length)) + case NonemptyArray(els, None) => + Some((els, None, IntLiteral(els.size))) + case _ => + None + } + } + object MatchLike { def unapply(m : MatchLike) : Option[(Expr, Seq[MatchCase], (Expr, Seq[MatchCase]) => Expr)] = { Option(m) map { m => @@ -361,6 +421,6 @@ object Extractors { case _ => None }} } - } + } } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 1559634ed..e8849a549 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -13,9 +13,9 @@ import utils._ import java.lang.StringBuffer import PrinterHelpers._ -import TreeOps.{isStringLiteral, isListLiteral, simplestValue, variablesOf} +import TreeOps.{isListLiteral, simplestValue, variablesOf} import TypeTreeOps.leastUpperBound -import Extractors.LetPattern +import Extractors._ import synthesis.Witnesses._ @@ -239,6 +239,14 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case e @ CaseClass(cct, args) => isListLiteral(e) match { case Some((tpe, elems)) => + val chars = elems.collect{case CharLiteral(ch) => ch} + if (chars.length == elems.length) { + // String literal + val str = chars mkString "" + val q = '"'; + p"$q$str$q" + } + val elemTps = leastUpperBound(elems.map(_.getType)) if (elemTps == Some(tpe)) { p"List($elems)" @@ -246,19 +254,12 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe p"List[$tpe]($elems)" } - case None => - isStringLiteral(e) match { - case Some(str) => - val q = '"'; - p"$q$str$q" - - case None => - if (cct.classDef.isCaseObject) { - p"$cct" - } else { - p"$cct($args)" - } - } + case None => + if (cct.classDef.isCaseObject) { + p"$cct" + } else { + p"$cct($args)" + } } @@ -275,7 +276,6 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case InfiniteIntegerLiteral(v) => p"BigInt($v)" case CharLiteral(v) => p"$v" case BooleanLiteral(v) => p"$v" - case StringLiteral(s) => p""""$s"""" case UnitLiteral() => p"()" case GenericValue(tp, id) => p"$tp#$id" case Tuple(exprs) => p"($exprs)" @@ -473,6 +473,7 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case WildcardPattern(Some(id)) => p"$id" case CaseClassPattern(ob, cct, subps) => + // TODO specialize for strings ob.foreach { b => p"$b @ " } // Print only the classDef because we don't want type parameters in patterns printWithPath(cct.classDef) diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index 82ddb9cb8..ab9ff171f 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -8,6 +8,7 @@ import Trees._ import TypeTrees._ import Definitions._ import Constructors._ +import Extractors._ import PrinterHelpers._ diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 1a256bc82..857b82fc2 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1002,10 +1002,10 @@ object TreeOps { case CharType => CharLiteral('a') case BooleanType => BooleanLiteral(false) case UnitType => UnitLiteral() - case SetType(baseType) => FiniteSet(Set()).setType(tpe) - case MapType(fromType, toType) => FiniteMap(Seq()).setType(tpe) + case SetType(baseType) => EmptySet(tpe) + case MapType(fromType, toType) => EmptyMap(fromType, toType) case TupleType(tpes) => Tuple(tpes.map(simplestValue)) - case ArrayType(tpe) => FiniteArray(Map(), Some(simplestValue(tpe)), IntLiteral(0)).setType(ArrayType(tpe)) + case ArrayType(tpe) => EmptyArray(tpe) case act @ AbstractClassType(acd, tpe) => val children = acd.knownChildren @@ -2090,36 +2090,6 @@ object TreeOps { (fds.values.toSet, res2) } - def isStringLiteral(e: Expr): Option[String] = e match { - case CaseClass(cct, args) => - programOf(cct.classDef) flatMap { p => - val lib = p.library - - if (Some(cct.classDef) == lib.String) { - isListLiteral(args(0)) match { - case Some((_, chars)) => - val str = chars.map { - case CharLiteral(c) => Some(c) - case _ => None - } - - if (str.forall(_.isDefined)) { - Some(str.flatten.mkString) - } else { - None - } - case _ => - None - - } - } else { - None - } - } - case _ => - None - } - def isListLiteral(e: Expr): Option[(TypeTree, List[Expr])] = e match { case CaseClass(cct, args) => programOf(cct.classDef) flatMap { p => diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 278d66f66..3ce5566d4 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -17,7 +17,10 @@ object Trees { /* EXPRESSIONS */ - abstract class Expr extends Tree with Typed with Serializable + abstract class Expr extends Tree with Typed with Serializable { + // All Expr's have constant type + override val getType: TypeTree + } trait Terminal { self: Expr => @@ -35,21 +38,21 @@ object Trees { } case class Require(pred: Expr, body: Expr) extends Expr with Typed { - def getType = body.getType + val getType = body.getType } case class Ensuring(body: Expr, id: Identifier, pred: Expr) extends Expr { - def getType = body.getType + val getType = body.getType } case class Assert(pred: Expr, error: Option[String], body: Expr) extends Expr { - def getType = body.getType + val getType = body.getType } case class Choose(vars: List[Identifier], pred: Expr, var impl: Option[Expr] = None) extends Expr with NAryExtractable { require(!vars.isEmpty) - def getType = if (vars.size > 1) TupleType(vars.map(_.getType)) else vars.head.getType + val getType = if (vars.size > 1) TupleType(vars.map(_.getType)) else vars.head.getType def extract = { Some((Seq(pred)++impl, (es: Seq[Expr]) => Choose(vars, es.head, es.tail.headOption).setPos(this))) @@ -58,15 +61,15 @@ object Trees { /* Like vals */ case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr { - def getType = body.getType + val getType = body.getType } case class LetDef(fd: FunDef, body: Expr) extends Expr { - def getType = body.getType + val getType = body.getType } case class FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) extends Expr { - def getType = tfd.returnType + val getType = tfd.returnType } /** @@ -76,7 +79,7 @@ object Trees { * This becomes first argument, and MethodInvocation become FunctionInvocation. */ case class MethodInvocation(rec: Expr, cd: ClassDef, tfd: TypedFunDef, args: Seq[Expr]) extends Expr { - def getType = { + val getType = { // We need ot instanciate the type based on the type of the function as well as receiver val fdret = tfd.returnType val extraMap: Map[TypeParameterDef, TypeTree] = rec.getType match { @@ -92,35 +95,35 @@ object Trees { case class Application(caller: Expr, args: Seq[Expr]) extends Expr { require(caller.getType.isInstanceOf[FunctionType]) - def getType = caller.getType.asInstanceOf[FunctionType].to + val getType = caller.getType.asInstanceOf[FunctionType].to } case class Lambda(args: Seq[ValDef], body: Expr) extends Expr { - def getType = FunctionType(args.map(_.tpe), body.getType) + val getType = FunctionType(args.map(_.tpe), body.getType) } case class Forall(args: Seq[ValDef], body: Expr) extends Expr { require(body.getType == BooleanType) - def getType = BooleanType + val getType = BooleanType } case class This(ct: ClassType) extends Expr with Terminal { - def getType = ct + val getType = ct } case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr { - def getType = leastUpperBound(thenn.getType, elze.getType).getOrElse(Untyped) + val getType = leastUpperBound(thenn.getType, elze.getType).getOrElse(Untyped) } case class Tuple(exprs: Seq[Expr]) extends Expr { - def getType = TupleType(exprs.map(_.getType)) + val getType = TupleType(exprs.map(_.getType)) } // Index is 1-based, first element of tuple is 1. case class TupleSelect(tuple: Expr, index: Int) extends Expr { require(index >= 1) - def getType = tuple.getType match { + val getType = tuple.getType match { case TupleType(ts) => require(index <= ts.size) ts(index - 1) @@ -133,7 +136,7 @@ object Trees { abstract sealed class MatchLike extends Expr { val scrutinee : Expr val cases : Seq[MatchCase] - def getType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse(Untyped) + val getType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse(Untyped) } case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends MatchLike { @@ -197,7 +200,7 @@ object Trees { /* Propositional logic */ case class And(exprs: Seq[Expr]) extends Expr { - def getType = BooleanType + val getType = BooleanType require(exprs.size >= 2) } @@ -207,7 +210,7 @@ object Trees { } case class Or(exprs: Seq[Expr]) extends Expr { - def getType = BooleanType + val getType = BooleanType require(exprs.size >= 2) } @@ -217,7 +220,7 @@ object Trees { } case class Implies(lhs: Expr, rhs: Expr) extends Expr { - def getType = BooleanType + val getType = BooleanType } case class Not(expr: Expr) extends Expr { @@ -228,15 +231,19 @@ object Trees { val getType = BooleanType } - case class Variable(id: Identifier) extends Expr with Terminal { - private var _tpe = id.getType - - def setType(tpe: TypeTree): this.type = { - _tpe = tpe - this - } + // tpe overrides the type of the identifier. + // This is useful for variables that represent class fields with instantiated types. + // E.g. list.head when list: List[Int] + // @mk: I know this breaks symmetry with the rest of the trees, but it does seem + // like a natural way to implement this. Feel free to rename the underlying class + // and define constructor/extractor + class Variable(val id: Identifier, val tpe: Option[TypeTree]) extends Expr with Terminal { + val getType = tpe getOrElse id.getType + } - def getType = _tpe + object Variable { + def apply(id: Identifier, tpe: Option[TypeTree] = None) = new Variable(id, tpe) + def unapply(v: Variable) = Some(v.id) } /* Literals */ @@ -263,8 +270,6 @@ object Trees { val getType = BooleanType } - case class StringLiteral(value: String) extends Literal[String] with MutableTyped - case class UnitLiteral() extends Literal[Unit] { val getType = UnitType val value = () @@ -298,7 +303,7 @@ object Trees { class CaseClassSelector(val classType: CaseClassType, val caseClass: Expr, val selector: Identifier) extends Expr { val selectorIndex = classType.classDef.selectorID2Index(selector) - def getType = classType.fieldsTypes(selectorIndex) + val getType = classType.fieldsTypes(selectorIndex) override def equals(that: Any): Boolean = (that != null) && (that match { case t: CaseClassSelector => (t.classType, t.caseClass, t.selector) == (classType, caseClass, selector) @@ -395,11 +400,15 @@ object Trees { } /* Set expressions */ - case class FiniteSet(elements: Set[Expr]) extends Expr with MutableTyped { - val tpe = if (elements.isEmpty) None else leastUpperBound(elements.toSeq.map(_.getType)) - tpe.filter(_ != Untyped).foreach(t => setType(SetType(t))) + case class NonemptySet(elements: Set[Expr]) extends Expr { + require(elements.nonEmpty) + val getType = SetType(leastUpperBound(elements.toSeq.map(_.getType))).unveilUntyped } - + + case class EmptySet(tpe: TypeTree) extends Expr { + val getType = SetType(tpe) + } + case class ElementOfSet(element: Expr, set: Expr) extends Expr { val getType = BooleanType } @@ -411,27 +420,43 @@ object Trees { } case class SetIntersection(set1: Expr, set2: Expr) extends Expr { - def getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped) + val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped) } case class SetUnion(set1: Expr, set2: Expr) extends Expr { - def getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped) + val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped) } case class SetDifference(set1: Expr, set2: Expr) extends Expr { - def getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped) + val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped) } /* Map operations. */ - case class FiniteMap(singletons: Seq[(Expr, Expr)]) extends Expr with MutableTyped + case class NonemptyMap(singletons: Seq[(Expr, Expr)]) extends Expr { + require(singletons.nonEmpty) + val getType = { + val (keys, values) = singletons.unzip + MapType( + leastUpperBound(keys.map(_.getType)), + leastUpperBound(values.map(_.getType)) + ).unveilUntyped + } + } + + case class EmptyMap(keyType: TypeTree, valueType: TypeTree) extends Expr { + val getType = MapType(keyType, valueType).unveilUntyped + } + case class MapGet(map: Expr, key: Expr) extends Expr { - def getType = map.getType match { + val getType = map.getType match { case MapType(from, to) => to case _ => Untyped } } case class MapUnion(map1: Expr, map2: Expr) extends Expr { - def getType = leastUpperBound(Seq(map1, map2).map(_.getType)).getOrElse(Untyped) + val getType = leastUpperBound(Seq(map1, map2).map(_.getType)).getOrElse(Untyped) + } + case class MapDifference(map: Expr, keys: Expr) extends Expr { + val getType = map.getType } - case class MapDifference(map: Expr, keys: Expr) extends Expr with MutableTyped case class MapIsDefinedAt(map: Expr, key: Expr) extends Expr { val getType = BooleanType } @@ -439,7 +464,7 @@ object Trees { /* Array operations */ case class ArraySelect(array: Expr, index: Expr) extends Expr { - def getType = array.getType match { + val getType = array.getType match { case ArrayType(base) => base case _ => @@ -448,7 +473,7 @@ object Trees { } case class ArrayUpdated(array: Expr, index: Expr, newValue: Expr) extends Expr { - def getType = array.getType match { + val getType = array.getType match { case ArrayType(base) => leastUpperBound(base, newValue.getType).map(ArrayType(_)).getOrElse(Untyped) case _ => @@ -460,14 +485,13 @@ object Trees { val getType = Int32Type } - case class FiniteArray(elems: Map[Int, Expr], default: Option[Expr], length: Expr) extends Expr with MutableTyped + case class NonemptyArray(elems: Map[Int, Expr], defaultLength: Option[(Expr, Expr)]) extends Expr { + private val elements = elems.values.toList ++ defaultLength.map{_._1} + val getType = ArrayType(leastUpperBound(elements map { _.getType})).unveilUntyped + } - object FiniteArray { - def apply(elems: Seq[Expr]): FiniteArray = { - val res = FiniteArray(elems.zipWithIndex.map(_.swap).toMap, None, IntLiteral(elems.size)) - elems.headOption.foreach(e => res.setType(ArrayType(e.getType))) - res - } + case class EmptyArray(tpe: TypeTree) extends Expr { + val getType = ArrayType(tpe).unveilUntyped } /* Special trees */ @@ -476,7 +500,7 @@ object Trees { case class WithOracle(oracles: List[Identifier], body: Expr) extends Expr with UnaryExtractable { require(!oracles.isEmpty) - def getType = body.getType + val getType = body.getType def extract = { Some((body, (e: Expr) => WithOracle(oracles, e).setPos(this))) @@ -495,45 +519,56 @@ object Trees { * DEPRECATED TREES * These trees are not guaranteed to be supported by Leon. **/ + @deprecated("3.0", "Use NonemptyArray with default value") case class ArrayFill(length: Expr, defaultValue: Expr) extends Expr { - def getType = ArrayType(defaultValue.getType) + val getType = ArrayType(defaultValue.getType) } + @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") case class SetMin(set: Expr) extends Expr { val getType = Int32Type } + @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") case class SetMax(set: Expr) extends Expr { val getType = Int32Type } + @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") case class EmptyMultiset(baseType: TypeTree) extends Expr with Terminal { - val getType = MultisetType(baseType) + val getType = MultisetType(baseType).unveilUntyped } - case class FiniteMultiset(elements: Seq[Expr]) extends Expr { - require(elements.nonEmpty) - def getType = MultisetType(leastUpperBound(elements.map(_.getType)).getOrElse(Untyped)) + @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") + case class NonemptyMultiset(elements: Seq[Expr]) extends Expr { + val getType = MultisetType(leastUpperBound(elements.toSeq.map(_.getType))).unveilUntyped } + @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") case class Multiplicity(element: Expr, multiset: Expr) extends Expr { val getType = Int32Type } + @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") case class MultisetCardinality(multiset: Expr) extends Expr { val getType = Int32Type } + @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") case class MultisetIntersection(multiset1: Expr, multiset2: Expr) extends Expr { - def getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) + val getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) } + @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") case class MultisetUnion(multiset1: Expr, multiset2: Expr) extends Expr { - def getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) + val getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) } + @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") case class MultisetPlus(multiset1: Expr, multiset2: Expr) extends Expr { // disjoint union - def getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) + val getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) } + @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") case class MultisetDifference(multiset1: Expr, multiset2: Expr) extends Expr { - def getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) + val getType = leastUpperBound(Seq(multiset1, multiset2).map(_.getType)).getOrElse(Untyped) } + @deprecated("3.0", "Leon does not guarantee to correctly handle this expression") case class MultisetToSet(multiset: Expr) extends Expr { - def getType = multiset.getType match { + val getType = multiset.getType match { case MultisetType(base) => SetType(base) case _ => Untyped } diff --git a/src/main/scala/leon/purescala/TypeTreeOps.scala b/src/main/scala/leon/purescala/TypeTreeOps.scala index 518d19370..f8ca94198 100644 --- a/src/main/scala/leon/purescala/TypeTreeOps.scala +++ b/src/main/scala/leon/purescala/TypeTreeOps.scala @@ -315,7 +315,8 @@ object TypeTreeOps { Ensuring(srec(body), newId, rec(idsMap + (id -> newId))(pred)).copiedFrom(ens) case s @ FiniteSet(elements) if elements.isEmpty => - FiniteSet(Set()).setType(tpeSub(s.getType)).copiedFrom(s) + val SetType(tp) = s.getType + EmptySet(tpeSub(tp)).copiedFrom(s) case v @ Variable(id) if idsMap contains id => Variable(idsMap(id)).copiedFrom(v) diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index 05c90501e..139729506 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -26,10 +26,7 @@ object TypeTrees { private var _type: Option[TypeTree] = None - def getType: TypeTree = _type match { - case None => Untyped - case Some(t) => t - } + def getType: TypeTree = _type getOrElse Untyped def setType(tt: TypeTree): self.type = _type match { case None => _type = Some(tt); this @@ -51,7 +48,12 @@ object TypeTrees { } abstract class TypeTree extends Tree with Typed { - def getType = this + val getType = this + def unveilUntyped: TypeTree = this match { + case NAryType(tps, builder) => + val subs = tps map { _.unveilUntyped } + if (subs contains Untyped) Untyped else builder(subs) + } } case object Untyped extends TypeTree @@ -155,4 +157,7 @@ object TypeTrees { case t => Some(Nil, fake => t) } } + + implicit def optTypeToType(tp: Option[TypeTree]) = tp getOrElse Untyped + } diff --git a/src/main/scala/leon/repair/rules/GuidedDecomp.scala b/src/main/scala/leon/repair/rules/GuidedDecomp.scala index 4297b616d..f69130658 100644 --- a/src/main/scala/leon/repair/rules/GuidedDecomp.scala +++ b/src/main/scala/leon/repair/rules/GuidedDecomp.scala @@ -66,7 +66,7 @@ case object GuidedDecomp extends Rule("Guided Decomp") { val subs = for ((c, cond) <- cs zip matchCasePathConditions(fullMatch, List(p.pc))) yield { - val localScrut = c.pattern.binder.map(Variable) getOrElse scrut + val localScrut = c.pattern.binder.map( Variable(_) ) getOrElse scrut val scrutConstraint = if (localScrut == scrut) BooleanLiteral(true) else Equals(localScrut, scrut) val substs = patternSubstitutions(localScrut, c.pattern) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala index 6b906d640..7b4668d83 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala @@ -50,7 +50,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { GenericValue(tp, n.toInt) case (QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset"), Seq()), _), SetType(base)) => - FiniteSet(Set()).setType(tpe) + EmptySet(base) case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), RawArrayType(k,v)) => RawArrayValue(k, Map(), fromSMT(elem, v)) @@ -67,17 +67,17 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { finiteLambda(dflt, mapping :+ (fromSMT(key, TupleType(from)) -> fromSMT(elem, to)), ft) case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), SetType(base)) => - FiniteSet(elems.map(fromSMT(_, base)).toSet).setType(tpe) + finiteSet(elems.map(fromSMT(_, base)).toSet, base) case (FunctionApplication(SimpleSymbol(SSymbol("insert")), elems), SetType(base)) => val selems = elems.init.map(fromSMT(_, base)) val FiniteSet(se) = fromSMT(elems.last, tpe) - FiniteSet(se ++ selems).setType(tpe) + finiteSet(se ++ selems, base) case (FunctionApplication(SimpleSymbol(SSymbol("union")), elems), SetType(base)) => - FiniteSet(elems.map(fromSMT(_, tpe) match { + finiteSet(elems.map(fromSMT(_, tpe) match { case FiniteSet(elems) => elems - }).flatten.toSet).setType(tpe) + }).flatten.toSet, base) // FIXME (nicolas) // some versions of CVC4 seem to generate array constants with "as const" notation instead of the __array_store_all__ diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 527330376..8bfc9bc4b 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -69,14 +69,14 @@ trait SMTLIBTarget { // Corresponds to a raw array value, which is coerced to a Leon expr depending on target type (set/array) // Should NEVER escape past SMT-world case class RawArrayValue(keyTpe: TypeTree, elems: Map[Expr, Expr], default: Expr) extends Expr { - def getType = RawArrayType(keyTpe, default.getType) + val getType = RawArrayType(keyTpe, default.getType) } def fromRawArray(r: RawArrayValue, tpe: TypeTree): Expr = tpe match { case SetType(base) => assert(r.default == BooleanLiteral(false) && r.keyTpe == base) - FiniteSet(r.elems.keySet).setType(tpe) + finiteSet(r.elems.keySet, base) case RawArrayType(from, to) => r @@ -551,20 +551,20 @@ trait SMTLIBTarget { val rargs = args.zip(tt.bases).map(fromSMT) Tuple(rargs) - case at: ArrayType => + case ArrayType(baseType) => val IntLiteral(size) = fromSMT(args(0), Int32Type) - val RawArrayValue(_, elems, default) = fromSMT(args(1), RawArrayType(Int32Type, at.base)) + val RawArrayValue(_, elems, default) = fromSMT(args(1), RawArrayType(Int32Type, baseType)) if(size > 10) { val definedElements = elems.collect{ case (IntLiteral(i), value) => (i, value) }.toMap - FiniteArray(definedElements, Some(default), IntLiteral(size)).setType(at) + finiteArray(definedElements, Some(default, IntLiteral(size)), baseType) } else { val entries = for (i <- 0 to size-1) yield elems.getOrElse(IntLiteral(i), default) - FiniteArray(entries).setType(at) + finiteArray(entries, None, baseType) } case t => diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index cb8954e29..6eca49c3a 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -495,7 +495,7 @@ trait AbstractZ3Solver } else { // FIXME TODO pleeeeeeeease make this cleaner. Ie. decide what set of // variable has to remain in a map etc. - variables.leonToZ3.filter(p => p._1.isInstanceOf[Variable]).map(p => (p._1.asInstanceOf[Variable].id -> p._2)) + variables.leonToZ3.collect{ case (Variable(id), p2) => id -> p2 } } def rec(ex: Expr): Z3AST = ex match { @@ -542,7 +542,7 @@ trait AbstractZ3Solver rb } - case Waypoint(_, e) => rec(e) + case Waypoint(_, e, _) => rec(e) case e @ Error(tpe, _) => { val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) // Might introduce dupplicates (e), but no worries here @@ -800,7 +800,7 @@ trait AbstractZ3Solver (index -> rec(v)) } - FiniteArray(valuesMap, Some(elseValue), IntLiteral(length)).setType(at) + finiteArray(valuesMap, Some(elseValue, IntLiteral(length)), dt) } case LeonType(tpe @ MapType(kt, vt)) => @@ -812,7 +812,7 @@ trait AbstractZ3Solver (rec(k), rec(arg)) } - FiniteMap(values).setType(tpe) + finiteMap(values, kt, vt) } case LeonType(tpe @ FunctionType(fts, tt)) => @@ -829,7 +829,7 @@ trait AbstractZ3Solver case None => throw new CantTranslateException(t) case Some(set) => val elems = set.map(e => rec(e)) - FiniteSet(elems).setType(tpe) + finiteSet(elems, dt) } case LeonType(UnitType) => diff --git a/src/main/scala/leon/synthesis/Witnesses.scala b/src/main/scala/leon/synthesis/Witnesses.scala index cfb8e00d2..1c454f09b 100644 --- a/src/main/scala/leon/synthesis/Witnesses.scala +++ b/src/main/scala/leon/synthesis/Witnesses.scala @@ -10,7 +10,7 @@ import Trees.Expr object Witnesses { class Witness extends Expr { - def getType = BooleanType + val getType = BooleanType } case class Guide(e : Expr) extends Witness with UnaryExtractable { @@ -21,4 +21,4 @@ object Witnesses { def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] = Some((args, Terminating(tfd, _))) } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/synthesis/rules/CegisLike.scala b/src/main/scala/leon/synthesis/rules/CegisLike.scala index fefec3462..9c9927f06 100644 --- a/src/main/scala/leon/synthesis/rules/CegisLike.scala +++ b/src/main/scala/leon/synthesis/rules/CegisLike.scala @@ -337,7 +337,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { tester = { (ins: Seq[Expr], bValues: Set[Identifier]) => - val bsValue = FiniteArray(bsOrdered.map(b => BooleanLiteral(bValues(b)))).setType(ArrayType(BooleanType)) + val bsValue = finiteArray(bsOrdered.map(b => BooleanLiteral(bValues(b))), None, BooleanType) val args = ins :+ bsValue val fi = FunctionInvocation(phiFd.typed, args) @@ -647,7 +647,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val solver = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(exSolverTo) val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable)) - val fixedBs = FiniteArray(bsOrdered.map(_.toVariable)).setType(ArrayType(BooleanType)) + val fixedBs = finiteArray(bsOrdered.map(_.toVariable), None, BooleanType) val cnstrFixed = replaceFromIDs(Map(bArrayId -> fixedBs), cnstr) val toFind = and(p.pc, cnstrFixed) @@ -699,7 +699,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val solver = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(cexSolverTo) val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable)) - val fixedBs = FiniteArray(bsOrdered.map(b => BooleanLiteral(bs(b)))).setType(ArrayType(BooleanType)) + val fixedBs = finiteArray(bsOrdered.map(b => BooleanLiteral(bs(b))), None, BooleanType) val cnstrFixed = replaceFromIDs(Map(bArrayId -> fixedBs), cnstr) solver.assertCnstr(p.pc) diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index ac3f47b17..8a0a04832 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -15,6 +15,8 @@ import purescala.DefOps._ import purescala.TypeTreeOps._ import purescala.Extractors._ import purescala.ScalaPrinter +import purescala.Constructors.finiteSet + import scala.language.implicitConversions import scala.collection.mutable.{HashMap => MutableMap} @@ -122,7 +124,7 @@ object ExpressionGrammars { case st @ SetType(base) => List( - Generator(List(base), { case elems => FiniteSet(elems.toSet).setType(st) }), + Generator(List(base), { case elems => finiteSet(elems.toSet, base) }), Generator(List(st, st), { case Seq(a, b) => SetUnion(a, b) }), Generator(List(st, st), { case Seq(a, b) => SetIntersection(a, b) }), Generator(List(st, st), { case Seq(a, b) => SetDifference(a, b) }) @@ -175,8 +177,8 @@ object ExpressionGrammars { case st @ SetType(base) => List( - Generator(List(base), { case elems => FiniteSet(elems.toSet).setType(st) }), - Generator(List(base, base), { case elems => FiniteSet(elems.toSet).setType(st) }) + Generator(List(base), { case elems => finiteSet(elems.toSet, base) }), + Generator(List(base, base), { case elems => finiteSet(elems.toSet, base) }) ) case UnitType => diff --git a/src/main/scala/leon/xlang/EpsilonElimination.scala b/src/main/scala/leon/xlang/EpsilonElimination.scala index e77030ab0..7e49ccd27 100644 --- a/src/main/scala/leon/xlang/EpsilonElimination.scala +++ b/src/main/scala/leon/xlang/EpsilonElimination.scala @@ -21,10 +21,10 @@ object EpsilonElimination extends TransformationPhase { val allFuns = pgm.definedFunctions allFuns.foreach(fd => fd.body.map(body => { val newBody = postMap{ - case eps@Epsilon(pred) => + case eps@Epsilon(pred, tpe) => val freshName = FreshIdentifier("epsilon") val newFunDef = new FunDef(freshName, Nil, eps.getType, Seq(), DefType.MethodDef) - val epsilonVar = EpsilonVariable(eps.getPos) + val epsilonVar = EpsilonVariable(eps.getPos, tpe) val resId = FreshIdentifier("res").setType(eps.getType) val postcondition = replace(Map(epsilonVar -> Variable(resId)), pred) newFunDef.postcondition = Some((resId, postcondition)) diff --git a/src/main/scala/leon/xlang/TreeOps.scala b/src/main/scala/leon/xlang/TreeOps.scala index b3de6e91a..d8d283b88 100644 --- a/src/main/scala/leon/xlang/TreeOps.scala +++ b/src/main/scala/leon/xlang/TreeOps.scala @@ -23,7 +23,7 @@ object TreeOps { case LetVar(_, _, _) => true case LetDef(_, _) => true case ArrayUpdate(_, _, _) => true - case Epsilon(_) => true + case Epsilon(_, _) => true case _ => false }}(expr) } diff --git a/src/main/scala/leon/xlang/Trees.scala b/src/main/scala/leon/xlang/Trees.scala index 11883a7f8..ce2c32520 100644 --- a/src/main/scala/leon/xlang/Trees.scala +++ b/src/main/scala/leon/xlang/Trees.scala @@ -31,7 +31,7 @@ object Trees { |}""" } - def getType = last.getType + val getType = last.getType } case class Assignment(varId: Identifier, expr: Expr) extends Expr with UnaryExtractable with PrettyPrintable { @@ -72,26 +72,30 @@ object Trees { } } - case class Epsilon(pred: Expr) extends Expr with UnaryExtractable with PrettyPrintable with MutableTyped { + case class Epsilon(pred: Expr, tpe: TypeTree) extends Expr with UnaryExtractable with PrettyPrintable { def extract: Option[(Expr, (Expr)=>Expr)] = { - Some((pred, (expr: Expr) => Epsilon(expr).setType(this.getType).setPos(this))) + Some((pred, (expr: Expr) => Epsilon(expr, this.getType).setPos(this))) } def printWith(implicit pctx: PrinterContext) { p"epsilon(x${getPos.line}_${getPos.col}. $pred)" } + + val getType = tpe } - case class EpsilonVariable(pos: Position) extends Expr with Terminal with PrettyPrintable with MutableTyped { + case class EpsilonVariable(pos: Position, tpe: TypeTree) extends Expr with Terminal with PrettyPrintable { def printWith(implicit pctx: PrinterContext) { p"x${pos.line}_${pos.col}" } + + val getType = tpe } //same as let, buf for mutable variable declaration case class LetVar(binder: Identifier, value: Expr, body: Expr) extends Expr with BinaryExtractable with PrettyPrintable { - def getType = body.getType + val getType = body.getType def extract: Option[(Expr, Expr, (Expr, Expr)=>Expr)] = { val LetVar(binders, expr, body) = this @@ -106,14 +110,16 @@ object Trees { } } - case class Waypoint(i: Int, expr: Expr) extends Expr with UnaryExtractable with PrettyPrintable with MutableTyped { + case class Waypoint(i: Int, expr: Expr, tpe: TypeTree) extends Expr with UnaryExtractable with PrettyPrintable{ def extract: Option[(Expr, (Expr)=>Expr)] = { - Some((expr, (e: Expr) => Waypoint(i, e))) + Some((expr, (e: Expr) => Waypoint(i, e, tpe))) } def printWith(implicit pctx: PrinterContext) { p"waypoint_$i($expr)" } + + val getType = tpe } case class ArrayUpdate(array: Expr, index: Expr, newValue: Expr) extends Expr with NAryExtractable with PrettyPrintable { diff --git a/src/test/resources/regression/verification/purescala/valid/LiteralMaps.scala b/src/test/resources/regression/verification/purescala/valid/LiteralMaps.scala index ed194e7a6..d6c07c261 100644 --- a/src/test/resources/regression/verification/purescala/valid/LiteralMaps.scala +++ b/src/test/resources/regression/verification/purescala/valid/LiteralMaps.scala @@ -10,7 +10,7 @@ object LiteralMaps { } def test3(): Map[Int, Int] = { - Map() + Map[Int, Int]() } def test4(): Map[Int, Int] = { diff --git a/src/test/scala/leon/test/evaluators/DefaultEvaluatorTests.scala b/src/test/scala/leon/test/evaluators/DefaultEvaluatorTests.scala index fd0c09a2a..b7234f0d7 100644 --- a/src/test/scala/leon/test/evaluators/DefaultEvaluatorTests.scala +++ b/src/test/scala/leon/test/evaluators/DefaultEvaluatorTests.scala @@ -13,7 +13,7 @@ import leon.purescala.Definitions._ import leon.purescala.Trees._ import leon.purescala.DefOps._ import leon.purescala.TypeTrees._ - +import leon.purescala.Constructors._ class DefaultEvaluatorTests extends leon.test.LeonTestSuite { private implicit lazy val leonContext: LeonContext = createLeonContext() @@ -145,24 +145,24 @@ class DefaultEvaluatorTests extends leon.test.LeonTestSuite { test("eval literal array ops") { expectSuccessful( - defaultEvaluator.eval(FiniteArray(Map(), Some(IntLiteral(12)), IntLiteral(7)).setType(ArrayType(Int32Type))), - FiniteArray(Map(), Some(IntLiteral(12)), IntLiteral(7))) + defaultEvaluator.eval(finiteArray(Map[Int,Expr](), Some(IntLiteral(12), IntLiteral(7)), Int32Type)), + finiteArray(Map[Int,Expr](), Some(IntLiteral(12), IntLiteral(7)), Int32Type)) expectSuccessful( defaultEvaluator.eval( - ArrayLength(FiniteArray(Map(), Some(IntLiteral(12)), IntLiteral(7)).setType(ArrayType(Int32Type)))), + ArrayLength(finiteArray(Map[Int,Expr](), Some(IntLiteral(12), IntLiteral(7)), Int32Type))), IntLiteral(7)) expectSuccessful( defaultEvaluator.eval(ArraySelect( - FiniteArray(Seq(IntLiteral(2), IntLiteral(4), IntLiteral(7))), + finiteArray(Seq(IntLiteral(2), IntLiteral(4), IntLiteral(7))), IntLiteral(1))), IntLiteral(4)) expectSuccessful( defaultEvaluator.eval( ArrayUpdated( - FiniteArray(Seq(IntLiteral(2), IntLiteral(4), IntLiteral(7))), + finiteArray(Seq(IntLiteral(2), IntLiteral(4), IntLiteral(7))), IntLiteral(1), IntLiteral(42))), - FiniteArray(Seq(IntLiteral(2), IntLiteral(42), IntLiteral(7)))) + finiteArray(Seq(IntLiteral(2), IntLiteral(42), IntLiteral(7)))) } test("eval variable length of array") { @@ -170,8 +170,7 @@ class DefaultEvaluatorTests extends leon.test.LeonTestSuite { expectSuccessful( defaultEvaluator.eval( ArrayLength( - FiniteArray(Map(), Some(IntLiteral(12)), Variable(id)) - .setType(ArrayType(Int32Type))), + finiteArray(Map[Int, Expr](), Some(IntLiteral(12), Variable(id)), Int32Type)), Map(id -> IntLiteral(27))), IntLiteral(27)) } @@ -180,9 +179,9 @@ class DefaultEvaluatorTests extends leon.test.LeonTestSuite { val id = FreshIdentifier("id").setType(Int32Type) expectSuccessful( defaultEvaluator.eval( - FiniteArray(Map(), Some(Variable(id)), IntLiteral(7)).setType(ArrayType(Int32Type)), + finiteArray(Map[Int, Expr](), Some(Variable(id), IntLiteral(7)), Int32Type), Map(id -> IntLiteral(27))), - FiniteArray(Map(), Some(IntLiteral(27)), IntLiteral(7))) + finiteArray(Map[Int, Expr](), Some(IntLiteral(27), IntLiteral(7)), Int32Type)) } } diff --git a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala index 5a0599b41..22cbcf472 100644 --- a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala +++ b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala @@ -13,6 +13,8 @@ import leon.purescala.Definitions._ import leon.purescala.Trees._ import leon.purescala.DefOps._ import leon.purescala.TypeTrees._ +import leon.purescala.Extractors._ +import leon.purescala.Constructors._ class EvaluatorsTests extends leon.test.LeonTestSuite { private implicit lazy val leonContext = testContext @@ -322,9 +324,9 @@ class EvaluatorsTests extends leon.test.LeonTestSuite { val nil = mkCaseClass("Nil") val cons12 = mkCaseClass("Cons", IL(1), mkCaseClass("Cons", IL(2), mkCaseClass("Nil"))) - val semp = FiniteSet(Set()).setType(SetType(Int32Type)) - val s123 = FiniteSet(Set(IL(1), IL(2), IL(3))).setType(SetType(Int32Type)) - val s246 = FiniteSet(Set(IL(2), IL(4), IL(6))).setType(SetType(Int32Type)) + val semp = EmptySet(Int32Type) + val s123 = NonemptySet(Set(IL(1), IL(2), IL(3))) + val s246 = NonemptySet(Set(IL(2), IL(4), IL(6))) for(e <- evaluators) { checkSetComp(e, mkCall("finite"), Set(1, 2, 3)) @@ -355,7 +357,7 @@ class EvaluatorsTests extends leon.test.LeonTestSuite { | case PCons(f,s,xs) => toMap(xs).updated(f, s) |} | - |def finite0() : Map[Int,Int] = Map() + |def finite0() : Map[Int,Int] = Map[Int, Int]() |def finite1() : Map[Int,Int] = Map(1 -> 2) |def finite2() : Map[Int,Int] = Map(2 -> 3, 1 -> 2) |def finite3() : Map[Int,Int] = finite1().updated(2, 3) @@ -394,8 +396,8 @@ class EvaluatorsTests extends leon.test.LeonTestSuite { implicit val progs = parseString(p) val evaluators = prepareEvaluators - val ba = FiniteArray(Seq(T, F)).setType(ArrayType(BooleanType)) - val ia = FiniteArray(Seq(IL(41), IL(42), IL(43))).setType(ArrayType(Int32Type)) + val ba = finiteArray(Seq(T, F)) + val ia = finiteArray(Seq(IL(41), IL(42), IL(43))) for(e <- evaluators) { checkComp(e, mkCall("boolArrayRead", ba, IL(0)), T) -- GitLab