diff --git a/lib/z3.jar b/lib/z3.jar index d99a37fc51a48beb3079d05fd603ae91706c4234..6416b0828d4ee9c354c0ad8d0cbf8b0586e36680 100644 Binary files a/lib/z3.jar and b/lib/z3.jar differ diff --git a/src/purescala/Definitions.scala b/src/purescala/Definitions.scala index ee59e388ad6ddb8981fa9c6d495f7bbdceb1b3dd..38196e50c1cb5afa8f07bfed087a8ac5237297d0 100644 --- a/src/purescala/Definitions.scala +++ b/src/purescala/Definitions.scala @@ -276,6 +276,7 @@ object Definitions { } @serializable class FunDef(val id: Identifier, val returnType: TypeTree, val args: VarDecls) extends Definition with ScalacPositional { var body: Option[Expr] = None + def implementation : Option[Expr] = body var precondition: Option[Expr] = None var postcondition: Option[Expr] = None @@ -300,4 +301,68 @@ object Definitions { def isPrivate : Boolean = annots.contains("private") } + + object Catamorphism { + // If a function is a catamorphism, this deconstructs it into the cases. Eg: + // def size(l : List) : Int = ... + // should return: + // List, + // Seq( + // (Nil(), 0) + // (Cons(x, xs), 1 + size(xs))) + // ...where x and xs are fresh (and could be unused in the expr) + import scala.collection.mutable.{Map=>MutableMap} + type CataRepr = (AbstractClassDef,Seq[(CaseClass,Expr)]) + private val unapplyCache : MutableMap[FunDef,CataRepr] = MutableMap.empty + + def unapply(funDef : FunDef) : Option[CataRepr] = if( + funDef == null || + funDef.args.size != 1 || + funDef.hasPrecondition || + !funDef.hasImplementation || + (funDef.hasPostcondition && functionCallsOf(funDef.postcondition.get) != Set.empty) + ) { + None + } else if(unapplyCache.isDefinedAt(funDef)) { + Some(unapplyCache(funDef)) + } else { + var moreConditions = true + val argVar = funDef.args(0).toVariable + val argVarType = argVar.getType + val body = funDef.body.get + val iteized = matchToIfThenElse(body) + val invocations = functionCallsOf(iteized) + moreConditions = moreConditions && invocations.forall(_ match { + case FunctionInvocation(fd, Seq(CaseClassSelector(_, e, _))) if fd == funDef && e == argVar => true + case _ => false + }) + moreConditions = moreConditions && argVarType.isInstanceOf[AbstractClassType] + var spmList : Seq[(CaseClassDef,Identifier,Seq[Identifier],Expr)] = Seq.empty + moreConditions = moreConditions && (body match { + case SimplePatternMatching(scrut, _, s) if scrut == argVar => spmList = s; true + case _ => false + }) + + val patternSeq : Seq[(CaseClass,Expr)] = if(moreConditions) { + spmList.map(tuple => { + val (ccd, id, ids, ex) = tuple + val ex2 = matchToIfThenElse(ex) + if(!(variablesOf(ex2) -- ids).isEmpty) { + moreConditions = false + } + (CaseClass(ccd, ids.map(Variable(_))), ex2) + }) + } else { + Seq.empty + } + + if(moreConditions) { + val finalResult = (argVarType.asInstanceOf[AbstractClassType].classDef, patternSeq) + unapplyCache(funDef) = finalResult + Some(finalResult) + } else { + None + } + } + } } diff --git a/src/purescala/FairZ3Solver.scala b/src/purescala/FairZ3Solver.scala index dd81fa3ea03917a8a308f9b92cfa30f4af06e86a..6371d0a309372832a1270b958ffe5e1e43f263ea 100644 --- a/src/purescala/FairZ3Solver.scala +++ b/src/purescala/FairZ3Solver.scala @@ -14,7 +14,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac // have to comment this to use the solver for constraint solving... // assert(Settings.useFairInstantiator) - private final val UNKNOWNASSAT : Boolean = false + private final val UNKNOWNASSAT : Boolean = !Settings.noForallAxioms val description = "Fair Z3 Solver" override val shortDescription = "Z3-f" @@ -49,7 +49,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac z3.delete } z3 = new Z3Context(z3cfg) - //z3.traceToStdout + // z3.traceToStdout exprToZ3Id = Map.empty z3IdToExpr = Map.empty @@ -185,6 +185,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac private var functionMap: Map[FunDef, Z3FuncDecl] = Map.empty private var reverseFunctionMap: Map[Z3FuncDecl, FunDef] = Map.empty + private var axiomatizedFunctions : Set[FunDef] = Set.empty def prepareFunctions: Unit = { functionMap = Map.empty @@ -197,6 +198,61 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac functionMap = functionMap + (funDef -> z3Decl) reverseFunctionMap = reverseFunctionMap + (z3Decl -> funDef) } + + if(!Settings.noForallAxioms) { + prepareAxioms + } + } + + private def prepareAxioms : Unit = { + assert(!Settings.noForallAxioms) + program.definedFunctions.foreach(_ match { + case fd @ Catamorphism(acd, cases) => { + assert(!fd.hasPrecondition && fd.hasImplementation) + for(cse <- cases) { + val (cc @ CaseClass(ccd, args), expr) = cse + assert(args.forall(_.isInstanceOf[Variable])) + val argsAsIDs = args.map(_.asInstanceOf[Variable].id) + assert(variablesOf(expr) -- argsAsIDs.toSet == Set.empty) + val axiom : Z3AST = if(args.isEmpty) { + val eq = Equals(FunctionInvocation(fd, Seq(cc)), expr) + toZ3Formula(z3, eq).get + } else { + val z3ArgSorts = argsAsIDs.map(i => typeToSort(i.getType)) + val boundVars = z3ArgSorts.zipWithIndex.map(p => z3.mkBound(p._2, p._1)) + val map : Map[Identifier,Z3AST] = (argsAsIDs zip boundVars).toMap + val eq = Equals(FunctionInvocation(fd, Seq(cc)), expr) + val z3IzedEq = toZ3Formula(z3, eq, map).get + val z3IzedCC = toZ3Formula(z3, cc, map).get + val pattern = z3.mkPattern(functionDefToDecl(fd)(z3IzedCC)) + val nameTypePairs = z3ArgSorts.map(s => (z3.mkFreshIntSymbol, s)) + z3.mkForAll(0, List(pattern), nameTypePairs, z3IzedEq) + } + //println("I'll assert now an axiom: " + axiom) + //println("Case axiom:") + //println(axiom) + z3.assertCnstr(axiom) + } + + if(fd.hasPostcondition) { + // we know it doesn't contain any function invocation + val cleaned = matchToIfThenElse(expandLets(fd.postcondition.get)) + val argSort = typeToSort(fd.args(0).getType) + val bound = z3.mkBound(0, argSort) + val subst = replace(Map(ResultVariable() -> FunctionInvocation(fd, Seq(fd.args(0).toVariable))), cleaned) + val z3IzedPost = toZ3Formula(z3, subst, Map(fd.args(0).id -> bound)).get + val pattern = z3.mkPattern(functionDefToDecl(fd)(bound)) + val nameTypePairs = Seq((z3.mkFreshIntSymbol, argSort)) + val postAxiom = z3.mkForAll(0, List(pattern), nameTypePairs, z3IzedPost) + //println("Post axiom:") + //println(postAxiom) + z3.assertCnstr(postAxiom) + } + + axiomatizedFunctions += fd + } + case _ => ; + }) } // assumes prepareSorts has been called.... @@ -306,13 +362,17 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac // println("Bunch of blocking bools: " + sib.map(_.toString).mkString(", ")) // println("Basis : " + basis) - z3.assertCnstr(toZ3Formula(z3, basis).get) + val bb = toZ3Formula(z3, basis).get + // println("Base : " + bb) + z3.assertCnstr(bb) // println(toZ3Formula(z3, basis).get) // for(clause <- clauses) { // println("we're getting a new clause " + clause) // z3.assertCnstr(toZ3Formula(z3, clause).get) // } - z3.assertCnstr(toZ3Formula(z3, And(clauses)).get) + val cc = toZ3Formula(z3, And(clauses)).get + // println("CC : " + cc) + z3.assertCnstr(cc) blockingSet ++= Set(guards.map(p => if(p._2) Not(Variable(p._1)) else Variable(p._1)) : _*) @@ -594,10 +654,10 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac // exprToZ3Id = Map.empty // z3IdToExpr = Map.empty - for((k, v) <- initialMap) { - exprToZ3Id += (k.toVariable -> v) - z3IdToExpr += (v -> k.toVariable) - } + // for((k, v) <- initialMap) { + // exprToZ3Id += (k.toVariable -> v) + // z3IdToExpr += (v -> k.toVariable) + // } def rec(ex: Expr): Z3AST = { //println("Stacking up call for:") @@ -630,6 +690,10 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac newAST } } + + case ite @ IfExpr(c, t, e) => { + z3.mkITE(rec(c), rec(t), rec(e)) + } // case ite @ IfExpr(c, t, e) => { // val switch = z3.mkFreshConst("path", z3.mkBoolSort) // val placeHolder = z3.mkFreshConst("ite", typeToSort(ite.getType)) @@ -821,9 +885,15 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac // Returns whether some invocations were actually blocked in the end. private def registerBlocked(blockingAtom : Identifier, polarity : Boolean, invocations : Set[FunctionInvocation]) : Boolean = { - // TODO - // val filtered = invocations -- "those who are axiomatized" - val filtered = invocations + val filtered = invocations.filter(i => { + val FunctionInvocation(fd, _) = i + if(axiomatizedFunctions(fd)) { + reporter.info("I'm not registering " + i + " as blocked because it's axiomatized.") + false + } else { + true + } + }) val pair = (blockingAtom, polarity) val alreadyBlocked = blocked.get(pair) @@ -843,11 +913,6 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac everythingEverUnrolled += functionInvocation } - def closeUnrollings2(formula : Expr) : (Expr, Seq[Expr], Seq[(Identifier,Boolean)]) = { - - scala.Predef.error("wtf") - } - def closeUnrollings(formula : Expr) : (Expr, Seq[Expr], Seq[(Identifier,Boolean)]) = { var (basis, clauses, ite2Bools) = clausifyITE(formula) @@ -856,19 +921,17 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac var treatedClauses : List[Expr] = Nil var blockers : List[(Identifier,Boolean)] = Nil - stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(basis) + stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(basis, axiomatizedFunctions) do { // We go through each clause and figure out what must be enrolled and // what must be blocked. We register everything. for(clause <- clauses) { clause match { case Iff(Variable(_), cond) => { - stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(cond) + stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(cond, axiomatizedFunctions) } - // TODO : sort out the functions that are not recursive and unroll - // them in any case case Implies(v @ Variable(id), then) => { - val calls = topLevelFunctionCallsOf(then) + val calls = topLevelFunctionCallsOf(then, axiomatizedFunctions) if(!calls.isEmpty) { assert(!blocked.isDefinedAt((id,true))) if(registerBlocked(id, true, calls)) //blocked((id,true)) = calls @@ -876,7 +939,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac } } case Implies(Not(v @ Variable(id)), elze) => { - val calls = topLevelFunctionCallsOf(elze) + val calls = topLevelFunctionCallsOf(elze, axiomatizedFunctions) if(!calls.isEmpty) { assert(!blocked.isDefinedAt((id,false))) if(registerBlocked(id, false, calls)) //blocked((id,false)) = calls @@ -902,7 +965,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac for(formula <- unrolled) { val (basis2, clauses2, _) = clausifyITE(formula) - stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(basis2) + stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(basis2, axiomatizedFunctions) clauses = clauses2 ::: clauses treatedClauses = basis2 :: treatedClauses } diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index 2ac55013ab184cf9bc025005b06889f3443ac348..ea891f76400fc2ed49e6ad353c153516ab3c33b1 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -681,14 +681,14 @@ object Trees { treeCatamorphism(convert, combine, compute, expr) } - def topLevelFunctionCallsOf(expr: Expr) : Set[FunctionInvocation] = { + def topLevelFunctionCallsOf(expr: Expr, barring : Set[FunDef] = Set.empty) : Set[FunctionInvocation] = { def convert(t: Expr) : Set[FunctionInvocation] = t match { - case f @ FunctionInvocation(_, _) => Set(f) + case f @ FunctionInvocation(fd, _) if(!barring(fd)) => Set(f) case _ => Set.empty } def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2 def compute(t: Expr, s: Set[FunctionInvocation]) = t match { - case f @ FunctionInvocation(_, _) => Set(f) // ++ s that's the difference with the one below + case f @ FunctionInvocation(fd, _) if(!barring(fd)) => Set(f) // ++ s that's the difference with the one below case _ => s } treeCatamorphism(convert, combine, compute, expr) diff --git a/testcases/sas2011-testcases/RedBlackTree.scala b/testcases/sas2011-testcases/RedBlackTree.scala index 3a1d768b013ef495b69acb95582d61581344bfe1..da5dd22c0c71649e0dd6ec01a7ee30fa7490ef7f 100644 --- a/testcases/sas2011-testcases/RedBlackTree.scala +++ b/testcases/sas2011-testcases/RedBlackTree.scala @@ -20,10 +20,10 @@ object RedBlackTree { case Node(_, l, v, r) => content(l) ++ Set(v) ++ content(r) } - def size(t: Tree) : Int = t match { + def size(t: Tree) : Int = (t match { case Empty() => 0 case Node(_, l, v, r) => size(l) + 1 + size(r) - } + }) ensuring(_ >= 0) /* We consider leaves to be black by definition */ def isBlack(t: Tree) : Boolean = t match {