diff --git a/lib/z3.jar b/lib/z3.jar index 56df8fc4590189cca80dd1fca9e0c3307d005df0..681895c92887696712b2e358f2dd4f8b700d923e 100644 Binary files a/lib/z3.jar and b/lib/z3.jar differ diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index a81e10545445a1192a1125061a3258d0d3715088..205de861b0473c43fd72fa71507a61011eceac1d 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -113,9 +113,9 @@ class Analysis(val program: Program) { Implies(prec.get, bodyAndPost) } + import Analysis._ val (newExpr1, sideExprs1) = rewriteSimplePatternMatching(newExpr) - - val (newExpr2, sideExprs2) = inlineFunctionsAndContracts(newExpr1) + val (newExpr2, sideExprs2) = inlineFunctionsAndContracts(program, newExpr1) if(sideExprs1.isEmpty && sideExprs2.isEmpty) { newExpr2 @@ -125,7 +125,10 @@ class Analysis(val program: Program) { } } - def inlineFunctionsAndContracts(expr: Expr) : (Expr, Seq[Expr]) = { +} + +object Analysis { + def inlineFunctionsAndContracts(program: Program, expr: Expr) : (Expr, Seq[Expr]) = { var extras : List[Expr] = Nil val isFunCall: Function[Expr,Boolean] = _.isInstanceOf[FunctionInvocation] @@ -215,7 +218,14 @@ class Analysis(val program: Program) { } } } - - (searchAndApply(isPMExpr, rewritePM, expr), extras.reverse) + + // this gets us "extras", but we will still need to clean these up. + val cleanerTree = searchAndApply(isPMExpr, rewritePM, expr) + val theExtras = extras.reverse + val onExtras: Seq[(Expr,Seq[Expr])] = theExtras.map(rewriteSimplePatternMatching(_)) + // the "moreExtras" should be cleaned up due to the recursive call.. + val (rewrittenExtras, moreExtras) = onExtras.unzip + + (cleanerTree, rewrittenExtras ++ moreExtras.flatten) } } diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index 8cef09757969a9e9a2615f1ea3b8c7d0a3eb2d8b..8593de5f03f90bcec664a1a380ac941cb5cfc9cc 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -20,7 +20,9 @@ object Trees { } /* Control flow */ - case class FunctionInvocation(funDef: FunDef, args: Seq[Expr]) extends Expr + case class FunctionInvocation(funDef: FunDef, args: Seq[Expr]) extends Expr with FixedType { + val fixedType = funDef.returnType + } case class IfExpr(cond: Expr, then: Expr, elze: Expr) extends Expr case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index 5414f6a3beca5e6e4acf35280cf08e702d6e6c15..52b9d91ad8ccbfc3202b9e334a922872514e2db6 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -39,6 +39,16 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { //println(prog.transitiveCallGraph.map(p => (p._1.id.name, p._2.id.name).toString)) } + private object nextIntForSymbol { + private var counter = 0 + + def apply() : Int = { + val res = counter + counter = counter + 1 + res + } + } + private var intSort : Z3Sort = null private var boolSort : Z3Sort = null private var setSorts : Map[TypeTree,Z3Sort] = Map.empty @@ -85,10 +95,10 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { root match { case c: CaseClassDef => { // we create a recursive type with exactly one constructor - (c.id.name, List(c.id.name), List(c.fields.map(f => (f.id.name, typeToSortRef(f.tpe))))) + (c.id.uniqueName, List(c.id.uniqueName), List(c.fields.map(f => (f.id.uniqueName, typeToSortRef(f.tpe))))) } case a: AbstractClassDef => { - (a.id.name, childrenList.map(ccd => ccd.id.name), childrenList.map(ccd => ccd.fields.map(f => (f.id.name, typeToSortRef(f.tpe))))) + (a.id.uniqueName, childrenList.map(ccd => ccd.id.uniqueName), childrenList.map(ccd => ccd.fields.map(f => (f.id.uniqueName, typeToSortRef(f.tpe))))) } } }) @@ -122,10 +132,42 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { private var functionDefToDef : Map[FunDef,Z3FuncDecl] = Map.empty def prepareFunctions : Unit = { - for(funDef <- program.definedFunctions) /* if (program.isRecursive(funDef)) */ { + for(funDef <- program.definedFunctions) { val sortSeq = funDef.args.map(vd => typeToSort(vd.tpe).get) - val newSym = z3.mkStringSymbol(funDef.id.uniqueName) - functionDefToDef = functionDefToDef + (funDef -> z3.mkFuncDecl(newSym, sortSeq, typeToSort(funDef.returnType).get)) + functionDefToDef = functionDefToDef + (funDef -> z3.mkFreshFuncDecl(funDef.id.name, sortSeq, typeToSort(funDef.returnType).get)) + } + + // universally quantifies all functions ! + for(funDef <- program.definedFunctions) if(funDef.hasImplementation && funDef.args.size > 0) { + val argSorts: Seq[Z3Sort] = funDef.args.map(vd => typeToSort(vd.getType).get) + val boundVars = argSorts.zipWithIndex.map(p => z3.mkBound(p._2, p._1)) + val pattern: Z3Pattern = z3.mkPattern(functionDefToDef(funDef)(boundVars: _*)) + val nameTypePairs = argSorts.map(s => (z3.mkIntSymbol(nextIntForSymbol()), s)) + val fOfX: Expr = FunctionInvocation(funDef, funDef.args.map(_.toVariable)) + val toConvert: Expr = if(funDef.hasPrecondition) { + Implies(funDef.precondition.get, Equals(fOfX, funDef.body.get)) + } else { + Equals(fOfX, funDef.body.get) + } + val (newExpr1, sideExprs1) = Analysis.rewriteSimplePatternMatching(toConvert) + val (newExpr2, sideExprs2) = Analysis.inlineFunctionsAndContracts(program, newExpr1) + val finalToConvert = if(sideExprs1.isEmpty && sideExprs2.isEmpty) { + newExpr2 + } else { + Implies(And(sideExprs1 ++ sideExprs2), newExpr2) + } + val initialMap: Map[Identifier,Z3AST] = Map((funDef.args.map(_.id) zip boundVars):_*) + toZ3Formula(z3, finalToConvert, initialMap) match { + case Some(axiomTree) => { + val quantifiedAxiom = z3.mkForAll(0, List(pattern), nameTypePairs, axiomTree) + //z3.printAST(quantifiedAxiom) + z3.assertCnstr(quantifiedAxiom) + } + case None => { + reporter.warning("Could not generate forall axiom for " + funDef.id.name) + reporter.warning(finalToConvert) + } + } } } @@ -183,7 +225,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { } case (Some(false),_) => Some(true) case (None,_) => { - reporter.error("Z3 couldn't run properly or does not know the answer :(") + reporter.warning("Z3 doesn't know because: " + z3.getSearchFailure.message) None } }) @@ -194,12 +236,10 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { result } - private def toZ3Formula(z3: Z3Context, expr: Expr) : Option[Z3AST] = { + private def toZ3Formula(z3: Z3Context, expr: Expr, initialMap: Map[Identifier,Z3AST] = Map.empty) : Option[Z3AST] = { class CantTranslateException extends Exception - // because we create identifiers the first time we see them, this is - // convenient. - var z3Vars: Map[Identifier,Z3AST] = Map.empty + var z3Vars: Map[Identifier,Z3AST] = initialMap def rec(ex: Expr) : Z3AST = ex match { case Let(i,e,b) => { @@ -211,7 +251,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { case None => { val newAST = typeToSort(v.getType) match { case Some(s) => { - z3.mkConst(z3.mkStringSymbol(id.uniqueName), s) + z3.mkFreshConst(id.name, s) } case None => { reporter.warning("Unsupported type in Z3 transformation: " + v.getType) diff --git a/testcases/BinarySearchTree.scala b/testcases/BinarySearchTree.scala index a01ccc3f6587816f2034a292e773ddadbb8a77ac..b2e8e773126423860d7614f841adae275b699204 100644 --- a/testcases/BinarySearchTree.scala +++ b/testcases/BinarySearchTree.scala @@ -16,7 +16,7 @@ object BinarySearchTree { } else { n } - }) ensuring(_ != Leaf()) //ensuring(result => contents(result) != Set.empty[Int]) + }) ensuring(contents(_) != Set.empty[Int]) def contains(tree: Tree, value: Int) : Boolean = tree match { case Leaf() => false