From 5212f8852b9d46dfaa2776cef3fac011568245b7 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Fri, 4 Nov 2016 09:21:21 +0100
Subject: [PATCH] Fixes for lambda equality

---
 src/main/scala/inox/ast/SymbolOps.scala       | 20 ++++++++++++-----
 .../solvers/unrolling/TemplateGenerator.scala | 22 ++++++++++---------
 2 files changed, 27 insertions(+), 15 deletions(-)

diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala
index 9ebfc07ba..deadd9826 100644
--- a/src/main/scala/inox/ast/SymbolOps.scala
+++ b/src/main/scala/inox/ast/SymbolOps.scala
@@ -130,32 +130,42 @@ trait SymbolOps { self: TypeOps =>
       newId
     }
 
-    def transformId(id: Identifier, tpe: Type): (Identifier, Type) = subst.get(Variable(id, tpe)) match {
-      case Some(Variable(newId, tpe)) => (newId, tpe)
+    def transformId(id: Identifier, tpe: Type): Identifier = subst.get(Variable(id, tpe)) match {
+      case Some(Variable(newId, _)) => newId
       case Some(_) => scala.sys.error("Should never happen!")
       case None => varSubst.get(id) match {
-        case Some(newId) => (newId, tpe)
+        case Some(newId) => newId
         case None =>
           val newId = getId(Variable(id, tpe))
           varSubst += id -> newId
-          (newId, tpe)
+          newId
       }
     }
 
     def rec(vars: Set[Variable], body: Expr): Expr = {
 
       class Normalizer extends SelfTreeTransformer {
-        override def transform(id: Identifier, tpe: Type): (Identifier, Type) = transformId(id, tpe)
+        override def transform(id: Identifier, tpe: Type): (Identifier, Type) = (transformId(id, tpe), tpe)
 
         override def transform(e: Expr): Expr = e match {
+          case Variable(id, tpe) =>
+            Variable(transformId(id, tpe), tpe)
+
+          case Let(vd, e, b) if (!onlySimple || isSimple(e)) && (variablesOf(e) & vars).nonEmpty =>
+            val newId = getId(e)
+            transform(replaceFromSymbols(Map(vd.toVariable -> Variable(newId, vd.tpe)), b))
+
           case expr if (!onlySimple || isSimple(expr)) && (variablesOf(expr) & vars).isEmpty =>
             Variable(getId(expr), expr.getType)
+
           case f: Forall =>
             val newBody = rec(vars ++ f.args.map(_.toVariable), f.body)
             Forall(f.args.map(vd => vd.copy(id = varSubst(vd.id))), newBody)
+
           case l: Lambda =>
             val newBody = rec(vars ++ l.args.map(_.toVariable), l.body)
             Lambda(l.args.map(vd => vd.copy(id = varSubst(vd.id))), newBody)
+
           case _ => super.transform(e)
         }
       }
diff --git a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala
index 96cdd6408..dc519c719 100644
--- a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala
+++ b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala
@@ -291,21 +291,21 @@ trait TemplateGenerator { self: Templates =>
 
         val ids: (Variable, Encoded) = lid -> storeLambda(lid)
 
-        val (dependencies, (depConds, depExprs, depTree, depGuarded, depEqs, depLambdas, depQuants)) =
-          sortedDeps.foldLeft[(Seq[Encoded], TemplateClauses)](Seq.empty -> emptyClauses) {
-            case ((dependencies, clsSet), (_, expr)) =>
+        val (depSubst, (depConds, depExprs, depTree, depGuarded, depEqs, depLambdas, depQuants)) =
+          sortedDeps.foldLeft[(Map[Variable, Encoded], TemplateClauses)](localSubst -> emptyClauses) {
+            case ((depSubst, clsSet), (v, expr)) =>
               if (!exprOps.isSimple(expr)) {
-                val (e, cls @ (_, _, _, _, _, lmbds, quants)) = mkExprClauses(pathVar, expr, localSubst)
-                val clauseSubst = localSubst ++ lmbds.map(_.ids) ++ quants.flatMap(_.mapping)
-                (dependencies :+ mkEncoder(clauseSubst)(e), clsSet ++ cls)
+                val (e, cls @ (_, _, _, _, _, lmbds, quants)) = mkExprClauses(pathVar, expr, depSubst)
+                val clauseSubst = depSubst ++ lmbds.map(_.ids) ++ quants.flatMap(_.mapping)
+                (depSubst + (v -> mkEncoder(clauseSubst)(e)), clsSet ++ cls)
               } else {
-                (dependencies :+ mkEncoder(localSubst)(expr), clsSet)
+                (depSubst + (v -> mkEncoder(depSubst)(expr)), clsSet)
               }
           }
 
         val (depClauses, depCalls, depApps, depMatchers, _) = Template.encode(
           pathVar -> encodedCond(pathVar), Seq.empty,
-          depConds, depExprs, depGuarded, depEqs, depLambdas, depQuants, localSubst)
+          depConds, depExprs, depGuarded, depEqs, depLambdas, depQuants, depSubst)
 
         val depClosures: Seq[Encoded] = {
           var cls: Seq[Variable] = Seq.empty
@@ -313,16 +313,18 @@ trait TemplateGenerator { self: Templates =>
             val vars = exprOps.variablesOf(e).toSet
             exprOps.preTraversal { case v: Variable if vars(v) => cls :+= v case _ => } (e)
           }
-          cls.distinct.map(localSubst)
+          cls.distinct.map(depSubst)
         }
 
+        val dependencies = sortedDeps.map(p => depSubst(p._1))
+
         val structure = new LambdaStructure(
           struct, dependencies, pathVar -> encodedCond(pathVar), depClosures,
           depConds, depExprs, depTree, depClauses, depCalls, depApps, depMatchers, depLambdas, depQuants)
 
         val template = LambdaTemplate(ids, pathVar -> encodedCond(pathVar),
           idArgs zip trArgs, idDeps zip trDeps, lambdaConds, lambdaExprs, lambdaTree,
-          lambdaGuarded, lambdaEqs, lambdaTemplates, lambdaQuants, structure, localSubst, l)
+          lambdaGuarded, lambdaEqs, lambdaTemplates, lambdaQuants, structure, depSubst, l)
         registerLambda(template)
         lid
 
-- 
GitLab