diff --git a/src/main/scala/inox/ast/Constructors.scala b/src/main/scala/inox/ast/Constructors.scala index 16363edb3d064ca90b9b1b61010f6155da5bdaa9..0852b113235ba5cb4aed1783ff811ffeba14d4b9 100644 --- a/src/main/scala/inox/ast/Constructors.scala +++ b/src/main/scala/inox/ast/Constructors.scala @@ -205,7 +205,7 @@ trait Constructors { * @see [[Expressions.Lambda Lambda]] * @see [[Expressions.Application Application]] */ - def application(fn: Expr, realArgs: Seq[Expr]) = fn match { + def application(fn: Expr, realArgs: Seq[Expr]): Expr = fn match { case Lambda(formalArgs, body) => assert(realArgs.size == formalArgs.size, "Invoking lambda with incorrect number of arguments") @@ -224,6 +224,9 @@ trait Constructors { case ((vd, bd), body) => let(vd, bd, body) } + case Assume(pred, l: Lambda) => + assume(pred, application(l, realArgs)) + case _ => Application(fn, realArgs) } diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index 583bf1cae2b63ef9ca3738b2d299796fe1ab6112..cb072f84ce2d8d6d602168c6f84b0cc25daaffc4 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -258,7 +258,7 @@ trait SymbolOps { self: TypeOps => forall(args ++ allArgs.flatten, recons(allBodies)) } - + postMap { case Forall(args1, Forall(args2, body)) => Some(forall(args1 ++ args2, body)) @@ -746,32 +746,35 @@ trait SymbolOps { self: TypeOps => }) } - def simplifyAssumptions(expr: Expr): Expr = { - def lift(expr: Expr): Expr = { - val vars = variablesOf(expr) - var assumptions: Seq[Expr] = Seq.empty - - object transformer extends transformers.TransformerWithPC { - val trees: self.trees.type = self.trees - val symbols: self.symbols.type = self.symbols - val initEnv = Path.empty - - override protected def rec(e: Expr, path: Path): Expr = e match { - case Assume(pred, body) if (variablesOf(pred) ++ path.variables) subsetOf vars => - assumptions :+= path implies pred - rec(body, path withCond pred) - case _ => super.rec(e, path) - } - } + def liftAssumptions(expr: Expr): (Seq[Expr], Expr) = { + val vars = variablesOf(expr) + var assumptions: Seq[Expr] = Seq.empty + + object transformer extends transformers.TransformerWithPC { + val trees: self.trees.type = self.trees + val symbols: self.symbols.type = self.symbols + val initEnv = Path.empty - val (vs, es, tps, recons) = deconstructor.deconstruct(expr) - val newEs = es.map(transformer.transform) - assume(andJoin(assumptions.toSeq), recons(vs, newEs, tps)) + override protected def rec(e: Expr, path: Path): Expr = e match { + case Assume(pred, body) if (variablesOf(pred) ++ path.variables) subsetOf vars => + assumptions :+= path implies pred + rec(body, path withCond pred) + case _ => super.rec(e, path) + } } - postMap(e => Some(lift(e)))(expr) + val newExpr = transformer.transform(expr) + (assumptions, newExpr) } + def simplifyAssumptions(expr: Expr): Expr = postMap { + case Assume(pred, body) => + val (predAssumptions, newPred) = liftAssumptions(pred) + val (bodyAssumptions, newBody) = liftAssumptions(body) + Some(assume(andJoin(predAssumptions ++ (newPred +: bodyAssumptions)), newBody)) + case _ => None + } (expr) + def simplifyFormula(e: Expr, simplify: Boolean = true): Expr = { if (simplify) { val simp: Expr => Expr = diff --git a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala index 1cbd408fb4ef53303e7de860bfed62f908e8a9bf..2a12e422bdf1cd1ebdcbcb13108b7f0fa14f3ccb 100644 --- a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala +++ b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala @@ -301,7 +301,13 @@ trait TemplateGenerator { self: Templates => val idArgs : Seq[Variable] = lambdaArgs(l) val trArgs : Seq[Encoded] = idArgs.map(id => substMap.getOrElse(id, encodeSymbol(id))) - val (struct, deps) = normalizeStructure(l) + val (assumptions, without) = liftAssumptions(l) + + for (a <- assumptions) { + rec(pathVar, a, Some(true)) + } + + val (struct, deps) = normalizeStructure(without.asInstanceOf[Lambda]) val sortedDeps = exprOps.variablesOf(struct).map(v => v -> deps(v)).toSeq.sortBy(_._1.id.uniqueName) val isNormalForm: Boolean = {