Skip to content
Snippets Groups Projects
Commit 015b52bc authored by Nicolas Voirol's avatar Nicolas Voirol
Browse files

Enforce model closure equalities

parent 7734a377
No related branches found
No related tags found
No related merge requests found
......@@ -158,9 +158,10 @@ trait AbstractUnrollingSolver
private def extractTotalModel(model: underlying.Model): Map[ValDef, Expr] = {
val wrapped = wrapModel(model)
val cache: MutableMap[Encoded, Expr] = MutableMap.empty
// maintain extracted functions to make sure equality is well-defined
var funExtractions: Seq[(Encoded, Lambda)] = Seq.empty
def extractValue(v: Encoded, tpe: Type): Expr = cache.getOrElseUpdate(v, {
def extractValue(v: Encoded, tpe: Type): Expr = {
def functionsOf(expr: Expr, selector: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = {
def reconstruct(subs: Seq[(Seq[(Expr, Expr)], Seq[Expr] => Expr)],
recons: Seq[Expr] => Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) =
......@@ -199,7 +200,7 @@ trait AbstractUnrollingSolver
val tpe = bestRealType(f.getType).asInstanceOf[FunctionType]
extractFunction(encoded, tpe)
})
})
}
object FiniteLambda {
def apply(params: Seq[Seq[ValDef]], mappings: Seq[(Expr, Expr)], dflt: Expr): Lambda = {
......@@ -236,7 +237,7 @@ trait AbstractUnrollingSolver
}
}
def extractFunction(f: Encoded, tpe: FunctionType): Expr = cache.getOrElseUpdate(f, {
def extractFunction(f: Encoded, tpe: FunctionType): Expr = {
def extractLambda(f: Encoded, tpe: FunctionType): Option[Lambda] = {
val optEqTemplate = templates.getLambdaTemplates(tpe).find { tmpl =>
wrapped.eval(tmpl.start, BooleanType) == Some(BooleanLiteral(true)) &&
......@@ -295,7 +296,26 @@ trait AbstractUnrollingSolver
}
}
(FiniteLambda(params, mappings, dflt), false)
val lambda = FiniteLambda(params, mappings, dflt)
// make sure `lambda` is not equal to any other distinct extracted first-class function
val res = (funExtractions.collectFirst {
case (encoded, `lambda`) =>
Right(encoded)
case (e, img) if
wrapped.eval(templates.mkEquals(e, f), BooleanType) == Some(BooleanLiteral(true)) =>
Left(img)
}) match {
case Some(Right(enc)) => wrapped.eval(enc, tpe).get match {
case Lambda(_, Let(_, IntegerLiteral(n), _)) => uniquateClosure(n, lambda)
case l => scala.sys.error("Unexpected extracted lambda format: " + l)
}
case Some(Left(img)) => img
case None => lambda
}
funExtractions :+= f -> res
(res, false)
}
}
}
......@@ -353,7 +373,7 @@ trait AbstractUnrollingSolver
extract(f, tpe, params, allArguments, default)._1
}
}
})
}
freeVars.toMap.map { case (v, idT) => v.toVal -> extractValue(idT, v.tpe) }
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment