diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index 11b4389392c89b994e118b7b25c30378e4aeff4d..d8e930b71d9894fd9c1dbcaea3872ed7a710fe2b 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -140,6 +140,7 @@ class Analysis(val program: Program) { val expr2 = rewriteSimplePatternMatching(expr1) // reporter.info("After PM-rewriting:") // reporter.info(expandLets(expr2)) + assert(wellOrderedLets(expr2)) expr2 } } @@ -166,17 +167,6 @@ object Analysis { val substMap = Map[Expr,Expr]((fArgsAsVars zip fParamsAsLetVarVars) : _*) if(fd.hasPostcondition) { val newVar = Variable(FreshIdentifier("call", true)).setType(fd.returnType) - /* START CHANGE */ - // Code before - /* - extras = mkBigLet(And( - replace(substMap + (ResultVariable() -> newVar), fd.postcondition.get), - Equals(newVar, FunctionInvocation(fd, fParamsAsLetVarVars).setType(fd.returnType)) - )) :: extras - Some(newVar) - */ - - // Fixed code ?!? extras = And( replace(substMap + (ResultVariable() -> newVar), fd.postcondition.get), Equals(newVar, FunctionInvocation(fd, fParamsAsLetVarVars).setType(fd.returnType)) @@ -193,7 +183,8 @@ object Analysis { } val finalE = searchAndReplace(applyToCall)(expr) - pulloutLets(Implies(And(extras.reverse), finalE)) + val toReturn = pulloutLets(Implies(And(extras.reverse), finalE)) + toReturn } // Warning: this should only be called on a top-level formula ! It will add @@ -248,7 +239,8 @@ object Analysis { val newLetBodies: Seq[Expr] = infoFromLets.map(_._1) val newSavedLets: Seq[(Identifier,Expr)] = savedLets.map(_._1) zip newLetBodies val (cleaned, extras) = unroll(naked) - rebuildLets(newSavedLets, Implies(And(extrasFromLets ++ extras), cleaned)) + val toReturn = rebuildLets(newSavedLets, Implies(And(extrasFromLets ++ extras), cleaned)) + toReturn } // Rewrites pattern matching expressions where the cases simply correspond to @@ -285,7 +277,7 @@ object Analysis { val newLetBodies: Seq[Expr] = infoFromLets.map(_._1) val newSavedLets: Seq[(Identifier,Expr)] = savedLets.map(_._1) zip newLetBodies val (cleaned, extras) = rspm(naked) - rebuildLets(newSavedLets, Implies(And(extrasFromLets ++ extras), cleaned)) + val toReturn = rebuildLets(newSavedLets, Implies(And(extrasFromLets ++ extras), cleaned)) + toReturn } - } diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index d4354853fcb9be999185349b31e4cfa26138b366..a4405a67cb926b4eb2a75c15d62c0b0f8049f6af 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -415,6 +415,31 @@ object Trees { rec(expr) } + def variablesOf(expr: Expr) : Set[Identifier] = { + def rec(ex: Expr, lets: Set[Identifier]) : Set[Identifier] = ex match { + case l @ Let(i,e,b) => { + val newLets = lets + i + rec(e, newLets) ++ rec(b, newLets) + } + case Variable(i) => if(lets(i)) Set.empty[Identifier] else Set(i) + case n @ NAryOperator(args, _) => if(args.isEmpty) Set.empty[Identifier] else args.map(rec(_, lets)).reduceLeft(_ ++ _) + case b @ BinaryOperator(t1,t2,_) => rec(t1,lets) ++ rec(t2,lets) + case u @ UnaryOperator(t,_) => rec(t, lets) + case i @ IfExpr(t1,t2,t3) => rec(t1,lets) ++ rec(t2,lets) ++ rec(t3,lets) + case m @ MatchExpr(scrut,cses) => rec(scrut, lets) ++ cses.map(inCase(_, lets)).reduceLeft(_ ++ _) + case t if t.isInstanceOf[Terminal] => Set.empty[Identifier] + case unhandled => scala.Predef.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) + } + + // note that the identifiers in the patterns are not included if they don't show up on the rhs. + def inCase(cse: MatchCase, lets: Set[Identifier]) : Set[Identifier] = cse match { + case SimpleCase(pat, rhs) => rec(rhs, lets) + case GuardedCase(pat, guard, rhs) => rec(guard, lets) ++ rec(rhs, lets) + } + + rec(expr, Set.empty) + } + /* Simplifies let expressions: * - removes lets when expression never occurs * - simplifies when expressions occurs exactly once @@ -450,27 +475,6 @@ object Trees { val (storedLets, noLets) = pulloutAndKeepLets(expr) rebuildLets(storedLets, noLets) } - - /* START CHANGE */ - // Previous code (keep this if nested lets can only appear in the body) - /* - def pulloutAndKeepLets(expr: Expr) : (Seq[(Identifier,Expr)], Expr) = { - var storedLets: List[(Identifier,Expr)] = Nil - - def storeLet(t: Expr) : Option[Expr] = t match { - case l @ Let(i, e, b) => (storedLets = ((i,e)) :: storedLets); None - case _ => None - } - def killLet(t: Expr) : Option[Expr] = t match { - case l @ Let(i, e, b) => Some(b) - case _ => None - } - - searchAndReplace(storeLet)(expr) - val noLets = searchAndReplace(killLet)(expr) - (storedLets, noLets) - } - */ // new code (keep this if nested lets can appear in the value part, too) def pulloutAndKeepLets(expr: Expr) : (List[(Identifier,Expr)], Expr) = { @@ -478,11 +482,6 @@ object Trees { def storeLet(t: Expr) : Option[Expr] = t match { case l @ Let(i, e, b) => - // Easy fix, but breaks define-before-use order !! - //val noLets = searchAndReplace(storeLet)(e) - //storedLets ::= i -> noLets - - // Better fix, but please check val (stored, value) = pulloutAndKeepLets(e) storedLets :::= stored storedLets ::= i -> value @@ -492,7 +491,6 @@ object Trees { val noLets = searchAndReplace(storeLet)(expr) (storedLets, noLets) } - /* END CHANGE */ def rebuildLets(lets: Seq[(Identifier,Expr)], expr: Expr) : Expr = { lets.foldLeft(expr)((e,p) => Let(p._1, p._2, e)) @@ -667,4 +665,19 @@ object Trees { case _ => None } } + + // we use this when debugging our tree transformations... + def wellOrderedLets(tree : Expr) : Boolean = { + val (pairs, _) = pulloutAndKeepLets(tree) + val definitions: Set[Identifier] = Set(pairs.map(_._1) : _*) + val vars: Set[Identifier] = variablesOf(tree) + val intersection = vars intersect definitions + if(intersection.isEmpty) true + else { + intersection.foreach(id => { + Settings.reporter.error("Variable with identifier '" + id + "' has escaped its let-definition !") + }) + false + } + } }