diff --git a/src/main/java/leon/codegen/runtime/FiniteLambda.java b/src/main/java/leon/codegen/runtime/FiniteLambda.java new file mode 100644 index 0000000000000000000000000000000000000000..b9b8f0d72eea3542419f56ab225037add8146c8d --- /dev/null +++ b/src/main/java/leon/codegen/runtime/FiniteLambda.java @@ -0,0 +1,27 @@ +package leon.codegen.runtime; + +import java.util.HashMap; + +public final class FiniteLambda extends Lambda { + private final HashMap<Tuple, Object> _underlying = new HashMap<Tuple, Object>(); + private final Object dflt; + + public FiniteLambda(Object dflt) { + super(); + this.dflt = dflt; + } + + public void add(Tuple key, Object value) { + _underlying.put(key, value); + } + + @Override + public Object apply(Object[] args) { + Tuple tuple = new Tuple(args); + if (_underlying.containsKey(tuple)) { + return _underlying.get(tuple); + } else { + return dflt; + } + } +} diff --git a/src/main/java/leon/codegen/runtime/Lambda.java b/src/main/java/leon/codegen/runtime/Lambda.java new file mode 100644 index 0000000000000000000000000000000000000000..5a6d6d8ea868f9a792db89c832e293355f0c451f --- /dev/null +++ b/src/main/java/leon/codegen/runtime/Lambda.java @@ -0,0 +1,5 @@ +package leon.codegen.runtime; + +public abstract class Lambda { + public abstract Object apply(Object[] args); +} diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 7712ab0853b191cacc9ada6068c9db22c9e44d85..f8f6b67c597465204c43512e0c241c0cbed5f491 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -24,27 +24,36 @@ trait CodeGeneration { * vars is a mapping from local variables/ parameters to the offset of the respective JVM local register * isStatic signifies if the current method is static (a function, in Leon terms) */ - case class Locals(vars: Map[Identifier, Int], private val isStatic : Boolean ) { + case class Locals( + vars : Map[Identifier, Int], + args : Map[Identifier, Int], + closures : Map[Identifier, (String,String,String)], + private val isStatic : Boolean + ) { /** Fetches the offset of a local variable/ parameter from its identifier */ def varToLocal(v: Identifier): Option[Int] = vars.get(v) + def varToArg(v: Identifier): Option[Int] = args.get(v) + + def varToClosure(v: Identifier): Option[(String,String,String)] = closures.get(v) + /** Adds some extra variables to the mapping */ - def withVars(newVars: Map[Identifier, Int]) = { - Locals(vars ++ newVars, isStatic) - } + def withVars(newVars: Map[Identifier, Int]) = Locals(vars ++ newVars, args, closures, isStatic) /** Adds an extra variable to the mapping */ - def withVar(nv: (Identifier, Int)) = { - Locals(vars + nv, isStatic) - } - + def withVar(nv: (Identifier, Int)) = Locals(vars + nv, args, closures, isStatic) + + def withArgs(newArgs: Map[Identifier, Int]) = Locals(vars, args ++ newArgs, closures, isStatic) + + def withClosures(newClosures: Map[Identifier,(String,String,String)]) = Locals(vars, args, closures ++ newClosures, isStatic) + /** The index of the monitor object in this function */ def monitorIndex = if (isStatic) 0 else 1 } object NoLocals { /** Make a $Locals object without any local variables */ - def apply(isStatic : Boolean) = new Locals(Map(), isStatic) + def apply(isStatic : Boolean) = new Locals(Map(), Map(), Map(), isStatic) } private[codegen] val BoxedIntClass = "java/lang/Integer" @@ -56,6 +65,7 @@ trait CodeGeneration { private[codegen] val SetClass = "leon/codegen/runtime/Set" private[codegen] val MapClass = "leon/codegen/runtime/Map" private[codegen] val CaseClassClass = "leon/codegen/runtime/CaseClass" + private[codegen] val LambdaClass = "leon/codegen/runtime/Lambda" private[codegen] val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" private[codegen] val ImpossibleEvaluationClass = "leon/codegen/runtime/LeonCodeGenEvaluationException" private[codegen] val HashingClass = "leon/codegen/runtime/LeonCodeGenRuntimeHashing" @@ -90,6 +100,9 @@ trait CodeGeneration { case _ : MapType => "L" + MapClass + ";" + case _ : FunctionType => + "L" + LambdaClass + ";" + case ArrayType(base) => "[" + typeToJVM(base) @@ -173,13 +186,13 @@ trait CodeGeneration { ch << ALoad(paramsOffset-1) << InvokeVirtual(MonitorClass, "onInvoke", "()V") } - mkExpr(exprToCompile, ch)(Locals(newMapping, isStatic)) + mkExpr(exprToCompile, ch)(Locals(newMapping, Map.empty, Map.empty, isStatic)) funDef.returnType match { case Int32Type | BooleanType | UnitType => ch << IRETURN - case _ : ClassType | _ : TupleType | _ : SetType | _ : MapType | _ : ArrayType | _: TypeParameter => + case _ : ClassType | _ : TupleType | _ : SetType | _ : MapType | _ : ArrayType | _ : FunctionType | _ : TypeParameter => ch << ARETURN case other => @@ -192,12 +205,7 @@ trait CodeGeneration { private[codegen] def mkExpr(e: Expr, ch: CodeHandler, canDelegateToMkBranch: Boolean = true)(implicit locals: Locals) { e match { case Variable(id) => - val slot = slotFor(id) - val instr = id.getType match { - case Int32Type | CharType | BooleanType | UnitType => ILoad(slot) - case _ => ALoad(slot) - } - ch << instr + load(id, ch) case Assert(cond, oerr, body) => mkExpr(IfExpr(Not(cond), Error(oerr.getOrElse("Assertion failed @"+e.getPos)), body), ch) @@ -479,6 +487,97 @@ trait CodeGeneration { mkUnbox(tpe, ch) case _ => } + + case app @ Application(caller, args) => + mkExpr(caller, ch) + ch << Ldc(args.size) << NewArray("java/lang/Object") + for ((arg,i) <- args.zipWithIndex) { + ch << DUP << Ldc(i) + mkBoxedExpr(arg, ch) + ch << AASTORE + } + + ch << InvokeVirtual(LambdaClass, "apply", "([Ljava/lang/Object;)Ljava/lang/Object;") + mkUnbox(app.getType, ch) + + case l @ Lambda(args, body) => + val afName = "Leon$CodeGen$Lambda$" + CompilationUnit.nextLambdaId + + val cf = new ClassFile(afName, Some(LambdaClass)) + + cf.setFlags(( + CLASS_ACC_SUPER | + CLASS_ACC_PUBLIC | + CLASS_ACC_FINAL + ).asInstanceOf[U2]) + + val closures = purescala.TreeOps.variablesOf(l).toSeq.sortBy(_.uniqueName) + val closureTypes = closures.map(id => id.name -> typeToJVM(id.getType)) + + if (closureTypes.isEmpty) { + cf.addDefaultConstructor + } else { + for ((nme, jvmt) <- closureTypes) { + val fh = cf.addField(jvmt, nme) + fh.setFlags(( + FIELD_ACC_PUBLIC | + FIELD_ACC_FINAL + ).asInstanceOf[U2]) + } + + val cch = cf.addConstructor(closureTypes.map(_._2).toList).codeHandler + + cch << ALoad(0) + cch << InvokeSpecial(LambdaClass, constructorName, "()V") + + var c = 1 + for ((nme, jvmt) <- closureTypes) { + cch << ALoad(0) + cch << (jvmt match { + case "I" | "Z" => ILoad(c) + case _ => ALoad(c) + }) + cch << PutField(afName, nme, jvmt) + c += 1 + } + + cch << RETURN + cch.freeze + } + + locally { + val argTypes = args.map(arg => typeToJVM(arg.tpe)) + + val apm = cf.addMethod("Ljava/lang/Object;", "apply", "[Ljava/lang/Object;") + + apm.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val argMapping = args.map(_.id).zipWithIndex.toMap + val closureMapping = (closures zip closureTypes).map { case (id, (name, tpe)) => id -> (afName, name, tpe) }.toMap + + val newLocals = locals.withArgs(argMapping).withClosures(closureMapping) + + val apch = apm.codeHandler + + mkBoxedExpr(body, apch)(newLocals) + + apch << ARETURN + + apch.freeze + } + + loader.register(cf) + + val consSig = "(" + closures.map(id => typeToJVM(id.getType)).mkString("") + ")V" + + ch << New(afName) << DUP + for (a <- closures) { + mkExpr(Variable(a), ch) + } + ch << InvokeSpecial(afName, constructorName, consSig) // Arithmetic case Plus(l, r) => @@ -604,7 +703,7 @@ trait CodeGeneration { mkBranch(b, al, fl, ch, canDelegateToMkExpr = false) ch << Label(fl) << POP << Ldc(0) << Label(al) - case _ => throw CompilationException("Unsupported expr. : " + e) + case _ => throw CompilationException("Unsupported expr " + e + " : " + e.getClass) } } @@ -682,7 +781,10 @@ trait CodeGeneration { case mt : MapType => ch << CheckCast(MapClass) - case tp : TypeParameter => + case ft : FunctionType => + ch << CheckCast(LambdaClass) + + case tp : TypeParameter => case tp : ArrayType => ch << CheckCast(BoxedArrayClass) << InvokeVirtual(BoxedArrayClass, "arrayValue", "()%s".format(typeToJVM(tp))) @@ -720,7 +822,8 @@ trait CodeGeneration { mkBranch(c, elze, thenn, ch) case Variable(b) => - ch << ILoad(slotFor(b)) << IfEq(elze) << Goto(thenn) + load(b, ch) + ch << IfEq(elze) << Goto(thenn) case Equals(l,r) => mkExpr(l, ch) @@ -767,13 +870,27 @@ trait CodeGeneration { } } - private[codegen] def slotFor(id: Identifier)(implicit locals: Locals) : Int = { - locals.varToLocal(id).getOrElse { - throw CompilationException("Unknown variable: " + id) + private def load(id: Identifier, ch: CodeHandler)(implicit locals: Locals): Unit = { + locals.varToArg(id) match { + case Some(slot) => + ch << ALoad(1) << Ldc(slot) << AALOAD + mkUnbox(id.getType, ch) + case None => locals.varToClosure(id) match { + case Some((afName, nme, tpe)) => + ch << ALoad(0) << GetField(afName, nme, tpe) + case None => locals.varToLocal(id) match { + case Some(slot) => + val instr = id.getType match { + case Int32Type | CharType | BooleanType | UnitType => ILoad(slot) + case _ => ALoad(slot) + } + ch << instr + case None => throw CompilationException("Unknown variable : " + id) + } + } } } - /** * Compiles a lazy field $lzy, owned by the module/ class $owner. * diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 9be718da8915d89cd5dfd86cf4af27ba7c85b80e..da2c32b0ff8c71d684e67df3a38d466473cb8846 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -163,6 +163,18 @@ class CompilationUnit(val ctx: LeonContext, } m + case f @ FiniteLambda(dflt, els) => + val l = new leon.codegen.runtime.FiniteLambda(exprToJVM(dflt)) + for ((k,v) <- els) { + val jvmK = if (f.fixedType.from.size == 1) { + exprToJVM(Tuple(Seq(k))) + } else { + exprToJVM(k) + } + l.add(jvmK.asInstanceOf[leon.codegen.runtime.Tuple], exprToJVM(v)) + } + l + // Just slightly overkill... case _ => compileExpression(e, Seq()).evalToJVM(Seq(),monitor) @@ -270,13 +282,13 @@ class CompilationUnit(val ctx: LeonContext, val exprToCompile = purescala.TreeOps.matchToIfThenElse(e) - mkExpr(e, ch)(Locals(newMapping, true)) + mkExpr(e, ch)(Locals(newMapping, Map.empty, Map.empty, true)) e.getType match { case Int32Type | BooleanType => ch << IRETURN - case UnitType | _: TupleType | _: SetType | _: MapType | _: AbstractClassType | _: CaseClassType | _: ArrayType | _: TypeParameter => + case UnitType | _: TupleType | _: SetType | _: MapType | _: AbstractClassType | _: CaseClassType | _: ArrayType | _: FunctionType | _: TypeParameter => ch << ARETURN case other => @@ -416,9 +428,15 @@ class CompilationUnit(val ctx: LeonContext, object CompilationUnit { private var _nextExprId = 0 - private def nextExprId = synchronized { + private[codegen] def nextExprId = synchronized { _nextExprId += 1 _nextExprId } + + private var _nextLambdaId = 0 + private[codegen] def nextLambdaId = synchronized { + _nextLambdaId += 1 + _nextLambdaId + } } diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index e6e9fe451fda7bc6af671817a49fb62232cf500b..0173e39dd4a628e7eae5f1bd88e96469556112d7 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -78,6 +78,24 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { cs }) + case ft @ FunctionType(from, to) => + constructors.getOrElse(ft, { + val cs = for (size <- List(1, 2, 3, 5)) yield { + val subs = (1 to size).flatMap(_ => from :+ to).toList + Constructor[Expr, TypeTree](subs, ft, { s => + val args = from.map(tpe => FreshIdentifier("x", true).setType(tpe)) + val argsTuple = Tuple(args.map(_.toVariable)) + val grouped = s.grouped(from.size + 1).toSeq + val body = grouped.init.foldRight(grouped.last.last) { case (t, elze) => + IfExpr(Equals(argsTuple, Tuple(t.init)), t.last, elze) + } + Lambda(args.map(id => ValDef(id, id.getType)), body) + }, ft.toString + "@" + size) + } + constructors += ft -> cs + cs + }) + case tp: TypeParameter => constructors.getOrElse(tp, { val cs = for (i <- List(1, 2)) yield { diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 58a86ac6a3ee57f7cd477ed877e06489793c3712..716193e0b914051be3829d1eb88fb66f1542a7e0 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -79,6 +79,16 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int throw EvalError("No value for identifier " + id.name + " in mapping.") } + case Application(caller, args) => + e(caller) match { + case Lambda(params, body) => + val newArgs = args.map(e) + val mapping = (params.map(_.id) zip newArgs).toMap + e(body)(rctx.withVars(mapping), gctx) + case f => + throw EvalError("Cannot apply non-lambda function " + f) + } + case Tuple(ts) => val tsRec = ts.map(e) Tuple(tsRec) @@ -316,6 +326,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case i @ IntLiteral(_) => i case b @ BooleanLiteral(_) => b case u @ UnitLiteral() => u + case l @ Lambda(_, _) => l case f @ ArrayFill(length, default) => val rDefault = e(default) diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index a5c40f034a02411c8f52578d65f47470c28cf9d7..c0eb1cc9a8b55896dac73e1ed16894dfb17ab0b8 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -31,6 +31,10 @@ trait ASTExtractors { protected lazy val arraySym = classFromName("scala.Array") protected lazy val someClassSym = classFromName("scala.Some") protected lazy val function1TraitSym = classFromName("scala.Function1") + protected lazy val function2TraitSym = classFromName("scala.Function2") + protected lazy val function3TraitSym = classFromName("scala.Function3") + protected lazy val function4TraitSym = classFromName("scala.Function4") + protected lazy val function5TraitSym = classFromName("scala.Function5") protected lazy val byNameSym = classFromName("scala.<byname>") def isTuple2(sym : Symbol) : Boolean = sym == tuple2Sym @@ -65,10 +69,11 @@ trait ASTExtractors { sym == optionClassSym || sym == someClassSym } - def isFunction1TraitSym(sym : Symbol) : Boolean = { - sym == function1TraitSym - } - + def isFunction1(sym : Symbol) : Boolean = sym == function1TraitSym + def isFunction2(sym : Symbol) : Boolean = sym == function2TraitSym + def isFunction3(sym : Symbol) : Boolean = sym == function3TraitSym + def isFunction4(sym : Symbol) : Boolean = sym == function4TraitSym + def isFunction5(sym : Symbol) : Boolean = sym == function5TraitSym protected lazy val multisetTraitSym = try { classFromName("scala.collection.immutable.Multiset") @@ -273,7 +278,7 @@ trait ASTExtractors { case _ => None } } - + object ExLazyAccessorFunction { def unapply(dd: DefDef): Option[(Symbol, Type, Tree)] = dd match { case DefDef(_, name, tparams, vparamss, tpt, rhs) if( @@ -401,6 +406,23 @@ trait ASTExtractors { } } + object ExLambdaExpression { + def unapply(tree: Function) : Option[(Seq[ValDef], Tree)] = tree match { + case Function(vds, body) => Some((vds, body)) + case _ => None + } + } + + object ExForallExpression { + def unapply(tree: Apply) : Option[(List[(Tree, Symbol)], Tree)] = tree match { + case a @ Apply( + TypeApply(s @ ExSymbol("leon", "lang", "forall"), types), + Function(vds, predicateBody) :: Nil) => + Some(((types zip vds.map(_.symbol)).toList, predicateBody)) + case _ => None + } + } + object ExArrayUpdated { def unapply(tree: Apply): Option[(Tree,Tree,Tree)] = tree match { case Apply( diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 1427ff10a6579a754baf2849f2785ff3214ddbc4..e31ffba62a8b57c473f5155bf130f649ad873d66 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1242,6 +1242,38 @@ trait CodeExtraction extends ASTExtractors { Choose(vars, cBody) + case l @ ExLambdaExpression(args, body) => + val vds = args map { vd => + val aTpe = extractType(vd.tpt) + val newID = FreshIdentifier(vd.symbol.name.toString).setType(aTpe) + owners += (newID -> None) + LeonValDef(newID, aTpe) + } + + val newVars = (args zip vds).map { case (vd, lvd) => + vd.symbol -> (() => lvd.toVariable) + } + + val exBody = extractTree(body)(dctx.withNewVars(newVars)) + + Lambda(vds, exBody) + + case f @ ExForallExpression(args, body) => + val vds = args map { case (tpt, sym) => + val aTpe = extractType(tpt) + val newID = FreshIdentifier(sym.name.toString).setType(aTpe) + owners += (newID -> None) + LeonValDef(newID, aTpe) + } + + val newVars = (args zip vds) map { case ((_, sym), vd) => + sym -> (() => vd.toVariable) + } + + val exBody = extractTree(body)(dctx.withNewVars(newVars)) + + Forall(vds, exBody) + case ExCaseClassConstruction(tpt, args) => extractType(tpt) match { case cct: CaseClassType => @@ -1361,7 +1393,6 @@ trait CodeExtraction extends ASTExtractors { } } - case pm @ ExPatternMatching(sel, cses) => val rs = extractTree(sel) val rc = cses.map(extractMatchCase(_)) @@ -1437,6 +1468,9 @@ trait CodeExtraction extends ASTExtractors { MethodInvocation(rec, cd, fd.typed(newTps), args) + case (IsTyped(rec, ft: FunctionType), _, args) => + Application(rec, args) + case (IsTyped(rec, cct: CaseClassType), name, Nil) if cct.fields.exists(_.id.name == name) => val fieldID = cct.fields.find(_.id.name == name).get.id @@ -1616,6 +1650,21 @@ trait CodeExtraction extends ASTExtractors { case TypeRef(_, sym, btt :: Nil) if isArrayClassSym(sym) => ArrayType(extractType(btt)) + case TypeRef(_, sym, List(f1,to)) if isFunction1(sym) => + FunctionType(Seq(extractType(f1)), extractType(to)) + + case TypeRef(_, sym, List(f1,f2,to)) if isFunction2(sym) => + FunctionType(Seq(extractType(f1),extractType(f2)), extractType(to)) + + case TypeRef(_, sym, List(f1,f2,f3,to)) if isFunction3(sym) => + FunctionType(Seq(extractType(f1),extractType(f2),extractType(f3)), extractType(to)) + + case TypeRef(_, sym, List(f1,f2,f3,f4,to)) if isFunction4(sym) => + FunctionType(Seq(extractType(f1),extractType(f2),extractType(f3),extractType(f4)), extractType(to)) + + case TypeRef(_, sym, List(f1,f2,f3,f4,f5,to)) if isFunction5(sym) => + FunctionType(Seq(extractType(f1),extractType(f2),extractType(f3),extractType(f4),extractType(f5)), extractType(to)) + case TypeRef(_, sym, tps) if isByNameSym(sym) => extractType(tps.head) diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index db86711715825840bc02ab8e1fba5c6470135f67..bf5efce46c23be9074ac4c08c13ced0a7cc51e16 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -27,6 +27,8 @@ object Extractors { case ArrayLength(a) => Some((a, ArrayLength)) case ArrayClone(a) => Some((a, ArrayClone)) case ArrayMake(t) => Some((t, ArrayMake)) + case Lambda(args, body) => Some((body, Lambda(args, _))) + case Forall(args, body) => Some((body, Forall(args, _))) case (ue: UnaryExtractable) => ue.extract case _ => None } @@ -85,6 +87,7 @@ object Extractors { def unapply(expr: Expr) : Option[(Seq[Expr],(Seq[Expr])=>Expr)] = expr match { case fi @ FunctionInvocation(fd, args) => Some((args, (as => FunctionInvocation(fd, as).setPos(fi)))) case mi @ MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, (as => MethodInvocation(as.head, cd, tfd, as.tail).setPos(mi)))) + case fa @ Application(caller, args) => Some((caller +: args), (as => Application(as.head, as.tail).setPos(fa))) case CaseClass(cd, args) => Some((args, CaseClass(cd, _))) case And(args) => Some((args, And.apply)) case Or(args) => Some((args, Or.apply)) diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index bc9e2fd88f54c7627d40ac368fba25e08bfb1f94..422307e924cb623f3d0f88198a1c48d021f03f1f 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -315,6 +315,12 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe if (tfd.fd.isRealFunction) p"($args)" + case Application(caller, args) => + p"$caller($args)" + + case Lambda(args, body) => + optP { p"($args) => $body" } + case Plus(l,r) => optP { p"$l + $r" } case Minus(l,r) => optP { p"$l - $r" } case Times(l,r) => optP { p"$l * $r" } @@ -364,6 +370,12 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case FiniteMap(rs) => p"{$rs}" + case IfExpr(c, t, ie : IfExpr) => + optP { + p"""|if ($c) { + | $t + |} else $ie""" + } case IfExpr(c, t, e) => optP { diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index ee74c46b6053943697b74753f07cb3b1e2a49671..f1ce4fffda0a1071fa984e7d4bb0c18cb6deee48 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -350,6 +350,8 @@ object TreeOps { case Let(i,_,_) => subvs - i case Choose(is,_) => subvs -- is case MatchExpr(_, cses) => subvs -- (cses.map(_.pattern.binders).foldLeft(Set[Identifier]())((a, b) => a ++ b)) + case Lambda(args, body) => subvs -- args.map(_.id) + case Forall(args, body) => subvs -- args.map(_.id) case _ => subvs } })(expr) @@ -823,6 +825,10 @@ object TreeOps { case tp: TypeParameter => GenericValue(tp, 0) + case FunctionType(from, to) => + val args = from.map(tpe => ValDef(FreshIdentifier("x", true).setType(tpe), tpe)) + Lambda(args, simplestValue(to)) + case _ => throw new Exception("I can't choose simplest value for type " + tpe) } @@ -1989,6 +1995,157 @@ object TreeOps { } } + def functionAppsOf(expr: Expr): Set[Application] = { + collect[Application] { + case f: Application => Set(f) + case _ => Set() + }(expr) + } + + private val lambdaArgumentsCache = new TrieMap[TypeTree,Seq[Identifier]] + def lambdaArguments(tpe: TypeTree): Seq[Identifier] = lambdaArgumentsCache.get(tpe) match { + case Some(ids) => ids + case None => + val seq = tpe match { + case FunctionType(argTypes, returnType) => + argTypes.map(FreshIdentifier("x", true).setType(_)) ++ lambdaArguments(returnType) + case _ => Seq() + } + lambdaArgumentsCache(tpe) = seq + seq + } + + def functionApplication(expr: Expr, args: Seq[Expr]): Expr = expr.getType match { + case FunctionType(argTypes, returnType) => + val (currentArgs, nextArgs) = args.splitAt(argTypes.size) + val application = Application(expr, currentArgs) + functionApplication(application, nextArgs) + case tpe => + assert(args.isEmpty && !tpe.isInstanceOf[FunctionType]) + expr + } + + def createLambda(expr: Expr, args: Seq[Identifier]): Expr = expr.getType match { + case FunctionType(argTypes, returnType) => + val (currentArgs, nextArgs) = args.splitAt(argTypes.size) + val application = Application(expr, currentArgs.map(_.toVariable)) + Lambda(currentArgs.map(id => ValDef(id, id.getType)), createLambda(application, nextArgs)) + case tpe => + assert(args.isEmpty && !tpe.isInstanceOf[FunctionType]) + expr + } + + def lambdaTransform(expr: Expr) : Expr = { + + def hoistHOIte(expr: Expr) = { + def transform(expr: Expr): Option[Expr] = expr match { + case uop @ UnaryOperator(ife @ IfExpr(c, t, e), op) if ife.getType.isInstanceOf[FunctionType] => + Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType)) + case bop @ BinaryOperator(ife @ IfExpr(c, t, e), t2, op) if ife.getType.isInstanceOf[FunctionType] => + Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType)) + case bop @ BinaryOperator(t1, ife @ IfExpr(c, t, e), op) if ife.getType.isInstanceOf[FunctionType] => + Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType)) + case nop @ NAryOperator(ts, op) => { + val iteIndex = ts.indexWhere { + case ife @ IfExpr(_, _, _) if ife.getType.isInstanceOf[FunctionType] => true + case _ => false + } + if(iteIndex == -1) None else { + val (beforeIte, startIte) = ts.splitAt(iteIndex) + val afterIte = startIte.tail + val IfExpr(c, t, e) = startIte.head + Some(IfExpr(c, + op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType), + op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType) + ).setType(nop.getType)) + } + } + case _ => None + } + + fixpoint(postMap(transform))(expr) + } + + def expandHOLets(expr: Expr) : Expr = { + def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match { + case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s) + case l @ Let(i,e,b) => + if (i.getType.isInstanceOf[FunctionType]) rec(b, s + (i -> rec(e, s))) + else Let(i, rec(e,s), rec(b,s)) + case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1,s), rec(t2,s), rec(t3,s)).setType(i.getType) + case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut,s), cses.map(inCase(_, s))).setType(m.getType).setPos(m) + case n @ NAryOperator(args, recons) => { + var change = false + val rargs = args.map(a => { + val ra = rec(a, s) + if (ra != a) { + change = true + ra + } else { + a + } + }) + if (change) recons(rargs).setType(n.getType) + else n + } + case b @ BinaryOperator(t1,t2,recons) => { + val r1 = rec(t1, s) + val r2 = rec(t2, s) + if (r1 != t1 || r2 != t2) recons(r1, r2).setType(b.getType) + else b + } + case u @ UnaryOperator(t,recons) => { + val r = rec(t, s) + if (r != t) recons(r).setType(u.getType) + else u + } + case t: Terminal => t + case unhandled => scala.sys.error("Unhandled case in expandHOLets: " + unhandled) + } + + def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = cse match { + case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs, s)) + case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard, s), rec(rhs, s)) + } + + rec(expr, Map.empty) + } + + def extractToLambda(expr: Expr) = { + def extract(expr: Expr, build: Boolean) = + if (build) createLambda(expr, lambdaArguments(expr.getType)) else expr + + def rec(expr: Expr, build: Boolean): Expr = expr match { + case Application(caller, args) => + val newArgs = args.map(rec(_, true)) + val newCaller = rec(caller, false) + extract(Application(newCaller, newArgs), build) + case FunctionInvocation(fd, args) => + val newArgs = args.map(rec(_, true)) + extract(FunctionInvocation(fd, newArgs), build) + case l @ Lambda(args, body) => l + case NAryOperator(es, recons) => recons(es.map(rec(_, build))) + case BinaryOperator(e1, e2, recons) => recons(rec(e1, build), rec(e2, build)) + case UnaryOperator(e, recons) => recons(rec(e, build)) + case t: Terminal => t + } + + rec(expr, true) + } + + extractToLambda( + hoistHOIte( + expandHOLets( + simplifyLets( + matchToIfThenElse( + expr + ) + ) + ) + ) + ) + } + /** * Used to lift closures introduced by synthesis. Closures already define all * the necessary information as arguments, no need to close them. diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 60d9155354019590d197ba069141c6250a3770b6..af6d1ad9e84bf8253eda0c39b03ba2e4e477ad24 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -111,6 +111,66 @@ object Trees { } } + case class Application(caller: Expr, args: Seq[Expr]) extends Expr with FixedType { + assert(caller.getType.isInstanceOf[FunctionType]) + val fixedType = caller.getType.asInstanceOf[FunctionType].to + } + + case class Lambda(args: Seq[ValDef], body: Expr) extends Expr with FixedType { + val fixedType = FunctionType(args.map(_.tpe), body.getType) + } + + object FiniteLambda { + def unapply(lambda: Lambda): Option[(Expr, Seq[(Expr, Expr)])] = { + val args = lambda.args.map(_.toVariable) + lazy val argsTuple = if (lambda.args.size > 1) Tuple(args) else args.head + + def rec(body: Expr): Option[(Expr, Seq[(Expr, Expr)])] = body match { + case _ : IntLiteral | _ : BooleanLiteral | _ : GenericValue | _ : Tuple | + _ : CaseClass | _ : FiniteArray | _ : FiniteSet | _ : FiniteMap | _ : Lambda => + Some(body -> Seq.empty) + case IfExpr(Equals(tpArgs, key), expr, elze) if tpArgs == argsTuple => + rec(elze).map { case (dflt, mapping) => dflt -> ((key -> expr) +: mapping) } + case _ => None + } + + rec(lambda.body) + } + + def apply(dflt: Expr, els: Seq[(Expr, Expr)], tpe: FunctionType): Lambda = { + val args = tpe.from.zipWithIndex.map { case (tpe, idx) => + ValDef(FreshIdentifier(s"x${idx + 1}").setType(tpe), tpe) + } + + assert(els.isEmpty || !tpe.from.isEmpty, "Can't provide finite mapping for lambda without parameters") + + lazy val (tupleArgs, tupleKey) = if (tpe.from.size > 1) { + val tpArgs = Tuple(args.map(_.toVariable)) + val key = (x: Expr) => x + (tpArgs, key) + } else { // note that value is lazy, so if tpe.from.size == 0, foldRight will never access (tupleArgs, tupleKey) + val tpArgs = args.head.toVariable + val key = (x: Expr) => { + if (isSubtypeOf(x.getType, tpe.from.head)) x + else if (isSubtypeOf(x.getType, TupleType(tpe.from))) x.asInstanceOf[Tuple].exprs.head + else throw new RuntimeException("Can't determine key tuple state : " + x + " of " + tpe) + } + (tpArgs, key) + } + + val body = els.toSeq.foldRight(dflt) { case ((k, v), elze) => + IfExpr(Equals(tupleArgs, tupleKey(k)), v, elze) + } + + Lambda(args, body) + } + } + + case class Forall(args: Seq[ValDef], body: Expr) extends Expr with FixedType { + assert(body.getType == BooleanType) + val fixedType = BooleanType + } + case class This(ct: ClassType) extends Expr with FixedType with Terminal { val fixedType = ct } diff --git a/src/main/scala/leon/purescala/TypeTreeOps.scala b/src/main/scala/leon/purescala/TypeTreeOps.scala index 2f6ff8c36e1caec52bfa29f5c94e2435356b7f83..fd35346f36b2a3e16758bcf85509396feb09a400 100644 --- a/src/main/scala/leon/purescala/TypeTreeOps.scala +++ b/src/main/scala/leon/purescala/TypeTreeOps.scala @@ -200,6 +200,14 @@ object TypeTreeOps { val newXs = xs.map(id => freshId(id, tpeSub(id.getType))) Choose(newXs, rec(idsMap ++ (xs zip newXs))(pred)).copiedFrom(c) + case l @ Lambda(args, body) => + val newArgs = args.map { arg => + val tpe = tpeSub(arg.tpe) + ValDef(freshId(arg.id, tpe), tpe) + } + val mapping = args.map(_.id) zip newArgs.map(_.id) + Lambda(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(l) + case m @ MatchExpr(e, cases) => val newTpe = tpeSub(e.getType) diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index 24249ad2aa3e75ae46595bda5efdad8ef86ee616..16274e0864eac990bae0e2ad102b2c0c56a10bc8 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -77,7 +77,7 @@ object TypeTrees { case class SetType(base: TypeTree) extends TypeTree case class MultisetType(base: TypeTree) extends TypeTree case class MapType(from: TypeTree, to: TypeTree) extends TypeTree - case class FunctionType(from: List[TypeTree], to: TypeTree) extends TypeTree + case class FunctionType(from: Seq[TypeTree], to: TypeTree) extends TypeTree case class ArrayType(base: TypeTree) extends TypeTree sealed abstract class ClassType extends TypeTree { diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index 7827dfd693859ce555ebb4fd223b154652775f3e..eacfbb77e2ce6c0e63ef11100dd4a16500ec7525 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -19,7 +19,7 @@ object SolverFactory { } } - val definedSolvers = Set("fairz3", "enum", "smt", "smt-z3", "smt-cvc4"); + val definedSolvers = Set("fairz3", "unrollz3", "enum", "smt", "smt-z3", "smt-cvc4"); def getFromSettings[S](ctx: LeonContext, program: Program): SolverFactory[TimeoutSolver] = { import combinators._ @@ -30,6 +30,9 @@ object SolverFactory { case "fairz3" => SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver) + case "unrollz3" => + SolverFactory(() => new UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver) + case "enum" => SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver) diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index 8f4d216f7197cc99296ffadf42672f057ccf6653..60bff5440f400f709d361880e190cba028c09399 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -37,17 +37,20 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In } } - private var lastCheckResult: (Boolean, Option[Boolean], Option[Map[Identifier,Expr]]) = (false, None, None) - private var varsInVC = List[Set[Identifier]](Set()) - private var constraints = List[List[Expr]](Nil) - private var interrupted: Boolean = false + private var lastCheckResult : (Boolean, Option[Boolean], Option[Map[Identifier,Expr]]) = (false, None, None) + + private var varsInVC = List[Set[Identifier]](Set()) + private var frameExpressions = List[List[Expr]](Nil) + + private var interrupted : Boolean = false val reporter = context.reporter def name = "U:"+underlying.name - def free {} - + def free { + underlying.free + } val templateGenerator = new TemplateGenerator(new TemplateEncoder[Expr] { def encodeId(id: Identifier): Expr= { @@ -62,8 +65,11 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In (e: Expr) => replace(substMap, e) } - def not(e: Expr) = Not(e) - def implies(l: Expr, r: Expr) = Implies(l, r) + def mkNot(e: Expr) = Not(e) + def mkOr(es: Expr*) = Or(es) + def mkAnd(es: Expr*) = And(es) + def mkEquals(l: Expr, r: Expr) = Equals(l, r) + def mkImplies(l: Expr, r: Expr) = Implies(l, r) }) val unrollingBank = new UnrollingBank(reporter, templateGenerator) @@ -71,7 +77,10 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In val solver = underlying def assertCnstr(expression: Expr) { + frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail + val freeIds = variablesOf(expression) + varsInVC = (varsInVC.head ++ freeIds) :: varsInVC.tail val freeVars = freeIds.map(_.toVariable: Expr) @@ -82,24 +91,20 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In for (cl <- newClauses) { solver.assertCnstr(cl) } - - varsInVC = (varsInVC.head ++ freeIds) :: varsInVC.tail - constraints = (constraints.head ++ newClauses) :: constraints.tail } - def push() { unrollingBank.push() solver.push() varsInVC = Set[Identifier]() :: varsInVC - constraints = Nil :: constraints + frameExpressions = Nil :: frameExpressions } def pop(lvl: Int = 1) { unrollingBank.pop(lvl) solver.pop(lvl) varsInVC = varsInVC.drop(lvl) - constraints = constraints.drop(lvl) + frameExpressions = frameExpressions.drop(lvl) } def check: Option[Boolean] = { @@ -115,9 +120,10 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In def isValidModel(model: Map[Identifier, Expr], silenceErrors: Boolean = false): Boolean = { import EvaluationResults._ - val expr = And(constraints.flatten) + val expr = And(frameExpressions.flatten) + val allVars = varsInVC.flatten.toSet - val fullModel = variablesOf(expr).map(v => v -> model.getOrElse(v, simplestValue(v.getType))).toMap + val fullModel = allVars.map(v => v -> model.getOrElse(v, simplestValue(v.getType))).toMap evaluator.eval(expr, fullModel) match { case Successful(BooleanLiteral(true)) => @@ -170,7 +176,6 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: In case Some(true) => // SAT val model = solver.getModel solver.pop() - foundAnswer(Some(true), Some(model)) case Some(false) if !unrollingBank.canUnroll => diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala index 7f14022d6186fa9060143267f5f3325cef833a54..185d8646f098e54f12f649e13b203c11422d9ea4 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala @@ -54,11 +54,17 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), RawArrayType(k,v)) => RawArrayValue(k, Map(), fromSMT(elem, v)) + case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), ft @ FunctionType(from,to)) => + FiniteLambda(fromSMT(elem, to), Seq.empty, ft) + case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), RawArrayType(k,v)) => val RawArrayValue(_, elems, base) = fromSMT(arr, tpe) - RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) + case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), ft @ FunctionType(from,to)) => + val FiniteLambda(dflt, mapping) = fromSMT(arr, tpe) + FiniteLambda(dflt, mapping :+ (fromSMT(key, TupleType(from)) -> fromSMT(elem, to)), ft) + case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), SetType(base)) => FiniteSet(elems.map(fromSMT(_, base)).toSet).setType(tpe) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 5c0276201141d5bd54303564a51fae4e80e6b1c1..96b8f34dd31a64f64af3c21ee469de82720e5382 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -38,7 +38,7 @@ abstract class SMTLIBSolver(val context: LeonContext, override def free() = { interpreter.free() - out.close + reporter.ifDebug { _ => out.close } } } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 53264c59c1f1b43b8a556d2811df9b3e9a5f6fc2..536ad9e009ba6aee3eb622bf1133b1379d5421f6 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -10,7 +10,7 @@ import Extractors._ import TreeOps._ import TypeTrees._ import Definitions._ -import utils.Bijection +import utils.IncrementalBijection import _root_.smtlib.common._ import _root_.smtlib.printer.{RecursivePrinter => SMTPrinter} @@ -48,12 +48,12 @@ trait SMTLIBTarget { def id2sym(id: Identifier): SSymbol = SSymbol(id.name+"!"+id.globalId) // metadata for CC, and variables - val constructors = new Bijection[TypeTree, SSymbol]() - val selectors = new Bijection[(TypeTree, Int), SSymbol]() - val testers = new Bijection[TypeTree, SSymbol]() - val variables = new Bijection[Identifier, SSymbol]() - val sorts = new Bijection[TypeTree, Sort]() - val functions = new Bijection[TypedFunDef, SSymbol]() + val constructors = new IncrementalBijection[TypeTree, SSymbol]() + val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() + val testers = new IncrementalBijection[TypeTree, SSymbol]() + val variables = new IncrementalBijection[Identifier, SSymbol]() + val sorts = new IncrementalBijection[TypeTree, Sort]() + val functions = new IncrementalBijection[TypedFunDef, SSymbol]() def normalizeType(t: TypeTree): TypeTree = t match { case ct: ClassType if ct.parent.isDefined => ct.parent.get @@ -98,6 +98,9 @@ trait SMTLIBTarget { case MapType(from, to) => declareMapSort(from, to) + case FunctionType(from, to) => + Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(TupleType(from)), declareSort(to))) + case TypeParameter(id) => val s = id2sym(id) val cmd = DeclareSort(s, 0) @@ -424,6 +427,8 @@ trait SMTLIBTarget { /** * ===== Everything else ===== */ + case ap @ Application(caller, args) => + ArraysEx.Select(toSMT(caller), toSMT(Tuple(args))) case e @ UnaryOperator(u, _) => e match { @@ -596,9 +601,26 @@ trait SMTLIBTarget { } override def push(): Unit = { + constructors.push() + selectors.push() + testers.push() + variables.push() + sorts.push() + functions.push() + sendCommand(Push(1)) } + override def pop(lvl: Int = 1): Unit = { + assert(lvl == 1, "Current implementation only supports lvl = 1") + + constructors.pop() + selectors.pop() + testers.pop() + variables.pop() + sorts.pop() + functions.pop() + sendCommand(Pop(1)) } diff --git a/src/main/scala/leon/solvers/templates/FunctionTemplate.scala b/src/main/scala/leon/solvers/templates/FunctionTemplate.scala deleted file mode 100644 index 76fe5a4888cbcd097b1c06e3d2bb16afcbc6d091..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/templates/FunctionTemplate.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package solvers -package templates - -import utils._ -import purescala.Common._ -import purescala.Trees._ -import purescala.Extractors._ -import purescala.TreeOps._ -import purescala.TypeTrees._ -import purescala.Definitions._ - -import evaluators._ - -class FunctionTemplate[T]( - val tfd: TypedFunDef, - val encoder: TemplateEncoder[T], - activatingBool: Identifier, - condVars: Set[Identifier], - exprVars: Set[Identifier], - guardedExprs: Map[Identifier,Seq[Expr]], - isRealFunDef: Boolean) { - - val evalGroundApps = false - - val clauses: Seq[Expr] = { - (for((b,es) <- guardedExprs; e <- es) yield { - Implies(Variable(b), e) - }).toSeq - } - - val trActivatingBool = encoder.encodeId(activatingBool) - - val trFunDefArgs = tfd.params.map( ad => encoder.encodeId(ad.id)) - val zippedCondVars = condVars.map(id => (id -> encoder.encodeId(id))) - val zippedExprVars = exprVars.map(id => (id -> encoder.encodeId(id))) - val zippedFunDefArgs = tfd.params.map(_.id) zip trFunDefArgs - - val idToTrId: Map[Identifier, T] = { - Map(activatingBool -> trActivatingBool) ++ - zippedCondVars ++ - zippedExprVars ++ - zippedFunDefArgs - } - - val encodeExpr = encoder.encodeExpr(idToTrId) _ - - val trClauses: Seq[T] = clauses.map(encodeExpr) - - val trBlockers: Map[T, Set[TemplateCallInfo[T]]] = { - val idCall = TemplateCallInfo[T](tfd, trFunDefArgs) - - Map((for((b, es) <- guardedExprs) yield { - val allCalls = es.map(functionCallsOf).flatten.toSet - val calls = (for (c <- allCalls) yield { - TemplateCallInfo[T](c.tfd, c.args.map(encodeExpr)) - }) - idCall - - if(calls.isEmpty) { - None - } else { - Some(idToTrId(b) -> calls) - } - }).flatten.toSeq : _*) - } - - // We use a cache to create the same boolean variables. - var cache = Map[Seq[T], Map[T, T]]() - - def instantiate(aVar: T, args: Seq[T]): (Seq[T], Map[T, Set[TemplateCallInfo[T]]]) = { - assert(args.size == tfd.params.size) - - // The "isRealFunDef" part is to prevent evaluation of "fake" - // function templates, as generated from FairZ3Solver. - //if(evalGroundApps && isRealFunDef) { - // val ga = args.view.map(solver.asGround) - // if(ga.forall(_.isDefined)) { - // val leonArgs = ga.map(_.get).force - // val invocation = FunctionInvocation(tfd, leonArgs) - // solver.getEvaluator.eval(invocation) match { - // case EvaluationResults.Successful(result) => - // val z3Invocation = z3.mkApp(solver.functionDefToDecl(tfd), args: _*) - // val z3Value = solver.toZ3Formula(result).get - // val asZ3 = z3.mkEq(z3Invocation, z3Value) - // return (Seq(asZ3), Map.empty) - - // case _ => throw new Exception("Evaluation of ground term should have succeeded.") - // } - // } - //} - // ...end of ground evaluation part. - - val baseSubstMap = cache.get(args) match { - case Some(m) => m - case None => - val newMap: Map[T, T] = - (zippedCondVars ++ zippedExprVars).map{ case (id, idT) => idT -> encoder.encodeId(id) }.toMap ++ - (trFunDefArgs zip args) - - cache += args -> newMap - newMap - } - - val substMap : Map[T, T] = baseSubstMap + (trActivatingBool -> aVar) - - val substituter = encoder.substitute(substMap) - - val newClauses = trClauses.map(substituter) - - val newBlockers = trBlockers.map { case (b, funs) => - val bp = substituter(b) - - val newFuns = funs.map(fi => fi.copy(args = fi.args.map(substituter))) - - bp -> newFuns - } - - (newClauses, newBlockers) - } - - override def toString : String = { - "Template for def " + tfd.signature + "(" + tfd.params.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + tfd.returnType + " is :\n" + - " * Activating boolean : " + trActivatingBool + "\n" + - " * Control booleans : " + zippedCondVars.map(_._2.toString).mkString(", ") + "\n" + - " * Expression vars : " + zippedExprVars.map(_._2.toString).mkString(", ") + "\n" + - " * Clauses : " + "\n " +trClauses.mkString("\n ") + "\n" + - " * Block-map : " + trBlockers.toString - } -} diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala new file mode 100644 index 0000000000000000000000000000000000000000..d82a03e5562bba04afe7c6d53f7026318f42a4d0 --- /dev/null +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -0,0 +1,88 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package solvers +package templates + +import purescala.Common._ +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ + +class LambdaManager[T](encoder: TemplateEncoder[T]) { + private var byID : Map[T, LambdaTemplate[T]] = Map.empty + private var byType : Map[TypeTree, Set[(T, LambdaTemplate[T])]] = Map.empty.withDefaultValue(Set.empty) + private var quantified : Map[TypeTree, Set[T]] = Map.empty.withDefaultValue(Set.empty) + private var applications : Map[TypeTree, Set[(T, App[T])]] = Map.empty.withDefaultValue(Set.empty) + private var blockedApplications : Map[(T, App[T]), Set[T]] = Map.empty.withDefaultValue(Set.empty) + + private var globalBlocker : Option[T] = None + private var previousGlobals : Set[T] = Set.empty + + def quantify(args: Seq[(TypeTree, T)]): Unit = { + args.foreach(p => quantified += p._1 -> (quantified(p._1) + p._2)) + } + + def instantiate(apps: Map[T, Set[App[T]]], lambdas: Map[T, LambdaTemplate[T]]) : (Seq[T], Map[T, Set[TemplateInfo[T]]]) = { + var clauses : Seq[T] = Seq.empty + var blockers : Map[T, Set[TemplateInfo[T]]] = Map.empty.withDefaultValue(Set.empty) + + def mkBlocker(blockedApp: (T, App[T]), lambda: (T, LambdaTemplate[T])) : Unit = { + val (_, App(caller, tpe, args)) = blockedApp + val (idT, template) = lambda + + val unrollingBlocker = encoder.encodeId(FreshIdentifier("unrolled", true).setType(BooleanType)) + + val conj = encoder.mkAnd(encoder.mkEquals(idT, caller), template.start, unrollingBlocker) + + val templateBlocker = encoder.encodeId(FreshIdentifier("b", true).setType(BooleanType)) + val constraint = encoder.mkEquals(templateBlocker, conj) + + clauses :+= constraint + blockedApplications += (blockedApp -> (blockedApplications(blockedApp) + templateBlocker)) + blockers += (unrollingBlocker -> Set(TemplateAppInfo(template, templateBlocker, args))) + } + + for (lambda @ (idT, template) <- lambdas) { + byID += idT -> template + byType += template.tpe -> (byType(template.tpe) + (idT -> template)) + + for (guardedApp <- applications(template.tpe)) { + mkBlocker(guardedApp, lambda) + } + } + + for ((b, fas) <- apps; app @ App(caller, tpe, args) <- fas) { + if (byID contains caller) { + val (newClauses, newBlockers) = byID(caller).instantiate(b, args) + clauses ++= newClauses + newBlockers.foreach(p => blockers += p._1 -> (blockers(p._1) ++ p._2)) + } else { + for (lambda <- byType(tpe)) { + mkBlocker(b -> app, lambda) + } + + applications += tpe -> (applications(tpe) + (b -> app)) + } + } + + (clauses, blockers) + } + + def guards : Seq[T] = { + previousGlobals ++= globalBlocker + val globalGuard = encoder.encodeId(FreshIdentifier("lambda_phaser", true).setType(BooleanType)) + globalBlocker = Some(globalGuard) + + (for (((b, App(caller, tpe, _)), tbs) <- blockedApplications) yield { + val qbs = quantified(tpe).map(l => encoder.mkEquals(caller, l)) + val or = encoder.mkOr((tbs ++ qbs).toSeq : _*) + // TODO: get global blocker + val guard = encoder.mkAnd(globalGuard, encoder.mkNot(or)) + encoder.mkImplies(guard, encoder.mkNot(b)) + }).toSeq ++ previousGlobals.map(encoder.mkNot(_)) + } + + def assumption : T = globalBlocker.get +} + diff --git a/src/main/scala/leon/solvers/templates/TemplateCallInfo.scala b/src/main/scala/leon/solvers/templates/TemplateCallInfo.scala deleted file mode 100644 index ee5eb1b25ee3363cc9ceac06ae8ff78be65af2cf..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/templates/TemplateCallInfo.scala +++ /dev/null @@ -1,13 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package solvers -package templates - -import purescala.Definitions.TypedFunDef - -case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[T]) { - override def toString = { - tfd.signature+args.mkString("(", ", ", ")") - } -} diff --git a/src/main/scala/leon/solvers/templates/TemplateEncoder.scala b/src/main/scala/leon/solvers/templates/TemplateEncoder.scala index a7ba69fbb3a91c9c22ce2ef4d8fa07fa241b53d9..7460ae0cf321d10f6d43f9d42092003dd56f5cef 100644 --- a/src/main/scala/leon/solvers/templates/TemplateEncoder.scala +++ b/src/main/scala/leon/solvers/templates/TemplateEncoder.scala @@ -13,6 +13,9 @@ trait TemplateEncoder[T] { def substitute(map: Map[T, T]): T => T // Encodings needed for unrollingbank - def not(v: T): T - def implies(l: T, r: T): T + def mkNot(v: T): T + def mkOr(ts: T*): T + def mkAnd(ts: T*): T + def mkEquals(l: T, r: T): T + def mkImplies(l: T, r: T): T } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index cfe2c42218f49950d7a6fc24ed93dc0478106a38..114eafe8d4ceb12a50947605be3a1c16b934c671 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -18,6 +18,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { private var cache = Map[TypedFunDef, FunctionTemplate[T]]() private var cacheExpr = Map[Expr, FunctionTemplate[T]]() + private[templates] val lambdaManager = new LambdaManager[T](encoder) + def mkTemplate(body: Expr): FunctionTemplate[T] = { if (cacheExpr contains body) { return cacheExpr(body); @@ -41,13 +43,82 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { return cache(tfd) } - var condVars = Set[Identifier]() - var exprVars = Set[Identifier]() + // The precondition if it exists. + val prec : Option[Expr] = tfd.precondition.map(p => matchToIfThenElse(p)) + + val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b)) + + val invocation : Expr = FunctionInvocation(tfd, tfd.params.map(_.toVariable)) + + val invocationEqualsBody : Option[Expr] = newBody match { + case Some(body) if isRealFunDef => + val b : Expr = Equals(invocation, body) + + Some(if(prec.isDefined) { + Implies(prec.get, b) + } else { + b + }) + + case _ => + None + } + + val start : Identifier = FreshIdentifier("start", true).setType(BooleanType) + val pathVar : (Identifier, T) = start -> encoder.encodeId(start) + val arguments : Seq[(Identifier, T)] = tfd.params.map(vd => vd.id -> encoder.encodeId(vd.id)) + val substMap : Map[Identifier, T] = arguments.toMap + pathVar + + val (bodyConds, bodyExprs, bodyGuarded, bodyLambdas) = if (isRealFunDef) { + invocationEqualsBody.map(expr => mkClauses(start, expr, substMap)).getOrElse { + (Map[Identifier,T](), Map[Identifier,T](), Map[Identifier,Seq[Expr]](), Map[T,LambdaTemplate[T]]()) + } + } else { + mkClauses(start, newBody.get, substMap) + } + + // Now the postcondition. + val (condVars, exprVars, guardedExprs, lambdas) = tfd.postcondition match { + case Some((id, post)) => + val newPost : Expr = replace(Map(Variable(id) -> invocation), matchToIfThenElse(post)) + + val postHolds : Expr = + if(tfd.hasPrecondition) { + Implies(prec.get, newPost) + } else { + newPost + } + + val (postConds, postExprs, postGuarded, postLambdas) = mkClauses(start, postHolds, substMap) + val allGuarded = (bodyGuarded.keys ++ postGuarded.keys).map { k => + k -> (bodyGuarded.getOrElse(k, Seq.empty) ++ postGuarded.getOrElse(k, Seq.empty)) + }.toMap + + (bodyConds ++ postConds, bodyExprs ++ postExprs, allGuarded, bodyLambdas ++ postLambdas) + + case None => + (bodyConds, bodyExprs, bodyGuarded, bodyLambdas) + } + + val template = FunctionTemplate(tfd, encoder, lambdaManager, + pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, isRealFunDef) + cache += tfd -> template + template + } + + def mkClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): + (Map[Identifier,T], Map[Identifier,T], Map[Identifier, Seq[Expr]], Map[T, LambdaTemplate[T]]) = { + + var condVars = Map[Identifier, T]() + @inline def storeCond(id: Identifier) : Unit = condVars += id -> encoder.encodeId(id) + @inline def encodedCond(id: Identifier) : T = substMap.getOrElse(id, condVars(id)) + + var exprVars = Map[Identifier, T]() + @inline def storeExpr(id: Identifier) : Unit = exprVars += id -> encoder.encodeId(id) // Represents clauses of the form: // id => expr && ... && expr var guardedExprs = Map[Identifier, Seq[Expr]]() - def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = { assert(expr.getType == BooleanType) @@ -56,6 +127,9 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { guardedExprs += guardVar -> (expr +: prev) } + var lambdas = Map[T, LambdaTemplate[T]]() + @inline def storeLambda(idT: T, lambda: LambdaTemplate[T]) : Unit = lambdas += idT -> lambda + // Group elements that satisfy p toghether // List(a, a, a, b, c, a, a), with p = _ == a will produce: // List(List(a,a,a), List(b), List(c), List(a, a)) @@ -81,7 +155,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { def requireDecomposition(e: Expr) = { exists{ - case (_: FunctionInvocation) | (_: Assert) | (_: Ensuring) | (_: Choose) => true + case (_: FunctionInvocation) | (_: Assert) | (_: Ensuring) | (_: Choose) | (_: Application) => true case _ => false }(e) } @@ -97,7 +171,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { case l @ Let(i, e, b) => val newExpr : Identifier = FreshIdentifier("lt", true).setType(i.getType) - exprVars += newExpr + storeExpr(newExpr) val re = rec(pathVar, e) storeGuarded(pathVar, Equals(Variable(newExpr), re)) val rb = rec(pathVar, replace(Map(Variable(i) -> Variable(newExpr)), b)) @@ -105,13 +179,13 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { case l @ LetTuple(is, e, b) => val tuple : Identifier = FreshIdentifier("t", true).setType(TupleType(is.map(_.getType))) - exprVars += tuple + storeExpr(tuple) val re = rec(pathVar, e) storeGuarded(pathVar, Equals(Variable(tuple), re)) val mapping = for ((id, i) <- is.zipWithIndex) yield { val newId = FreshIdentifier("ti", true).setType(id.getType) - exprVars += newId + storeExpr(newId) storeGuarded(pathVar, Equals(Variable(newId), TupleSelect(Variable(tuple), i+1))) (Variable(id) -> Variable(newId)) @@ -139,10 +213,10 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { val newBool2 : Identifier = FreshIdentifier("b", true).setType(BooleanType) val newExpr : Identifier = FreshIdentifier("e", true).setType(i.getType) - condVars += newBool1 - condVars += newBool2 + storeCond(newBool1) + storeCond(newBool2) - exprVars += newExpr + storeExpr(newExpr) val crec = rec(pathVar, cond) val trec = rec(newBool1, thenn) @@ -161,7 +235,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { case c @ Choose(ids, cond) => val cid = FreshIdentifier("choose", true).setType(c.getType) - exprVars += cid + storeExpr(cid) val m: Map[Expr, Expr] = if (ids.size == 1) { Map(Variable(ids.head) -> Variable(cid)) @@ -172,70 +246,35 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { storeGuarded(pathVar, replace(m, cond)) Variable(cid) - case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, a))).setType(n.getType) - case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, a1), rec(pathVar, a2)).setType(b.getType) - case u @ UnaryOperator(a, r) => r(rec(pathVar, a)).setType(u.getType) - case t : Terminal => t - } - } + case l @ Lambda(args, body) => + val idArgs : Seq[Identifier] = args.map(_.id) + val trArgs : Seq[T] = idArgs.map(encoder.encodeId(_)) - // The precondition if it exists. - val prec : Option[Expr] = tfd.precondition.map(p => matchToIfThenElse(p)) + val lid = FreshIdentifier("lambda", true).setType(l.getType) + val clause = Equals(Application(Variable(lid), idArgs.map(Variable(_))), body) - val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b)) + val localSubst : Map[Identifier, T] = substMap ++ condVars ++ exprVars + val clauseSubst : Map[Identifier, T] = localSubst ++ (idArgs zip trArgs) + val (lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates) = mkClauses(pathVar, clause, clauseSubst) - val invocation : Expr = FunctionInvocation(tfd, tfd.params.map(_.toVariable)) + val ids: (Identifier, T) = lid -> encoder.encodeId(lid) + val dependencies: Set[T] = variablesOf(l).map(localSubst) + val template = LambdaTemplate(ids, encoder, lambdaManager, pathVar -> encodedCond(pathVar), idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates, localSubst, dependencies, l) + storeLambda(ids._2, template) - val invocationEqualsBody : Option[Expr] = newBody match { - case Some(body) if isRealFunDef => - val b : Expr = Equals(invocation, body) - - Some(if(prec.isDefined) { - Implies(prec.get, b) - } else { - b - }) - - case _ => - None - } + Variable(lid) - val activatingBool : Identifier = FreshIdentifier("start", true).setType(BooleanType) - - if (isRealFunDef) { - val finalPred : Option[Expr] = invocationEqualsBody.map(expr => rec(activatingBool, expr)) - finalPred.foreach(p => storeGuarded(activatingBool, p)) - } else { - val newFormula = rec(activatingBool, newBody.get) - storeGuarded(activatingBool, newFormula) + case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, a))).setType(n.getType) + case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, a1), rec(pathVar, a2)).setType(b.getType) + case u @ UnaryOperator(a, r) => r(rec(pathVar, a)).setType(u.getType) + case t : Terminal => t + } } - // Now the postcondition. - tfd.postcondition match { - case Some((id, post)) => - val newPost : Expr = replace(Map(Variable(id) -> invocation), matchToIfThenElse(post)) - - val postHolds : Expr = - if(tfd.hasPrecondition) { - Implies(prec.get, newPost) - } else { - newPost - } - - val finalPred2 : Expr = rec(activatingBool, postHolds) - storeGuarded(activatingBool, finalPred2) - case None => - - } + val p = rec(pathVar, expr) + storeGuarded(pathVar, p) - val template = new FunctionTemplate[T](tfd, - encoder, - activatingBool, - Set(condVars.toSeq : _*), - Set(exprVars.toSeq : _*), - Map(guardedExprs.toSeq : _*), - isRealFunDef) - cache += tfd -> template - template + (condVars, exprVars, guardedExprs, lambdas) } + } diff --git a/src/main/scala/leon/solvers/templates/TemplateInfo.scala b/src/main/scala/leon/solvers/templates/TemplateInfo.scala new file mode 100644 index 0000000000000000000000000000000000000000..80ed670fd0844b673202748227a75b6bfc9b2a30 --- /dev/null +++ b/src/main/scala/leon/solvers/templates/TemplateInfo.scala @@ -0,0 +1,22 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package solvers +package templates + +import purescala.Definitions.TypedFunDef +import purescala.TypeTrees.TypeTree + +sealed abstract class TemplateInfo[T] + +case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[T]) extends TemplateInfo[T] { + override def toString = { + tfd.signature+args.mkString("(", ", ", ")") + } +} + +case class TemplateAppInfo[T](template: LambdaTemplate[T], b: T, args: Seq[T]) extends TemplateInfo[T] { + override def toString = { + template.id + "|" + b + "|" + args.mkString("(", ",", ")") + } +} diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala new file mode 100644 index 0000000000000000000000000000000000000000..05e7a4c01708938d0066c01bf751ba0170dc1f9d --- /dev/null +++ b/src/main/scala/leon/solvers/templates/Templates.scala @@ -0,0 +1,337 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package solvers +package templates + +import purescala.Common._ +import purescala.Trees._ +import purescala.Extractors._ +import purescala.TreeOps._ +import purescala.TypeTrees._ +import purescala.Definitions._ + +import evaluators._ + +case class App[T](caller: T, tpe: TypeTree, args: Seq[T]) { + override def toString = { + "(" + caller + " : " + tpe + ")" + args.mkString("(", ",", ")") + } +} + +trait Template[T] { self => + val encoder : TemplateEncoder[T] + val lambdaManager : LambdaManager[T] + + val start : T + val args : Seq[T] + val condVars : Map[Identifier, T] + val exprVars : Map[Identifier, T] + val clauses : Seq[T] + val blockers : Map[T, Set[TemplateCallInfo[T]]] + val applications : Map[T, Set[App[T]]] + val lambdas : Map[T, LambdaTemplate[T]] + + private var substCache : Map[Seq[T],Map[T,T]] = Map.empty + private var lambdaCache : Map[(T, Map[T,T]), T] = Map.empty + + def instantiate(aVar: T, args: Seq[T]): (Seq[T], Map[T, Set[TemplateInfo[T]]]) = { + + val baseSubstMap : Map[T,T] = substCache.get(args) match { + case Some(subst) => subst + case None => + val subst = (condVars ++ exprVars).map { case (id, idT) => idT -> encoder.encodeId(id) } ++ + (this.args zip args) + substCache += args -> subst + subst + } + + val (lambdaSubstMap, lambdaClauses) = lambdas.foldLeft((Map.empty[T,T], Seq.empty[T])) { + case ((subst, clauses), (idT, lambda)) => + val closureMap = lambda.dependencies.map(idT => idT -> baseSubstMap(idT)).toMap + val key : (T, Map[T,T]) = idT -> closureMap + + val newIdT = encoder.encodeId(lambda.id) + val prevIdT = lambdaCache.get(key) match { + case Some(id) => + Some(id) + case None => + lambdaCache += key -> newIdT + None + } + + val newClause = prevIdT.map(id => encoder.mkEquals(newIdT, id)) + (subst + (idT -> newIdT), clauses ++ newClause) + } + + val substMap : Map[T,T] = baseSubstMap ++ lambdaSubstMap + (start -> aVar) + val substituter : T => T = encoder.substitute(substMap) + + val newClauses = clauses.map(substituter) + val newBlockers = blockers.map { case (b,fis) => + substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + } + + val newApplications = applications.map { case (b,fas) => + substituter(b) -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) + } + + val newLambdas = lambdas.map { case (idT, lambda) => + substituter(idT) -> lambda.substitute(substMap) + } + + val (appClauses, appBlockers) = lambdaManager.instantiate(newApplications, newLambdas) + + val allClauses = newClauses ++ appClauses ++ lambdaClauses + val allBlockers = (newBlockers.keys ++ appBlockers.keys).map { k => + k -> (newBlockers.getOrElse(k, Set.empty) ++ appBlockers.getOrElse(k, Set.empty)) + }.toMap + + (allClauses, allBlockers) + } + + override def toString : String = "Instantiated template" +} + +object Template { + + def encode[T]( + encoder: TemplateEncoder[T], + pathVar: (Identifier, T), + arguments: Seq[(Identifier, T)], + condVars: Map[Identifier, T], + exprVars: Map[Identifier, T], + guardedExprs: Map[Identifier, Seq[Expr]], + lambdas: Map[T, LambdaTemplate[T]], + substMap: Map[Identifier, T] = Map.empty[Identifier, T], + optCall: Option[TypedFunDef] = None, + optApp: Option[(T, TypeTree)] = None + ) : (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[T, Set[App[T]]], () => String) = { + + val idToTrId : Map[Identifier, T] = { + condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ + lambdas.map { case (idT, template) => template.id -> idT } + } + + val encodeExpr : Expr => T = encoder.encodeExpr(idToTrId) _ + + val clauses : Seq[T] = (for ((b,es) <- guardedExprs; e <- es) yield { + encodeExpr(Implies(Variable(b), e)) + }).toSeq + + val blockers : Map[Identifier, Set[TemplateCallInfo[T]]] = { + val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(_._2))) + + Map((for ((b,es) <- guardedExprs) yield { + val calls = es.flatMap(e => functionCallsOf(e).map { fi => + TemplateCallInfo[T](fi.tfd, fi.args.map(encodeExpr)) + }).toSet -- optIdCall + + if (calls.isEmpty) None else Some(b -> calls) + }).flatten.toSeq : _*) + } + + val encodedBlockers : Map[T, Set[TemplateCallInfo[T]]] = blockers.map(p => idToTrId(p._1) -> p._2) + + val applications : Map[Identifier, Set[App[T]]] = { + val optIdApp = optApp.map { case (idT, tpe) => App(idT, tpe, arguments.map(_._2)) } + + Map((for ((b,es) <- guardedExprs) yield { + val apps = es.flatMap(e => functionAppsOf(e).map { fa => + App[T](encodeExpr(fa.caller), fa.caller.getType, fa.args.map(encodeExpr)) + }).toSet -- optIdApp + + if (apps.isEmpty) None else Some(b -> apps) + }).flatten.toSeq : _*) + } + + val encodedApps : Map[T, Set[App[T]]] = applications.map(p => idToTrId(p._1) -> p._2) + + val stringRepr : () => String = () => { + " * Activating boolean : " + pathVar._1 + "\n" + + " * Control booleans : " + condVars.keys.mkString(", ") + "\n" + + " * Expression vars : " + exprVars.keys.mkString(", ") + "\n" + + " * Clauses : " + + (for ((b,es) <- guardedExprs; e <- es) yield (b + " ==> " + e)).mkString("\n ") + "\n" + + " * Invocation-blocks :" + (if (blockers.isEmpty) "\n" else { + "\n " + blockers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" + }) + + " * Application-blocks :" + (if (applications.isEmpty) "\n" else { + "\n " + applications.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" + }) + + " * Lambdas :\n" + lambdas.map { case (_, template) => + " +> " + template.toString.split("\n").mkString("\n ") + }.mkString("\n") + } + + (clauses, encodedBlockers, encodedApps, stringRepr) + } +} + +object FunctionTemplate { + + def apply[T]( + tfd: TypedFunDef, + encoder: TemplateEncoder[T], + lambdaManager: LambdaManager[T], + pathVar: (Identifier, T), + arguments: Seq[(Identifier, T)], + condVars: Map[Identifier, T], + exprVars: Map[Identifier, T], + guardedExprs: Map[Identifier, Seq[Expr]], + lambdas: Map[T, LambdaTemplate[T]], + isRealFunDef: Boolean + ) : FunctionTemplate[T] = { + + val (clauses, blockers, applications, templateString) = + Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, + optCall = Some(tfd)) + + val funString : () => String = () => { + "Template for def " + tfd.signature + + "(" + tfd.params.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + + tfd.returnType + " is :\n" + templateString() + } + + new FunctionTemplate[T]( + tfd, + encoder, + lambdaManager, + pathVar._2, + arguments.map(_._2), + condVars, + exprVars, + clauses, + blockers, + applications, + lambdas, + isRealFunDef, + funString + ) + } + +} + +class FunctionTemplate[T] private( + val tfd: TypedFunDef, + val encoder: TemplateEncoder[T], + val lambdaManager: LambdaManager[T], + val start: T, + val args: Seq[T], + val condVars: Map[Identifier, T], + val exprVars: Map[Identifier, T], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val applications: Map[T, Set[App[T]]], + val lambdas: Map[T, LambdaTemplate[T]], + isRealFunDef: Boolean, + stringRepr: () => String) extends Template[T] { + + private lazy val str : String = stringRepr() + override def toString : String = str +} + +object LambdaTemplate { + + def apply[T]( + ids: (Identifier, T), + encoder: TemplateEncoder[T], + lambdaManager: LambdaManager[T], + pathVar: (Identifier, T), + arguments: Seq[(Identifier, T)], + condVars: Map[Identifier, T], + exprVars: Map[Identifier, T], + guardedExprs: Map[Identifier, Seq[Expr]], + lambdas: Map[T, LambdaTemplate[T]], + baseSubstMap: Map[Identifier, T], + dependencies: Set[T], + lambda: Lambda + ) : LambdaTemplate[T] = { + + val id = ids._2 + val tpe = ids._1.getType + val (clauses, blockers, applications, templateString) = + Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, + substMap = baseSubstMap + ids, optApp = Some(id -> tpe)) + + val lambdaString : () => String = () => { + "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() + } + + new LambdaTemplate[T]( + ids._1, + encoder, + lambdaManager, + pathVar._2, + arguments.map(_._2), + condVars, + exprVars, + clauses, + blockers, + applications, + lambdas, + dependencies, + lambda, + lambdaString + ) + } +} + +class LambdaTemplate[T] private ( + val id: Identifier, + val encoder: TemplateEncoder[T], + val lambdaManager: LambdaManager[T], + val start: T, + val args: Seq[T], + val condVars: Map[Identifier, T], + val exprVars: Map[Identifier, T], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val applications: Map[T, Set[App[T]]], + val lambdas: Map[T, LambdaTemplate[T]], + val dependencies: Set[T], + val lambda: Lambda, + stringRepr: () => String) extends Template[T] { + + val tpe = id.getType + + def substitute(substMap: Map[T,T]): LambdaTemplate[T] = { + val substituter : T => T = encoder.substitute(substMap) + + val newStart = substituter(start) + val newClauses = clauses.map(substituter) + val newBlockers = blockers.map { case (b, fis) => + val bp = if (b == start) newStart else b + bp -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + } + + val newApplications = applications.map { case (b, fas) => + val bp = if (b == start) newStart else b + bp -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) + } + + val newLambdas = lambdas.map { case (idT, template) => idT -> template.substitute(substMap) } + + val newDependencies = dependencies.map(substituter) + + new LambdaTemplate[T]( + id, + encoder, + lambdaManager, + newStart, + args, + condVars, + exprVars, + newClauses, + newBlockers, + newApplications, + newLambdas, + newDependencies, + lambda, + stringRepr + ) + } + + private lazy val str : String = stringRepr() + override def toString : String = str +} diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala index b579b88ec645b145a150995d744b2bd4ab494e29..8d8b604de83a4f31951580761dbf25f52b44a157 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala @@ -18,17 +18,18 @@ class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[ implicit val debugSection = utils.DebugSectionSolver private val encoder = templateGenerator.encoder + private val lambdaManager = templateGenerator.lambdaManager // Keep which function invocation is guarded by which guard, // also specify the generation of the blocker. - private var blockersInfoStack = List[Map[T, (Int, Int, T, Set[TemplateCallInfo[T]])]](Map()) + private var blockersInfoStack = List[Map[T, (Int, Int, T, Set[TemplateInfo[T]])]](Map()) // Function instantiations have their own defblocker - private var defBlockers = Map[TemplateCallInfo[T], T]() + private var defBlockers = Map[TemplateInfo[T], T]() def blockersInfo = blockersInfoStack.head - def blockersInfo_= (v: Map[T, (Int, Int, T, Set[TemplateCallInfo[T]])]) = { + def blockersInfo_= (v: Map[T, (Int, Int, T, Set[TemplateInfo[T]])]) = { blockersInfoStack = v :: blockersInfoStack.tail } @@ -53,7 +54,7 @@ class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[ def canUnroll = !blockersInfo.isEmpty - def currentBlockers = blockersInfo.map(_._2._3) + def currentBlockers = blockersInfo.map(_._2._3).toSeq :+ lambdaManager.assumption def getBlockersToUnlock: Seq[T] = { if (!blockersInfo.isEmpty) { @@ -65,8 +66,8 @@ class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[ } } - private def registerBlocker(gen: Int, id: T, fis: Set[TemplateCallInfo[T]]) { - val notId = encoder.not(id) + private def registerBlocker(gen: Int, id: T, fis: Set[TemplateInfo[T]]) { + val notId = encoder.mkNot(id) blockersInfo.get(id) match { case Some((exGen, origGen, _, exFis)) => @@ -88,21 +89,25 @@ class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[ // define an activating boolean... val template = templateGenerator.mkTemplate(expr) - val trArgs = template.tfd.params.map(vd => bindings(Variable(vd.id))) + lambdaManager.quantify(template.tfd.params.collect { + case vd if vd.tpe.isInstanceOf[FunctionType] => + vd.tpe -> bindings(vd.toVariable) + }) + // ...now this template defines clauses that are all guarded // by that activating boolean. If that activating boolean is // undefined (or false) these clauses have no effect... val (newClauses, newBlocks) = - template.instantiate(template.trActivatingBool, trArgs) + template.instantiate(template.start, trArgs) for((i, fis) <- newBlocks) { registerBlocker(nextGeneration(0), i, fis) } - + // ...so we must force it to true! - template.trActivatingBool +: newClauses + template.start +: (newClauses ++ lambdaManager.guards) } def nextGeneration(gen: Int) = gen + 3 @@ -137,35 +142,51 @@ class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[ for (fi <- fis) { var newCls = Seq[T]() - val defBlocker = defBlockers.get(fi) match { - case Some(defBlocker) => - // we already have defBlocker => f(args) = body - defBlocker - case None => - // we need to define this defBlocker and link it to definition - val defBlocker = encoder.encodeId(FreshIdentifier("d").setType(BooleanType)) - defBlockers += fi -> defBlocker + fi match { + case TemplateCallInfo(tfd, args) => + val defBlocker = defBlockers.get(fi) match { + case Some(defBlocker) => + // we already have defBlocker => f(args) = body + defBlocker + + case None => + // we need to define this defBlocker and link it to definition + val defBlocker = encoder.encodeId(FreshIdentifier("d").setType(BooleanType)) + defBlockers += fi -> defBlocker + + val template = templateGenerator.mkTemplate(tfd) + reporter.debug(template) + + val (newExprs, newBlocks) = template.instantiate(defBlocker, args) + + for((i, fis2) <- newBlocks) { + registerBlocker(nextGeneration(gen), i, fis2) + } - val template = templateGenerator.mkTemplate(fi.tfd) + newCls ++= newExprs + defBlocker + } + + // We connect it to the defBlocker: blocker => defBlocker + if (defBlocker != id) { + newCls ++= List(encoder.mkImplies(id, defBlocker)) + } + + case TemplateAppInfo(template, b, args) => reporter.debug(template) - val (newExprs, newBlocks) = template.instantiate(defBlocker, fi.args) + val (newExprs, newBlocks) = template.instantiate(b, args) for((i, fis2) <- newBlocks) { registerBlocker(nextGeneration(gen), i, fis2) } newCls ++= newExprs - defBlocker - } - - // We connect it to the defBlocker: blocker => defBlocker - if (defBlocker != id) { - newCls ++= List(encoder.implies(id, defBlocker)) + newCls :+= id } reporter.debug("Unrolling behind "+fi+" ("+newCls.size+")") for (cl <- newCls) { - reporter.debug(" . "+cl) + reporter.debug(" . "+cl) } newClauses ++= newCls @@ -173,6 +194,8 @@ class UnrollingBank[T](reporter: Reporter, templateGenerator: TemplateGenerator[ } + newClauses ++= lambdaManager.guards + reporter.debug(s" - ${newClauses.size} new clauses") //context.reporter.ifDebug { debug => // debug(s" - new clauses:") diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 1529dc7f7d7b0c4ad14c87fb7ac41d284aa3a7a5..b9b614a126dea9ae2af2dc08e6441b5338d9109b 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -463,6 +463,14 @@ trait AbstractZ3Solver newTPSort } + case ft @ FunctionType(from, to) => + sorts.toZ3OrCompute(ft) { + val fromSort = typeToSort(TupleType(from)) + val toSort = typeToSort(to) + + z3.mkArraySort(fromSort, toSort) + } + case other => sorts.toZ3OrCompute(other) { reporter.warning(other.getPos, "Resorting to uninterpreted type for : " + other) @@ -579,6 +587,9 @@ trait AbstractZ3Solver case f @ FunctionInvocation(tfd, args) => z3.mkApp(functionDefToDecl(tfd), args.map(rec(_)): _*) + case fa @ Application(caller, args) => + z3.mkSelect(rec(caller), rec(Tuple(args))) + case SetEquals(s1, s2) => z3.mkEq(rec(s1), rec(s2)) case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) @@ -736,7 +747,7 @@ trait AbstractZ3Solver model.getArrayValue(t) match { case None => throw new CantTranslateException(t) case Some((map, elseZ3Value)) => - var values = map.toSeq.map { case (k, v) => (k, z3.getASTKind(v)) }.collect { + val values = map.toSeq.map { case (k, v) => (k, z3.getASTKind(v)) }.collect { case (k, Z3AppAST(cons, arg :: Nil)) if cons == mapRangeSomeConstructors(vt) => (rec(k), rec(arg)) } @@ -744,6 +755,15 @@ trait AbstractZ3Solver FiniteMap(values).setType(tpe) } + case LeonType(tpe @ FunctionType(fts, tt)) => + model.getArrayValue(t) match { + case None => throw new CantTranslateException(t) + case Some((map, elseZ3Value)) => + val leonElseValue = rec(elseZ3Value) + val leonMap = map.toSeq.map(p => rec(p._1) -> rec(p._2)) + FiniteLambda(leonElseValue, leonMap, tpe) + } + case LeonType(tpe @ SetType(dt)) => model.getSetValue(t) match { case None => throw new CantTranslateException(t) diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index bca62952fb346f45794cff863060d5d3fdaf1b79..0fd03c4403c0cd48498f4c83814b4df3a7b5b21e 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -149,8 +149,11 @@ class FairZ3Solver(val context : LeonContext, val program: Program) (c: Z3AST) => z3.substitute(c, fromArray, toArray) } - def not(e: Z3AST) = z3.mkNot(e) - def implies(l: Z3AST, r: Z3AST) = z3.mkImplies(l, r) + def mkNot(e: Z3AST) = z3.mkNot(e) + def mkOr(es: Z3AST*) = z3.mkOr(es : _*) + def mkAnd(es: Z3AST*) = z3.mkAnd(es : _*) + def mkEquals(l: Z3AST, r: Z3AST) = z3.mkEq(l, r) + def mkImplies(l: Z3AST, r: Z3AST) = z3.mkImplies(l, r) }) @@ -158,7 +161,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) val solver = z3.mkSolver - private var varsInVC = Set[Identifier]() + private var varsInVC = List[Set[Identifier]](Set()) private var frameExpressions = List[List[Expr]](Nil) @@ -167,12 +170,14 @@ class FairZ3Solver(val context : LeonContext, val program: Program) def push() { solver.push() unrollingBank.push() + varsInVC = Set[Identifier]() :: varsInVC frameExpressions = Nil :: frameExpressions } def pop(lvl: Int = 1) { solver.pop(lvl) unrollingBank.pop(lvl) + varsInVC = varsInVC.drop(lvl) frameExpressions = frameExpressions.drop(lvl) } @@ -191,7 +196,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) def assertCnstr(expression: Expr) { val freeVars = variablesOf(expression) - varsInVC ++= freeVars + varsInVC = (varsInVC.head ++ freeVars) :: varsInVC.tail // We make sure all free variables are registered as variables freeVars.foreach { v => @@ -260,6 +265,8 @@ class FairZ3Solver(val context : LeonContext, val program: Program) reporter.debug(" - Finished search with blocked literals") + lazy val allVars = varsInVC.flatten.toSet + res match { case None => reporter.ifDebug { debug => @@ -274,7 +281,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) val z3model = solver.getModel if (this.checkModels) { - val (isValid, model) = validateModel(z3model, entireFormula, varsInVC, silenceErrors = false) + val (isValid, model) = validateModel(z3model, entireFormula, allVars, silenceErrors = false) if (isValid) { foundAnswer(Some(true), model) @@ -284,7 +291,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) foundAnswer(None, model) } } else { - val model = modelToMap(z3model, varsInVC) + val model = modelToMap(z3model, allVars) //lazy val modelAsString = model.toList.map(p => p._1 + " -> " + p._2).mkString("\n") //reporter.debug("- Found a model:") @@ -357,7 +364,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) //reporter.debug("SAT WITHOUT Blockers") if (this.feelingLucky && !interrupted) { // we might have been lucky :D - val (wereWeLucky, cleanModel) = validateModel(solver.getModel, entireFormula, varsInVC, silenceErrors = true) + val (wereWeLucky, cleanModel) = validateModel(solver.getModel, entireFormula, allVars, silenceErrors = true) if(wereWeLucky) { foundAnswer(Some(true), cleanModel) diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala index 521f181de603e4df47c6d94898cf632c1af5a6e2..b903921a2af79fcd4e2a47d5b5514212dc0dfb3b 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala @@ -44,33 +44,17 @@ class UninterpretedZ3Solver(val context : LeonContext, val program: Program) solver.push } - def pop(lvl: Int = 1) { solver.pop(lvl) } private var freeVariables = Set[Identifier]() - private var containsFunCalls = false - def assertCnstr(expression: Expr) { freeVariables ++= variablesOf(expression) - containsFunCalls ||= containsFunctionCalls(expression) solver.assertCnstr(toZ3Formula(expression).getOrElse(scala.sys.error("Failed to compile to Z3: "+expression))) } - override def check: Option[Boolean] = { - solver.check match { - case Some(true) => - if (containsFunCalls) { - None - } else { - Some(true) - } - - case r => - r - } - } + override def check: Option[Boolean] = solver.check override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { freeVariables ++= assumptions.flatMap(variablesOf(_)) diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/leon/utils/Bijection.scala index 12233df8ac96aba1b9cb1151b3a5663615c83479..027f3fa4c5c2928b05aa5f8d1f0a7df70475dcdf 100644 --- a/src/main/scala/leon/utils/Bijection.scala +++ b/src/main/scala/leon/utils/Bijection.scala @@ -1,8 +1,8 @@ package leon.utils class Bijection[A, B] { - var a2b = Map[A, B]() - var b2a = Map[B, A]() + protected var a2b = Map[A, B]() + protected var b2a = Map[B, A]() def +=(a: A, b: B): Unit = { a2b += a -> b diff --git a/src/main/scala/leon/utils/IncrementalBijection.scala b/src/main/scala/leon/utils/IncrementalBijection.scala new file mode 100644 index 0000000000000000000000000000000000000000..ea4ee59065e4408b1a4de9962600357626e1c84c --- /dev/null +++ b/src/main/scala/leon/utils/IncrementalBijection.scala @@ -0,0 +1,48 @@ +package leon.utils + +class IncrementalBijection[A,B] extends Bijection[A,B] { + private var a2bStack = List[Map[A,B]]() + private var b2aStack = List[Map[B,A]]() + + override def clear() : Unit = { + super.clear() + a2bStack = Nil + b2aStack = Nil + } + + private def recursiveGet[T,U](stack: List[Map[T,U]], t: T): Option[U] = stack match { + case t2u :: xs => t2u.get(t) orElse recursiveGet(xs, t) + case Nil => None + } + + override def getA(b: B) = b2a.get(b) match { + case s @ Some(a) => s + case None => recursiveGet(b2aStack, b) + } + + override def getB(a: A) = a2b.get(a) match { + case s @ Some(b) => s + case None => recursiveGet(a2bStack, a) + } + + override def containsA(a: A) = getB(a).isDefined + override def containsB(b: B) = getA(b).isDefined + + override def aSet = a2b.keySet ++ a2bStack.flatMap(_.keySet) + override def bSet = b2a.keySet ++ b2aStack.flatMap(_.keySet) + + def push(): Unit = { + a2bStack = a2b :: a2bStack + b2aStack = b2a :: b2aStack + a2b = Map() + b2a = Map() + } + + def pop(): Unit = { + a2b = a2bStack.head + b2a = b2aStack.head + a2bStack = a2bStack.tail + b2aStack = b2aStack.tail + } + +} diff --git a/src/test/resources/regression/verification/purescala/invalid/HOInvocations.scala b/src/test/resources/regression/verification/purescala/invalid/HOInvocations.scala new file mode 100644 index 0000000000000000000000000000000000000000..3d2e16835a7347ed4140e9902ec330ae3d058ea9 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/HOInvocations.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object HOInvocations { + def switch(x: Int, f: (Int) => Int, g: (Int) => Int) = if(x > 0) f else g + + def failling_1(f: (Int) => Int) = { + switch(-10, (x: Int) => x + 1, f)(2) + } ensuring { res => res > 0 } + + def failling_2(x: Int, f: (Int) => Int, g: (Int) => Int) = { + require(x > 0) + switch(1, switch(x, f, g), g)(1) + } ensuring { res => res != f(1) } +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/invalid/Lists.scala b/src/test/resources/regression/verification/purescala/invalid/Lists.scala new file mode 100644 index 0000000000000000000000000000000000000000..3286fd638dcf3af111bb77343a44498a2c97378a --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/Lists.scala @@ -0,0 +1,32 @@ +import leon.lang._ + +object Lists4 { + abstract class List[T] + case class Cons[T](head: T, tail: List[T]) extends List[T] + case class Nil[T]() extends List[T] + + def forall[T](list: List[T], f: T => Boolean): Boolean = list match { + case Cons(head, tail) => f(head) && forall(tail, f) + case Nil() => true + } + + def positive(list: List[Int]): Boolean = list match { + case Cons(head, tail) => if (head < 0) false else positive(tail) + case Nil() => true + } + + def gt(i: Int): Int => Boolean = x => x > i + + def positive_lemma(list: List[Int]): Boolean = { + positive(list) == forall(list, gt(0)) + } + + def failling_1(list: List[Int]): Boolean = { + list match { + case Nil() => positive_lemma(list) + case Cons(head, tail) => positive_lemma(list) && failling_1(tail) + } + }.holds +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/invalid/PositiveMap.scala b/src/test/resources/regression/verification/purescala/invalid/PositiveMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..fb1452f478b7e71f3d9d23093553a05aa0d3af1b --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/PositiveMap.scala @@ -0,0 +1,25 @@ +import leon.lang._ + +object PositiveMap { + + abstract class List + case class Cons(head: Int, tail: List) extends List + case class Nil() extends List + + def positive(list: List): Boolean = list match { + case Cons(head, tail) => if (head < 0) false else positive(tail) + case Nil() => true + } + + def positiveMap_failling_1(f: (Int) => Int, list: List): List = { + list match { + case Cons(head, tail) => + val fh = f(head) + val nh = if (fh < -1) 0 else fh + Cons(nh, positiveMap_failling_1(f, tail)) + case Nil() => Nil() + } + } ensuring { res => positive(res) } +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/valid/Anonymous.scala b/src/test/resources/regression/verification/purescala/valid/Anonymous.scala new file mode 100644 index 0000000000000000000000000000000000000000..b8d3235ccf5e817d0202b0f8de3145d977e1989f --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Anonymous.scala @@ -0,0 +1,9 @@ +import leon.lang._ + +object Anonymous { + def test(x: Int) = { + require(x > 0) + val i = (a: Int) => a + 1 + i(x) + i(2) + } ensuring { res => res > 0 } +} diff --git a/src/test/resources/regression/verification/purescala/valid/Closures.scala b/src/test/resources/regression/verification/purescala/valid/Closures.scala new file mode 100644 index 0000000000000000000000000000000000000000..5e191cf814e4c66b2df0a3fa16fe19d78a1c343d --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Closures.scala @@ -0,0 +1,15 @@ +import leon.lang._ + +object Closures { + def addX(x: Int): Int => Int = { + (a: Int) => a + x + } + + def test(x: Int): Boolean = { + val add1 = addX(1) + val add2 = addX(2) + add1(add2(1)) == 4 + }.holds +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/valid/Closures2.scala b/src/test/resources/regression/verification/purescala/valid/Closures2.scala new file mode 100644 index 0000000000000000000000000000000000000000..b791965d1e06edd872d747a68fe084efed52a206 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Closures2.scala @@ -0,0 +1,35 @@ +import leon.lang._ + +object Closures2 { + def set(i: Int): Int => Boolean = x => x == i + + def union(s1: Int => Boolean, s2: Int => Boolean): Int => Boolean = x => s1(x) || s2(x) + + def intersection(s1: Int => Boolean, s2: Int => Boolean): Int => Boolean = x => s1(x) && s2(x) + + def diff(s1: Int => Boolean, s2: Int => Boolean): Int => Boolean = x => s1(x) && !s2(x) + + def set123(): Int => Boolean = union(set(1), union(set(2), set(3))) + + def test1(): Boolean = { + val s1 = set123() + val s2 = union(s1, set(4)) + s2(1) && s2(2) && s2(3) && s2(4) + }.holds + + def test2(): Boolean = { + val s1 = set123() + val s2 = intersection(s1, union(set(1), set(3))) + val s3 = diff(s1, s2) + s3(2) && !s3(1) && !s3(3) + }.holds + + def test3(): Boolean = { + val s1 = set123() + val s2 = set123() + val s3 = union(s1, s2) + s3(1) && s3(2) && s3(3) + }.holds +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/valid/HOInvocations.scala b/src/test/resources/regression/verification/purescala/valid/HOInvocations.scala new file mode 100644 index 0000000000000000000000000000000000000000..0f2fbda2ff6e892e200b475f06d8e1876bd6ed64 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/HOInvocations.scala @@ -0,0 +1,17 @@ +import leon.lang._ + +object HOInvocations { + def switch(x: Int, f: (Int) => Int, g: (Int) => Int) = if(x > 0) f else g + + def passing_1(f: (Int) => Int) = { + switch(10, (x: Int) => x + 1, f)(2) + } ensuring { res => res > 0 } + + def passing_2(x: Int, f: (Int) => Int, g: (Int) => Int) = { + require(x > 0) + switch(1, switch(x, f, g), g)(1) + } ensuring { res => res == f(1) } + +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/valid/Lists.scala b/src/test/resources/regression/verification/purescala/valid/Lists.scala new file mode 100644 index 0000000000000000000000000000000000000000..45c3adadc610f00fa2cf1494bacfe3a6ef6310a7 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Lists.scala @@ -0,0 +1,31 @@ +import leon.lang._ + +object Lists { + abstract class List + case class Cons(head: Int, tail: List) extends List + case class Nil() extends List + + def exists(list: List, f: Int => Boolean): Boolean = list match { + case Cons(head, tail) => f(head) || exists(tail, f) + case Nil() => false + } + + def forall(list: List, f: Int => Boolean): Boolean = list match { + case Cons(head, tail) => f(head) && forall(tail, f) + case Nil() => true + } + + def exists_lemma(list: List, f: Int => Boolean): Boolean = { + exists(list, f) == !forall(list, x => !f(x)) + } + + def exists_lemma_induct(list: List, f: Int => Boolean): Boolean = { + list match { + case Nil() => exists_lemma(list, f) + case Cons(head, tail) => exists_lemma(list, f) && exists_lemma_induct(tail, f) + } + }.holds + +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/valid/Lists2.scala b/src/test/resources/regression/verification/purescala/valid/Lists2.scala new file mode 100644 index 0000000000000000000000000000000000000000..da096d1d3370fb4978dff57a3d8ae598b9a94694 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Lists2.scala @@ -0,0 +1,30 @@ +import leon.lang._ + +object Lists2 { + abstract class List[T] + case class Cons[T](head: T, tail: List[T]) extends List[T] + case class Nil[T]() extends List[T] + + def forall[T](list: List[T], f: T => Boolean): Boolean = list match { + case Cons(head, tail) => f(head) && forall(tail, f) + case Nil() => true + } + + def positive(list: List[Int]): Boolean = list match { + case Cons(head, tail) => if (head < 0) false else positive(tail) + case Nil() => true + } + + def positive_lemma(list: List[Int]): Boolean = { + positive(list) == forall(list, (x: Int) => x >= 0) + } + + def positive_lemma_induct(list: List[Int]): Boolean = { + list match { + case Nil() => positive_lemma(list) + case Cons(head, tail) => positive_lemma(list) && positive_lemma_induct(tail) + } + }.holds +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/valid/Lists3.scala b/src/test/resources/regression/verification/purescala/valid/Lists3.scala new file mode 100644 index 0000000000000000000000000000000000000000..67060133a65e1661b32757ff98b0857e71897d12 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Lists3.scala @@ -0,0 +1,34 @@ +import leon.lang._ + +object Lists3 { + abstract class List[T] + case class Cons[T](head: T, tail: List[T]) extends List[T] + case class Nil[T]() extends List[T] + + def forall[T](list: List[T], f: T => Boolean): Boolean = list match { + case Cons(head, tail) => f(head) && forall(tail, f) + case Nil() => true + } + + def positive(list: List[Int]): Boolean = list match { + case Cons(head, tail) => if (head < 0) false else positive(tail) + case Nil() => true + } + + def gt(i: Int): Int => Boolean = x => x > i + + def gte(i: Int): Int => Boolean = x => gt(i)(x) || x == i + + def positive_lemma(list: List[Int]): Boolean = { + positive(list) == forall(list, gte(0)) + } + + def positive_lemma_induct(list: List[Int]): Boolean = { + list match { + case Nil() => positive_lemma(list) + case Cons(head, tail) => positive_lemma(list) && positive_lemma_induct(tail) + } + }.holds +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/valid/Lists4.scala b/src/test/resources/regression/verification/purescala/valid/Lists4.scala new file mode 100644 index 0000000000000000000000000000000000000000..02c24111d2a9936ba14764355f7c51d4cfc9e1fa --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Lists4.scala @@ -0,0 +1,26 @@ +import leon.lang._ + +object Lists4 { + abstract class List[T] + case class Cons[T](head: T, tail: List[T]) extends List[T] + case class Nil[T]() extends List[T] + + def map[F,T](list: List[F], f: F => T): List[T] = list match { + case Cons(head, tail) => Cons(f(head), map(tail, f)) + case Nil() => Nil() + } + + def map_lemma[A,B,C](list: List[A], f: A => B, g: B => C): Boolean = { + map(list, (x: A) => g(f(x))) == map(map(list, f), g) + } + + def map_lemma_induct[D,E,F](list: List[D], f: D => E, g: E => F): Boolean = { + list match { + case Nil() => map_lemma(list, f, g) + case Cons(head, tail) => map_lemma(list, f, g) && map_lemma_induct(tail, f, g) + } + }.holds + +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/valid/Monads2.scala b/src/test/resources/regression/verification/purescala/valid/Monads2.scala new file mode 100644 index 0000000000000000000000000000000000000000..341c60379522981ce625818d08d316a079fa1fa8 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Monads2.scala @@ -0,0 +1,35 @@ +import leon.lang._ + +object Monads2 { + abstract class Option[T] + case class Some[T](t: T) extends Option[T] + case class None[T]() extends Option[T] + + def flatMap[T,U](opt: Option[T], f: T => Option[U]): Option[U] = opt match { + case Some(x) => f(x) + case None() => None() + } + + def associative_law[T,U,V](opt: Option[T], f: T => Option[U], g: U => Option[V]): Boolean = { + flatMap(flatMap(opt, f), g) == flatMap(opt, (x: T) => flatMap(f(x), g)) + }.holds + + def left_unit_law[T,U](x: T, f: T => Option[U]): Boolean = { + flatMap(Some(x), f) == f(x) + }.holds + + def right_unit_law[T,U](opt: Option[T]): Boolean = { + flatMap(opt, (x: T) => Some(x)) == opt + }.holds + + /* + def associative_induct[T,U,V](opt: Option[T], f: T => Option[U], g: U => Option[V]): Boolean = { + opt match { + case Some(x) => associative(opt) + + } + } + */ +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/valid/PositiveMap.scala b/src/test/resources/regression/verification/purescala/valid/PositiveMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..7ee05ec9cffdc9789bfe4372586c20a882a84d37 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/PositiveMap.scala @@ -0,0 +1,36 @@ +import leon.lang._ + +object PositiveMap { + + abstract class List + case class Cons(head: Int, tail: List) extends List + case class Nil() extends List + + def positive(list: List): Boolean = list match { + case Cons(head, tail) => if (head < 0) false else positive(tail) + case Nil() => true + } + + def positiveMap_passing_1(f: (Int) => Int, list: List): List = { + list match { + case Cons(head, tail) => + val fh = f(head) + val nh = if (fh <= 0) -fh else fh + Cons(nh, positiveMap_passing_1(f, tail)) + case Nil() => Nil() + } + } ensuring { res => positive(res) } + + def positiveMap_passing_2(f: (Int) => Int, list: List): List = { + list match { + case Cons(head, tail) => + val fh = f(head) + val nh = if (fh < 0) -fh else fh + Cons(nh, positiveMap_passing_2(f, tail)) + case Nil() => Nil() + } + } ensuring { res => positive(res) } + +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/scala/leon/test/solvers/UnrollingSolverTests.scala b/src/test/scala/leon/test/solvers/UnrollingSolverTests.scala new file mode 100644 index 0000000000000000000000000000000000000000..94ed51f643d5991cd2db73dad9cfddf37a1fb751 --- /dev/null +++ b/src/test/scala/leon/test/solvers/UnrollingSolverTests.scala @@ -0,0 +1,47 @@ +package leon.test.solvers + +import leon._ +import leon.test._ +import leon.purescala.Trees._ +import leon.purescala.TypeTrees._ +import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.solvers._ +import leon.solvers.z3._ +import leon.solvers.combinators._ + +class UnrollingSolverTests extends LeonTestSuite { + + private val fx : Identifier = FreshIdentifier("x").setType(Int32Type) + private val fres : Identifier = FreshIdentifier("res").setType(Int32Type) + private val fDef : FunDef = new FunDef(FreshIdentifier("f"), Nil, Int32Type, ValDef(fx, Int32Type) :: Nil, DefType.MethodDef) + fDef.body = Some(IfExpr(GreaterThan(Variable(fx), IntLiteral(0)), + Plus(Variable(fx), FunctionInvocation(fDef.typed, Seq(Minus(Variable(fx), IntLiteral(1))))), + IntLiteral(1) + )) + fDef.postcondition = Some(fres -> GreaterThan(Variable(fres), IntLiteral(0))) + + private val program = Program( + FreshIdentifier("Minimal"), + List(UnitDef( + FreshIdentifier("Minimal"), + List(ModuleDef(FreshIdentifier("Minimal"), Seq(fDef), false)) + )) + ) + + private val sf = SolverFactory(() => new UnrollingSolver(testContext, program, new UninterpretedZ3Solver(testContext, program))) + private def check(expr: Expr, expected: Option[Boolean], msg: String) : Unit = { + test(msg) { + val solver = sf.getNewSolver + solver.assertCnstr(expr) + assert(solver.check == expected) + solver.free + } + } + + check(BooleanLiteral(true), Some(true), "'true' should always be valid") + check(BooleanLiteral(false), Some(false), "'false' should never be valid") + + check(Not(GreaterThan(FunctionInvocation(fDef.typed, Seq(Variable(FreshIdentifier("toto").setType(Int32Type)))), IntLiteral(0))), + Some(false), "unrolling should enable recursive definition verification") +} diff --git a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala index b56d3e0bad214d4e92a1b79a97a94d88d2df65c6..b2214f57304b2eca5ce241a09cd8c59ed1cda637 100644 --- a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala +++ b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala @@ -83,8 +83,12 @@ class UninterpretedZ3SolverTests extends LeonTestSuite { assertInvalid(solver, wrong2) // This is true, but that solver shouldn't know it. + // However, since the uninterpreted solver is a nice backend for the unrolling solver, + // it makes more sense to allow such formulas even if they are not completely known + /* private val unknown1 : Expr = Equals(f(x), Plus(x, IntLiteral(1))) assertUnknown(solver, unknown1) + */ assertValid(solver, Equals(g(x), g(x))) }