diff --git a/src/main/scala/inox/ast/Constructors.scala b/src/main/scala/inox/ast/Constructors.scala index d9f55e26e58e924d3927a414ca8c59b2d415e75b..293c70394f11dfa14e5f71fe0a9b53c540821b44 100644 --- a/src/main/scala/inox/ast/Constructors.scala +++ b/src/main/scala/inox/ast/Constructors.scala @@ -121,10 +121,10 @@ trait Constructors { } var stop = false - val simpler = for(e <- flat if !stop && e != BooleanLiteral(true)) yield { - if(e == BooleanLiteral(false)) stop = true + val simpler = (for (e <- flat if !stop && e != BooleanLiteral(true)) yield { + if (e == BooleanLiteral(false)) stop = true e - } + }).distinct simpler match { case Seq() => BooleanLiteral(true) diff --git a/src/main/scala/inox/ast/ExprOps.scala b/src/main/scala/inox/ast/ExprOps.scala index a6c202ba0c46f0b997bcb3230f8c17e1480a1871..65262b2f9fd0902ca2fd0fea81cd239e6a21a696 100644 --- a/src/main/scala/inox/ast/ExprOps.scala +++ b/src/main/scala/inox/ast/ExprOps.scala @@ -66,6 +66,20 @@ trait ExprOps extends GenTreeOps { }(expr) } + /** Freshens all local variables */ + def freshenLocals(expr: Expr): Expr = { + def rec(expr: Expr, bindings: Map[Variable, Variable]): Expr = expr match { + case v: Variable => bindings(v) + case _ => + val (vs, es, tps, recons) = deconstructor.deconstruct(expr) + val newVs = vs.map(_.freshen) + val newBindings = bindings ++ (vs zip newVs) + recons(newVs, es map (rec(_, newBindings)), tps) + } + + rec(expr, variablesOf(expr).map(v => v -> v).toMap) + } + /** Returns true if the expression contains a function call */ def containsFunctionCalls(expr: Expr): Boolean = { exists{ diff --git a/src/main/scala/inox/ast/Expressions.scala b/src/main/scala/inox/ast/Expressions.scala index 55c8aad2a5e486675e9eb7ef0e9f8ba8d76a4f6d..5f1fd53facb7d87d2e9f66fdd1356b91458963d5 100644 --- a/src/main/scala/inox/ast/Expressions.scala +++ b/src/main/scala/inox/ast/Expressions.scala @@ -130,7 +130,7 @@ trait Expressions { self: Trees => def inlined(implicit s: Symbols): Expr = { val tfd = this.tfd - tfd.withParamSubst(args, tfd.fullBody) + exprOps.freshenLocals(tfd.withParamSubst(args, tfd.fullBody)) } } diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index 0099afd17ea8ea860ac550df44e0e9592f2f39c7..3ceabcb3f7b54eddfa22d6f7c9668bd2603000ee 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -246,10 +246,65 @@ trait SymbolOps { self: TypeOps => fixpoint(inline)(e) } - def inlineQuantifications(e: Expr): Expr = postMap { - case Forall(args1, Forall(args2, body)) => Some(Forall(args1 ++ args2, body)) - case _ => None - } (e) + def inlineQuantifications(e: Expr): Expr = { + def liftForalls(args: Seq[ValDef], es: Seq[Expr], recons: Seq[Expr] => Expr): Forall = { + val (allArgs, allBodies) = es.map { + case f: Forall => + val Forall(args, body) = freshenLocals(f) + (args, body) + case e => + (Seq[ValDef](), e) + }.unzip + + Forall(args ++ allArgs.flatten, recons(allBodies)) + } + + postMap { + case Forall(args1, Forall(args2, body)) => + Some(Forall(args1 ++ args2, body)) + + case Forall(args, And(es)) => + Some(liftForalls(args, es, andJoin)) + + case Forall(args, Or(es)) => + Some(liftForalls(args, es, orJoin)) + + case Forall(args, Implies(e1, e2)) => + Some(liftForalls(args, Seq(e1, e2), es => implies(es(0), es(1)))) + + case And(es) => Some(andJoin(SeqUtils.groupWhile(es)(_.isInstanceOf[Forall]).map { + case Seq(e) => e + case foralls => + val pairs = foralls.collect { case Forall(args, body) => (args, body) } + val (allArgs, allBodies) = pairs.foldLeft((Seq[ValDef](), Seq[Expr]())) { + case ((allArgs, bodies), (args, body)) => + val available = allArgs.groupBy(_.tpe).mapValues(_.sortBy(_.id.uniqueName)) + val (_, map) = args.foldLeft((available, Map[ValDef, ValDef]())) { + case ((available, map), vd) => available.get(vd.tpe) match { + case Some(x +: xs) => + val newAvailable = if (xs.isEmpty) { + available - vd.tpe + } else { + available + (vd.tpe -> xs) + } + (newAvailable, map + (vd -> x)) + case _ => + (available, map + (vd -> vd)) + } + } + + val newBody = replaceFromSymbols(map.mapValues(_.toVariable), body) + val newArgs = allArgs ++ map.map(_._2).filterNot(allArgs contains _) + val newBodies = bodies :+ newBody + (newArgs, newBodies) + } + + Forall(allArgs, andJoin(allBodies)) + })) + + case _ => None + } (e) + } /* Weaker variant of disjunctive normal form */ def normalizeClauses(e: Expr): Expr = e match {