diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index 5e8dd3a8cfba320657984836671e5b88a6bb5257..6d882a45c1b8aebd9b8e75e33e4ff6cac5aac14d 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -113,19 +113,29 @@ class Analysis(val program: Program) { } import Analysis._ - val (newExpr0, sideExprs0) = unrollRecursiveFunctions(program, withPrec, Settings.unrollingLevel) - val expr0 = simplifyLets(Implies(And(sideExprs0), newExpr0)) - val (newExpr1, sideExprs1) = inlineFunctionsAndContracts(program, expr0) - val expr1 = simplifyLets(Implies(And(sideExprs1), newExpr1)) + reporter.info("Before unrolling:") + reporter.info(withPrec) + val expr0 = unrollRecursiveFunctions(program, withPrec, Settings.unrollingLevel) + reporter.info("Before inlining:") + reporter.info(expr0) + val expr1 = inlineFunctionsAndContracts(program, expr0) + reporter.info("Before PM-rewriting:") + reporter.info(expr1) val (newExpr2, sideExprs2) = rewriteSimplePatternMatching(expr1) - simplifyLets(Implies(And(sideExprs2), newExpr2)) + val expr2 = (pulloutLets(Implies(And(sideExprs2), newExpr2))) + reporter.info("After PM-rewriting:") + reporter.info(expr2) + expr2 } } } object Analysis { - def inlineFunctionsAndContracts(program: Program, expr: Expr) : (Expr, Seq[Expr]) = { + // Warning: this should only be called on a top-level formula ! It will add + // new variables and implications in a way that preserves the validity of the + // formula. + def inlineFunctionsAndContracts(program: Program, expr: Expr) : Expr = { var extras : List[Expr] = Nil val isFunCall: Function[Expr,Boolean] = _.isInstanceOf[FunctionInvocation] @@ -156,10 +166,14 @@ object Analysis { case o => o } - (searchAndApply(isFunCall, applyToCall, expr), extras.reverse) + val finalE = searchAndApply(isFunCall, applyToCall, expr) + pulloutLets(Implies(And(extras.reverse), finalE)) } - def unrollRecursiveFunctions(program: Program, expression: Expr, times: Int) : (Expr,Seq[Expr]) = { + // Warning: this should only be called on a top-level formula ! It will add + // new variables and implications in a way that preserves the validity of the + // formula. + def unrollRecursiveFunctions(program: Program, expression: Expr, times: Int) : Expr = { var extras : List[Expr] = Nil def urf(expr: Expr, left: Int) : Expr = { @@ -180,13 +194,13 @@ object Analysis { val newExtra2 = replace(substs + (ResultVariable() -> newVar), post) val bigLet = (newLetIDs zip args).foldLeft(And(newExtra1, newExtra2))((e,p) => Let(p._1, p._2, e)) extras = urf(bigLet, t-1) :: extras - println("*********************************") - println(bigLet) - println(" --- from -----------------------") - println(f) - println(" --- newVar is ------------------") - println(newVar) - println("*********************************") + // println("*********************************") + // println(bigLet) + // println(" --- from -----------------------") + // println(f) + // println(" --- newVar is ------------------") + // println(newVar) + // println("*********************************") newVar } else { val bigLet = (newLetIDs zip args).foldLeft(bodyWithLetVars)((e,p) => Let(p._1, p._2, e)) @@ -203,7 +217,7 @@ object Analysis { } val finalE = urf(expression, times) - (finalE, extras.reverse) + pulloutLets(Implies(And(extras.reverse), finalE)) } // Rewrites pattern matching expressions where the cases simply correspond to diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index ec25321c5ffc51e919761b070726c8c19d4fc435..fb3eb3e5a014abc9b4d8d16a9b2da78083fd18c9 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -405,6 +405,34 @@ object Trees { searchAndApply(isLet,simplerLet,expr) } + /* Rewrites the expression so that all lets are at the top levels. */ + def pulloutLets(expr: Expr) : Expr = { + val (storedLets, noLets) = pulloutAndKeepLets(expr) + rebuildLets(storedLets, noLets) + } + + def pulloutAndKeepLets(expr: Expr) : (Seq[(Identifier,Expr)], Expr) = { + var storedLets: List[(Identifier,Expr)] = Nil + + val isLet = ((t: Expr) => t.isInstanceOf[Let]) + def storeLet(t: Expr) : Expr = t match { + case l @ Let(i, e, b) => (storedLets = ((i,e)) :: storedLets); l + case _ => t + } + def killLet(t: Expr) : Expr = t match { + case l @ Let(i, e, b) => b + case _ => t + } + + searchAndApply(isLet, storeLet, expr) + val noLets = searchAndApply(isLet, killLet, expr) + (storedLets, noLets) + } + + def rebuildLets(lets: Seq[(Identifier,Expr)], expr: Expr) : Expr = { + lets.foldLeft(expr)((e,p) => Let(p._1, p._2, e)) + } + /* Fully expands all let expressions. */ def expandLets(expr: Expr) : Expr = { def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match { diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index 1765358a8a40a7615f8999913d83120970f18d3e..db70fa1a36059ac1a3f8603fab442917464e202a 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -139,46 +139,46 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { // universally quantifies all functions ! if(!Settings.noForallAxioms) { - import Analysis.SimplePatternMatching - for(funDef <- program.definedFunctions) if(funDef.hasImplementation && program.isRecursive(funDef) && funDef.args.size > 0) { - funDef.body.get match { - case SimplePatternMatching(_,_,_) => reporter.info("There's a good opportunity for a good axiomatization of " + funDef.id.name) - case _ => ; - } + // import Analysis.SimplePatternMatching + // for(funDef <- program.definedFunctions) if(funDef.hasImplementation && program.isRecursive(funDef) && funDef.args.size > 0) { + // funDef.body.get match { + // case SimplePatternMatching(_,_,_) => reporter.info("There's a good opportunity for a good axiomatization of " + funDef.id.name) + // case _ => ; + // } - val argSorts: Seq[Z3Sort] = funDef.args.map(vd => typeToSort(vd.getType).get) - val boundVars = argSorts.zipWithIndex.map(p => z3.mkBound(p._2, p._1)) - val pattern: Z3Pattern = z3.mkPattern(functionDefToDef(funDef)(boundVars: _*)) - val nameTypePairs = argSorts.map(s => (z3.mkIntSymbol(nextIntForSymbol()), s)) - val fOfX: Expr = FunctionInvocation(funDef, funDef.args.map(_.toVariable)) - val toConvert: Expr = if(funDef.hasPrecondition) { - Implies(funDef.precondition.get, Equals(fOfX, funDef.body.get)) - } else { - Equals(fOfX, funDef.body.get) - } - val (newExpr1, sideExprs1) = Analysis.rewriteSimplePatternMatching(toConvert) - val (newExpr2, sideExprs2) = (newExpr1, Nil) // Analysis.inlineFunctionsAndContracts(program, newExpr1) - val finalToConvert = if(sideExprs1.isEmpty && sideExprs2.isEmpty) { - newExpr2 - } else { - Implies(And(sideExprs1 ++ sideExprs2), newExpr2) - } - val initialMap: Map[Identifier,Z3AST] = Map((funDef.args.map(_.id) zip boundVars):_*) - toZ3Formula(z3, finalToConvert, initialMap) match { - case Some(axiomTree) => { - val quantifiedAxiom = z3.mkForAll(0, List(pattern), nameTypePairs, axiomTree) - //z3.printAST(quantifiedAxiom) - z3.assertCnstr(quantifiedAxiom) - } - case None => { - reporter.warning("Could not generate forall axiom for " + funDef.id.name) - reporter.warning(toConvert) - reporter.warning(newExpr1) - reporter.warning(newExpr2) - reporter.warning(finalToConvert) - } - } - } + // val argSorts: Seq[Z3Sort] = funDef.args.map(vd => typeToSort(vd.getType).get) + // val boundVars = argSorts.zipWithIndex.map(p => z3.mkBound(p._2, p._1)) + // val pattern: Z3Pattern = z3.mkPattern(functionDefToDef(funDef)(boundVars: _*)) + // val nameTypePairs = argSorts.map(s => (z3.mkIntSymbol(nextIntForSymbol()), s)) + // val fOfX: Expr = FunctionInvocation(funDef, funDef.args.map(_.toVariable)) + // val toConvert: Expr = if(funDef.hasPrecondition) { + // Implies(funDef.precondition.get, Equals(fOfX, funDef.body.get)) + // } else { + // Equals(fOfX, funDef.body.get) + // } + // val (newExpr1, sideExprs1) = Analysis.rewriteSimplePatternMatching(toConvert) + // val (newExpr2, sideExprs2) = (newExpr1, Nil) // Analysis.inlineFunctionsAndContracts(program, newExpr1) + // val finalToConvert = if(sideExprs1.isEmpty && sideExprs2.isEmpty) { + // newExpr2 + // } else { + // Implies(And(sideExprs1 ++ sideExprs2), newExpr2) + // } + // val initialMap: Map[Identifier,Z3AST] = Map((funDef.args.map(_.id) zip boundVars):_*) + // toZ3Formula(z3, finalToConvert, initialMap) match { + // case Some(axiomTree) => { + // val quantifiedAxiom = z3.mkForAll(0, List(pattern), nameTypePairs, axiomTree) + // //z3.printAST(quantifiedAxiom) + // z3.assertCnstr(quantifiedAxiom) + // } + // case None => { + // reporter.warning("Could not generate forall axiom for " + funDef.id.name) + // reporter.warning(toConvert) + // reporter.warning(newExpr1) + // reporter.warning(newExpr2) + // reporter.warning(finalToConvert) + // } + // } + // } } }