diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala index 6ca90d715cff35ad5878e120bbbcb4c910a89780..6dcf5748568ebe8d5abfabe6ef21b65acb125967 100644 --- a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala +++ b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala @@ -158,9 +158,10 @@ trait AbstractUnrollingSolver private def extractTotalModel(model: underlying.Model): Map[ValDef, Expr] = { val wrapped = wrapModel(model) - val cache: MutableMap[Encoded, Expr] = MutableMap.empty + // maintain extracted functions to make sure equality is well-defined + var funExtractions: Seq[(Encoded, Lambda)] = Seq.empty - def extractValue(v: Encoded, tpe: Type): Expr = cache.getOrElseUpdate(v, { + def extractValue(v: Encoded, tpe: Type): Expr = { def functionsOf(expr: Expr, selector: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = { def reconstruct(subs: Seq[(Seq[(Expr, Expr)], Seq[Expr] => Expr)], recons: Seq[Expr] => Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = @@ -199,7 +200,7 @@ trait AbstractUnrollingSolver val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] extractFunction(encoded, tpe) }) - }) + } object FiniteLambda { def apply(params: Seq[Seq[ValDef]], mappings: Seq[(Expr, Expr)], dflt: Expr): Lambda = { @@ -236,7 +237,7 @@ trait AbstractUnrollingSolver } } - def extractFunction(f: Encoded, tpe: FunctionType): Expr = cache.getOrElseUpdate(f, { + def extractFunction(f: Encoded, tpe: FunctionType): Expr = { def extractLambda(f: Encoded, tpe: FunctionType): Option[Lambda] = { val optEqTemplate = templates.getLambdaTemplates(tpe).find { tmpl => wrapped.eval(tmpl.start, BooleanType) == Some(BooleanLiteral(true)) && @@ -295,7 +296,26 @@ trait AbstractUnrollingSolver } } - (FiniteLambda(params, mappings, dflt), false) + val lambda = FiniteLambda(params, mappings, dflt) + // make sure `lambda` is not equal to any other distinct extracted first-class function + val res = (funExtractions.collectFirst { + case (encoded, `lambda`) => + Right(encoded) + case (e, img) if + wrapped.eval(templates.mkEquals(e, f), BooleanType) == Some(BooleanLiteral(true)) => + Left(img) + }) match { + case Some(Right(enc)) => wrapped.eval(enc, tpe).get match { + case Lambda(_, Let(_, IntegerLiteral(n), _)) => uniquateClosure(n, lambda) + case l => scala.sys.error("Unexpected extracted lambda format: " + l) + } + case Some(Left(img)) => img + case None => lambda + } + + funExtractions :+= f -> res + + (res, false) } } } @@ -353,7 +373,7 @@ trait AbstractUnrollingSolver extract(f, tpe, params, allArguments, default)._1 } } - }) + } freeVars.toMap.map { case (v, idT) => v.toVal -> extractValue(idT, v.tpe) } }