diff --git a/demo-munch b/demo-munch index d523208364e68f025895bb6b036e06ff9641230f..9b1b6a7de9cfc05e7a4f0cd4914123d942a8e72d 100644 --- a/demo-munch +++ b/demo-munch @@ -1 +1 @@ -./scalac-funcheck -P:funcheck:extensions=multisets.Main testcases/MultiExample.scala +./scalac-funcheck -P:funcheck:nodefaults -P:funcheck:extensions=multisets.Main testcases/MultiExample.scala diff --git a/lib-bin/libz3.so b/lib-bin/libz3.so index bb6fcf57d8f85161e3e14738aae7a3c090b8cdb5..a6ea716dd3c70164f2308536703537ae61663166 100755 Binary files a/lib-bin/libz3.so and b/lib-bin/libz3.so differ diff --git a/lib/z3.jar b/lib/z3.jar index c9534e6d12ab40fd170b1b06ebafced18cf16689..33340bcc3861619181a3b03d4b43770713ef8c57 100644 Binary files a/lib/z3.jar and b/lib/z3.jar differ diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index af93704e39c458c2088586e8219dda485d20c862..b989ce0fd8cbf75cd62eedaeb764398f41eb310a 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -136,155 +136,142 @@ class Analysis(val program: Program) { } import Analysis._ - // reporter.info("Before unrolling:") - // reporter.info(expandLets(withPrec)) - val expr0 = unrollRecursiveFunctions(program, withPrec, Settings.unrollingLevel) - // reporter.info("Before inlining:") - // reporter.info(expandLets(expr0)) - val expr1 = inlineFunctionsAndContracts(program, expr0) - // reporter.info("Before PM-rewriting:") - // reporter.info(expandLets(expr1)) - val expr2 = rewriteSimplePatternMatching(expr1) - // reporter.info("After PM-rewriting:") - // reporter.info(expandLets(expr2)) - assert(wellOrderedLets(expr2)) - expr2 + + if(Settings.experimental) { + reporter.info("Raw:") + reporter.info(withPrec) + reporter.info("Raw, expanded:") + reporter.info(expandLets(withPrec)) + } + reporter.info(" - inlining...") + val expr0 = inlineNonRecursiveFunctions(program, withPrec) + if(Settings.experimental) { + reporter.info("Inlined:") + reporter.info(expr0) + reporter.info("Inlined, expanded:") + reporter.info(expandLets(expr0)) + } + reporter.info(" - unrolling...") + val expr1 = unrollRecursiveFunctions(program, expr0, Settings.unrollingLevel) + if(Settings.experimental) { + reporter.info("Unrolled:") + reporter.info(expr1) + reporter.info("Unrolled, expanded:") + reporter.info(expandLets(expr1)) + } + reporter.info(" - inlining contracts...") + val expr2 = inlineContracts(expr1) + if(Settings.experimental) { + reporter.info("Contract'ed:") + reporter.info(expr2) + reporter.info("Contract'ed, expanded:") + reporter.info(expandLets(expr2)) + } + reporter.info(" - converting pattern-matching...") + val expr3 = rewriteSimplePatternMatching(expr2) + if(Settings.experimental) { + reporter.info("Pattern'ed:") + reporter.info(expr3) + reporter.info("Pattern'ed, expanded:") + reporter.info(expandLets(expr3)) + } + expr3 } } - } object Analysis { - // Warning: this should only be called on a top-level formula ! It will add - // new variables and implications in a way that preserves the validity of the - // formula. - def inlineFunctionsAndContracts(program: Program, expr: Expr) : Expr = { - var extras : List[Expr] = Nil + private def inlineFunctionCall(f : FunctionInvocation) : Expr = { + val FunctionInvocation(fd, args) = f + val newLetIDs = fd.args.map(a => FreshIdentifier("arg_" + a.id.name, true).setType(a.tpe)).toList + val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) + val newBody = replace(substMap, fd.body.get) + simplifyLets((newLetIDs zip args).foldRight(newBody)((iap, e) => Let(iap._1, iap._2, e))) + } + def inlineNonRecursiveFunctions(program: Program, expression: Expr) : Expr = { def applyToCall(e: Expr) : Option[Expr] = e match { - case f @ FunctionInvocation(fd, args) => { - val fArgsAsVars: List[Variable] = fd.args.map(_.toVariable).toList - val fParamsAsLetVars: List[Identifier] = fd.args.map(a => FreshIdentifier("arg", true).setType(a.tpe)).toList - val fParamsAsLetVarVars = fParamsAsLetVars.map(Variable(_)) - - def mkBigLet(ex: Expr) : Expr = (fParamsAsLetVars zip args).foldRight(ex)((iap, e) => { - Let(iap._1, iap._2, e) - }) - - val substMap = Map[Expr,Expr]((fArgsAsVars zip fParamsAsLetVarVars) : _*) - if(fd.hasPostcondition) { - val newVar = Variable(FreshIdentifier("call", true)).setType(fd.returnType) - extras = And( - replace(substMap + (ResultVariable() -> newVar), fd.postcondition.get), - Equals(newVar, FunctionInvocation(fd, fParamsAsLetVarVars).setType(fd.returnType)) - ) :: extras - Some(mkBigLet(newVar)) - /* END CHANGE */ - } else if(fd.hasImplementation && !program.isRecursive(fd)) { // means we can inline at least one level... - Some(mkBigLet(replace(substMap, fd.body.get))) - } else { // we can't do much for calls to recursive functions or to functions with no bodies - None - } - } - case o => None + case f @ FunctionInvocation(fd, args) if fd.hasImplementation && !program.isRecursive(fd) => Some(inlineFunctionCall(f)) + case _ => None } - val finalE = searchAndReplace(applyToCall)(expr) - val toReturn = pulloutLets(Implies(And(extras.reverse), finalE)) + var change: Boolean = true + var toReturn: Expr = expression + while(change) { + val (t,c) = searchAndReplaceDFSandTrackChanges(applyToCall)(toReturn) + change = c + toReturn = t + } toReturn } - // Warning: this should only be called on a top-level formula ! It will add - // new variables and implications in a way that preserves the validity of the - // formula. def unrollRecursiveFunctions(program: Program, expression: Expr, times: Int) : Expr = { - def unroll(exx: Expr) : (Expr,Seq[Expr]) = { - var extras : List[Expr] = Nil + def applyToCall(e: Expr) : Option[Expr] = e match { + case f @ FunctionInvocation(fd, args) if fd.hasImplementation && program.isRecursive(fd) => Some(inlineFunctionCall(f)) + case _ => None + } - def urf(expr: Expr, left: Int) : Expr = { - def unrollCall(t: Int)(e: Expr) : Option[Expr] = e match { - case f @ FunctionInvocation(fd, args) if fd.hasImplementation && program.isRecursive(fd) => { - val newLetIDs = fd.args.map(a => FreshIdentifier(a.id.name, true).setType(a.tpe)) - val newLetVars = newLetIDs.map(Variable(_)) - val substs: Map[Expr,Expr] = Map((fd.args.map(_.toVariable) zip newLetVars) :_*) - val bodyWithLetVars: Expr = replace(substs, fd.body.get) - if(fd.hasPostcondition) { - val post = fd.postcondition.get - val newVar = Variable(FreshIdentifier("call", true)).setType(fd.returnType) - val newExtra1 = Equals(newVar, bodyWithLetVars) - val newExtra2 = replace(substs + (ResultVariable() -> newVar), post) - val bigLet = (newLetIDs zip args).foldLeft(And(newExtra1, newExtra2))((e,p) => Let(p._1, p._2, e)) - extras = urf(bigLet, t-1) :: extras - // println("*********************************") - // println(bigLet) - // println(" --- from -----------------------") - // println(f) - // println(" --- newVar is ------------------") - // println(newVar) - // println("*********************************") - Some(newVar) - } else { - val bigLet = (newLetIDs zip args).foldLeft(bodyWithLetVars)((e,p) => Let(p._1, p._2, e)) - Some(urf(bigLet, t-1)) - } - } - case o => None - } + var remaining = if(times < 0) 0 else times + var change: Boolean = true + var toReturn: Expr = expression + while(remaining > 0 && change) { + val (t,c) = searchAndReplaceDFSandTrackChanges(applyToCall)(toReturn) + change = c + toReturn = inlineNonRecursiveFunctions(program, t) + remaining = remaining - 1 + } + toReturn + } - if(left > 0) - searchAndReplace(unrollCall(left), false)(expr) - else - expr + def inlineContracts(expression: Expr) : Expr = { + var trueThings: List[Expr] = Nil + + def applyToCall(e: Expr) : Option[Expr] = e match { + case f @ FunctionInvocation(fd, args) if fd.hasPostcondition => { + val argsAsLet = fd.args.map(a => FreshIdentifier("parg_" + a.id.name, true).setType(a.tpe)).toList + val argsAsLetVars = argsAsLet.map(Variable(_)) + val resultAsLet = FreshIdentifier("call_" + fd.id.name, true).setType(f.getType) + val newFunCall = FunctionInvocation(fd, argsAsLetVars) + val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip argsAsLetVars) : _*) + (ResultVariable() -> Variable(resultAsLet)) + // this thing is full of let variables! We will need to lift the let + // defs. later to make sure they capture this + val trueFact = replace(substMap, fd.postcondition.get) + val defList: Seq[(Identifier,Expr)] = ((argsAsLet :+ resultAsLet) zip (args :+ newFunCall)) + trueThings = trueFact :: trueThings + // again: these let defs. need eventually to capture the "true thing" + Some(defList.foldRight[Expr](Variable(resultAsLet))((iap, e) => Let(iap._1, iap._2, e))) } - val finalE = urf(exx, times) - (finalE, extras) + case _ => None } - - val (savedLets, naked) = pulloutAndKeepLets(expression) - val infoFromLets: Seq[(Expr,Seq[Expr])] = savedLets.map(_._2).map(unroll(_)) - val extrasFromLets: Seq[Expr] = infoFromLets.map(_._2).flatten - val newLetBodies: Seq[Expr] = infoFromLets.map(_._1) - val newSavedLets: Seq[(Identifier,Expr)] = savedLets.map(_._1) zip newLetBodies - val (cleaned, extras) = unroll(naked) - val toReturn = rebuildLets(newSavedLets, Implies(And(extrasFromLets ++ extras), cleaned)) - toReturn + val result = searchAndReplaceDFS(applyToCall)(expression) + liftLets(Implies(And(trueThings.reverse), result)) } // Rewrites pattern matching expressions where the cases simply correspond to // the list of constructors def rewriteSimplePatternMatching(expression: Expr) : Expr = { - def rspm(expr: Expr) : (Expr,Seq[Expr]) = { - var extras : List[Expr] = Nil + var extras : List[Expr] = Nil - def rewritePM(e: Expr) : Option[Expr] = e match { - // case NotSoSimplePatternMatching(_) => None - case SimplePatternMatching(scrutinee, classType, casesInfo) => Some({ - val newVar = Variable(FreshIdentifier("pm", true)).setType(e.getType) - val scrutAsLetID = FreshIdentifier("scrut", true).setType(scrutinee.getType) - val lle : List[(Variable,List[Expr])] = casesInfo.map(cseInfo => { - val (ccd, newPID, argIDs, rhs) = cseInfo - val newPVar = Variable(newPID) - val argVars = argIDs.map(Variable(_)) - val (rewrittenRHS, moreExtras) = rspm(rhs) - (newPVar, List(Equals(newPVar, CaseClass(ccd, argVars)), Implies(Equals(Variable(scrutAsLetID), newPVar), Equals(newVar, rewrittenRHS))) ::: moreExtras.toList) - }).toList - val (newPVars, newExtras) = lle.unzip - extras = Let(scrutAsLetID, scrutinee, And(Or(newPVars.map(Equals(Variable(scrutAsLetID), _))), And(newExtras.flatten))) :: extras - newVar - }) - case _ => None + def rewritePM(e: Expr) : Option[Expr] = e match { + // case NotSoSimplePatternMatching(_) => None + case SimplePatternMatching(scrutinee, classType, casesInfo) => { + val newVar = Variable(FreshIdentifier("pm", true)).setType(e.getType) + val scrutAsLetID = FreshIdentifier("scrut", true).setType(scrutinee.getType) + val lle : List[(Variable,List[Expr])] = casesInfo.map(cseInfo => { + val (ccd, newPID, argIDs, rhs) = cseInfo + val newPVar = Variable(newPID) + val argVars = argIDs.map(Variable(_)) + (newPVar, List(Equals(newPVar, CaseClass(ccd, argVars)), Implies(Equals(Variable(scrutAsLetID), newPVar), Equals(newVar, rhs)))) + }).toList + val (newPVars, newExtras) = lle.unzip + extras = Let(scrutAsLetID, scrutinee, And(Or(newPVars.map(Equals(Variable(scrutAsLetID), _))), And(newExtras.flatten))) :: extras + Some(newVar) } - - val cleanerTree = searchAndReplace(rewritePM)(expr) - (cleanerTree, extras.reverse) + case m @ MatchExpr(s,_) => Settings.reporter.error("Untranslatable PM expression on type " + s.getType + " : " + m); None + case _ => None } - val (savedLets, naked) = pulloutAndKeepLets(expression) - val infoFromLets: Seq[(Expr,Seq[Expr])] = savedLets.map(_._2).map(rspm(_)) - val extrasFromLets: Seq[Expr] = infoFromLets.map(_._2).flatten - val newLetBodies: Seq[Expr] = infoFromLets.map(_._1) - val newSavedLets: Seq[(Identifier,Expr)] = savedLets.map(_._1) zip newLetBodies - val (cleaned, extras) = rspm(naked) - val toReturn = rebuildLets(newSavedLets, Implies(And(extrasFromLets ++ extras), cleaned)) - toReturn + + val newExpr = searchAndReplaceDFS(rewritePM)(expression) + liftLets(Implies(And(extras), newExpr)) } } diff --git a/src/purescala/Common.scala b/src/purescala/Common.scala index bd99628e711c6d76acec44b8cbed5a9e312aa5cc..7e326ca3c865a17ac0646e3aa9468cb8ddb0fa16 100644 --- a/src/purescala/Common.scala +++ b/src/purescala/Common.scala @@ -26,6 +26,10 @@ object Common { } def uniqueName : String = name + id + + private var _islb: Boolean = false + def markAsLetBinder : Identifier = { _islb = true; this } + def isLetBinder : Boolean = _islb } private object UniqueCounter { diff --git a/src/purescala/Definitions.scala b/src/purescala/Definitions.scala index 90d418bb7604291ea0b235c55bb427805821d50d..67fe9fd5b185891e7598a40f4888f1a720a8c4cb 100644 --- a/src/purescala/Definitions.scala +++ b/src/purescala/Definitions.scala @@ -48,19 +48,20 @@ object Definitions { lazy val classHierarchyRoots : Seq[ClassTypeDef] = defs.filter(_.isInstanceOf[ClassTypeDef]).map(_.asInstanceOf[ClassTypeDef]).filter(!_.hasParent) lazy val (callGraph, callers, callees) = { - var resSet: Set[(FunDef,FunDef)] = - new scala.collection.immutable.HashSet[(FunDef,FunDef)]() + type CallGraph = Set[(FunDef,FunDef)] - def applyToFunCall(f1: FunDef)(e: Expr) : Option[Expr] = e match { - case f @ FunctionInvocation(f2, _) => { resSet = resSet + ((f1,f2)); Some(f) } - case _ => None + val convert: Expr=>CallGraph = (_ => Set.empty) + val combine: (CallGraph,CallGraph)=>CallGraph = (s1,s2) => s1 ++ s2 + def compute(fd: FunDef)(e: Expr, g: CallGraph) : CallGraph = e match { + case f @ FunctionInvocation(f2, _) => g + ((fd, f2)) + case _ => g } - for(funDef <- definedFunctions) { - funDef.precondition.map(searchAndReplace(applyToFunCall(funDef))(_)) - funDef.body.map(searchAndReplace(applyToFunCall(funDef))(_)) - funDef.postcondition.map(searchAndReplace(applyToFunCall(funDef))(_)) - } + val resSet: CallGraph = (for(funDef <- definedFunctions) yield { + funDef.precondition.map(treeCatamorphism[CallGraph](convert, combine, compute(funDef)_, _)).getOrElse(Set.empty) ++ + funDef.body.map(treeCatamorphism[CallGraph](convert, combine, compute(funDef)_, _)).getOrElse(Set.empty) ++ + funDef.postcondition.map(treeCatamorphism[CallGraph](convert, combine, compute(funDef)_, _)).getOrElse(Set.empty) + }).reduceLeft(_ ++ _) var callers: Map[FunDef,Set[FunDef]] = new scala.collection.immutable.HashMap[FunDef,Set[FunDef]] diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index fbe596910f2977c7c2906cc89b393e6cb2aba6e3..2b0d5413032c4d50f740e3a63693eeb1758b0b71 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -16,6 +16,7 @@ object Trees { /* Like vals */ case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr { + binder.markAsLetBinder val et = body.getType if(et != NoType) setType(et) @@ -26,7 +27,23 @@ object Trees { val fixedType = funDef.returnType } case class IfExpr(cond: Expr, then: Expr, elze: Expr) extends Expr - case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr { + + object MatchExpr { + def apply(scrutinee: Expr, cases: Seq[MatchCase]) : MatchExpr = { + scrutinee.getType match { + case a: AbstractClassType => new MatchExpr(scrutinee, cases) + case c: CaseClassType => new MatchExpr(scrutinee, cases.filter(_.pattern match { + case CaseClassPattern(_, ccd, _) if ccd != c.classDef => false + case _ => true + })) + case _ => scala.Predef.error("Constructing match expression on non-class type.") + } + } + + def unapply(me: MatchExpr) : Option[(Expr,Seq[MatchCase])] = if (me == null) None else Some((me.scrutinee, me.cases)) + } + + class MatchExpr(val scrutinee: Expr, val cases: Seq[MatchCase]) extends Expr { def scrutineeClassType: ClassType = scrutinee.getType.asInstanceOf[ClassType] } @@ -35,13 +52,16 @@ object Trees { val rhs: Expr val theGuard: Option[Expr] def hasGuard = theGuard.isDefined + def expressions: Seq[Expr] } case class SimpleCase(pattern: Pattern, rhs: Expr) extends MatchCase { val theGuard = None + def expressions = List(rhs) } case class GuardedCase(pattern: Pattern, guard: Expr, rhs: Expr) extends MatchCase { val theGuard = Some(guard) + def expressions = List(guard, rhs) } sealed abstract class Pattern @@ -338,14 +358,14 @@ object Trees { // Warning ! This may loop forever if the substitutions are not // well-formed! def replace(substs: Map[Expr,Expr], expr: Expr) : Expr = { - searchAndReplace(substs.get)(expr) + searchAndReplaceDFS(substs.get)(expr) } def searchAndReplace(subst: Expr=>Option[Expr], recursive: Boolean=true)(expr: Expr) : Expr = { def rec(ex: Expr, skip: Expr = null) : Expr = (if (ex == skip) None else subst(ex)) match { case Some(newExpr) => { if(newExpr.getType == NoType) { - Settings.reporter.warning("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) + Settings.reporter.error("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) } if(ex == newExpr) if(recursive) rec(ex, ex) else ex @@ -415,29 +435,153 @@ object Trees { rec(expr) } - def variablesOf(expr: Expr) : Set[Identifier] = { - def rec(ex: Expr, lets: Set[Identifier]) : Set[Identifier] = ex match { + def searchAndReplaceDFS(subst: Expr=>Option[Expr])(expr: Expr) : Expr = { + val (res,_) = searchAndReplaceDFSandTrackChanges(subst)(expr) + res + } + + def searchAndReplaceDFSandTrackChanges(subst: Expr=>Option[Expr])(expr: Expr) : (Expr,Boolean) = { + var somethingChanged: Boolean = false + def applySubst(ex: Expr) : Expr = subst(ex) match { + case None => ex + case Some(newEx) => { + somethingChanged = true + if(newEx.getType == NoType) { + Settings.reporter.warning("REPLACING WITH AN UNTYPED EXPRESSION !") + } + newEx + } + } + + def rec(ex: Expr) : Expr = ex match { case l @ Let(i,e,b) => { - val newLets = lets + i - rec(e, newLets) ++ rec(b, newLets) + val re = rec(e) + val rb = rec(b) + applySubst(if(re != e || rb != b) { + Let(i, re, rb).setType(l.getType) + } else { + l + }) } - case Variable(i) => if(lets(i)) Set.empty[Identifier] else Set(i) - case n @ NAryOperator(args, _) => if(args.isEmpty) Set.empty[Identifier] else args.map(rec(_, lets)).reduceLeft(_ ++ _) - case b @ BinaryOperator(t1,t2,_) => rec(t1,lets) ++ rec(t2,lets) - case u @ UnaryOperator(t,_) => rec(t, lets) - case i @ IfExpr(t1,t2,t3) => rec(t1,lets) ++ rec(t2,lets) ++ rec(t3,lets) - case m @ MatchExpr(scrut,cses) => rec(scrut, lets) ++ cses.map(inCase(_, lets)).reduceLeft(_ ++ _) - case t if t.isInstanceOf[Terminal] => Set.empty[Identifier] - case unhandled => scala.Predef.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) + case n @ NAryOperator(args, recons) => { + var change = false + val rargs = args.map(a => { + val ra = rec(a) + if(ra != a) { + change = true + ra + } else { + a + } + }) + applySubst(if(change) { + recons(rargs).setType(n.getType) + } else { + n + }) + } + case b @ BinaryOperator(t1,t2,recons) => { + val r1 = rec(t1) + val r2 = rec(t2) + applySubst(if(r1 != t1 || r2 != t2) { + recons(r1,r2).setType(b.getType) + } else { + b + }) + } + case u @ UnaryOperator(t,recons) => { + val r = rec(t) + applySubst(if(r != t) { + recons(r).setType(u.getType) + } else { + u + }) + } + case i @ IfExpr(t1,t2,t3) => { + val r1 = rec(t1) + val r2 = rec(t2) + val r3 = rec(t3) + applySubst(if(r1 != t1 || r2 != t2 || r3 != t3) { + IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType) + } else { + i + }) + } + case m @ MatchExpr(scrut,cses) => { + val rscrut = rec(scrut) + val (newCses,changes) = cses.map(inCase(_)).unzip + applySubst(if(rscrut != scrut || changes.exists(res=>res)) { + MatchExpr(rscrut, newCses).setType(m.getType) + } else { + m + }) + } + case t if t.isInstanceOf[Terminal] => applySubst(t) + case unhandled => scala.Predef.error("Non-terminal case should be handled in searchAndReplaceDFS: " + unhandled) } - // note that the identifiers in the patterns are not included if they don't show up on the rhs. - def inCase(cse: MatchCase, lets: Set[Identifier]) : Set[Identifier] = cse match { - case SimpleCase(pat, rhs) => rec(rhs, lets) - case GuardedCase(pat, guard, rhs) => rec(guard, lets) ++ rec(rhs, lets) + def inCase(cse: MatchCase) : (MatchCase,Boolean) = cse match { + case s @ SimpleCase(pat, rhs) => { + val rrhs = rec(rhs) + if(rrhs != rhs) { + (SimpleCase(pat, rrhs), true) + } else { + (s, false) + } + } + case g @ GuardedCase(pat, guard, rhs) => { + val rguard = rec(guard) + val rrhs = rec(rhs) + if(rguard != guard || rrhs != rhs) { + (GuardedCase(pat, rguard, rrhs), true) + } else { + (g, false) + } + } + } + + val res = rec(expr) + (res, somethingChanged) + } + + // convert describes how to compute a value for the leaves (that includes + // functions with no args.) + // combine descriess how to combine two values + def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, expression: Expr) : A = { + treeCatamorphism(convert, combine, (e:Expr,a:A)=>a, expression) + } + // compute allows the catamorphism to change the combined value depending on the tree + def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, compute: (Expr,A)=>A, expression: Expr) : A = { + def rec(expr: Expr) : A = expr match { + case l @ Let(_, e, b) => compute(l, combine(rec(e), rec(b))) + case n @ NAryOperator(args, _) => { + if(args.size == 0) + compute(n, convert(n)) + else + compute(n, args.map(rec(_)).reduceLeft(combine)) + } + case b @ BinaryOperator(a1,a2,_) => compute(b, combine(rec(a1),rec(a2))) + case u @ UnaryOperator(a,_) => compute(u, rec(a)) + case i @ IfExpr(a1,a2,a3) => compute(i, combine(combine(rec(a1), rec(a2)), rec(a3))) + case m @ MatchExpr(scrut, cses) => compute(m, (scrut +: cses.flatMap(_.expressions)).map(rec(_)).reduceLeft(combine)) + case t: Terminal => compute(t, convert(t)) + case unhandled => scala.Predef.error("Non-terminal case should be handled in treeCatamorphism: " + unhandled) } - rec(expr, Set.empty) + rec(expression) + } + + def variablesOf(expr: Expr) : Set[Identifier] = { + def convert(t: Expr) : Set[Identifier] = t match { + case Variable(i) => Set(i) + case _ => Set.empty + } + def combine(s1: Set[Identifier], s2: Set[Identifier]) = s1 ++ s2 + def compute(t: Expr, s: Set[Identifier]) = t match { + case Let(i,_,_) => s -- Set(i) + case _ => s + } + treeCatamorphism(convert, combine, compute, expr) } /* Simplifies let expressions: @@ -448,15 +592,12 @@ object Trees { */ def simplifyLets(expr: Expr) : Expr = { def simplerLet(t: Expr) : Option[Expr] = t match { - case letExpr @ Let(i, Variable(v), b) => Some(replace(Map((Variable(i) -> Variable(v))), b)) - case letExpr @ Let(i, l: Literal[_], b) => Some(replace(Map((Variable(i) -> l)), b)) + case letExpr @ Let(i, t: Terminal, b) => Some(replace(Map((Variable(i) -> t)), b)) case letExpr @ Let(i,e,b) => { - var occurences = 0 - def incCount(tr: Expr) = tr match { - case Variable(x) if x == i => { occurences = occurences + 1; None } - case _ => None - } - searchAndReplace(incCount, false)(b) + val occurences = treeCatamorphism[Int]((e:Expr) => e match { + case Variable(x) if x == i => 1 + case _ => 0 + }, (x:Int,y:Int)=>x+y, b) if(occurences == 0) { Some(b) } else if(occurences == 1) { @@ -470,30 +611,72 @@ object Trees { searchAndReplace(simplerLet)(expr) } - /* Rewrites the expression so that all lets are at the top levels. */ - def pulloutLets(expr: Expr) : Expr = { - val (storedLets, noLets) = pulloutAndKeepLets(expr) - rebuildLets(storedLets, noLets) - } + // Pulls out all let constructs to the top level, and makes sure they're + // properly ordered. + private type DefPair = (Identifier,Expr) + private type DefPairs = List[DefPair] + private def allLetDefinitions(expr: Expr) : DefPairs = treeCatamorphism[DefPairs]( + (e: Expr) => Nil, + (s1: DefPairs, s2: DefPairs) => s1 ::: s2, + (e: Expr, dps: DefPairs) => e match { + case Let(i, e, _) => (i,e) :: dps + case _ => dps + }, + expr) - // new code (keep this if nested lets can appear in the value part, too) - def pulloutAndKeepLets(expr: Expr) : (List[(Identifier,Expr)], Expr) = { - var storedLets: List[(Identifier,Expr)] = Nil - - def storeLet(t: Expr) : Option[Expr] = t match { - case l @ Let(i, e, b) => - val (stored, value) = pulloutAndKeepLets(e) - storedLets :::= stored - storedLets ::= i -> value - Some(b) - case _ => None + private def killAllLets(expr: Expr) : Expr = searchAndReplaceDFS((e: Expr) => e match { + case Let(i,_,ex) => Some(ex) + case _ => None + })(expr) + + def liftLets(expr: Expr) : Expr = { + val initialDefinitionPairs = allLetDefinitions(expr) + val definitionPairs = initialDefinitionPairs.map(p => (p._1, killAllLets(p._2))) + val occursLists : Map[Identifier,Set[Identifier]] = Map(definitionPairs.map((dp: DefPair) => (dp._1 -> variablesOf(dp._2).toSet.filter(_.isLetBinder))) : _*) + var newList : DefPairs = Nil + var placed : Set[Identifier] = Set.empty + val toPlace = definitionPairs.size + var placedC = 0 + var traversals = 0 + + while(placedC < toPlace) { + if(traversals > toPlace + 1) { + scala.Predef.error("Cycle in let definitions or multiple definition for the same identifier in liftLets : " + definitionPairs.mkString("\n")) + } + for((id,ex) <- definitionPairs) if (!placed(id)) { + if((occursLists(id) -- placed) == Set.empty) { + placed = placed + id + newList = (id,ex) :: newList + placedC = placedC + 1 + } + } + traversals = traversals + 1 } - val noLets = searchAndReplace(storeLet)(expr) - (storedLets, noLets) + + val noLets = killAllLets(expr) + + val res = (newList.foldLeft(noLets)((e,iap) => Let(iap._1, iap._2, e))) + simplifyLets(res) } - def rebuildLets(lets: Seq[(Identifier,Expr)], expr: Expr) : Expr = { - lets.foldLeft(expr)((e,p) => Let(p._1, p._2, e)) + def wellOrderedLets(tree : Expr) : Boolean = { + val pairs = allLetDefinitions(tree) + val definitions: Set[Identifier] = Set(pairs.map(_._1) : _*) + val vars: Set[Identifier] = variablesOf(tree) + val intersection = vars intersect definitions + if(!intersection.isEmpty) { + intersection.foreach(id => { + Settings.reporter.error("Variable with identifier '" + id + "' has escaped its let-definition !") + }) + false + } else { + vars.forall(id => if(id.isLetBinder) { + Settings.reporter.error("Variable with identifier '" + id + "' has lost its let-definition (it disappeared??)") + false + } else { + true + }) + } } /* Fully expands all let expressions. */ @@ -666,18 +849,4 @@ object Trees { }) } - // we use this when debugging our tree transformations... - def wellOrderedLets(tree : Expr) : Boolean = { - val (pairs, _) = pulloutAndKeepLets(tree) - val definitions: Set[Identifier] = Set(pairs.map(_._1) : _*) - val vars: Set[Identifier] = variablesOf(tree) - val intersection = vars intersect definitions - if(intersection.isEmpty) true - else { - intersection.foreach(id => { - Settings.reporter.error("Variable with identifier '" + id + "' has escaped its let-definition !") - }) - false - } - } } diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index c828fcd5ebb364d309b5b69f6bf478b0bd46f304..114e46b0da8140c0ef1179f544f07327d0127122 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -282,6 +282,8 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { private def toZ3Formula(z3: Z3Context, expr: Expr, initialMap: Map[String,Z3AST] = Map.empty) : Option[Z3AST] = { class CantTranslateException extends Exception + val varsInformula: Set[Identifier] = variablesOf(expr) + var z3Vars: Map[String,Z3AST] = initialMap def rec(ex: Expr) : Z3AST = ex match { @@ -295,6 +297,9 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { case v @ Variable(id) => z3Vars.get(id.uniqueName) match { case Some(ast) => ast case None => { + if(id.isLetBinder) { + scala.Predef.error("Error in formula being translated to Z3: identifier " + id + " seems to have escaped its let-definition") + } val newAST = z3.mkFreshConst(id.name, typeToSort(v.getType)) z3Vars = z3Vars + (id.uniqueName -> newAST) newAST @@ -346,7 +351,11 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { } try { - Some(rec(expr)) + val res = Some(rec(expr)) + val usedInZ3Form = z3Vars.keys.toSet + println("Variables in formula: " + varsInformula.map(_.uniqueName)) + println("Variables passed to Z3: " + usedInZ3Form) + res } catch { case e: CantTranslateException => None }