diff --git a/src/funcheck/FunCheckPlugin.scala b/src/funcheck/FunCheckPlugin.scala index d038a8f51a614113f7b2c9aef37a02ad74efb62d..0b998ad30a988b0723564923d7dcea89e1fae436 100644 --- a/src/funcheck/FunCheckPlugin.scala +++ b/src/funcheck/FunCheckPlugin.scala @@ -24,6 +24,7 @@ class FunCheckPlugin(val global: Global) extends Plugin { " -P:funcheck:nodefaults Runs only the analyses provided by the extensions" + "\n" + " -P:funcheck:functions=fun1:... Only generates verification conditions for the specified functions" + "\n" + " -P:funcheck:unrolling=[0,1,2] Unrolling depth for recursive functions" + "\n" + + " -P:funcheck:noaxioms Don't generate forall axioms for recursive functions" + "\n" + " -P:funcheck:tolerant Silently extracts non-pure function bodies as ''unknown''" + "\n" + " -P:funcheck:quiet No info and warning messages from the extensions" ) @@ -39,6 +40,7 @@ class FunCheckPlugin(val global: Global) extends Plugin { case "tolerant" => silentlyTolerateNonPureBodies = true case "quiet" => purescala.Settings.quietExtensions = true case "nodefaults" => purescala.Settings.runDefaultExtensions = false + case "noaxioms" => purescala.Settings.noForallAxioms = true case s if s.startsWith("unrolling=") => purescala.Settings.unrollingLevel = try { s.substring("unrolling=".length, s.length).toInt } catch { case _ => 0 } case s if s.startsWith("functions=") => purescala.Settings.functionsToAnalyse = Set(splitList(s.substring("functions=".length, s.length)): _*) case s if s.startsWith("extensions=") => purescala.Settings.extensionNames = splitList(s.substring("extensions=".length, s.length)) diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index a45e05982f864d6bda25425299fc1a25b8c5b29b..b8ec92d9d72d7cd517101d2c20bb0751f2308693 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -115,15 +115,12 @@ class Analysis(val program: Program) { } import Analysis._ - val newExpr0 = unrollRecursiveFunctions(program, withPrec, Settings.unrollingLevel) - val (newExpr1, sideExprs1) = rewriteSimplePatternMatching(newExpr0) - val (newExpr2, sideExprs2) = inlineFunctionsAndContracts(program, newExpr1) - - if(sideExprs1.isEmpty && sideExprs2.isEmpty) { - newExpr2 - } else { - Implies(And(sideExprs1 ++ sideExprs2), newExpr2) - } + val (newExpr0, sideExprs0) = unrollRecursiveFunctions(program, withPrec, Settings.unrollingLevel) + val expr0 = Implies(And(sideExprs0), newExpr0) + val (newExpr1, sideExprs1) = rewriteSimplePatternMatching(expr0) + val expr1 = Implies(And(sideExprs1), newExpr1) + val (newExpr2, sideExprs2) = inlineFunctionsAndContracts(program, expr1) + Implies(And(sideExprs2), newExpr2) } } @@ -164,27 +161,44 @@ object Analysis { (searchAndApply(isFunCall, applyToCall, expr), extras.reverse) } - def unrollRecursiveFunctions(program: Program, expr: Expr, times: Int) : Expr = { - def isRecursiveCall(e: Expr) = e match { - case f @ FunctionInvocation(fd, _) if fd.hasImplementation && program.isRecursive(fd) => true - case _ => false - } - def unrollCall(t: Int)(e: Expr) = e match { - case f @ FunctionInvocation(fd, args) if fd.hasImplementation && program.isRecursive(fd) => { - val newLetIDs = fd.args.map(a => FreshIdentifier(a.id.name, true).setType(a.tpe)) - val newLetVars = newLetIDs.map(Variable(_)) - val substs: Map[Expr,Expr] = Map((fd.args.map(_.toVariable) zip newLetVars) :_*) - val bodyWithLetVars: Expr = replace(substs, fd.body.get) - val bigLet = (newLetIDs zip args).foldLeft(bodyWithLetVars)((e,p) => Let(p._1, p._2, e)) - unrollRecursiveFunctions(program, bigLet, t - 1) + def unrollRecursiveFunctions(program: Program, expression: Expr, times: Int) : (Expr,Seq[Expr]) = { + var extras : List[Expr] = Nil + + def urf(expr: Expr, left: Int) : Expr = { + def isRecursiveCall(e: Expr) = e match { + case f @ FunctionInvocation(fd, _) if fd.hasImplementation && program.isRecursive(fd) => true + case _ => false } - case o => o + def unrollCall(t: Int)(e: Expr) = e match { + case f @ FunctionInvocation(fd, args) if fd.hasImplementation && program.isRecursive(fd) => { + val newLetIDs = fd.args.map(a => FreshIdentifier(a.id.name, true).setType(a.tpe)) + val newLetVars = newLetIDs.map(Variable(_)) + val substs: Map[Expr,Expr] = Map((fd.args.map(_.toVariable) zip newLetVars) :_*) + val bodyWithLetVars: Expr = replace(substs, fd.body.get) + if(fd.hasPostcondition) { + val post = fd.postcondition.get + val newVar = Variable(FreshIdentifier("call", true)).setType(fd.returnType) + val newExtra1 = Equals(newVar, bodyWithLetVars) + val newExtra2 = replace(substs + (ResultVariable() -> newVar), post) + val bigLet = (newLetIDs zip args).foldLeft(And(newExtra1, newExtra2))((e,p) => Let(p._1, p._2, e)) + extras = bigLet :: extras + newVar + } else { + val bigLet = (newLetIDs zip args).foldLeft(bodyWithLetVars)((e,p) => Let(p._1, p._2, e)) + urf(bigLet, t-1) + } + } + case o => o + } + + if(left > 0) + searchAndApply(isRecursiveCall, unrollCall(left), expr, false) + else + expr } - if(times > 0) - searchAndApply(isRecursiveCall, unrollCall(times), expr, false) - else - expr + val finalE = urf(expression, times) + (finalE, extras.reverse) } // Rewrites pattern matching expressions where the cases simply correspond to diff --git a/src/purescala/Settings.scala b/src/purescala/Settings.scala index 5e591acf78b2154213abab700df97c79e9633091..1bd8fac0354ac09e2790fefe0239fd0c274bc8fd 100644 --- a/src/purescala/Settings.scala +++ b/src/purescala/Settings.scala @@ -9,5 +9,6 @@ object Settings { var reporter: Reporter = new DefaultReporter var quietReporter: Reporter = new QuietReporter var runDefaultExtensions: Boolean = true + var noForallAxioms: Boolean = false var unrollingLevel: Int = 0 } diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index 561a1f71528a37687463db35114babeb67634f09..ec25321c5ffc51e919761b070726c8c19d4fc435 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -51,7 +51,11 @@ object Trees { /* Propositional logic */ object And { - def apply(exprs: Seq[Expr]) : And = new And(exprs) + def apply(exprs: Seq[Expr]) : Expr = exprs.size match { + case 0 => BooleanLiteral(true) + case 1 => exprs.head + case _ => new And(exprs) + } def apply(l: Expr, r: Expr): Expr = (l,r) match { case (And(exs1), And(exs2)) => And(exs1 ++ exs2) @@ -61,7 +65,7 @@ object Trees { } def unapply(and: And) : Option[Seq[Expr]] = - if(and == null) None else Some(and.exprs) + if(and == null) None else Some(and.exprs) } class And(val exprs: Seq[Expr]) extends Expr with FixedType { @@ -69,7 +73,11 @@ object Trees { } object Or { - def apply(exprs: Seq[Expr]) : Or = new Or(exprs) + def apply(exprs: Seq[Expr]) : Expr = exprs.size match { + case 0 => BooleanLiteral(false) + case 1 => exprs.head + case _ => new Or(exprs) + } def apply(l: Expr, r: Expr): Expr = (l,r) match { case (Or(exs1), Or(exs2)) => Or(exs1 ++ exs2) @@ -79,7 +87,7 @@ object Trees { } def unapply(or: Or) : Option[Seq[Expr]] = - if(or == null) None else Some(or.exprs) + if(or == null) None else Some(or.exprs) } class Or(val exprs: Seq[Expr]) extends Expr with FixedType { @@ -90,7 +98,20 @@ object Trees { val fixedType = BooleanType } - case class Implies(left: Expr, right: Expr) extends Expr with FixedType { + object Implies { + def apply(left: Expr, right: Expr) : Expr = (left,right) match { + case (BooleanLiteral(false), _) => BooleanLiteral(true) + case (_, BooleanLiteral(true)) => BooleanLiteral(true) + case (BooleanLiteral(true), r) => r + case (l, BooleanLiteral(false)) => Not(l) + case (l1, Implies(l2, r2)) => Implies(And(l1, l2), r2) + case _ => new Implies(left, right) + } + def unapply(imp: Implies) : Option[(Expr,Expr)] = + if(imp == null) None else Some(imp.left, imp.right) + } + + class Implies(val left: Expr, val right: Expr) extends Expr with FixedType { val fixedType = BooleanType } @@ -227,7 +248,7 @@ object Trees { def unapply(expr: Expr) : Option[(Expr,Expr,(Expr,Expr)=>Expr)] = expr match { case Equals(t1,t2) => Some((t1,t2,Equals)) case Iff(t1,t2) => Some((t1,t2,Iff)) - case Implies(t1,t2) => Some((t1,t2,Implies)) + case Implies(t1,t2) => Some((t1,t2, ((e1,e2) => Implies(e1,e2)))) case Plus(t1,t2) => Some((t1,t2,Plus)) case Minus(t1,t2) => Some((t1,t2,Minus)) case Times(t1,t2) => Some((t1,t2,Times)) diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index b0fb46068c9a8c72267e74c5520b52caa8bcea67..1765358a8a40a7615f8999913d83120970f18d3e 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -138,37 +138,45 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { } // universally quantifies all functions ! - for(funDef <- program.definedFunctions) if(funDef.hasImplementation && funDef.args.size > 0) { - val argSorts: Seq[Z3Sort] = funDef.args.map(vd => typeToSort(vd.getType).get) - val boundVars = argSorts.zipWithIndex.map(p => z3.mkBound(p._2, p._1)) - val pattern: Z3Pattern = z3.mkPattern(functionDefToDef(funDef)(boundVars: _*)) - val nameTypePairs = argSorts.map(s => (z3.mkIntSymbol(nextIntForSymbol()), s)) - val fOfX: Expr = FunctionInvocation(funDef, funDef.args.map(_.toVariable)) - val toConvert: Expr = if(funDef.hasPrecondition) { - Implies(funDef.precondition.get, Equals(fOfX, funDef.body.get)) - } else { - Equals(fOfX, funDef.body.get) - } - val (newExpr1, sideExprs1) = Analysis.rewriteSimplePatternMatching(toConvert) - val (newExpr2, sideExprs2) = (newExpr1, Nil) // Analysis.inlineFunctionsAndContracts(program, newExpr1) - val finalToConvert = if(sideExprs1.isEmpty && sideExprs2.isEmpty) { - newExpr2 - } else { - Implies(And(sideExprs1 ++ sideExprs2), newExpr2) - } - val initialMap: Map[Identifier,Z3AST] = Map((funDef.args.map(_.id) zip boundVars):_*) - toZ3Formula(z3, finalToConvert, initialMap) match { - case Some(axiomTree) => { - val quantifiedAxiom = z3.mkForAll(0, List(pattern), nameTypePairs, axiomTree) - //z3.printAST(quantifiedAxiom) - z3.assertCnstr(quantifiedAxiom) + if(!Settings.noForallAxioms) { + import Analysis.SimplePatternMatching + for(funDef <- program.definedFunctions) if(funDef.hasImplementation && program.isRecursive(funDef) && funDef.args.size > 0) { + funDef.body.get match { + case SimplePatternMatching(_,_,_) => reporter.info("There's a good opportunity for a good axiomatization of " + funDef.id.name) + case _ => ; } - case None => { - reporter.warning("Could not generate forall axiom for " + funDef.id.name) - reporter.warning(toConvert) - reporter.warning(newExpr1) - reporter.warning(newExpr2) - reporter.warning(finalToConvert) + + val argSorts: Seq[Z3Sort] = funDef.args.map(vd => typeToSort(vd.getType).get) + val boundVars = argSorts.zipWithIndex.map(p => z3.mkBound(p._2, p._1)) + val pattern: Z3Pattern = z3.mkPattern(functionDefToDef(funDef)(boundVars: _*)) + val nameTypePairs = argSorts.map(s => (z3.mkIntSymbol(nextIntForSymbol()), s)) + val fOfX: Expr = FunctionInvocation(funDef, funDef.args.map(_.toVariable)) + val toConvert: Expr = if(funDef.hasPrecondition) { + Implies(funDef.precondition.get, Equals(fOfX, funDef.body.get)) + } else { + Equals(fOfX, funDef.body.get) + } + val (newExpr1, sideExprs1) = Analysis.rewriteSimplePatternMatching(toConvert) + val (newExpr2, sideExprs2) = (newExpr1, Nil) // Analysis.inlineFunctionsAndContracts(program, newExpr1) + val finalToConvert = if(sideExprs1.isEmpty && sideExprs2.isEmpty) { + newExpr2 + } else { + Implies(And(sideExprs1 ++ sideExprs2), newExpr2) + } + val initialMap: Map[Identifier,Z3AST] = Map((funDef.args.map(_.id) zip boundVars):_*) + toZ3Formula(z3, finalToConvert, initialMap) match { + case Some(axiomTree) => { + val quantifiedAxiom = z3.mkForAll(0, List(pattern), nameTypePairs, axiomTree) + //z3.printAST(quantifiedAxiom) + z3.assertCnstr(quantifiedAxiom) + } + case None => { + reporter.warning("Could not generate forall axiom for " + funDef.id.name) + reporter.warning(toConvert) + reporter.warning(newExpr1) + reporter.warning(newExpr2) + reporter.warning(finalToConvert) + } } } }