diff --git a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala index 2b31225ff407e6ce2e22dddef14852bf77eb521f..079fee06dcbc9684ee087a6ac181dbb9966a6dd8 100644 --- a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala +++ b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala @@ -13,29 +13,57 @@ import purescala.Constructors._ import solvers.TimeoutSolver import solvers.z3._ -import codegen.CompilationUnit +import java.util.WeakHashMap +import java.lang.ref.WeakReference +import scala.collection.mutable.{HashMap => MutableMap} -import scala.collection.immutable.{Map => ScalaMap} +import codegen.CompilationUnit import synthesis._ object ChooseEntryPoint { - private[this] var map = ScalaMap[Int, (Problem, CompilationUnit)]() - implicit val debugSection = DebugSectionSynthesis + private case class ChooseId(id: Int) { } + + private[this] val context = new WeakHashMap[ChooseId, (WeakReference[CompilationUnit], Problem)]() + private[this] val cache = new WeakHashMap[ChooseId, MutableMap[Seq[AnyRef], java.lang.Object]]() + + private[this] val ids = new WeakHashMap[CompilationUnit, MutableMap[Problem, ChooseId]]() + + private[this] var _next = 0 + private[this] def nextInt(): Int = { + _next += 1 + _next + } + + private def getUniqueId(unit: CompilationUnit, p: Problem): ChooseId = { + if (!ids.containsKey(unit)) { + ids.put(unit, new MutableMap()) + } + + if (ids.get(unit) contains p) { + ids.get(unit)(p) + } else { + val cid = new ChooseId(nextInt()) + ids.get(unit) += p -> cid + cid + } + } + + def register(p: Problem, unit: CompilationUnit): Int = { - val stored = (p, unit) - val hash = stored.## + val cid = getUniqueId(unit, p) - map += hash -> stored + context.put(cid, new WeakReference(unit) -> p) - hash + cid.id } - private[this] var cache = ScalaMap[(Int, Seq[AnyRef]), java.lang.Object]() def invoke(i: Int, inputs: Array[AnyRef]): java.lang.Object = { - val (p, unit) = map(i) + val id = ChooseId(i) + val (ur, p) = context.get(id) + val unit = ur.get val program = unit.program val ctx = unit.ctx @@ -43,8 +71,14 @@ object ChooseEntryPoint { ctx.reporter.debug("Executing choose (codegen)!") val is = inputs.toSeq - if (cache contains (i, is)) { - cache((i, is)) + if (!cache.containsKey(id)) { + cache.put(id, new MutableMap()) + } + + val chCache = cache.get(id) + + if (chCache contains is) { + chCache(is) } else { val tStart = System.currentTimeMillis @@ -73,7 +107,7 @@ object ChooseEntryPoint { ctx.reporter.debug("Finished synthesis with "+leonRes.asString(ctx)) val obj = unit.exprToJVM(leonRes)(new LeonCodeGenRuntimeMonitor(unit.params.maxFunctionInvocations)) - cache += (i, is) -> obj + chCache += is -> obj obj case Some(false) => throw new LeonCodeGenRuntimeException("Constraint is UNSAT")