diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index d099ed2b5dee76285f75ecfc8a3e3e46899894fb..7265fb1796a8b07095006c77372b315c5bec7d76 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -219,7 +219,7 @@ object ExprOps extends GenTreeOps[Expr] { case CaseClassSelector(cct, cc: CaseClass, id) => Some(caseClassSelector(cct, cc, id)) - case IfExpr(c, thenn, elze) if (thenn == elze) && isDeterministic(e) => + case IfExpr(c, thenn, elze) if (thenn == elze) && !evalOrderSensitive(c) => Some(thenn) case IfExpr(c, BooleanLiteral(true), BooleanLiteral(false)) => @@ -336,12 +336,12 @@ object ExprOps extends GenTreeOps[Expr] { def simplifyLets(expr: Expr) : Expr = { def simplerLet(t: Expr) : Option[Expr] = t match { - case letExpr @ Let(i, t: Terminal, b) if isDeterministic(b) => + case letExpr @ Let(i, t: Terminal, b) if !evalOrderSensitive(t) => Some(replaceFromIDs(Map(i -> t), b)) - case letExpr @ Let(i,e,b) if isDeterministic(b) => { + case letExpr @ Let(i,e,b) if !evalOrderSensitive(e) => val occurrences = count { - case Variable(x) if x == i => 1 + case Variable(`i`) => 1 case _ => 0 }(b) @@ -352,79 +352,33 @@ object ExprOps extends GenTreeOps[Expr] { } else { None } - } - - case letTuple @ LetTuple(ids, Tuple(exprs), body) if isDeterministic(body) => - var newBody = body - - val (remIds, remExprs) = (ids zip exprs).filter { - case (id, value: Terminal) => - newBody = replaceFromIDs(Map(id -> value), newBody) - //we replace, so we drop old - false - case (id, value) => - val occurences = count { - case Variable(x) if x == id => 1 - case _ => 0 - }(body) - - if(occurences == 0) { - false - } else if(occurences == 1) { - newBody = replace(Map(Variable(id) -> value), newBody) - false - } else { - true - } - }.unzip - Some(Constructors.letTuple(remIds, tupleWrap(remExprs), newBody)) + case LetPattern(patt, e0, body) if !evalOrderSensitive(e0) => + // Will turn the match-expression with a single case into a list of lets. - case l @ LetTuple(ids, tExpr: Terminal, body) if isDeterministic(body) => - val substMap : Map[Expr,Expr] = ids.map(Variable(_) : Expr).zipWithIndex.toMap.map { - case (v,i) => v -> tupleSelect(tExpr, i + 1, true).copiedFrom(v) + // Just extra safety... + val e = (e0.getType, patt) match { + case (_:AbstractClassType, CaseClassPattern(_, cct, _)) => + asInstOf(e0, cct) + case (at: AbstractClassType, InstanceOfPattern(_, ct)) if at != ct => + asInstOf(e0, ct) + case _ => + e0 } - Some(replace(substMap, body)) - - case l @ LetTuple(ids, tExpr, body) if isDeterministic(body) => - val arity = ids.size - val zeroVec = Seq.fill(arity)(0) - val idMap = ids.zipWithIndex.toMap.mapValues(i => zeroVec.updated(i, 1)) - - // A map containing vectors of the form (0, ..., 1, ..., 0) where - // the one corresponds to the index of the identifier in the - // LetTuple. The idea is that we can sum such vectors up to compute - // the occurences of all variables in one traversal of the - // expression. - - val occurences : Seq[Int] = fold[Seq[Int]]({ case (e, subs) => - e match { - case Variable(x) => idMap.getOrElse(x, zeroVec) - case _ => subs.foldLeft(zeroVec) { case (a1, a2) => - (a1 zip a2).map(p => p._1 + p._2) - } - } - })(body) - - val total = occurences.sum - - if(total == 0) { - Some(body) - } else if(total == 1) { - val substMap : Map[Expr,Expr] = ids.map(Variable(_) : Expr).zipWithIndex.toMap.map { - case (v,i) => v -> tupleSelect(tExpr, i + 1, ids.size).copiedFrom(v) - } - - Some(replace(substMap, body)) - } else { - None + // Sort lets in dependency order + val lets = mapForPattern(e, patt).toList.sortWith { + case ((id1, e1), (id2, e2)) => exists{ _ == Variable(id1) }(e2) } + Some(lets.foldRight(body) { + case ((id, e), bd) => Let(id, e, bd) + }) + case _ => None } - postMap(simplerLet)(expr) + postMap(simplerLet, applyRec = true)(expr) } /** Fully expands all let expressions. */ @@ -435,7 +389,7 @@ object ExprOps extends GenTreeOps[Expr] { case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)) case m @ MatchExpr(scrut, cses) => matchExpr(rec(scrut, s), cses.map(inCase(_, s))).setPos(m) case p @ Passes(in, out, cses) => Passes(rec(in, s), rec(out,s), cses.map(inCase(_, s))).setPos(p) - case n @ Deconstructor(args, recons) => { + case n @ Deconstructor(args, recons) => var change = false val rargs = args.map(a => { val ra = rec(a, s) @@ -450,8 +404,7 @@ object ExprOps extends GenTreeOps[Expr] { recons(rargs) else n - } - case unhandled => scala.sys.error("Unhandled case in expandLets: " + unhandled) + case unhandled => throw LeonFatalError("Unhandled case in expandLets: " + unhandled) } def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = { @@ -966,50 +919,6 @@ object ExprOps extends GenTreeOps[Expr] { postMap(transform, applyRec = true)(expr) } - /** Simplify If expressions when the branch is predetermined by the path condition */ - def simplifyTautologies(sf: SolverFactory[Solver])(expr : Expr) : Expr = { - val solver = SimpleSolverAPI(sf) - - def pre(e : Expr) = e match { - - case LetDef(fds, expr) => - for(fd <- fds if fd.hasPrecondition) { - val pre = fd.precondition.get - - solver.solveVALID(pre) match { - case Some(true) => - fd.precondition = None - - case Some(false) => solver.solveSAT(pre) match { - case (Some(false), _) => - fd.precondition = Some(BooleanLiteral(false).copiedFrom(e)) - case _ => - } - case None => - } - } - e - case IfExpr(cond, thenn, elze) => - try { - solver.solveVALID(cond) match { - case Some(true) => thenn - case Some(false) => solver.solveVALID(Not(cond)) match { - case Some(true) => elze - case _ => e - } - case None => e - } - } catch { - // let's give up when the solver crashes - case _ : Exception => e - } - - case _ => e - } - - simplePreTransform(pre)(expr) - } - def simplifyPaths(sf: SolverFactory[Solver], initC: List[Expr] = Nil): Expr => Expr = { new SimplifierWithPaths(sf, initC).transform } @@ -1069,16 +978,25 @@ object ExprOps extends GenTreeOps[Expr] { super.formulaSize(e) } - /** Returns true if the expression is deterministic / does not contain any [[purescala.Expressions.Choose Choose]] or [[purescala.Expressions.Hole Hole]]*/ + /** Returns true if the expression is deterministic / + * does not contain any [[purescala.Expressions.Choose Choose]] + * or [[purescala.Expressions.Hole Hole]] or [[purescala.Expressions.WithOracle]] + */ def isDeterministic(e: Expr): Boolean = { - preTraversal{ - case Choose(_) => return false - case Hole(_, _) => return false - //@EK FIXME: do we need it? - //case Error(_, _) => return false - case _ => + exists { + case _ : Choose | _: Hole | _: WithOracle => false + case _ => true + }(e) + } + + /** Returns if this expression would change the results of a program + * if its evaluation order in the program changed + */ + def evalOrderSensitive(e: Expr): Boolean = { + exists { + case _ : Error | _ : Choose | _: Hole | _: WithOracle => true + case _ => false }(e) - true } /** Returns the value for an identifier given a model. */ diff --git a/src/main/scala/leon/purescala/SimplifierWithPaths.scala b/src/main/scala/leon/purescala/SimplifierWithPaths.scala index a53b3b17698ccacd435978b8b36da139491cef17..40f100fcb07c84b995a2c0c82433c557b277c426 100644 --- a/src/main/scala/leon/purescala/SimplifierWithPaths.scala +++ b/src/main/scala/leon/purescala/SimplifierWithPaths.scala @@ -54,29 +54,38 @@ class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Ex } protected override def rec(e: Expr, path: C) = e match { - case IfExpr(cond, thenn, elze) => - super.rec(e, path) match { - case IfExpr(BooleanLiteral(true) , t, _) => t - case IfExpr(BooleanLiteral(false), _, e) => e - case ite => ite - } + case Require(pre, body) if impliedBy(pre, path) => + body - case And(es) => - var soFar = path - var continue = true - val r = andJoin(for(e <- es if continue) yield { - val se = rec(e, soFar) - if(se == BooleanLiteral(false)) continue = false - soFar = register(se, soFar) - se - }).copiedFrom(e) - - if (continue) { - r + case IfExpr(cond, thenn, elze) => + if (impliedBy(cond, path)) { + rec(thenn, path) + } else if (contradictedBy(cond, path)) { + rec(elze, path) } else { - BooleanLiteral(false).copiedFrom(e) + super.rec(e, path) } + case And(e +: _) if contradictedBy(e, path) => + BooleanLiteral(false).copiedFrom(e) + + case And(e +: es) if impliedBy(e, path) => + val remaining = if (es.size > 1) And(es).copiedFrom(e) else es.head + rec(remaining, path) + + case Or(e +: _) if impliedBy(e, path) => + BooleanLiteral(true).copiedFrom(e) + + case Or(e +: es) if contradictedBy(e, path) => + val remaining = if (es.size > 1) Or(es).copiedFrom(e) else es.head + rec(remaining, path) + + case Implies(lhs, rhs) if impliedBy(lhs, path) => + rec(rhs, path) + + case Implies(lhs, rhs) if contradictedBy(lhs, path) => + BooleanLiteral(true).copiedFrom(e) + case me@MatchExpr(scrut, cases) => val rs = rec(scrut, path) @@ -91,7 +100,7 @@ class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Ex stillPossible = false } - Some((cs match { + Seq((cs match { case SimpleCase(p, rhs) => SimpleCase(p, rec(rhs, cond)) case GuardedCase(p, g, rhs) => @@ -106,35 +115,18 @@ class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Ex GuardedCase(p, newGuard, rec(rhs, cond)) }).copiedFrom(cs)) } else { - None + Seq() } } newCases match { - case List() => Error(e.getType, "Unreachable code").copiedFrom(e) - case List(theCase) if !scrut.getType.isInstanceOf[AbstractClassType] => - // Avoid AbstractClassType as it may lead to invalid field accesses - replaceFromIDs(mapForPattern(scrut, theCase.pattern), theCase.rhs) - case _ => matchExpr(rs, newCases).copiedFrom(e) - } - - case Or(es) => - var soFar = path - var continue = true - val r = orJoin(for(e <- es if continue) yield { - val se = rec(e, soFar) - if(se == BooleanLiteral(true)) continue = false - soFar = register(Not(se), soFar) - se - }).copiedFrom(e) - - if (continue) { - r - } else { - BooleanLiteral(true).copiedFrom(e) + case List() => + Error(e.getType, "Unreachable code").copiedFrom(e) + case _ => + matchExpr(rs, newCases).copiedFrom(e) } case a @ Assert(pred, _, body) if impliedBy(pred, path) => - body.copiedFrom(a) + body case b if b.getType == BooleanType && impliedBy(b, path) => BooleanLiteral(true).copiedFrom(b) diff --git a/src/main/scala/leon/utils/Simplifiers.scala b/src/main/scala/leon/utils/Simplifiers.scala index 26a62b2983d2627ac4bb7c3a925f1e08f38f3969..6fdcbc9da79101bbcacee65e3c16af637c9699b6 100644 --- a/src/main/scala/leon/utils/Simplifiers.scala +++ b/src/main/scala/leon/utils/Simplifiers.scala @@ -15,23 +15,14 @@ object Simplifiers { val solverf = SolverFactory.uninterpreted(ctx, p) try { - val simplifiers = List[Expr => Expr]( - simplifyTautologies(solverf)(_), - simplifyLets, - simplifyPaths(solverf)(_), - simplifyArithmetic, - evalGround(ctx, p), - normalizeExpression - ) - - val simple = { expr: Expr => - simplifiers.foldLeft(expr){ case (x, sim) => - sim(x) - } - } + val simplifiers = (simplifyLets _). + andThen(simplifyPaths(solverf)). + andThen(simplifyArithmetic). + andThen(evalGround(ctx, p)). + andThen(normalizeExpression) // Simplify first using stable simplifiers - val s = fixpoint(simple, 5)(e) + val s = fixpoint(simplifiers, 5)(e) // Clean up ids/names (new ScopeSimplifier).transform(s) @@ -44,21 +35,12 @@ object Simplifiers { val solverf = SolverFactory.uninterpreted(ctx, p) try { - val simplifiers = List[Expr => Expr]( - simplifyTautologies(solverf)(_), - simplifyArithmetic, - evalGround(ctx, p), - normalizeExpression - ) - - val simple = { expr: Expr => - simplifiers.foldLeft(expr){ case (x, sim) => - sim(x) - } - } + val simplifiers = (simplifyArithmetic _). + andThen(evalGround(ctx, p)). + andThen(normalizeExpression) // Simplify first using stable simplifiers - fixpoint(simple, 5)(e) + fixpoint(simplifiers, 5)(e) } finally { solverf.shutdown() } diff --git a/src/test/scala/leon/integration/purescala/SimplifyPathsSuite.scala b/src/test/scala/leon/integration/purescala/SimplifyPathsSuite.scala index f70543fc98504e3f8ea4bc689530d1d9d0cf7dd9..a560f9ca4cd97f38f65c91e8c04a1a39c7dbd336 100644 --- a/src/test/scala/leon/integration/purescala/SimplifyPathsSuite.scala +++ b/src/test/scala/leon/integration/purescala/SimplifyPathsSuite.scala @@ -50,10 +50,9 @@ class SimplifyPathsSuite extends LeonTestSuite { MatchCase(LiteralPattern(None, BooleanLiteral(false)), None, Not(cV)) )) - val exp = cV val out = simplifyPaths(ctx, in) - assert(out === exp) + assert(out.asInstanceOf[MatchExpr].cases.size == 1) } test("Simplify Paths 03 - ") { ctx =>