diff --git a/src/main/scala/inox/ast/Constructors.scala b/src/main/scala/inox/ast/Constructors.scala index 46aeafd3c4fc4511d984a8d37bd60e7aad6bbd83..d9f55e26e58e924d3927a414ca8c59b2d415e75b 100644 --- a/src/main/scala/inox/ast/Constructors.scala +++ b/src/main/scala/inox/ast/Constructors.scala @@ -228,10 +228,24 @@ trait Constructors { Application(fn, realArgs) } + /** $encodingof simplified `assume(pred, body)` (assumption). + * Transforms + * {{{ assume(assume(pred1, pred2), body) }}} + * and + * {{{ assume(pred1, assume(pred2, body)) }}} + * into + * {{{ assume(pred1 && pred2, body) }}} + * @see [[purescala.Expressions.Assume Assume]] + */ + def assume(pred: Expr, body: Expr): Expr = (pred, body) match { + case (Assume(pred1, pred2), _) => assume(and(pred1, pred2), body) + case (_, Assume(pred2, body)) => assume(and(pred, pred2), body) + case (BooleanLiteral(true), body) => body + case _ => Assume(pred, body) + } + /** $encodingof simplified `... + ...` (plus). * @see [[purescala.Expressions.Plus Plus]] - * @see [[purescala.Expressions.BVPlus BVPlus]] - * @see [[purescala.Expressions.RealPlus RealPlus]] */ def plus(lhs: Expr, rhs: Expr): Expr = (lhs, rhs) match { case (IntegerLiteral(bi), _) if bi == 0 => rhs diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index 990bb892805b4c540277577933765479a1af6136..0099afd17ea8ea860ac550df44e0e9592f2f39c7 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -248,24 +248,6 @@ trait SymbolOps { self: TypeOps => def inlineQuantifications(e: Expr): Expr = postMap { case Forall(args1, Forall(args2, body)) => Some(Forall(args1 ++ args2, body)) - case a @ Assume(pred, body) => - val vars = variablesOf(a) - var assumptions: Seq[Expr] = Seq.empty - object transformer extends transformers.TransformerWithPC { - val trees: self.trees.type = self.trees - val symbols: self.symbols.type = self.symbols - val initEnv = Path.empty - - override protected def rec(e: Expr, path: Path): Expr = e match { - case Assume(pred, body) if (variablesOf(pred) ++ path.variables) subsetOf vars => - assumptions :+= path implies pred - rec(body, path withCond pred) - case _ => super.rec(e, path) - } - } - val newPred = transformer.transform(pred) - val newBody = transformer.transform(body) - Some(Assume(andJoin(newPred +: assumptions), newBody)) case _ => None } (e) @@ -709,9 +691,40 @@ trait SymbolOps { self: TypeOps => }) } + def simplifyAssumptions(expr: Expr): Expr = { + def lift(expr: Expr): Expr = { + val vars = variablesOf(expr) + var assumptions: Seq[Expr] = Seq.empty + + object transformer extends transformers.TransformerWithPC { + val trees: self.trees.type = self.trees + val symbols: self.symbols.type = self.symbols + val initEnv = Path.empty + + override protected def rec(e: Expr, path: Path): Expr = e match { + case Assume(pred, body) if (variablesOf(pred) ++ path.variables) subsetOf vars => + assumptions :+= path implies pred + rec(body, path withCond pred) + case _ => super.rec(e, path) + } + } + + val (vs, es, tps, recons) = deconstructor.deconstruct(expr) + val newEs = es.map(transformer.transform) + assume(andJoin(assumptions.toSeq), recons(vs, newEs, tps)) + } + + postMap(e => Some(lift(e)))(expr) + } + def simplifyFormula(e: Expr, simplify: Boolean = true): Expr = { if (simplify) { - fixpoint((e: Expr) => simplifyHOFunctions(simplifyByConstructors(simplifyQuantifications(e))))(e) + val simp: Expr => Expr = + ((e: Expr) => simplifyHOFunctions(e)) compose + ((e: Expr) => simplifyByConstructors(e)) compose + ((e: Expr) => simplifyAssumptions(e)) compose + ((e: Expr) => simplifyQuantifications(e)) + fixpoint(simp)(e) } else { simplifyHOFunctions(e, simplify = false) } diff --git a/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala b/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala index 96774c91d6adb63f903cbb31d014ab24b355e225..31b3739d033af6111b88e43b71bdd949e37672e3 100644 --- a/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala @@ -601,7 +601,7 @@ trait DatatypeTemplates { self: Templates => def apply(dtpe: Type, containerType: FunctionType): CaptureTemplate = cache.getOrElseUpdate(dtpe -> containerType, { val (ps, idT, exprVars, condVars, condTree, clauses, types, funs) = tmplCache.getOrElseUpdate(dtpe, { - object b extends { val tpe = dtpe } with Builder + object b extends { val tpe = dtpe } with super[FunctionUnrolling].Builder with super[ADTUnrolling].Builder assert(b.calls.isEmpty, "Captured function templates shouldn't have any calls: " + b.calls) (b.pathVar -> b.pathVarT, b.idT, b.exprs, b.conds, b.tree, b.clauses, b.types, b.funs) }) diff --git a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala index ff73de92107e96392645981f5ecf09d0a5cd9c88..1cbd408fb4ef53303e7de860bfed62f908e8a9bf 100644 --- a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala +++ b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala @@ -302,7 +302,7 @@ trait TemplateGenerator { self: Templates => val trArgs : Seq[Encoded] = idArgs.map(id => substMap.getOrElse(id, encodeSymbol(id))) val (struct, deps) = normalizeStructure(l) - val sortedDeps = deps.toSeq.sortBy(_._1.id.uniqueName) + val sortedDeps = exprOps.variablesOf(struct).map(v => v -> deps(v)).toSeq.sortBy(_._1.id.uniqueName) val isNormalForm: Boolean = { def extractBody(e: Expr): (Seq[ValDef], Expr) = e match {