diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index 9f283850bef315acbea5e62211cb2e69820aff63..c57afcadf5c1755f3666d87f51769c009a342767 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -113,17 +113,17 @@ class Analysis(val program: Program) { } import Analysis._ - reporter.info("Before unrolling:") - reporter.info(withPrec) + //reporter.info("Before unrolling:") + //reporter.info(expandLets(withPrec)) val expr0 = unrollRecursiveFunctions(program, withPrec, Settings.unrollingLevel) - reporter.info("Before inlining:") - reporter.info(expr0) + //reporter.info("Before inlining:") + //reporter.info(expandLets(expr0)) val expr1 = inlineFunctionsAndContracts(program, expr0) - reporter.info("Before PM-rewriting:") - reporter.info(expr1) + //reporter.info("Before PM-rewriting:") + //reporter.info(expandLets(expr1)) val expr2 = rewriteSimplePatternMatching(expr1) - reporter.info("After PM-rewriting:") - reporter.info(expr2) + //reporter.info("After PM-rewriting:") + //reporter.info(expandLets(expr2)) expr2 } } @@ -173,50 +173,59 @@ object Analysis { // new variables and implications in a way that preserves the validity of the // formula. def unrollRecursiveFunctions(program: Program, expression: Expr, times: Int) : Expr = { - var extras : List[Expr] = Nil + def unroll(exx: Expr) : (Expr,Seq[Expr]) = { + var extras : List[Expr] = Nil - def urf(expr: Expr, left: Int) : Expr = { - def isRecursiveCall(e: Expr) = e match { - case f @ FunctionInvocation(fd, _) if fd.hasImplementation && program.isRecursive(fd) => true - case _ => false - } - def unrollCall(t: Int)(e: Expr) = e match { - case f @ FunctionInvocation(fd, args) if fd.hasImplementation && program.isRecursive(fd) => { - val newLetIDs = fd.args.map(a => FreshIdentifier(a.id.name, true).setType(a.tpe)) - val newLetVars = newLetIDs.map(Variable(_)) - val substs: Map[Expr,Expr] = Map((fd.args.map(_.toVariable) zip newLetVars) :_*) - val bodyWithLetVars: Expr = replace(substs, fd.body.get) - if(fd.hasPostcondition) { - val post = fd.postcondition.get - val newVar = Variable(FreshIdentifier("call", true)).setType(fd.returnType) - val newExtra1 = Equals(newVar, bodyWithLetVars) - val newExtra2 = replace(substs + (ResultVariable() -> newVar), post) - val bigLet = (newLetIDs zip args).foldLeft(And(newExtra1, newExtra2))((e,p) => Let(p._1, p._2, e)) - extras = urf(bigLet, t-1) :: extras - // println("*********************************") - // println(bigLet) - // println(" --- from -----------------------") - // println(f) - // println(" --- newVar is ------------------") - // println(newVar) - // println("*********************************") - newVar - } else { - val bigLet = (newLetIDs zip args).foldLeft(bodyWithLetVars)((e,p) => Let(p._1, p._2, e)) - urf(bigLet, t-1) + def urf(expr: Expr, left: Int) : Expr = { + def isRecursiveCall(e: Expr) = e match { + case f @ FunctionInvocation(fd, _) if fd.hasImplementation && program.isRecursive(fd) => true + case _ => false + } + def unrollCall(t: Int)(e: Expr) = e match { + case f @ FunctionInvocation(fd, args) if fd.hasImplementation && program.isRecursive(fd) => { + val newLetIDs = fd.args.map(a => FreshIdentifier(a.id.name, true).setType(a.tpe)) + val newLetVars = newLetIDs.map(Variable(_)) + val substs: Map[Expr,Expr] = Map((fd.args.map(_.toVariable) zip newLetVars) :_*) + val bodyWithLetVars: Expr = replace(substs, fd.body.get) + if(fd.hasPostcondition) { + val post = fd.postcondition.get + val newVar = Variable(FreshIdentifier("call", true)).setType(fd.returnType) + val newExtra1 = Equals(newVar, bodyWithLetVars) + val newExtra2 = replace(substs + (ResultVariable() -> newVar), post) + val bigLet = (newLetIDs zip args).foldLeft(And(newExtra1, newExtra2))((e,p) => Let(p._1, p._2, e)) + extras = urf(bigLet, t-1) :: extras + // println("*********************************") + // println(bigLet) + // println(" --- from -----------------------") + // println(f) + // println(" --- newVar is ------------------") + // println(newVar) + // println("*********************************") + newVar + } else { + val bigLet = (newLetIDs zip args).foldLeft(bodyWithLetVars)((e,p) => Let(p._1, p._2, e)) + urf(bigLet, t-1) + } } + case o => o } - case o => o - } - if(left > 0) - searchAndApply(isRecursiveCall, unrollCall(left), expr, false) - else - expr + if(left > 0) + searchAndApply(isRecursiveCall, unrollCall(left), expr, false) + else + expr + } + val finalE = urf(exx, times) + (finalE, extras) } - val finalE = urf(expression, times) - pulloutLets(Implies(And(extras.reverse), finalE)) + val (savedLets, naked) = pulloutAndKeepLets(expression) + val infoFromLets: Seq[(Expr,Seq[Expr])] = savedLets.map(_._2).map(unroll(_)) + val extrasFromLets: Seq[Expr] = infoFromLets.map(_._2).flatten + val newLetBodies: Seq[Expr] = infoFromLets.map(_._1) + val newSavedLets: Seq[(Identifier,Expr)] = savedLets.map(_._1) zip newLetBodies + val (cleaned, extras) = unroll(naked) + rebuildLets(newSavedLets, Implies(And(extrasFromLets ++ extras), cleaned)) } // Rewrites pattern matching expressions where the cases simply correspond to @@ -251,9 +260,12 @@ object Analysis { (cleanerTree, extras.reverse) } val (savedLets, naked) = pulloutAndKeepLets(expression) - val savedLets2 = savedLets.map(p => (p._1, rewriteSimplePatternMatching(p._2))) + val infoFromLets: Seq[(Expr,Seq[Expr])] = savedLets.map(_._2).map(rspm(_)) + val extrasFromLets: Seq[Expr] = infoFromLets.map(_._2).flatten + val newLetBodies: Seq[Expr] = infoFromLets.map(_._1) + val newSavedLets: Seq[(Identifier,Expr)] = savedLets.map(_._1) zip newLetBodies val (cleaned, extras) = rspm(naked) - rebuildLets(savedLets, Implies(And(extras), cleaned)) + rebuildLets(newSavedLets, Implies(And(extrasFromLets ++ extras), cleaned)) } } diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index db70fa1a36059ac1a3f8603fab442917464e202a..b896d77793bd6726e3d36a4bb91a2586313045bd 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -163,7 +163,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { // } else { // Implies(And(sideExprs1 ++ sideExprs2), newExpr2) // } - // val initialMap: Map[Identifier,Z3AST] = Map((funDef.args.map(_.id) zip boundVars):_*) + // val initialMap: Map[String,Z3AST] = Map((funDef.args.map(_.id.uniqueName) zip boundVars):_*) // toZ3Formula(z3, finalToConvert, initialMap) match { // case Some(axiomTree) => { // val quantifiedAxiom = z3.mkForAll(0, List(pattern), nameTypePairs, axiomTree) @@ -247,17 +247,20 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { result } - private def toZ3Formula(z3: Z3Context, expr: Expr, initialMap: Map[Identifier,Z3AST] = Map.empty) : Option[Z3AST] = { + private def toZ3Formula(z3: Z3Context, expr: Expr, initialMap: Map[String,Z3AST] = Map.empty) : Option[Z3AST] = { class CantTranslateException extends Exception - var z3Vars: Map[Identifier,Z3AST] = initialMap + var z3Vars: Map[String,Z3AST] = initialMap def rec(ex: Expr) : Z3AST = ex match { case Let(i,e,b) => { - z3Vars = z3Vars + (i -> rec(e)) - rec(b) + val re = rec(e) + z3Vars = z3Vars + (i.uniqueName -> re) + val rb = rec(b) + z3Vars = z3Vars - i.uniqueName + rb } - case v @ Variable(id) => z3Vars.get(id) match { + case v @ Variable(id) => z3Vars.get(id.uniqueName) match { case Some(ast) => ast case None => { val newAST = typeToSort(v.getType) match { @@ -269,7 +272,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { throw new CantTranslateException } } - z3Vars = z3Vars + (id -> newAST) + z3Vars = z3Vars + (id.uniqueName -> newAST) newAST } } diff --git a/testcases/ListWithSize.scala b/testcases/ListWithSize.scala index 91a13d0239507b0290f023487a0bc7e1a76201f3..be0e8c9d59ec573a29f7a89e5a4099b059440ad4 100644 --- a/testcases/ListWithSize.scala +++ b/testcases/ListWithSize.scala @@ -3,10 +3,10 @@ object ListWithSize { case class Cons(head: Int, tail: List) extends List case class Nil() extends List - def size(l: List) : Int = l match { + def size(l: List) : Int = (l match { case Nil() => 0 case Cons(_, t) => 1 + size(t) - } + }) ensuring (_ >= 0) def append(x: Int, l: List) : List = (l match { case Nil() => Cons(x, Nil())