From d2d70bba9504102b00f6e51e3be1b89c28d28373 Mon Sep 17 00:00:00 2001 From: Philippe Suter <philippe.suter@gmail.com> Date: Sat, 13 Nov 2010 17:53:04 +0000 Subject: [PATCH] some random commit --- pldi2011-testcases/LambdaEval.scala | 133 ++++++++------ src/purescala/Z3ModelReconstruction.scala | 4 +- src/purescala/Z3Solver.scala | 92 +++++++--- .../z3plugins/instantiator/Instantiator.scala | 164 ++++++++++++++++-- testcases/RedBlackTree.scala | 7 +- 5 files changed, 296 insertions(+), 104 deletions(-) diff --git a/pldi2011-testcases/LambdaEval.scala b/pldi2011-testcases/LambdaEval.scala index afe2ae81c..e7f12bde2 100644 --- a/pldi2011-testcases/LambdaEval.scala +++ b/pldi2011-testcases/LambdaEval.scala @@ -23,74 +23,97 @@ object LambdaEval { case Snd(_) => false } - def okPair(p: StoreExprPair): Boolean = p match { + def okPair(p: StoreExprPairAbs): Boolean = p match { case StoreExprPair(_, res) => ok(res) } sealed abstract class List - case class Cons(head: BindingPair, tail: List) extends List + case class Cons(head: BindingPairAbs, tail: List) extends List case class Nil() extends List - sealed abstract class AbstractPair - case class BindingPair(key: Int, value: Expr) extends AbstractPair - case class StoreExprPair(store: List, expr: Expr) extends AbstractPair + sealed abstract class BindingPairAbs + case class BindingPair(key: Int, value: Expr) extends BindingPairAbs + sealed abstract class StoreExprPairAbs + case class StoreExprPair(store: List, expr: Expr) extends StoreExprPairAbs + + def storeElems(store: List) : Set[Int] = store match { + case Nil() => Set.empty[Int] + case Cons(BindingPair(k,_), xs) => Set(k) ++ storeElems(xs) + } + + def freeVars(expr: Expr) : Set[Int] = expr match { + case Const(_) => Set.empty[Int] + case Plus(l,r) => freeVars(l) ++ freeVars(r) + case Lam(x, bdy) => freeVars(bdy) -- Set(x) + case Pair(f,s) => freeVars(f) ++ freeVars(s) + case Var(n) => Set(n) + case App(l,r) => freeVars(l) ++ freeVars(r) + case Fst(e) => freeVars(e) + case Snd(e) => freeVars(e) + } // Find first element in list that has first component 'x' and return its // second component, analogous to List.assoc in OCaml - def find(x: Int, l: List): Expr = l match { - case Cons(i, is) => if (i.key == x) i.value else find(x, is) + def find(x: Int, l: List): Expr = { + require(storeElems(l).contains(x)) + l match { + case Cons(BindingPair(k,v), is) => if (k == x) v else find(x, is) + } } // Evaluator - def eval(store: List, expr: Expr): StoreExprPair = (expr match { - case Const(i) => StoreExprPair(store, Const(i)) - case Var(x) => StoreExprPair(store, find(x, store)) - case Plus(e1, e2) => - val i1 = eval(store, e1) match { - case StoreExprPair(_, Const(i)) => i - } - val i2 = eval(store, e2) match { - case StoreExprPair(_, Const(i)) => i - } - StoreExprPair(store, Const(i1 + i2)) - case App(e1, e2) => - val store1 = eval(store, e1) match { - case StoreExprPair(resS,_) => resS - } - val x = eval(store, e1) match { - case StoreExprPair(_, Lam(resX, _)) => resX - } - val e = eval(store, e1) match { - case StoreExprPair(_, Lam(_, resE)) => resE - } - /* - val StoreExprPair(store1, Lam(x, e)) = eval(store, e1) match { - case StoreExprPair(resS, Lam(resX, resE)) => StoreExprPair(resS, Lam(resX, resE)) - } - */ - val v2 = eval(store, e2) match { - case StoreExprPair(_, v) => v - } - eval(Cons(BindingPair(x, v2), store1), e) - case Lam(x, e) => StoreExprPair(store, Lam(x, e)) - case Pair(e1, e2) => - val v1 = eval(store, e1) match { - case StoreExprPair(_, v) => v - } - val v2 = eval(store, e2) match { - case StoreExprPair(_, v) => v - } - StoreExprPair(store, Pair(v1, v2)) - case Fst(e) => - eval(store, e) match { - case StoreExprPair(_, Pair(v1, _)) => StoreExprPair(store, v1) - } - case Snd(e) => - eval(store, e) match { - case StoreExprPair(_, Pair(_, v2)) => StoreExprPair(store, v2) - } - }) ensuring(res => okPair(res)) + def eval(store: List, expr: Expr): StoreExprPair = { + require(freeVars(expr) subsetOf storeElems(store)) + expr match { + case Const(i) => StoreExprPair(store, Const(i)) + case Var(x) => StoreExprPair(store, find(x, store)) + case Plus(e1, e2) => + val i1 = eval(store, e1) match { + case StoreExprPair(_, Const(i)) => i + } + val i2 = eval(store, e2) match { + case StoreExprPair(_, Const(i)) => i + } + StoreExprPair(store, Const(i1 + i2)) + case App(e1, e2) => + val store1 = eval(store, e1) match { + case StoreExprPair(resS,_) => resS + } + val x = eval(store, e1) match { + case StoreExprPair(_, Lam(resX, _)) => resX + } + val e = eval(store, e1) match { + case StoreExprPair(_, Lam(_, resE)) => resE + } + /* + val StoreExprPair(store1, Lam(x, e)) = eval(store, e1) match { + case StoreExprPair(resS, Lam(resX, resE)) => StoreExprPair(resS, Lam(resX, resE)) + } + */ + val v2 = eval(store, e2) match { + case StoreExprPair(_, v) => v + } + eval(Cons(BindingPair(x, v2), store1), e) + case Lam(x, e) => StoreExprPair(store, Lam(x, e)) + case Pair(e1, e2) => + val v1 = eval(store, e1) match { + case StoreExprPair(_, v) => v + } + val v2 = eval(store, e2) match { + case StoreExprPair(_, v) => v + } + StoreExprPair(store, Pair(v1, v2)) + case Fst(e) => + eval(store, e) match { + case StoreExprPair(_, Pair(v1, _)) => StoreExprPair(store, v1) + } + case Snd(e) => + eval(store, e) match { + case StoreExprPair(_, Pair(_, v2)) => StoreExprPair(store, v2) + } + } + } ensuring(res => okPair(res)) /*ensuring(res => res match { case StoreExprPair(_, resExpr) => ok(resExpr) }) */ diff --git a/src/purescala/Z3ModelReconstruction.scala b/src/purescala/Z3ModelReconstruction.scala index 83ae2b1ee..6565549cd 100644 --- a/src/purescala/Z3ModelReconstruction.scala +++ b/src/purescala/Z3ModelReconstruction.scala @@ -17,8 +17,8 @@ trait Z3ModelReconstruction { val z3ID : Z3AST = exprToZ3Id(id.toVariable) expectedType match { - case BooleanType => model.evalAsBool(z3ID).map(BooleanLiteral(_)) - case Int32Type => model.evalAsInt(z3ID).map(IntLiteral(_)) + case BooleanType => model.evalAs[Boolean](z3ID).map(BooleanLiteral(_)) + case Int32Type => model.evalAs[Int](z3ID).map(IntLiteral(_)) case other => model.eval(z3ID) match { case None => None case Some(t) => softFromZ3Formula(t) diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index d372a45d0..bfadf1dc7 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -570,38 +570,80 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with Z3ModelReco class CantTranslateException(t: Z3AST) extends Exception("Can't translate from Z3 tree: " + t) def rec(t: Z3AST) : Expr = z3.getASTKind(t) match { - case Z3AppAST(_, args) if args.size == 0 && z3IdToExpr.isDefinedAt(t) => { - z3IdToExpr(t) - } - case Z3AppAST(decl, args) if isKnownDecl(decl) => { - val fd = functionDeclToDef(decl) - assert(fd.args.size == args.size) - FunctionInvocation(fd, args.map(rec(_))) - } - case Z3AppAST(decl, args) if args.size == 1 && reverseADTTesters.isDefinedAt(decl) => { - CaseClassInstanceOf(reverseADTTesters(decl), rec(args(0))) - } - case Z3AppAST(decl, args) if args.size == 1 && reverseADTFieldSelectors.isDefinedAt(decl) => { - val (ccd, fid) = reverseADTFieldSelectors(decl) - CaseClassSelector(ccd, rec(args(0)), fid) - } - case Z3AppAST(decl, args) if reverseADTConstructors.isDefinedAt(decl) => { - val ccd = reverseADTConstructors(decl) - assert(args.size == ccd.fields.size) - CaseClass(ccd, args.map(rec(_))) + case Z3AppAST(decl, args) => { + val argsSize = args.size + if(argsSize == 0 && z3IdToExpr.isDefinedAt(t)) { + z3IdToExpr(t) + } else if(isKnownDecl(decl)) { + val fd = functionDeclToDef(decl) + assert(fd.args.size == argsSize) + FunctionInvocation(fd, args.map(rec(_))) + } else if(argsSize == 1 && reverseADTTesters.isDefinedAt(decl)) { + CaseClassInstanceOf(reverseADTTesters(decl), rec(args(0))) + } else if(argsSize == 1 && reverseADTFieldSelectors.isDefinedAt(decl)) { + val (ccd, fid) = reverseADTFieldSelectors(decl) + CaseClassSelector(ccd, rec(args(0)), fid) + } else if(reverseADTConstructors.isDefinedAt(decl)) { + val ccd = reverseADTConstructors(decl) + assert(argsSize == ccd.fields.size) + CaseClass(ccd, args.map(rec(_))) + } else { + import Z3DeclKind._ + val rargs = args.map(rec(_)) + z3.getDeclKind(decl) match { + case OpTrue => BooleanLiteral(true) + case OpFalse => BooleanLiteral(false) + case OpEq => Equals(rargs(0), rargs(1)) + case OpITE => { + assert(argsSize == 3) + val r0 = rargs(0) + val r1 = rargs(1) + val r2 = rargs(2) + IfExpr(r0, r1, r2).setType(leastUpperBound(r1.getType, r2.getType)) + } + case OpAnd => And(rargs) + case OpOr => Or(rargs) + case OpIff => Iff(rargs(0), rargs(1)) + case OpXor => Not(Iff(rargs(0), rargs(1))) + case OpNot => Not(rargs(0)) + case OpImplies => Implies(rargs(0), rargs(1)) + case OpLE => LessEquals(rargs(0), rargs(1)) + case OpGE => GreaterEquals(rargs(0), rargs(1)) + case OpLT => LessThan(rargs(0), rargs(1)) + case OpGT => GreaterThan(rargs(0), rargs(1)) + case OpAdd => { + assert(argsSize == 2) + Plus(rargs(0), rargs(1)) + } + case OpSub => { + assert(argsSize == 2) + Minus(rargs(0), rargs(1)) + } + case OpUMinus => UMinus(rargs(0)) + case OpMul => { + assert(argsSize == 2) + Times(rargs(0), rargs(1)) + } + case other => { + System.err.println("Don't know what to do with this declKind : " + other) + throw new CantTranslateException(t) + } + } + } } + case Z3NumeralAST(Some(v)) => IntLiteral(v) case other @ _ => { - println("Don't know what this is " + other) + System.err.println("Don't know what this is " + other) if(useInstantiator) { instantiator.dumpFunctionMap } else { - println("REVERSE FUNCTION MAP:") - println(reverseFunctionMap.toSeq.mkString("\n")) + System.err.println("REVERSE FUNCTION MAP:") + System.err.println(reverseFunctionMap.toSeq.mkString("\n")) } - println("REVERSE CONS MAP:") - println(reverseADTConstructors.toSeq.mkString("\n")) - System.exit(-1) + System.err.println("REVERSE CONS MAP:") + System.err.println(reverseADTConstructors.toSeq.mkString("\n")) + // System.exit(-1) throw new CantTranslateException(t) } } diff --git a/src/purescala/z3plugins/instantiator/Instantiator.scala b/src/purescala/z3plugins/instantiator/Instantiator.scala index 2eecf5e87..9c84c857f 100644 --- a/src/purescala/z3plugins/instantiator/Instantiator.scala +++ b/src/purescala/z3plugins/instantiator/Instantiator.scala @@ -10,20 +10,22 @@ import purescala.Settings import purescala.Z3Solver +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instantiator") { import z3Solver.{z3,program,typeToSort,fromZ3Formula,toZ3Formula} setCallbacks( // reduceApp = true, -// finalCheck = true, -// push = true, -// pop = true, + finalCheck = true, + push = true, + pop = true, newApp = true, newAssignment = true, newRelevant = true, // newEq = true, // newDiseq = true, -// reset = true, + reset = true, restart = true ) @@ -51,35 +53,39 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan reverseFunctionMap.getOrElse(decl, scala.Predef.error("No FunDef found for Z3 definition " + decl + " in Instantiator.")) } - override def newAssignment(ast: Z3AST, polarity: Boolean) : Unit = { + // The logic starts here. + private var stillToAssert : Set[(Int,Expr)] = Set.empty + + override def newAssignment(ast: Z3AST, polarity: Boolean) : Unit = safeBlockToAssertAxioms { } override def newApp(ast: Z3AST) : Unit = { + examineAndUnroll(ast) + } + override def newRelevant(ast: Z3AST) : Unit = { + examineAndUnroll(ast) } private var bodyInlined : Int = 0 - override def newRelevant(ast: Z3AST) : Unit = { + def examineAndUnroll(ast: Z3AST) : Unit = if(bodyInlined < Settings.unrollingLevel) { val aps = fromZ3Formula(ast) val fis = functionCallsOf(aps) println("As Purescala: " + aps) for(fi <- fis) { val FunctionInvocation(fd, args) = fi println("interesting function call : " + fi) - if(fd.hasPostcondition) { + if(bodyInlined < Settings.unrollingLevel && fd.hasPostcondition) { + bodyInlined += 1 val post = matchToIfThenElse(fd.postcondition.get) - // FIXME TODO we could use let identifiers here to speed things up a little bit... - // val resFresh = FreshIdentifier("resForPostOf" + fd.id.uniqueName, true).setType(fi.getType) - // val newLetIDs = fd.args.map(a => FreshIdentifier("argForPostOf" + fd.id.uniqueName, true).setType(a.tpe)).toList - // val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) + (ResultVariable() -> Variable(resFresh)) + val isSafe = functionCallsOf(post).isEmpty + val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip args) : _*) + (ResultVariable() -> fi) // println(substMap) val newBody = replace(substMap, post) println("I'm going to add this : " + newBody) - val newAxiom = toZ3Formula(z3, newBody).get - println("As Z3: " + newAxiom) - assertAxiom(newAxiom) + assertIfSafeOrDelay(newBody)//, isSafe) } if(bodyInlined < Settings.unrollingLevel && fd.hasBody) { @@ -87,13 +93,133 @@ class Instantiator(val z3Solver: Z3Solver) extends Z3Theory(z3Solver.z3, "Instan val body = matchToIfThenElse(fd.body.get) val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip args) : _*) val newBody = replace(substMap, body) - println("I'm going to add this : " + newBody) - val newAxiom = z3.mkEq(toZ3Formula(z3, fi).get, toZ3Formula(z3, newBody).get) - println("As Z3: " + newAxiom) - assertAxiom(newAxiom) + val theEquality = Equals(fi, newBody) + println("I'm going to add this : " + theEquality) + assertIfSafeOrDelay(theEquality) + } + } + } + + override def finalCheck : Boolean = safeBlockToAssertAxioms { + if(stillToAssert.isEmpty) { + true + } else { + for((lvl,ast) <- stillToAssert) { + assertAxiomASAP(ast, lvl) + // assertPermanently(ast) } + stillToAssert = Set.empty + true } } - override def restart : Unit = { } + // This is concerned with how many new function calls the assertion is going + // to introduce. + private def assertIfSafeOrDelay(ast: Expr, isSafe: Boolean = false) : Unit = { + stillToAssert += ((pushLevel, ast)) + } + + // Assert as soon as possible and keep asserting as long as level is >= lvl. + private def assertAxiomASAP(expr: Expr, lvl: Int) : Unit = assertAxiomASAP(toZ3Formula(z3, expr).get, lvl) + private def assertAxiomASAP(ast: Z3AST, lvl: Int) : Unit = { + if(canAssertAxiom) { + assertAxiomNow(ast) + if(lvl < pushLevel) { + // Remember to reassert when we backtrack. + if(pushLevel > 0) { + rememberToReassertAt(pushLevel - 1, lvl, ast) + } + } + } else { + toAssertASAP = toAssertASAP + ((lvl, ast)) + } + } + + private def assertAxiomFrom(ast: Z3AST, level: Int) : Unit = { + toAssertASAP = toAssertASAP + ((level, ast)) + } + +// private def assertPermanently(expr: Expr) : Unit = { +// val asZ3 = toZ3Formula(z3, expr).get +// +// if(canAssertAxiom) { +// assertAxiomNow(asZ3) +// } else { +// toAssertASAP = toAssertASAP + ((0, asZ3)) +// } +// } + + private def assertAxiomNow(ast: Z3AST) : Unit = { + if(!canAssertAxiom) + println("WARNING ! ASSERTING AXIOM WHEN NOT SAFE !") + + println("Now asserting : " + ast) + assertAxiom(ast) + } + + override def push : Unit = { + pushLevel += 1 + } + + override def pop : Unit = { + pushLevel -= 1 + + if(toReassertAt.isDefinedAt(pushLevel)) { + for((lvl,ax) <- toReassertAt(pushLevel)) { + assertAxiomFrom(ax, lvl) + } + toReassertAt(pushLevel).clear + } + + assert(pushLevel >= 0) + } + + override def restart : Unit = { + pushLevel = 0 + } + + override def reset : Unit = reinit + + // Below is all the machinery to be able to assert axioms in safe states. + + private var pushLevel : Int = _ + private var canAssertAxiom : Boolean = _ + private var toAssertASAP : Set[(Int,Z3AST)] = _ + private val toReassertAt : MutableMap[Int,MutableSet[(Int,Z3AST)]] = MutableMap.empty + + private def rememberToReassertAt(lvl: Int, axLvl: Int, ax: Z3AST) : Unit = { + if(toReassertAt.isDefinedAt(lvl)) { + toReassertAt(lvl) += ((axLvl, ax)) + } else { + toReassertAt(lvl) = MutableSet((axLvl, ax)) + } + } + + reinit + private def reinit : Unit = { + pushLevel = 0 + canAssertAxiom = false + toAssertASAP = Set.empty + stillToAssert = Set.empty + } + + private def safeBlockToAssertAxioms[A](block: => A): A = { + canAssertAxiom = true + + if (toAssertASAP.nonEmpty) { + for ((lvl, ax) <- toAssertASAP) { + if(lvl <= pushLevel) { + assertAxiomNow(ax) + if(lvl < pushLevel && pushLevel > 0) { + rememberToReassertAt(pushLevel - 1, lvl, ax) + } + } + } + toAssertASAP = Set.empty + } + + val result = block + canAssertAxiom = false + result + } } diff --git a/testcases/RedBlackTree.scala b/testcases/RedBlackTree.scala index 9cb55c574..04ccb08fa 100644 --- a/testcases/RedBlackTree.scala +++ b/testcases/RedBlackTree.scala @@ -26,9 +26,10 @@ object RedBlackTree { if (x < y) balance(c, ins(x, a), y, b) else if (x == y) Node(c,a,y,b) else balance(c,a,y,ins(x, b)) - }) ensuring (res => - content(res) == content(t) ++ Set(x) && - size(t) <= size(res) && size(res) < size(t) + 2) + }) ensuring (res => ( + content(res) == content(t) ++ Set(x) +// && size(t) <= size(res) && size(res) < size(t) + 2) + )) def add(x: Int, t: Tree): Tree = { makeBlack(ins(x, t)) -- GitLab