diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala
index 6ca90d715cff35ad5878e120bbbcb4c910a89780..6dcf5748568ebe8d5abfabe6ef21b65acb125967 100644
--- a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala
+++ b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala
@@ -158,9 +158,10 @@ trait AbstractUnrollingSolver
   private def extractTotalModel(model: underlying.Model): Map[ValDef, Expr] = {
     val wrapped = wrapModel(model)
 
-    val cache: MutableMap[Encoded, Expr] = MutableMap.empty
+    // maintain extracted functions to make sure equality is well-defined
+    var funExtractions: Seq[(Encoded, Lambda)] = Seq.empty
 
-    def extractValue(v: Encoded, tpe: Type): Expr = cache.getOrElseUpdate(v, {
+    def extractValue(v: Encoded, tpe: Type): Expr = {
       def functionsOf(expr: Expr, selector: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = {
         def reconstruct(subs: Seq[(Seq[(Expr, Expr)], Seq[Expr] => Expr)],
                         recons: Seq[Expr] => Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) =
@@ -199,7 +200,7 @@ trait AbstractUnrollingSolver
         val tpe = bestRealType(f.getType).asInstanceOf[FunctionType]
         extractFunction(encoded, tpe)
       })
-    })
+    }
 
     object FiniteLambda {
       def apply(params: Seq[Seq[ValDef]], mappings: Seq[(Expr, Expr)], dflt: Expr): Lambda = {
@@ -236,7 +237,7 @@ trait AbstractUnrollingSolver
       }
     }
 
-    def extractFunction(f: Encoded, tpe: FunctionType): Expr = cache.getOrElseUpdate(f, {
+    def extractFunction(f: Encoded, tpe: FunctionType): Expr = {
       def extractLambda(f: Encoded, tpe: FunctionType): Option[Lambda] = {
         val optEqTemplate = templates.getLambdaTemplates(tpe).find { tmpl =>
           wrapped.eval(tmpl.start, BooleanType) == Some(BooleanLiteral(true)) &&
@@ -295,7 +296,26 @@ trait AbstractUnrollingSolver
               }
             }
 
-            (FiniteLambda(params, mappings, dflt), false)
+            val lambda = FiniteLambda(params, mappings, dflt)
+            // make sure `lambda` is not equal to any other distinct extracted first-class function
+            val res = (funExtractions.collectFirst {
+              case (encoded, `lambda`) =>
+                Right(encoded)
+              case (e, img) if
+              wrapped.eval(templates.mkEquals(e, f), BooleanType) == Some(BooleanLiteral(true)) =>
+                Left(img)
+            }) match {
+              case Some(Right(enc)) => wrapped.eval(enc, tpe).get match {
+                case Lambda(_, Let(_, IntegerLiteral(n), _)) => uniquateClosure(n, lambda)
+                case l => scala.sys.error("Unexpected extracted lambda format: " + l)
+              }
+              case Some(Left(img)) => img
+              case None => lambda
+            }
+
+            funExtractions :+= f -> res
+
+            (res, false)
           }
         }
       }
@@ -353,7 +373,7 @@ trait AbstractUnrollingSolver
           extract(f, tpe, params, allArguments, default)._1
         }
       }
-    })
+    }
 
     freeVars.toMap.map { case (v, idT) => v.toVal -> extractValue(idT, v.tpe) }
   }