From 5212f8852b9d46dfaa2776cef3fac011568245b7 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Fri, 4 Nov 2016 09:21:21 +0100 Subject: [PATCH] Fixes for lambda equality --- src/main/scala/inox/ast/SymbolOps.scala | 20 ++++++++++++----- .../solvers/unrolling/TemplateGenerator.scala | 22 ++++++++++--------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index 9ebfc07ba..deadd9826 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -130,32 +130,42 @@ trait SymbolOps { self: TypeOps => newId } - def transformId(id: Identifier, tpe: Type): (Identifier, Type) = subst.get(Variable(id, tpe)) match { - case Some(Variable(newId, tpe)) => (newId, tpe) + def transformId(id: Identifier, tpe: Type): Identifier = subst.get(Variable(id, tpe)) match { + case Some(Variable(newId, _)) => newId case Some(_) => scala.sys.error("Should never happen!") case None => varSubst.get(id) match { - case Some(newId) => (newId, tpe) + case Some(newId) => newId case None => val newId = getId(Variable(id, tpe)) varSubst += id -> newId - (newId, tpe) + newId } } def rec(vars: Set[Variable], body: Expr): Expr = { class Normalizer extends SelfTreeTransformer { - override def transform(id: Identifier, tpe: Type): (Identifier, Type) = transformId(id, tpe) + override def transform(id: Identifier, tpe: Type): (Identifier, Type) = (transformId(id, tpe), tpe) override def transform(e: Expr): Expr = e match { + case Variable(id, tpe) => + Variable(transformId(id, tpe), tpe) + + case Let(vd, e, b) if (!onlySimple || isSimple(e)) && (variablesOf(e) & vars).nonEmpty => + val newId = getId(e) + transform(replaceFromSymbols(Map(vd.toVariable -> Variable(newId, vd.tpe)), b)) + case expr if (!onlySimple || isSimple(expr)) && (variablesOf(expr) & vars).isEmpty => Variable(getId(expr), expr.getType) + case f: Forall => val newBody = rec(vars ++ f.args.map(_.toVariable), f.body) Forall(f.args.map(vd => vd.copy(id = varSubst(vd.id))), newBody) + case l: Lambda => val newBody = rec(vars ++ l.args.map(_.toVariable), l.body) Lambda(l.args.map(vd => vd.copy(id = varSubst(vd.id))), newBody) + case _ => super.transform(e) } } diff --git a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala index 96cdd6408..dc519c719 100644 --- a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala +++ b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala @@ -291,21 +291,21 @@ trait TemplateGenerator { self: Templates => val ids: (Variable, Encoded) = lid -> storeLambda(lid) - val (dependencies, (depConds, depExprs, depTree, depGuarded, depEqs, depLambdas, depQuants)) = - sortedDeps.foldLeft[(Seq[Encoded], TemplateClauses)](Seq.empty -> emptyClauses) { - case ((dependencies, clsSet), (_, expr)) => + val (depSubst, (depConds, depExprs, depTree, depGuarded, depEqs, depLambdas, depQuants)) = + sortedDeps.foldLeft[(Map[Variable, Encoded], TemplateClauses)](localSubst -> emptyClauses) { + case ((depSubst, clsSet), (v, expr)) => if (!exprOps.isSimple(expr)) { - val (e, cls @ (_, _, _, _, _, lmbds, quants)) = mkExprClauses(pathVar, expr, localSubst) - val clauseSubst = localSubst ++ lmbds.map(_.ids) ++ quants.flatMap(_.mapping) - (dependencies :+ mkEncoder(clauseSubst)(e), clsSet ++ cls) + val (e, cls @ (_, _, _, _, _, lmbds, quants)) = mkExprClauses(pathVar, expr, depSubst) + val clauseSubst = depSubst ++ lmbds.map(_.ids) ++ quants.flatMap(_.mapping) + (depSubst + (v -> mkEncoder(clauseSubst)(e)), clsSet ++ cls) } else { - (dependencies :+ mkEncoder(localSubst)(expr), clsSet) + (depSubst + (v -> mkEncoder(depSubst)(expr)), clsSet) } } val (depClauses, depCalls, depApps, depMatchers, _) = Template.encode( pathVar -> encodedCond(pathVar), Seq.empty, - depConds, depExprs, depGuarded, depEqs, depLambdas, depQuants, localSubst) + depConds, depExprs, depGuarded, depEqs, depLambdas, depQuants, depSubst) val depClosures: Seq[Encoded] = { var cls: Seq[Variable] = Seq.empty @@ -313,16 +313,18 @@ trait TemplateGenerator { self: Templates => val vars = exprOps.variablesOf(e).toSet exprOps.preTraversal { case v: Variable if vars(v) => cls :+= v case _ => } (e) } - cls.distinct.map(localSubst) + cls.distinct.map(depSubst) } + val dependencies = sortedDeps.map(p => depSubst(p._1)) + val structure = new LambdaStructure( struct, dependencies, pathVar -> encodedCond(pathVar), depClosures, depConds, depExprs, depTree, depClauses, depCalls, depApps, depMatchers, depLambdas, depQuants) val template = LambdaTemplate(ids, pathVar -> encodedCond(pathVar), idArgs zip trArgs, idDeps zip trDeps, lambdaConds, lambdaExprs, lambdaTree, - lambdaGuarded, lambdaEqs, lambdaTemplates, lambdaQuants, structure, localSubst, l) + lambdaGuarded, lambdaEqs, lambdaTemplates, lambdaQuants, structure, depSubst, l) registerLambda(template) lid -- GitLab