diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala
index cfecfacda32f2ff7724b7e4cb27eff28887ea701..7249313f16c1f171d1ef15888616ebff1108c468 100644
--- a/src/main/scala/leon/solvers/templates/LambdaManager.scala
+++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala
@@ -38,19 +38,11 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) {
     freeLambdasStack = map :: freeLambdasStack.tail
   }
 
-  private type StructuralMap = Map[Lambda, List[(T, LambdaTemplate[T])]]
-  private var structuralLambdasStack : List[StructuralMap] = List(Map.empty.withDefaultValue(List.empty))
-  private def structuralLambdas : StructuralMap = structuralLambdasStack.head
-  private def structuralLambdas_=(map: StructuralMap) : Unit = {
-    structuralLambdasStack = map :: structuralLambdasStack.tail
-  }
-
   def push(): Unit = {
     byIDStack = byID :: byIDStack
     byTypeStack = byType :: byTypeStack
     applicationsStack = applications :: applicationsStack
     freeLambdasStack = freeLambdas :: freeLambdasStack
-    structuralLambdasStack = structuralLambdas :: structuralLambdasStack
   }
 
   def pop(lvl: Int): Unit = {
@@ -58,7 +50,6 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) {
     byTypeStack = byTypeStack.drop(lvl)
     applicationsStack = applicationsStack.drop(lvl)
     freeLambdasStack = freeLambdasStack.drop(lvl)
-    structuralLambdasStack = structuralLambdasStack.drop(lvl)
   }
 
   def registerFree(lambdas: Seq[(TypeTree, T)]): Unit = {
@@ -79,10 +70,8 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) {
     }
 
     for (lambda @ (idT, template) <- lambdas) {
-      // get all lambda references...
-      val lambdaRefs = freeLambdas(template.tpe) ++ byType(template.tpe).map(_._1)
-      // ... and make sure the new lambda isn't equal to one of them!
-      clauses ++= lambdaRefs.map(pIdT => encoder.mkNot(encoder.mkEquals(pIdT, idT)))
+      // make sure the new lambda isn't equal to any free lambda var
+      clauses ++= freeLambdas(template.tpe).map(pIdT => encoder.mkNot(encoder.mkEquals(pIdT, idT)))
 
       byID += idT -> template
       byType += template.tpe -> (byType(template.tpe) + (idT -> template))
@@ -117,22 +106,15 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) {
     (clauses, callBlockers, appBlockers)
   }
 
-  def equalityClauses(template: LambdaTemplate[T], idT: T, substMap: Map[T,T]): Seq[T] = {
-    val key : Lambda = template.key
-    val t : LambdaTemplate[T] = template.substitute(substMap)
-
-    val newClauses = structuralLambdas(key).map { case (thatIdT, that) =>
+  def equalityClauses(idT: T, template: LambdaTemplate[T]): Seq[T] = {
+    byType(template.tpe).map { case (thatIdT, that) =>
       val equals = encoder.mkEquals(idT, thatIdT)
-      if (t.dependencies.isEmpty) {
-        equals
-      } else {
-        encoder.mkImplies(t.contextEquality(that), equals)
+      template.contextEquality(that) match {
+        case None => encoder.mkNot(equals)
+        case Some(Seq()) => equals
+        case Some(seq) => encoder.mkImplies(encoder.mkAnd(seq : _*), equals)
       }
-    }
-
-    structuralLambdas += key -> (structuralLambdas(key) :+ (idT -> t))
-
-    newClauses
+    }.toSeq
   }
 
 }
diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala
index fb73c599d2bf805f0f60b796ed8b0e38f92b15a8..c7e79263f807141565644ac864cafa59ad70a8b7 100644
--- a/src/main/scala/leon/solvers/templates/Templates.scala
+++ b/src/main/scala/leon/solvers/templates/Templates.scala
@@ -45,13 +45,7 @@ trait Template[T] { self =>
         subst
     }
 
-    val (lambdaSubstMap, lambdaClauses) = lambdas.foldLeft((Map.empty[T,T], Seq.empty[T])) {
-      case ((subst, clauses), (idT, lambda)) =>
-        val newIdT = encoder.encodeId(lambda.id)
-        val eqClauses = lambdaManager.equalityClauses(lambda, newIdT, baseSubstMap)
-
-        (subst + (idT -> newIdT), clauses ++ eqClauses)
-    }
+    val lambdaSubstMap = lambdas.map { case (idT, lambda) => idT -> encoder.encodeId(lambda.id) }
 
     val substMap : Map[T,T] = baseSubstMap ++ lambdaSubstMap + (start -> aVar)
     val substituter : T => T = encoder.substitute(substMap)
@@ -65,8 +59,12 @@ trait Template[T] { self =>
       substituter(b) -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter)))
     }
 
-    val newLambdas = lambdas.map { case (idT, lambda) =>
-      substituter(idT) -> lambda.substitute(substMap)
+    val (newLambdas, lambdaClauses) = lambdas.foldLeft((Map.empty[T,LambdaTemplate[T]], Seq.empty[T])) {
+      case ((newLambdas, clauses), (idT, lambda)) =>
+        val newIdT = substituter(idT)
+        val newTemplate = lambda.substitute(substMap)
+        val eqClauses = lambdaManager.equalityClauses(newIdT, newTemplate)
+        (newLambdas + (newIdT -> newTemplate), clauses ++ eqClauses)
     }
 
     val (appClauses, appBlockers, appApps) = lambdaManager.instantiate(newApplications, newLambdas)
@@ -229,7 +227,7 @@ object LambdaTemplate {
 
   private var typedIds : Map[TypeTree, List[Identifier]] = Map.empty.withDefaultValue(List.empty)
 
-  private def templateKey[T](lambda: LambdaTemplate[T]): Lambda = {
+  private def structuralKey[T](lambda: Lambda, dependencies: Map[Identifier, T]): (Lambda, Map[Identifier,T]) = {
 
     def closureIds(expr: Expr): Seq[Identifier] = {
       val vars = variablesOf(expr)
@@ -243,7 +241,7 @@ object LambdaTemplate {
       allVars.filter(vars(_)).distinct
     }
 
-    val grouped : Map[TypeTree, Seq[Identifier]] = closureIds(lambda.lambda).groupBy(_.getType)
+    val grouped : Map[TypeTree, Seq[Identifier]] = closureIds(lambda).groupBy(_.getType)
     val subst : Map[Identifier, Identifier] = grouped.foldLeft(Map.empty[Identifier,Identifier]) { case (subst, (tpe, ids)) =>
       val currentVars = typedIds(tpe)
 
@@ -259,7 +257,10 @@ object LambdaTemplate {
       subst ++ (ids zip typedVars)
     }
 
-    replaceFromIDs(subst.mapValues(_.toVariable), lambda.lambda).asInstanceOf[Lambda]
+    val structuralLambda = replaceFromIDs(subst.mapValues(_.toVariable), lambda).asInstanceOf[Lambda]
+    val newDeps = dependencies.map { case (id, idT) => subst(id) -> idT }
+
+    structuralLambda -> newDeps
   }
 
   def apply[T](
@@ -287,6 +288,8 @@ object LambdaTemplate {
       "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString()
     }
 
+    val (key, keyDeps) = structuralKey(lambda, dependencies)
+
     new LambdaTemplate[T](
       ids._1,
       encoder,
@@ -299,8 +302,8 @@ object LambdaTemplate {
       blockers,
       applications,
       lambdas,
-      dependencies,
-      lambda,
+      keyDeps,
+      key,
       lambdaString
     )
   }
@@ -319,7 +322,7 @@ class LambdaTemplate[T] private (
   val applications: Map[T, Set[App[T]]],
   val lambdas: Map[T, LambdaTemplate[T]],
   private[templates] val dependencies: Map[Identifier, T],
-  private val lambda: Lambda,
+  private val structuralKey: Lambda,
   stringRepr: () => String) extends Template[T] {
 
   val tpe = id.getType
@@ -356,7 +359,7 @@ class LambdaTemplate[T] private (
       newApplications,
       newLambdas,
       newDependencies,
-      lambda,
+      structuralKey,
       stringRepr
     )
   }
@@ -364,32 +367,33 @@ class LambdaTemplate[T] private (
   private lazy val str : String = stringRepr()
   override def toString : String = str
 
-  def contextEquality(that: LambdaTemplate[T]) : T = {
-    assert(key == that.key, "Can't generate equality clause for lambdas that don't share structure")
-    assert(dependencies.nonEmpty, "No closures implies obvious equality")
-
-    def rec(e1: Expr, e2: Expr): Seq[T] = (e1,e2) match {
-      case (Variable(id1), Variable(id2)) =>
-        if (dependencies.isDefinedAt(id1)) {
-          Seq(encoder.mkEquals(dependencies(id1), that.dependencies(id2)))
-        } else {
-          Seq.empty
-        }
-
-      case (NAryOperator(es1, _), NAryOperator(es2, _)) =>
-        (es1 zip es2).flatMap(p => rec(p._1, p._2))
-
-      case (BinaryOperator(e11, e12, _), BinaryOperator(e21, e22, _)) =>
-        rec(e11, e21) ++ rec(e12, e22)
-
-      case (UnaryOperator(ue1, _), UnaryOperator(ue2, _)) =>
-        rec(ue1, ue2)
+  def contextEquality(that: LambdaTemplate[T]) : Option[Seq[T]] = {
+    if (structuralKey != that.structuralKey) {
+      None // can't be equal
+    } else if (dependencies.isEmpty) {
+      Some(Seq.empty) // must be equal
+    } else {
+      def rec(e1: Expr, e2: Expr): Seq[T] = (e1,e2) match {
+        case (Variable(id1), Variable(id2)) =>
+          if (dependencies.isDefinedAt(id1)) {
+            Seq(encoder.mkEquals(dependencies(id1), that.dependencies(id2)))
+          } else {
+            Seq.empty
+          }
+
+        case (NAryOperator(es1, _), NAryOperator(es2, _)) =>
+          (es1 zip es2).flatMap(p => rec(p._1, p._2))
+
+        case (BinaryOperator(e11, e12, _), BinaryOperator(e21, e22, _)) =>
+          rec(e11, e21) ++ rec(e12, e22)
+
+        case (UnaryOperator(ue1, _), UnaryOperator(ue2, _)) =>
+          rec(ue1, ue2)
+
+        case _ => Seq.empty
+      }
 
-      case _ => Seq.empty
+      Some(rec(structuralKey, that.structuralKey))
     }
-
-    encoder.mkAnd(rec(lambda, that.lambda) : _*)
   }
-
-  def key : Lambda = LambdaTemplate.templateKey(this)
 }