diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala index bd8dc3252ec37aba3d3102564bcea482652288c7..0533b568d8cde5e844e44261471d0d044c6be39b 100644 --- a/src/purescala/Analysis.scala +++ b/src/purescala/Analysis.scala @@ -141,8 +141,7 @@ object Analysis { def inlineFunctionsAndContracts(program: Program, expr: Expr) : Expr = { var extras : List[Expr] = Nil - val isFunCall: Function[Expr,Boolean] = _.isInstanceOf[FunctionInvocation] - def applyToCall(e: Expr) : Expr = e match { + def applyToCall(e: Expr) : Option[Expr] = e match { case f @ FunctionInvocation(fd, args) => { val fArgsAsVars: List[Variable] = fd.args.map(_.toVariable).toList val fParamsAsLetVars: List[Identifier] = fd.args.map(a => FreshIdentifier("arg", true).setType(a.tpe)).toList @@ -159,17 +158,17 @@ object Analysis { replace(substMap + (ResultVariable() -> newVar), fd.postcondition.get), Equals(newVar, FunctionInvocation(fd, fParamsAsLetVarVars).setType(fd.returnType)) )) :: extras - newVar + Some(newVar) } else if(fd.hasImplementation && !program.isRecursive(fd)) { // means we can inline at least one level... - mkBigLet(replace(substMap, fd.body.get)) + Some(mkBigLet(replace(substMap, fd.body.get))) } else { // we can't do much for calls to recursive functions or to functions with no bodies - f + None } } - case o => o + case o => None } - val finalE = searchAndApply(isFunCall, applyToCall, expr) + val finalE = searchAndReplace(applyToCall)(expr) pulloutLets(Implies(And(extras.reverse), finalE)) } @@ -181,11 +180,7 @@ object Analysis { var extras : List[Expr] = Nil def urf(expr: Expr, left: Int) : Expr = { - def isRecursiveCall(e: Expr) = e match { - case f @ FunctionInvocation(fd, _) if fd.hasImplementation && program.isRecursive(fd) => true - case _ => false - } - def unrollCall(t: Int)(e: Expr) = e match { + def unrollCall(t: Int)(e: Expr) : Option[Expr] = e match { case f @ FunctionInvocation(fd, args) if fd.hasImplementation && program.isRecursive(fd) => { val newLetIDs = fd.args.map(a => FreshIdentifier(a.id.name, true).setType(a.tpe)) val newLetVars = newLetIDs.map(Variable(_)) @@ -205,17 +200,17 @@ object Analysis { // println(" --- newVar is ------------------") // println(newVar) // println("*********************************") - newVar + Some(newVar) } else { val bigLet = (newLetIDs zip args).foldLeft(bodyWithLetVars)((e,p) => Let(p._1, p._2, e)) - urf(bigLet, t-1) + Some(urf(bigLet, t-1)) } } - case o => o + case o => None } if(left > 0) - searchAndApply(isRecursiveCall, unrollCall(left), expr, false) + searchAndReplace(unrollCall(left), false)(expr) else expr } @@ -238,12 +233,8 @@ object Analysis { def rspm(expr: Expr) : (Expr,Seq[Expr]) = { var extras : List[Expr] = Nil - def isPMExpr(e: Expr) : Boolean = { - e.isInstanceOf[MatchExpr] - } - - def rewritePM(e: Expr) : Expr = e.asInstanceOf[MatchExpr] match { - case SimplePatternMatching(scrutinee, classType, casesInfo) => { + def rewritePM(e: Expr) : Option[Expr] = e match { + case SimplePatternMatching(scrutinee, classType, casesInfo) => Some({ val newVar = Variable(FreshIdentifier("pm", true)).setType(e.getType) val scrutAsLetID = FreshIdentifier("scrut", true).setType(scrutinee.getType) val lle : List[(Variable,List[Expr])] = casesInfo.map(cseInfo => { @@ -256,11 +247,11 @@ object Analysis { val (newPVars, newExtras) = lle.unzip extras = Let(scrutAsLetID, scrutinee, And(Or(newPVars.map(Equals(Variable(scrutAsLetID), _))), And(newExtras.flatten))) :: extras newVar - } - case _ => e + }) + case _ => None } - val cleanerTree = searchAndApply(isPMExpr, rewritePM, expr) + val cleanerTree = searchAndReplace(rewritePM)(expr) (cleanerTree, extras.reverse) } val (savedLets, naked) = pulloutAndKeepLets(expression) diff --git a/src/purescala/Definitions.scala b/src/purescala/Definitions.scala index 32b2d23d5310813ebac43b8cfa275afd299cb6ef..90d418bb7604291ea0b235c55bb427805821d50d 100644 --- a/src/purescala/Definitions.scala +++ b/src/purescala/Definitions.scala @@ -51,16 +51,15 @@ object Definitions { var resSet: Set[(FunDef,FunDef)] = new scala.collection.immutable.HashSet[(FunDef,FunDef)]() - def isFunCall(e: Expr) : Boolean = e.isInstanceOf[FunctionInvocation] - def applyToFunCall(f1: FunDef)(e: Expr) : Expr = e match { - case f @ FunctionInvocation(f2, _) => { resSet = resSet + ((f1,f2)); f } - case o => o + def applyToFunCall(f1: FunDef)(e: Expr) : Option[Expr] = e match { + case f @ FunctionInvocation(f2, _) => { resSet = resSet + ((f1,f2)); Some(f) } + case _ => None } for(funDef <- definedFunctions) { - funDef.precondition.map(searchAndApply(isFunCall, applyToFunCall(funDef), _)) - funDef.body.map(searchAndApply(isFunCall, applyToFunCall(funDef), _)) - funDef.postcondition.map(searchAndApply(isFunCall, applyToFunCall(funDef), _)) + funDef.precondition.map(searchAndReplace(applyToFunCall(funDef))(_)) + funDef.body.map(searchAndReplace(applyToFunCall(funDef))(_)) + funDef.postcondition.map(searchAndReplace(applyToFunCall(funDef))(_)) } var callers: Map[FunDef,Set[FunDef]] = diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index 91d0c15dea293ca881cc419bdfeeec500360bcb7..96418d14c5822672daec8205cf2a1765e427ace3 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -400,6 +400,80 @@ object Trees { rec(expr) } + def searchAndReplace(subst: Expr=>Option[Expr], recursive: Boolean=true)(expr: Expr) : Expr = { + def rec(ex: Expr, skip: Expr = null) : Expr = (if (ex == skip) None else subst(ex)) match { + case Some(newExpr) => { + if(newExpr.getType == NoType) { + Settings.reporter.warning("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) + } + if(ex == newExpr) + if(recursive) rec(ex, ex) else ex + else + if(recursive) rec(newExpr) else newExpr + } + case None => ex match { + case l @ Let(i,e,b) => { + val re = rec(e) + val rb = rec(b) + if(re != e || rb != b) + Let(i, re, rb).setType(l.getType) + else + l + } + case n @ NAryOperator(args, recons) => { + var change = false + val rargs = args.map(a => { + val ra = rec(a) + if(ra != a) { + change = true + ra + } else { + a + } + }) + if(change) + recons(rargs).setType(n.getType) + else + n + } + case b @ BinaryOperator(t1,t2,recons) => { + val r1 = rec(t1) + val r2 = rec(t2) + if(r1 != t1 || r2 != t2) + recons(r1,r2).setType(b.getType) + else + b + } + case u @ UnaryOperator(t,recons) => { + val r = rec(t) + if(r != t) + recons(r).setType(u.getType) + else + u + } + case i @ IfExpr(t1,t2,t3) => { + val r1 = rec(t1) + val r2 = rec(t2) + val r3 = rec(t3) + if(r1 != t1 || r2 != t2 || r3 != t3) + IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType) + else + i + } + case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType) + case t if t.isInstanceOf[Terminal] => t + case unhandled => scala.Predef.error("Non-terminal case should be handled in searchAndApply: " + unhandled) + } + } + + def inCase(cse: MatchCase) : MatchCase = cse match { + case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs)) + case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard), rec(rhs)) + } + + rec(expr) + } + /* Simplifies let expressions: * - removes lets when expression never occurs * - simplifies when expressions occurs exactly once