diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index e7d14301d2322fc9cccfcfba9edb4142df0935b6..1e01ffb56a2f07f6e44e02233e6788a53b865899 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -386,10 +386,11 @@ trait SymbolOps extends TreeOps { self: TypeOps => preMap(rewritePM)(expr) } - /** For each case in the [[purescala.Expressions.MatchExpr MatchExpr]], concatenates the path condition with the newly induced conditions. - * - * Each case holds the conditions on other previous cases as negative. - * + /** For each case in the [[purescala.Expressions.MatchExpr MatchExpr]], + * concatenates the path condition with the newly induced conditions. + * Each case holds the conditions on other previous cases as negative. + * @note The guard of the final case is NOT included in the Paths. + * * @see [[purescala.ExprOps#conditionForPattern conditionForPattern]] * @see [[purescala.ExprOps#mapForPattern mapForPattern]] */ @@ -398,12 +399,12 @@ trait SymbolOps extends TreeOps { self: TypeOps => var pcSoFar = path for (c <- cases) yield { - val g = c.optGuard getOrElse BooleanLiteral(true) val cond = conditionForPattern(scrut, c.pattern, includeBinders = true) - val localCond = pcSoFar merge (cond withCond g) + val localCond = pcSoFar merge cond // These contain no binders defined in this MatchCase val condSafe = conditionForPattern(scrut, c.pattern) + val g = c.optGuard getOrElse BooleanLiteral(true) val gSafe = replaceFromSymbols(mapForPattern(scrut, c.pattern), g) pcSoFar = pcSoFar merge (condSafe withCond gSafe).negate diff --git a/src/main/scala/inox/transformers/ScopeSimplifier.scala b/src/main/scala/inox/transformers/ScopeSimplifier.scala new file mode 100644 index 0000000000000000000000000000000000000000..e4a2404bf61c03424f2ee3987db4fb8bddbf0cc9 --- /dev/null +++ b/src/main/scala/inox/transformers/ScopeSimplifier.scala @@ -0,0 +1,100 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package transformers + +/** Simplifies variable ids in scope */ +trait ScopeSimplifier extends Transformer { + import trees._ + case class Scope(inScope: Set[ValDef] = Set(), oldToNew: Map[ValDef, ValDef] = Map(), funDefs: Map[Identifier, Identifier] = Map()) { + + def register(oldNew: (ValDef, ValDef)): Scope = { + val newId = oldNew._2 + copy(inScope = inScope + newId, oldToNew = oldToNew + oldNew) + } + + def register(oldNews: Seq[(ValDef, ValDef)]): Scope = { + (this /: oldNews){ case (oldScope, oldNew) => oldScope.register(oldNew) } + } + + def registerFunDef(oldNew: (Identifier, Identifier)): Scope = { + copy(funDefs = funDefs + oldNew) + } + } + + protected def genId(vd: ValDef, scope: Scope): ValDef = { + val ValDef(id, tp) = vd + val existCount = scope.inScope.count(_.id.name == id.name) + + ValDef(FreshIdentifier.forceId(id.name, existCount, existCount >= 1), tp) + } + + protected def rec(e: Expr, scope: Scope): Expr = e match { + case Let(i, e, b) => + val si = genId(i, scope) + val se = rec(e, scope) + val sb = rec(b, scope.register(i -> si)) + Let(si, se, sb) + + case MatchExpr(scrut, cases) => + val rs = rec(scrut, scope) + + def trPattern(p: Pattern, scope: Scope): (Pattern, Scope) = { + val (newBinder, newScope) = p.binder match { + case Some(id) => + val newId = genId(id, scope) + val newScope = scope.register(id -> newId) + (Some(newId), newScope) + case None => + (None, scope) + } + + var curScope = newScope + val newSubPatterns = for (sp <- p.subPatterns) yield { + val (subPattern, subScope) = trPattern(sp, curScope) + curScope = subScope + subPattern + } + + val newPattern = p match { + case InstanceOfPattern(b, ctd) => + InstanceOfPattern(newBinder, ctd) + case WildcardPattern(b) => + WildcardPattern(newBinder) + case CaseClassPattern(b, ccd, sub) => + CaseClassPattern(newBinder, ccd, newSubPatterns) + case TuplePattern(b, sub) => + TuplePattern(newBinder, newSubPatterns) + case UnapplyPattern(b, fd, obj, sub) => + UnapplyPattern(newBinder, fd, obj, newSubPatterns) + case LiteralPattern(_, lit) => + LiteralPattern(newBinder, lit) + } + + (newPattern, curScope) + } + + MatchExpr(rs, cases.map { c => + val (newP, newScope) = trPattern(c.pattern, scope) + MatchCase(newP, c.optGuard map {rec(_, newScope)}, rec(c.rhs, newScope)) + }) + + case v: Variable => + val vd = v.toVal + scope.oldToNew.getOrElse(vd, vd).toVariable + + // This only makes sense if we have Let-Defs at some point + case FunctionInvocation(id, tps, args) => + val newFd = scope.funDefs.getOrElse(id, id) + val newArgs = args.map(rec(_, scope)) + + FunctionInvocation(newFd, tps, newArgs) + + case Operator(es, builder) => + builder(es.map(rec(_, scope))) + + case _ => + sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") + } + +} diff --git a/src/main/scala/inox/transformers/SimplifierWithPC.scala b/src/main/scala/inox/transformers/SimplifierWithPC.scala new file mode 100644 index 0000000000000000000000000000000000000000..8161240c7adb0f740f519c9838fc82c010c7a02f --- /dev/null +++ b/src/main/scala/inox/transformers/SimplifierWithPC.scala @@ -0,0 +1,101 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package transformers + +/** Uses solvers to perform PC-aware simplifications */ +trait SimplifierWithPC extends TransformerWithPC { + + import trees._ + import symbols.{Path, matchExpr, matchExprCaseConditions} + + implicit protected val s = symbols + + // FIXME: This needs to be changed when SolverAPI's are available + protected def impliedBy(e: Expr, path: Path) : Boolean + protected def contradictedBy(e: Expr, path: Path) : Boolean + protected def valid(e: Expr) : Boolean + protected def sat(e: Expr) : Boolean + + protected override def rec(e: Expr, path: Path) = e match { + case Require(pre, body) if impliedBy(pre, path) => + body + + case IfExpr(cond, thenn, _) if impliedBy(cond, path) => + rec(thenn, path) + + case IfExpr(cond, _, elze ) if contradictedBy(cond, path) => + rec(elze, path) + + case And(e +: _) if contradictedBy(e, path) => + BooleanLiteral(false).copiedFrom(e) + + case And(e +: es) if impliedBy(e, path) => + val remaining = if (es.size > 1) And(es).copiedFrom(e) else es.head + rec(remaining, path) + + case Or(e +: _) if impliedBy(e, path) => + BooleanLiteral(true).copiedFrom(e) + + case Or(e +: es) if contradictedBy(e, path) => + val remaining = if (es.size > 1) Or(es).copiedFrom(e) else es.head + rec(remaining, path) + + case Implies(lhs, rhs) if impliedBy(lhs, path) => + rec(rhs, path) + + case Implies(lhs, rhs) if contradictedBy(lhs, path) => + BooleanLiteral(true).copiedFrom(e) + + case me @ MatchExpr(scrut, cases) => + val rs = rec(scrut, path) + + var stillPossible = true + + val conds = matchExprCaseConditions(me, path) + + val newCases = cases.zip(conds).flatMap { case (cs, cond) => + if (stillPossible && sat(cond.toClause)) { + + if (valid(cond.toClause)) { + stillPossible = false + } + + Seq((cs match { + case SimpleCase(p, rhs) => + SimpleCase(p, rec(rhs, cond)) + case GuardedCase(p, g, rhs) => + val newGuard = rec(g, cond) + if (valid(newGuard)) + SimpleCase(p, rec(rhs,cond)) + else + GuardedCase(p, newGuard, rec(rhs, cond withCond newGuard)) + }).copiedFrom(cs)) + } else { + Seq() + } + } + + newCases match { + case List() => + Error(e.getType, "Unreachable code").copiedFrom(e) + case _ => + matchExpr(rs, newCases).copiedFrom(e) + } + + case a @ Assert(pred, _, body) if impliedBy(pred, path) => + body + + case a @ Assert(pred, msg, body) if contradictedBy(pred, path) => + Error(body.getType, s"Assertion failed: $msg").copiedFrom(a) + + case b if b.getType == BooleanType && impliedBy(b, path) => + BooleanLiteral(true).copiedFrom(b) + + case b if b.getType == BooleanType && contradictedBy(b, path) => + BooleanLiteral(false).copiedFrom(b) + + case _ => + super.rec(e, path) + } +}