diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 86405da600c7ec006c41da3c4c33433354775b66..d99a9a1a951edeb2260874d004e911de9628ab7d 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -82,7 +82,8 @@ trait CodeGeneration { private[codegen] val HashingClass = "leon/codegen/runtime/LeonCodeGenRuntimeHashing" private[codegen] val ChooseEntryPointClass = "leon/codegen/runtime/ChooseEntryPoint" private[codegen] val GenericValuesClass = "leon/codegen/runtime/GenericValues" - private[codegen] val MonitorClass = "leon/codegen/runtime/LeonCodeGenRuntimeMonitor" + private[codegen] val MonitorClass = "leon/codegen/runtime/Monitor" + private[codegen] val NoMonitorClass = "leon/codegen/runtime/NoMonitor" private[codegen] val HenkinClass = "leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor" private[codegen] val StrOpsClass = "leon/codegen/runtime/StrOps" @@ -165,13 +166,7 @@ trait CodeGeneration { val cf = classes(owner) val (_,mn,_) = leonFunDefToJVMInfo(funDef).get - val paramsTypes = funDef.params.map(a => typeToJVM(a.getType)) - - val realParams = if (requireMonitor) { - ("L" + MonitorClass + ";") +: paramsTypes - } else { - paramsTypes - } + val realParams = ("L" + MonitorClass + ";") +: funDef.params.map(a => typeToJVM(a.getType)) val m = cf.addMethod( typeToJVM(funDef.returnType), @@ -194,7 +189,7 @@ trait CodeGeneration { // An offset we introduce to the parameters: // 1 if this is a method, so we need "this" in position 0 of the stack // 1 if we are monitoring - val idParams = (if (requireMonitor) Seq(monitorID) else Seq.empty) ++ funDef.paramIds + val idParams = monitorID +: funDef.paramIds val newMapping = idParams.zipWithIndex.toMap.mapValues(_ + (if (!isStatic) 1 else 0)) val body = if (params.checkContracts) { @@ -207,7 +202,7 @@ trait CodeGeneration { if (params.recordInvocations) { load(monitorID, ch)(locals) - ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") + ch << InvokeVirtual(MonitorClass, "onInvocation", "()V") } mkExpr(body, ch)(locals) @@ -233,9 +228,7 @@ trait CodeGeneration { val closureIDs = purescala.ExprOps.variablesOf(nl).toSeq.sortBy(_.uniqueName) val closuresWithoutMonitor = closureIDs.map(id => id -> typeToJVM(id.getType)) - val closures = if (requireMonitor) { - (monitorID -> s"L$MonitorClass;") +: closuresWithoutMonitor - } else closuresWithoutMonitor + val closures = (monitorID -> s"L$MonitorClass;") +: closuresWithoutMonitor val afName = lambdaToClass.getOrElse(nl, { val afName = "Leon$CodeGen$Lambda$" + lambdaCounter.nextGlobal @@ -839,9 +832,7 @@ trait CodeGeneration { throw CompilationException("Unknown class : " + cct.id) } ch << New(ccName) << DUP - if (requireMonitor) { - load(monitorID, ch) - } + load(monitorID, ch) for((a, vd) <- as zip cct.classDef.fields) { vd.getType match { @@ -968,11 +959,6 @@ trait CodeGeneration { throw CompilationException("Unknown method : " + tfd.id) } - if (requireMonitor) { - load(monitorID, ch) - ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") - } - // Get static field ch << GetStatic(className, fieldName, typeToJVM(tfd.fd.returnType)) @@ -1024,10 +1010,9 @@ trait CodeGeneration { ch << POP << POP // list, it, cons, cons, elem, list - if (requireMonitor) { - load(monitorID, ch) - ch << DUP_X2 << POP - } + load(monitorID, ch) + ch << DUP_X2 << POP + ch << InvokeSpecial(consName, constructorName, ccApplySig) // list, it, newList ch << DUP_X2 << POP << SWAP << POP @@ -1039,15 +1024,36 @@ trait CodeGeneration { ch << POP // list + case FunctionInvocation(tfd, as) if abstractFunDefs(tfd.fd.id) => + val id = registerAbstractFD(tfd.fd) + + load(monitorID, ch) + + ch << Ldc(id) + + ch << Ldc(as.size) + ch << NewArray(ObjectClass) + + for ((e, i) <- as.zipWithIndex) { + ch << DUP + ch << Ldc(i) + mkExpr(e, ch) + mkBox(e.getType, ch) + ch << AASTORE + } + + ch << InvokeVirtual(MonitorClass, "onAbstractInvocation", s"(I[L$ObjectClass;)L$ObjectClass;") + + mkUnbox(tfd.returnType, ch) + + // Static lazy fields/ functions case fi @ FunctionInvocation(tfd, as) => val (cn, mn, ms) = leonFunDefToJVMInfo(tfd.fd).getOrElse { throw CompilationException("Unknown method : " + tfd.id) } - if (requireMonitor) { - load(monitorID, ch) - } + load(monitorID, ch) for((a, vd) <- as zip tfd.fd.params) { vd.getType match { @@ -1072,10 +1078,6 @@ trait CodeGeneration { throw CompilationException("Unknown method : " + tfd.id) } - if (requireMonitor) { - load(monitorID, ch) - ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") - } // Load receiver mkExpr(rec,ch) @@ -1097,11 +1099,9 @@ trait CodeGeneration { } // Receiver of the method call - mkExpr(rec,ch) + mkExpr(rec, ch) - if (requireMonitor) { - load(monitorID, ch) - } + load(monitorID, ch) for((a, vd) <- as zip tfd.fd.params) { vd.getType match { @@ -1410,7 +1410,10 @@ trait CodeGeneration { case choose: Choose => val prob = synthesis.Problem.fromSpec(choose.pred) - val id = runtime.ChooseEntryPoint.register(prob, this) + val id = registerProblem(prob) + + load(monitorID, ch) + ch << Ldc(id) ch << Ldc(prob.as.size) @@ -1424,7 +1427,7 @@ trait CodeGeneration { ch << AASTORE } - ch << InvokeStatic(ChooseEntryPointClass, "invoke", s"(I[L$ObjectClass;)L$ObjectClass;") + ch << InvokeVirtual(MonitorClass, "onChooseInvocation", s"(I[L$ObjectClass;)L$ObjectClass;") mkUnbox(choose.getType, ch) @@ -1731,9 +1734,7 @@ trait CodeGeneration { // accessor method locally { - val parameters = if (requireMonitor) { - Seq(monitorID -> s"L$MonitorClass;") - } else Seq() + val parameters = Seq(monitorID -> s"L$MonitorClass;") val paramMapping = parameters.map(_._1).zipWithIndex.toMap.mapValues(_ + (if (isStatic) 0 else 1)) val newLocs = NoLocals.withVars(paramMapping) @@ -1749,11 +1750,6 @@ trait CodeGeneration { val body = lzy.body.getOrElse(throw CompilationException("Lazy field without body?")) val initLabel = ch.getFreshLabel("isInitialized") - if (requireMonitor) { - load(monitorID, ch)(newLocs) - ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") - } - if (isStatic) { ch << GetStatic(cName, underlyingName, underlyingType) } else { @@ -1890,9 +1886,7 @@ trait CodeGeneration { // definition of the constructor locally { - val constrParams = if (requireMonitor) { - Seq(monitorID -> s"L$MonitorClass;") - } else Seq() + val constrParams = Seq(monitorID -> s"L$MonitorClass;") val newLocs = NoLocals.withVars { constrParams.map(_._1).zipWithIndex.toMap.mapValues(_ + 1) @@ -1909,8 +1903,8 @@ trait CodeGeneration { case Some(parent) => val pName = defToJVMName(parent.classDef) // Load monitor object - if (requireMonitor) cch << ALoad(1) - val constrSig = if (requireMonitor) "(L" + MonitorClass + ";)V" else "()V" + cch << ALoad(1) + val constrSig = "(L" + MonitorClass + ";)V" cch << InvokeSpecial(pName, constructorName, constrSig) case None => @@ -1985,9 +1979,7 @@ trait CodeGeneration { // Case class parameters val fieldsTypes = ccd.fields.map { vd => (vd.id, typeToJVM(vd.getType)) } - val constructorArgs = if (requireMonitor) { - (monitorID -> s"L$MonitorClass;") +: fieldsTypes - } else fieldsTypes + val constructorArgs = (monitorID -> s"L$MonitorClass;") +: fieldsTypes val newLocs = NoLocals.withFields(constructorArgs.map { case (id, jvmt) => (id, (cName, id.name, jvmt)) @@ -2013,62 +2005,54 @@ trait CodeGeneration { } // definition of the constructor - if(!params.doInstrument && !requireMonitor && ccd.fields.isEmpty && !ccd.methods.exists(_.canBeField)) { - cf.addDefaultConstructor - } else { - for((id, jvmt) <- constructorArgs) { - val fh = cf.addField(jvmt, id.name) - fh.setFlags(( - FIELD_ACC_PUBLIC | - FIELD_ACC_FINAL - ).asInstanceOf[U2]) - } - - if (params.doInstrument) { - val fh = cf.addField("I", instrumentedField) - fh.setFlags(FIELD_ACC_PUBLIC) - } + for((id, jvmt) <- constructorArgs) { + val fh = cf.addField(jvmt, id.name) + fh.setFlags(( + FIELD_ACC_PUBLIC | + FIELD_ACC_FINAL + ).asInstanceOf[U2]) + } - val cch = cf.addConstructor(constructorArgs.map(_._2) : _*).codeHandler + if (params.doInstrument) { + val fh = cf.addField("I", instrumentedField) + fh.setFlags(FIELD_ACC_PUBLIC) + } - if (params.doInstrument) { - cch << ALoad(0) - cch << Ldc(0) - cch << PutField(cName, instrumentedField, "I") - } + val cch = cf.addConstructor(constructorArgs.map(_._2) : _*).codeHandler - var c = 1 - for((id, jvmt) <- constructorArgs) { - cch << ALoad(0) - cch << (jvmt match { - case "I" | "Z" => ILoad(c) - case _ => ALoad(c) - }) - cch << PutField(cName, id.name, jvmt) - c += 1 - } + if (params.doInstrument) { + cch << ALoad(0) + cch << Ldc(0) + cch << PutField(cName, instrumentedField, "I") + } - // Call parent constructor AFTER initializing case class parameters - if (ccd.parent.isDefined) { - cch << ALoad(0) - if (requireMonitor) { - cch << ALoad(1) - cch << InvokeSpecial(pName.get, constructorName, s"(L$MonitorClass;)V") - } else { - cch << InvokeSpecial(pName.get, constructorName, "()V") - } - } else { - // Call constructor of java.lang.Object - cch << ALoad(0) - cch << InvokeSpecial(ObjectClass, constructorName, "()V") - } + var c = 1 + for((id, jvmt) <- constructorArgs) { + cch << ALoad(0) + cch << (jvmt match { + case "I" | "Z" => ILoad(c) + case _ => ALoad(c) + }) + cch << PutField(cName, id.name, jvmt) + c += 1 + } - // Now initialize fields - for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, isStatic = false)(newLocs) } - for (field <- strictFields) { initStrictField(cch, cName , field, isStatic = false)(newLocs) } - cch << RETURN - cch.freeze + // Call parent constructor AFTER initializing case class parameters + if (ccd.parent.isDefined) { + cch << ALoad(0) + cch << ALoad(1) + cch << InvokeSpecial(pName.get, constructorName, s"(L$MonitorClass;)V") + } else { + // Call constructor of java.lang.Object + cch << ALoad(0) + cch << InvokeSpecial(ObjectClass, constructorName, "()V") } + + // Now initialize fields + for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, isStatic = false)(newLocs) } + for (field <- strictFields) { initStrictField(cch, cName , field, isStatic = false)(newLocs) } + cch << RETURN + cch.freeze } locally { diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 26f976c095489f22cbc97dd9ee810d1d71ef9f99..18a0b607589f334fa021ab31ab4163201b5ba888 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -10,9 +10,8 @@ import purescala.ExprOps._ import purescala.Types._ import purescala.Extractors._ import purescala.Constructors._ -import codegen.runtime.LeonCodeGenRuntimeMonitor -import codegen.runtime.LeonCodeGenRuntimeHenkinMonitor import utils.UniqueCounter +import runtime.{Monitor, StdMonitor} import cafebabe._ import cafebabe.AbstractByteCodes._ @@ -24,22 +23,43 @@ import scala.collection.JavaConverters._ import java.lang.reflect.Constructor +import synthesis.Problem class CompilationUnit(val ctx: LeonContext, val program: Program, val params: CodeGenParams = CodeGenParams.default) extends CodeGeneration { + protected[codegen] val requireQuantification = program.definedFunctions.exists { fd => exists { case _: Forall => true case _ => false } (fd.fullBody) } - protected[codegen] val requireMonitor = params.requireMonitor || requireQuantification - val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader) var classes = Map[Definition, ClassFile]() + var defToModuleOrClass = Map[Definition, Definition]() + val abstractFunDefs = program.definedFunctions.filter(_.body.isEmpty).map(_.id).toSet + + val runtimeCounter = new UniqueCounter[Unit] + + var runtimeProblemMap = Map[Int, Problem]() + + def registerProblem(p: Problem): Int = { + val id = runtimeCounter.nextGlobal + runtimeProblemMap += id -> p + id + } + + var runtimeAbstractMap = Map[Int, FunDef]() + + def registerAbstractFD(fd: FunDef): Int = { + val id = runtimeCounter.nextGlobal + runtimeAbstractMap += id -> fd + id + } + def defineClass(df: Definition) { val cName = defToJVMName(df) @@ -65,8 +85,7 @@ class CompilationUnit(val ctx: LeonContext, def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = { classes.get(cd) match { case Some(cf) => - val monitorType = if (requireMonitor) "L"+MonitorClass+";" else "" - val sig = "(" + monitorType + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" + val sig = "(L"+MonitorClass+";" + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" Some((cf.className, sig)) case _ => None } @@ -84,9 +103,7 @@ class CompilationUnit(val ctx: LeonContext, */ def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = { funDefInfo.get(fd).orElse { - val monitorType = if (requireMonitor) "L"+MonitorClass+";" else "" - - val sig = "(" + monitorType + fd.params.map(a => typeToJVM(a.getType)).mkString("") + ")" + typeToJVM(fd.returnType) + val sig = "(L"+MonitorClass+";" + fd.params.map(a => typeToJVM(a.getType)).mkString("") + ")" + typeToJVM(fd.returnType) defToModuleOrClass.get(fd).flatMap(m => classes.get(m)) match { case Some(cf) => @@ -127,31 +144,36 @@ class CompilationUnit(val ctx: LeonContext, conss.last } - def modelToJVM(model: solvers.Model, maxInvocations: Int, check: Boolean): LeonCodeGenRuntimeMonitor = model match { - case hModel: solvers.HenkinModel => - val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations, check) - for ((lambda, domain) <- hModel.doms.lambdas) { - val (afName, _, _) = compileLambda(lambda) - val lc = loader.loadClass(afName) - - for (args <- domain) { - // note here that it doesn't matter that `lhm` doesn't yet have its domains - // filled since all values in `args` should be grounded - val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] - lhm.add(lc, inputJvm) - } - } + def getMonitor(model: solvers.Model, maxInvocations: Int, check: Boolean): Monitor = { + val bodies = model.toSeq.filter { case (id, v) => abstractFunDefs(id) }.toMap - for ((tpe, domain) <- hModel.doms.tpes; args <- domain) { - val tpeId = typeId(tpe) - // same remark as above about valueToJVM(_)(lhm) - val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] - lhm.add(tpeId, inputJvm) - } - lhm - case _ => - new LeonCodeGenRuntimeMonitor(maxInvocations) + new StdMonitor(this, maxInvocations, bodies) } + // model match { + // case hModel: solvers.HenkinModel => + // val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations, check) + // for ((lambda, domain) <- hModel.doms.lambdas) { + // val (afName, _, _) = compileLambda(lambda) + // val lc = loader.loadClass(afName) + + // for (args <- domain) { + // // note here that it doesn't matter that `lhm` doesn't yet have its domains + // // filled since all values in `args` should be grounded + // val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] + // lhm.add(lc, inputJvm) + // } + // } + + // for ((tpe, domain) <- hModel.doms.tpes; args <- domain) { + // val tpeId = typeId(tpe) + // // same remark as above about valueToJVM(_)(lhm) + // val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] + // lhm.add(tpeId, inputJvm) + // } + // lhm + // case _ => + // new LeonCodeGenRuntimeMonitor(maxInvocations) + //} /** Translates Leon values (not generic expressions) to JVM compatible objects. * @@ -159,7 +181,7 @@ class CompilationUnit(val ctx: LeonContext, * This means it is safe to return AnyRef (as opposed to primitive types), because * reflection needs this anyway. */ - def valueToJVM(e: Expr)(implicit monitor: LeonCodeGenRuntimeMonitor): AnyRef = e match { + def valueToJVM(e: Expr)(implicit monitor: Monitor): AnyRef = e match { case IntLiteral(v) => new java.lang.Integer(v) @@ -190,8 +212,8 @@ class CompilationUnit(val ctx: LeonContext, case CaseClass(cct, args) => caseClassConstructor(cct.classDef) match { case Some(cons) => - val realArgs = if (requireMonitor) monitor +: args.map(valueToJVM) else args.map(valueToJVM) - cons.newInstance(realArgs.toArray : _*).asInstanceOf[AnyRef] + val jvmArgs = monitor +: args.map(valueToJVM) + cons.newInstance(jvmArgs.toArray : _*).asInstanceOf[AnyRef] case None => ctx.reporter.fatalError("Case class constructor not found?!?") } @@ -271,10 +293,6 @@ class CompilationUnit(val ctx: LeonContext, case _ => throw CompilationException(s"Unexpected expression $e in valueToJVM") - - // Just slightly overkill... - //case _ => - // compileExpression(e, Seq()).evalToJVM(Seq(),monitor) } /** Translates JVM objects back to Leon values of the appropriate type */ @@ -390,11 +408,7 @@ class CompilationUnit(val ctx: LeonContext, val argsTypes = args.map(a => typeToJVM(a.getType)) - val realArgs = if (requireMonitor) { - ("L" + MonitorClass + ";") +: argsTypes - } else { - argsTypes - } + val realArgs = ("L" + MonitorClass + ";") +: argsTypes val m = cf.addMethod( typeToJVM(e.getType), @@ -410,11 +424,7 @@ class CompilationUnit(val ctx: LeonContext, val ch = m.codeHandler - val newMapping = if (requireMonitor) { - args.zipWithIndex.toMap.mapValues(_ + 1) + (monitorID -> 0) - } else { - args.zipWithIndex.toMap - } + val newMapping = Map(monitorID -> 0) ++ args.zipWithIndex.toMap.mapValues(_ + 1) mkExpr(e, ch)(NoLocals.withVars(newMapping)) @@ -480,10 +490,10 @@ class CompilationUnit(val ctx: LeonContext, * method invocations here :( */ val locals = NoLocals.withVar(monitorID -> ch.getFreshVar) - ch << New(MonitorClass) << DUP - ch << Ldc(Int.MaxValue) // Allow "infinite" method calls - ch << InvokeSpecial(MonitorClass, cafebabe.Defaults.constructorName, "(I)V") + ch << New(NoMonitorClass) << DUP + ch << InvokeSpecial(NoMonitorClass, cafebabe.Defaults.constructorName, "()V") ch << AStore(locals.varToLocal(monitorID).get) // position 0 + for (lzy <- lazyFields) { initLazyField(ch, cName, lzy, isStatic = true)(locals) } for (field <- strictFields) { initStrictField(ch, cName , field, isStatic = true)(locals) } ch << RETURN diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala index f9fca911564c61ad984fa97c3f2ac0da7fc021b4..8ce2cb269546ef69f6edb1eefecd0d53123fd15b 100644 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ b/src/main/scala/leon/codegen/CompiledExpression.scala @@ -8,7 +8,7 @@ import purescala.Expressions._ import cafebabe._ -import runtime.{LeonCodeGenRuntimeMonitor => LM} +import runtime.Monitor import java.lang.reflect.InvocationTargetException @@ -21,29 +21,21 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, private val params = unit.params - def argsToJVM(args: Seq[Expr], monitor: LM): Seq[AnyRef] = { + def argsToJVM(args: Seq[Expr], monitor: Monitor): Seq[AnyRef] = { args.map(unit.valueToJVM(_)(monitor)) } - def evalToJVM(args: Seq[AnyRef], monitor: LM): AnyRef = { + def evalToJVM(args: Seq[AnyRef], monitor: Monitor): AnyRef = { assert(args.size == argsDecl.size) - val realArgs = if (unit.requireMonitor) { - monitor +: args - } else { - args - } + val allArgs = monitor +: args - if (realArgs.isEmpty) { - meth.invoke(null) - } else { - meth.invoke(null, realArgs.toArray : _*) - } + meth.invoke(null, allArgs.toArray : _*) } // This may throw an exception. We unwrap it if needed. // We also need to reattach a type in some cases (sets, maps). - def evalFromJVM(args: Seq[AnyRef], monitor: LM) : Expr = { + def evalFromJVM(args: Seq[AnyRef], monitor: Monitor) : Expr = { try { unit.jvmToValue(evalToJVM(args, monitor), exprType) } catch { @@ -53,7 +45,8 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, def eval(model: solvers.Model, check: Boolean = false) : Expr = { try { - val monitor = unit.modelToJVM(model, params.maxFunctionInvocations, check) + val monitor = unit.getMonitor(model, params.maxFunctionInvocations, check) + evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) } catch { case ite : InvocationTargetException => throw ite.getCause diff --git a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala deleted file mode 100644 index 84968a169370ff3c415bf9689f48c86577eed44e..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package codegen.runtime - -import utils._ -import purescala.Expressions._ -import purescala.ExprOps.valuateWithModel -import purescala.Constructors._ -import solvers.SolverFactory - -import java.util.WeakHashMap -import java.lang.ref.WeakReference -import scala.collection.mutable.{HashMap => MutableMap} -import scala.concurrent.duration._ - -import codegen.CompilationUnit - -import synthesis._ - -object ChooseEntryPoint { - implicit val debugSection = DebugSectionSynthesis - - private case class ChooseId(id: Int) { } - - private[this] val context = new WeakHashMap[ChooseId, (WeakReference[CompilationUnit], Problem)]() - private[this] val cache = new WeakHashMap[ChooseId, MutableMap[Seq[AnyRef], java.lang.Object]]() - - private[this] val ids = new WeakHashMap[CompilationUnit, MutableMap[Problem, ChooseId]]() - - private val intCounter = new UniqueCounter[Unit] - intCounter.nextGlobal // Start with 1 - - private def getUniqueId(unit: CompilationUnit, p: Problem): ChooseId = synchronized { - if (!ids.containsKey(unit)) { - ids.put(unit, new MutableMap()) - } - - if (ids.get(unit) contains p) { - ids.get(unit)(p) - } else { - val cid = new ChooseId(intCounter.nextGlobal) - ids.get(unit) += p -> cid - cid - } - } - - def register(p: Problem, unit: CompilationUnit): Int = { - val cid = getUniqueId(unit, p) - - context.put(cid, new WeakReference(unit) -> p) - - cid.id - } - - def invoke(i: Int, inputs: Array[AnyRef]): java.lang.Object = { - val id = ChooseId(i) - val (ur, p) = context.get(id) - val unit = ur.get - - val program = unit.program - val ctx = unit.ctx - - ctx.reporter.debug("Executing choose (codegen)!") - val is = inputs.toSeq - - if (!cache.containsKey(id)) { - cache.put(id, new MutableMap()) - } - - val chCache = cache.get(id) - - if (chCache contains is) { - chCache(is) - } else { - val tStart = System.currentTimeMillis - - val solverf = SolverFactory.default(ctx, program).withTimeout(10.second) - val solver = solverf.getNewSolver() - - val inputsMap = (p.as zip inputs).map { - case (id, v) => - Equals(Variable(id), unit.jvmToValue(v, id.getType)) - } - - solver.assertCnstr(andJoin(Seq(p.pc, p.phi) ++ inputsMap)) - - try { - solver.check match { - case Some(true) => - val model = solver.getModel - - val valModel = valuateWithModel(model) _ - - val res = p.xs.map(valModel) - val leonRes = tupleWrap(res) - - val total = System.currentTimeMillis-tStart - - ctx.reporter.debug("Synthesis took "+total+"ms") - ctx.reporter.debug("Finished synthesis with "+leonRes.asString(ctx)) - - val obj = unit.valueToJVM(leonRes)(new LeonCodeGenRuntimeMonitor(unit.params.maxFunctionInvocations)) - chCache += is -> obj - obj - case Some(false) => - throw new LeonCodeGenRuntimeException("Constraint is UNSAT") - case _ => - throw new LeonCodeGenRuntimeException("Timeout exceeded") - } - } finally { - solver.free() - solverf.shutdown() - } - } - } -} diff --git a/src/main/scala/leon/codegen/runtime/Monitor.scala b/src/main/scala/leon/codegen/runtime/Monitor.scala new file mode 100644 index 0000000000000000000000000000000000000000..a4fcbe8158d2b727a865969f36f19194fdfef56b --- /dev/null +++ b/src/main/scala/leon/codegen/runtime/Monitor.scala @@ -0,0 +1,131 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package codegen.runtime + +import utils._ +import purescala.Expressions._ +import purescala.Constructors._ +import purescala.Common._ +import purescala.ExprOps.valuateWithModel + +import codegen.CompilationUnit + +import scala.collection.immutable.{Map => ScalaMap} +import scala.collection.mutable.{HashMap => MutableMap} +import scala.concurrent.duration._ + +import solvers.SolverFactory + + +import synthesis._ + +abstract class Monitor { + def onInvocation(): Unit + + def onAbstractInvocation(id: Int, args: Array[AnyRef]): AnyRef + + def onChooseInvocation(id: Int, args: Array[AnyRef]): AnyRef +} + +class NoMonitor extends Monitor { + def onInvocation(): Unit = {} + + def onAbstractInvocation(id: Int, args: Array[AnyRef]): AnyRef = { + throw new LeonCodeGenEvaluationException("No monitor available."); + } + + def onChooseInvocation(id: Int, args: Array[AnyRef]): AnyRef = { + throw new LeonCodeGenEvaluationException("No monitor available."); + } +} + +class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Identifier, Expr]) extends Monitor { + + private[this] var invocations = 0 + + def onInvocation(): Unit = { + if(invocationsMax >= 0) { + if (invocations < invocationsMax) { + invocations += 1; + } else { + throw new LeonCodeGenEvaluationException("Maximum number of invocations reached ("+invocationsMax+")."); + } + } + } + + def onAbstractInvocation(id: Int, args: Array[AnyRef]): AnyRef = { + val fd = unit.runtimeAbstractMap(id) + + bodies.get(fd.id) match { + case Some(expr) => + throw new LeonCodeGenRuntimeException("Found body!") + + case None => + throw new LeonCodeGenRuntimeException("Did not find body!") + } + } + + + private[this] val cache = new MutableMap[(Int, Seq[AnyRef]), AnyRef]() + + def onChooseInvocation(id: Int, inputs: Array[AnyRef]): AnyRef = { + + implicit val debugSection = DebugSectionSynthesis + + val p = unit.runtimeProblemMap(id) + + val program = unit.program + val ctx = unit.ctx + + ctx.reporter.debug("Executing choose (codegen)!") + val is = inputs.toSeq + + if (cache contains ((id, is))) { + cache((id, is)) + } else { + val tStart = System.currentTimeMillis + + val solverf = SolverFactory.default(ctx, program).withTimeout(10.second) + val solver = solverf.getNewSolver() + + val inputsMap = (p.as zip inputs).map { + case (id, v) => + Equals(Variable(id), unit.jvmToValue(v, id.getType)) + } + + solver.assertCnstr(andJoin(Seq(p.pc, p.phi) ++ inputsMap)) + + try { + solver.check match { + case Some(true) => + val model = solver.getModel + + val valModel = valuateWithModel(model) _ + + val res = p.xs.map(valModel) + val leonRes = tupleWrap(res) + + val total = System.currentTimeMillis-tStart + + ctx.reporter.debug("Synthesis took "+total+"ms") + ctx.reporter.debug("Finished synthesis with "+leonRes.asString(ctx)) + + val obj = unit.valueToJVM(leonRes)(this) + cache += (id, is) -> obj + obj + case Some(false) => + throw new LeonCodeGenRuntimeException("Constraint is UNSAT") + case _ => + throw new LeonCodeGenRuntimeException("Timeout exceeded") + } + } finally { + solver.free() + solverf.shutdown() + } + } + } + + +} + diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index b82dc6a0f256fec3c3cf733addad76f899456f4a..af315aacade7e06eeb346dcbb94f56f53f9d5f87 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -13,7 +13,7 @@ import purescala.Constructors._ import codegen.CompilationUnit import codegen.CodeGenParams -import codegen.runtime.LeonCodeGenRuntimeMonitor +import codegen.runtime.StdMonitor import vanuatoo.{Pattern => VPattern, _} import evaluators._ @@ -262,7 +262,8 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { Some((args : Expr) => { try { - val monitor = new LeonCodeGenRuntimeMonitor(unit.params.maxFunctionInvocations) + val monitor = new StdMonitor(unit, unit.params.maxFunctionInvocations, Map()) + val jvmArgs = ce.argsToJVM(Seq(args), monitor) val result = ce.evalFromJVM(jvmArgs, monitor) diff --git a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala index 533ba695ca27f478a03ac7b6cf53d46885e11309..e6bb722aedde0b6faecdcc2bf11651c0a7353859 100644 --- a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala +++ b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala @@ -38,13 +38,23 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva } } - def check(expression: Expr, model: solvers.Model) : CheckResult = { - compileExpr(expression, model.toSeq.map(_._1)).map { ce => + + def check(expression: Expr, fullModel: solvers.Model) : CheckResult = { + val (_, assign) = fullModel.toSeq.partition { + case (id, v) => unit.abstractFunDefs(id) + } + + compileExpr(expression, assign.map(_._1)).map { ce => ctx.timers.evaluators.codegen.runtime.start() + try { - val res = ce.eval(model, check = true) - if (res == BooleanLiteral(true)) EvaluationResults.CheckSuccess - else EvaluationResults.CheckValidityFailure + val res = ce.eval(fullModel, check = true) + + if (res == BooleanLiteral(true)) { + EvaluationResults.CheckSuccess + } else { + EvaluationResults.CheckValidityFailure + } } catch { case e : ArithmeticException => EvaluationResults.CheckRuntimeFailure(e.getMessage) diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala index 4c405c8b6f216ee9b839101d8bff5574035b05f4..b61bcb5a038dc47482beb926f5a6bce43e2fccee 100644 --- a/src/main/scala/leon/evaluators/DualEvaluator.scala +++ b/src/main/scala/leon/evaluators/DualEvaluator.scala @@ -9,6 +9,7 @@ import purescala.Definitions._ import purescala.Types._ import codegen._ +import codegen.runtime.{StdMonitor, Monitor} class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) extends RecursiveEvaluator(ctx, prog, params.maxFunctionInvocations) @@ -19,10 +20,11 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) def initRC(mappings: Map[Identifier, Expr]): RC = DualRecContext(mappings) implicit val debugSection = utils.DebugSectionEvaluation - var monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) val unit = new CompilationUnit(ctx, prog, params) + var monitor: Monitor = new StdMonitor(unit, params.maxFunctionInvocations, Map()) + val isCompiled = prog.definedFunctions.toSet case class DualRecContext(mappings: Map[Identifier, Expr], needJVMRef: Boolean = false) extends RecContext[DualRecContext] { @@ -126,7 +128,7 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) override def eval(ex: Expr, model: solvers.Model) = { - monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) + monitor = unit.getMonitor(model, params.maxFunctionInvocations, false) super.eval(ex, model) }