diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 1534508922630c1f68bc67c968db96c42a5e5d95..dcea1804a8fcfe99214510334247ac01c8551aa4 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -7,6 +7,7 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ import purescala.TypeTrees._ +import purescala.TypeTreeOps.instantiateType import utils._ import cafebabe._ @@ -19,19 +20,32 @@ import cafebabe.Flags._ trait CodeGeneration { self: CompilationUnit => - case class Locals(vars: Map[Identifier, Int]) { + /** A class providing information about the status of parameters in the function that is being currently compiled. + * vars is a mapping from local variables/ parameters to the offset of the respective JVM local register + * isStatic signifies if the current method is static (a function, in Leon terms) + */ + case class Locals(vars: Map[Identifier, Int], private val isStatic : Boolean ) { + /** Fetches the offset of a local variable/ parameter from its identifier */ def varToLocal(v: Identifier): Option[Int] = vars.get(v) + /** Adds some extra variables to the mapping */ def withVars(newVars: Map[Identifier, Int]) = { - Locals(vars ++ newVars) + Locals(vars ++ newVars, isStatic) } + /** Adds an extra variable to the mapping */ def withVar(nv: (Identifier, Int)) = { - Locals(vars + nv) + Locals(vars + nv, isStatic) } + + /** The index of the monitor object in this function */ + def monitorIndex = if (isStatic) 0 else 1 + } + + object NoLocals { + /** Make a $Locals object without any local variables */ + def apply(isStatic : Boolean) = new Locals(Map(), isStatic) } - - object NoLocals extends Locals(Map()) private[codegen] val BoxedIntClass = "java/lang/Integer" private[codegen] val BoxedBoolClass = "java/lang/Boolean" @@ -50,6 +64,10 @@ trait CodeGeneration { def idToSafeJVMName(id: Identifier) = id.uniqueName.replaceAll("\\.", "\\$") def defToJVMName(d : Definition) : String = "Leon$CodeGen$" + idToSafeJVMName(d.id) + /** Retrieve the name of the underlying lazy field from a lazy field accessor method */ + private[codegen] def underlyingField(lazyAccessor : String) = lazyAccessor + "$underlying" + + /** Return the respective JVM type from a Leon type */ def typeToJVM(tpe : TypeTree) : String = tpe match { case Int32Type => "I" @@ -78,15 +96,57 @@ trait CodeGeneration { case _ => throw CompilationException("Unsupported type : " + tpe) } - // Assumes the CodeHandler has never received any bytecode. - // Generates method body, and freezes the handler at the end. - def compileFunDef(funDef : FunDef, ch : CodeHandler) { - val newMapping = if (params.requireMonitor) { - funDef.params.map(_.id).zipWithIndex.toMap.mapValues(_ + 1) - } else { - funDef.params.map(_.id).zipWithIndex.toMap - } + /** Return the respective boxed JVM type from a Leon type */ + def typeToJVMBoxed(tpe : TypeTree) : String = tpe match { + case Int32Type => s"L$BoxedIntClass;" + case BooleanType | UnitType => s"L$BoxedBoolClass;" + case other => typeToJVM(other) + } + + /** + * Compiles a function/method definition. + * @param funDef The function definition to be compiled + * @param owner The module/class that contains $funDef + */ + def compileFunDef(funDef : FunDef, owner : Definition) { + + val isStatic = owner.isInstanceOf[ModuleDef] + + val cf = classes(owner) + val (_,mn,_) = leonFunDefToJVMInfo(funDef).get + + val paramsTypes = funDef.params.map(a => typeToJVM(a.tpe)) + + val realParams = if (params.requireMonitor) { + ("L" + MonitorClass + ";") +: paramsTypes + } else { + paramsTypes + } + val m = cf.addMethod( + typeToJVM(funDef.returnType), + mn, + realParams : _* + ) + m.setFlags(( + if (isStatic) + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL | + METHOD_ACC_STATIC + else + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val ch = m.codeHandler + + // An offset we introduce to the parameters: + // 1 if this is a method, so we need "this" in position 0 of the stack + // 1 if we are monitoring // FIXME + val paramsOffset = Seq(!isStatic, params.requireMonitor).count(x => x) + val newMapping = + funDef.params.map(_.id).zipWithIndex.toMap.mapValues(_ + paramsOffset) + val body = funDef.body.getOrElse(throw CompilationException("Can't compile a FunDef without body: "+funDef.id.name)) val bodyWithPre = if(funDef.hasPrecondition && params.checkContracts) { @@ -105,10 +165,11 @@ trait CodeGeneration { val exprToCompile = purescala.TreeOps.matchToIfThenElse(bodyWithPost) if (params.recordInvocations) { - ch << ALoad(0) << InvokeVirtual(MonitorClass, "onInvoke", "()V") + // index of monitor object will be before the first Scala parameter + ch << ALoad(paramsOffset-1) << InvokeVirtual(MonitorClass, "onInvoke", "()V") } - mkExpr(exprToCompile, ch)(Locals(newMapping)) + mkExpr(exprToCompile, ch)(Locals(newMapping, isStatic)) funDef.returnType match { case Int32Type | BooleanType | UnitType => @@ -184,6 +245,8 @@ trait CodeGeneration { throw CompilationException("Unknown class : " + cct.id) } ch << New(ccName) << DUP + if (params.requireMonitor) + ch << ALoad(locals.monitorIndex) for((a, vd) <- as zip cct.classDef.fields) { vd.tpe match { case TypeParameter(_) => @@ -306,13 +369,35 @@ trait CodeGeneration { mkExpr(e, ch) ch << Label(al) + // Strict static fields + case FunctionInvocation(tfd, as) if tfd.fd.canBeStrictField => + val (className, fieldName, _) = leonFunDefToJVMInfo(tfd.fd).getOrElse { + throw CompilationException("Unknown method : " + tfd.id) + } + + if (params.requireMonitor) { + // index of monitor object will be before the first Scala parameter + ch << ALoad(locals.monitorIndex) << InvokeVirtual(MonitorClass, "onInvoke", "()V") + } + + // Get static field + ch << GetStatic(className, fieldName, typeToJVM(tfd.fd.returnType)) + + // unbox field + (tfd.fd.returnType, tfd.returnType) match { + case (TypeParameter(_), tpe) => + mkUnbox(tpe, ch) + case _ => + } + + // Static lazy fields/ functions case FunctionInvocation(tfd, as) => val (cn, mn, ms) = leonFunDefToJVMInfo(tfd.fd).getOrElse { throw CompilationException("Unknown method : " + tfd.id) } - - if (params.requireMonitor) { - ch << ALoad(0) + + if (params.requireMonitor) { + ch << ALoad(locals.monitorIndex) } for((a, vd) <- as zip tfd.fd.params) { @@ -331,7 +416,63 @@ trait CodeGeneration { mkUnbox(tpe, ch) case _ => } - + + // Strict fields are handled as fields + case MethodInvocation(rec, _, tfd, _) if tfd.fd.canBeStrictField => + val (className, fieldName, _) = leonFunDefToJVMInfo(tfd.fd).getOrElse { + throw CompilationException("Unknown method : " + tfd.id) + } + + if (params.requireMonitor) { + // index of monitor object will be before the first Scala parameter + ch << ALoad(locals.monitorIndex) << InvokeVirtual(MonitorClass, "onInvoke", "()V") + } + // Load receiver + mkExpr(rec,ch) + + // Get field + ch << GetField(className, fieldName, typeToJVM(tfd.fd.returnType)) + + // unbox field + (tfd.fd.returnType, tfd.returnType) match { + case (TypeParameter(_), tpe) => + mkUnbox(tpe, ch) + case _ => + } + + // This is for lazy fields and real methods. + // To access a lazy field, we call its accessor function. + case MethodInvocation(rec, cd, tfd, as) => + val (className, methodName, sig) = leonFunDefToJVMInfo(tfd.fd).getOrElse { + throw CompilationException("Unknown method : " + tfd.id) + } + + // Receiver of the method call + mkExpr(rec,ch) + + if (params.requireMonitor) { + ch << ALoad(locals.monitorIndex) + } + + for((a, vd) <- as zip tfd.fd.params) { + vd.tpe match { + case TypeParameter(_) => + mkBoxedExpr(a, ch) + case _ => + mkExpr(a, ch) + } + } + + // No dynamic dispatching/overriding in Leon, + // so no need to take care of own vs. "super" methods + ch << InvokeVirtual(className, methodName, sig) + + (tfd.fd.returnType, tfd.returnType) match { + case (TypeParameter(_), tpe) => + mkUnbox(tpe, ch) + case _ => + } + // Arithmetic case Plus(l, r) => mkExpr(l, ch) @@ -442,7 +583,10 @@ trait CodeGeneration { ch << InvokeStatic(ChooseEntryPointClass, "invoke", "(I[Ljava/lang/Object;)Ljava/lang/Object;") mkUnbox(choose.getType, ch) - + + case This(ct) => + ch << ALoad(0) // FIXME what if doInstrument etc + case b if b.getType == BooleanType && canDelegateToMkBranch => val fl = ch.getFreshLabel("boolfalse") val al = ch.getFreshLabel("boolafter") @@ -608,7 +752,162 @@ trait CodeGeneration { } } - def compileAbstractClassDef(acd : AbstractClassDef) { + + /** + * Compiles a lazy field $lzy, owned by the module/ class $owner. + * + * To define a lazy field, we have to add an accessor method and an underlying field. + * The accessor method has the name of the original (Scala) lazy field and can be public. + * The underlying field has a different name, is private, and is of a boxed type + * to support null value (to signify uninitialized). + * + * @param lzy The lazy field to be compiled + * @param owner The module/class containing $lzy + */ + def compileLazyField(lzy : FunDef, owner : Definition) { + ctx.reporter.internalAssertion(lzy.canBeLazyField, s"Trying to compile non-lazy ${lzy.id.name} as a lazy field") + + val (_, accessorName, _ ) = leonFunDefToJVMInfo(lzy).get + val cf = classes(owner) + val cName = defToJVMName(owner) + + val isStatic = owner.isInstanceOf[ModuleDef] + + // Name of the underlying field + val underlyingName = underlyingField(accessorName) + // Underlying field is of boxed type + val underlyingType = typeToJVMBoxed(lzy.returnType) + + // Underlying field. It is of a boxed type + val fh = cf.addField(underlyingType,underlyingName) + fh.setFlags( if (isStatic) {( + FIELD_ACC_STATIC | + FIELD_ACC_PRIVATE + ).asInstanceOf[U2] } else { + FIELD_ACC_PRIVATE + }) // FIXME private etc? + + // accessor method + locally { + val parameters = if (params.requireMonitor) { + Seq("L" + MonitorClass + ";") + } else Seq() + + val accM = cf.addMethod(typeToJVM(lzy.returnType), accessorName, parameters : _*) + accM.setFlags( if (isStatic) {( + METHOD_ACC_STATIC | // FIXME other flags? Not always public? + METHOD_ACC_PUBLIC + ).asInstanceOf[U2] } else { + METHOD_ACC_PUBLIC + }) + val ch = accM.codeHandler + val body = purescala.TreeOps.matchToIfThenElse(lzy.body.getOrElse(throw CompilationException("Lazy field without body?"))) + val initLabel = ch.getFreshLabel("isInitialized") + + if (params.requireMonitor) { + ch << ALoad(if (isStatic) 0 else 1) << InvokeVirtual(MonitorClass, "onInvoke", "()V") + } + + if (isStatic) { + ch << GetStatic(cName, underlyingName, underlyingType) + } else { + ch << ALoad(0) << GetField(cName, underlyingName, underlyingType) // if (lzy == null) + } + // oldValue + ch << DUP << IfNonNull(initLabel) + // null + ch << POP + // + mkBoxedExpr(body,ch)(NoLocals(isStatic)) // lzy = <expr> + ch << DUP + // newValue, newValue + if (isStatic) { + ch << PutStatic(cName, underlyingName, underlyingType) + //newValue + } + else { + ch << ALoad(0) << SWAP + // newValue, object, newValue + ch << PutField (cName, underlyingName, underlyingType) + //newValue + } + ch << Label(initLabel) // return lzy + //newValue + lzy.returnType match { + case Int32Type | BooleanType | UnitType => + // Since the underlying field only has boxed types, we have to unbox them to return them + mkUnbox(lzy.returnType, ch)(NoLocals(isStatic)) + ch << IRETURN + case _ : ClassType | _ : TupleType | _ : SetType | _ : MapType | _ : ArrayType | _: TypeParameter => + ch << ARETURN + case other => throw CompilationException("Unsupported return type : " + other.getClass) + } + ch.freeze + } + } + + /** Compile the (strict) field $field which is owned by class $owner */ + def compileStrictField(field : FunDef, owner : Definition) = { + + ctx.reporter.internalAssertion(field.canBeStrictField, + s"Trying to compile ${field.id.name} as a strict field") + val (_, fieldName, _) = leonFunDefToJVMInfo(field).get + + val cf = classes(owner) + val fh = cf.addField(typeToJVM(field.returnType),fieldName) + fh.setFlags( owner match { + case _ : ModuleDef => ( + FIELD_ACC_STATIC | + FIELD_ACC_PUBLIC | // FIXME + FIELD_ACC_FINAL + ).asInstanceOf[U2] + case _ => ( + FIELD_ACC_PUBLIC | // FIXME + FIELD_ACC_FINAL + ).asInstanceOf[U2] + }) + } + + /** Initializes a lazy field to null + * @param ch the codehandler to add the initializing code to + * @param className the name of the class in which the field is initialized + * @param lzy the lazy field to be initialized + * @param isStatic true if this is a static field + */ + def initLazyField(ch: CodeHandler, className : String, lzy : FunDef, isStatic: Boolean) = { + val (_, name, _) = leonFunDefToJVMInfo(lzy).get + val underlyingName = underlyingField(name) + val jvmType = typeToJVMBoxed(lzy.returnType) + if (isStatic){ + ch << ACONST_NULL << PutStatic(className, underlyingName, jvmType) + } else { + ch << ALoad(0) << ACONST_NULL << PutField(className, underlyingName, jvmType) + } + } + + /** Initializes a (strict) field + * @param ch the codehandler to add the initializing code to + * @param className the name of the class in which the field is initialized + * @param field the field to be initialized + * @param isStatic true if this is a static field + */ + def initStrictField(ch : CodeHandler, className : String, field: FunDef, isStatic: Boolean) { + val (_, name , _) = leonFunDefToJVMInfo(field).get + val body = field.body.getOrElse(throw CompilationException("No body for field?")) + val jvmType = typeToJVM(field.returnType) + + mkExpr(purescala.TreeOps.matchToIfThenElse(body), ch)(NoLocals(isStatic)) // FIXME Locals? + + if (isStatic){ + ch << PutStatic(className, name, jvmType) + } else { + ch << ALoad(0) << SWAP << PutField (className, name, jvmType) + } + } + + + def compileAbstractClassDef(acd : AbstractClassDef) { + val cName = defToJVMName(acd) val cf = classes(acd) @@ -621,7 +920,56 @@ trait CodeGeneration { cf.addInterface(CaseClassClass) - cf.addDefaultConstructor + // add special monitor for method invocations + if (params.doInstrument) { + val fh = cf.addField("I", instrumentedField) + fh.setFlags(FIELD_ACC_PUBLIC) + } + + val (fields, methods) = acd.methods partition { _.canBeField } + val (strictFields, lazyFields) = fields partition { _.canBeStrictField } + + // Compile methods + for (method <- methods) { + compileFunDef(method,acd) + } + + // Compile lazy fields + for (lzy <- lazyFields) { + compileLazyField(lzy, acd) + } + + // Compile strict fields + for (field <- strictFields) { + compileStrictField(field, acd) + } + + // definition of the constructor + if (fields.isEmpty && !params.doInstrument && !params.requireMonitor) cf.addDefaultConstructor else { + + val constrParams = if (params.requireMonitor) { + Seq("L" + MonitorClass + ";") + } else Seq() + + val cch = cf.addConstructor(constrParams : _*).codeHandler + // Abstract classes are hierarchy roots, so call java.lang.Object constructor + cch << ALoad(0) + cch << InvokeSpecial("java/lang/Object", constructorName, "()V") + + // Initialize special monitor field + if (params.doInstrument) { + cch << ALoad(0) + cch << Ldc(0) + cch << PutField(cName, instrumentedField, "I") + } + + for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, false) } + for (field <- strictFields) { initStrictField(cch, cName, field, false)} + + cch << RETURN + cch.freeze + } + } /** @@ -629,9 +977,8 @@ trait CodeGeneration { */ val instrumentedField = "__read" - def instrumentedGetField(ch: CodeHandler, cct: CaseClassType, id: Identifier)(implicit locals: Locals): Unit = { + def instrumentedGetField(ch: CodeHandler, cct: ClassType, id: Identifier)(implicit locals: Locals): Unit = { val ccd = cct.classDef - ccd.fields.zipWithIndex.find(_._1.id == id) match { case Some((f, i)) => val expType = cct.fields(i).tpe @@ -658,12 +1005,15 @@ trait CodeGeneration { } } + + def compileCaseClassDef(ccd: CaseClassDef) { val cName = defToJVMName(ccd) val pName = ccd.parent.map(parent => defToJVMName(parent.classDef)) + // An instantiation of ccd with its own type parameters val cct = CaseClassType(ccd, ccd.tparams.map(_.tp)) - + val cf = classes(ccd) cf.setFlags(( @@ -676,48 +1026,94 @@ trait CodeGeneration { cf.addInterface(CaseClassClass) } - val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.tpe)) } - - // definition of the constructor - if(!params.doInstrument && ccd.fields.isEmpty) { - cf.addDefaultConstructor - } else { - for((nme, jvmt) <- namesTypes) { - val fh = cf.addField(jvmt, nme) - fh.setFlags(( - FIELD_ACC_PUBLIC | - FIELD_ACC_FINAL - ).asInstanceOf[U2]) + locally { + + val (fields, methods) = ccd.methods partition { _.canBeField } + val (strictFields, lazyFields) = fields partition { _.canBeStrictField } + + // Compile methods + for (method <- methods) { + compileFunDef(method,ccd) } - - if (params.doInstrument) { - val fh = cf.addField("I", instrumentedField) - fh.setFlags(FIELD_ACC_PUBLIC) + + // Compile lazy fields + for (lzy <- lazyFields) { + compileLazyField(lzy, ccd) } - - val cch = cf.addConstructor(namesTypes.map(_._2).toList).codeHandler - - cch << ALoad(0) - cch << InvokeSpecial(pName.getOrElse("java/lang/Object"), constructorName, "()V") - - if (params.doInstrument) { - cch << ALoad(0) - cch << Ldc(0) - cch << PutField(cName, instrumentedField, "I") + + // Compile strict fields + for (field <- strictFields) { + compileStrictField(field, ccd) } + + // Case class parameters + val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.tpe)) } + + // definition of the constructor + if(!params.doInstrument && !params.requireMonitor && ccd.fields.isEmpty && ccd.methods.filter{ _.canBeField }.isEmpty) { + cf.addDefaultConstructor + } else { + for((nme, jvmt) <- namesTypes) { + val fh = cf.addField(jvmt, nme) + fh.setFlags(( + FIELD_ACC_PUBLIC | + FIELD_ACC_FINAL + ).asInstanceOf[U2]) + } + + if (params.doInstrument) { + val fh = cf.addField("I", instrumentedField) + fh.setFlags(FIELD_ACC_PUBLIC) + } + + // If we are monitoring function calls, we have an extra argument on the constructor + val realArgs = if (params.requireMonitor) { + ("L" + MonitorClass + ";") +: (namesTypes map (_._2)) + } else (namesTypes map (_._2)) + + // Offset of the first Scala parameter of the constructor + val paramOffset = if (params.requireMonitor) 2 else 1 + + val cch = cf.addConstructor(realArgs.toList).codeHandler + + if (params.doInstrument) { + cch << ALoad(0) + cch << Ldc(0) + cch << PutField(cName, instrumentedField, "I") + } + + var c = paramOffset + for((nme, jvmt) <- namesTypes) { + cch << ALoad(0) + cch << (jvmt match { + case "I" | "Z" => ILoad(c) + case _ => ALoad(c) + }) + cch << PutField(cName, nme, jvmt) + c += 1 + } + + // Call parent constructor AFTER initializing case class parameters + if (ccd.parent.isDefined) { + // Load this + cch << ALoad(0) + // Load monitor object + if (params.requireMonitor) cch << ALoad(1) + val constrSig = if (params.requireMonitor) "(L" + MonitorClass + ";)V" else "()V" + cch << InvokeSpecial(pName.get, constructorName, constrSig) + } else { + // Call constructor of java.lang.Object + cch << ALoad(0) + cch << InvokeSpecial("java/lang/Object", constructorName, "()V") + } - var c = 1 - for((nme, jvmt) <- namesTypes) { - cch << ALoad(0) - cch << (jvmt match { - case "I" | "Z" => ILoad(c) - case _ => ALoad(c) - }) - cch << PutField(cName, nme, jvmt) - c += 1 + + // Now initialize fields + for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, false)} + for (field <- strictFields) { initStrictField(cch, cName , field, false)} + cch << RETURN + cch.freeze } - cch << RETURN - cch.freeze } locally { @@ -764,8 +1160,12 @@ trait CodeGeneration { pech << DUP pech << Ldc(i) pech << ALoad(0) - instrumentedGetField(pech, cct, f.id)(NoLocals) - mkBox(f.tpe, pech)(NoLocals) + // WARNING: Passing NoLocals(false) is kind of a hack, + // since there is no monitor object anywhere in this method. + // We are saved because it is not used anywhere, + // but beware if you decide to add any mkExpr and the like. + instrumentedGetField(pech, cct, f.id)(NoLocals(false)) + mkBox(f.tpe, pech)(NoLocals(false)) pech << AASTORE } @@ -798,10 +1198,14 @@ trait CodeGeneration { ech << ALoad(1) << CheckCast(cName) << AStore(castSlot) for(vd <- ccd.fields) { + // WARNING: Passing NoLocals(false) is kind of a hack, + // since there is no monitor object anywhere in this method. + // We are saved because it is not used anywhere, + // but beware if you decide to add any mkExpr and the like. ech << ALoad(0) - instrumentedGetField(ech, cct, vd.id)(NoLocals) + instrumentedGetField(ech, cct, vd.id)(NoLocals(false)) ech << ALoad(castSlot) - instrumentedGetField(ech, cct, vd.id)(NoLocals) + instrumentedGetField(ech, cct, vd.id)(NoLocals(false)) typeToJVM(vd.getType) match { case "I" | "Z" => diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 165317b324c2cc07d434df3be5be689fe788112f..4e32e619b3f6e8dac3d4e3b1a91a12254856968e 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -25,8 +25,8 @@ class CompilationUnit(val ctx: LeonContext, val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader) var classes = Map[Definition, ClassFile]() - var defToModule = Map[Definition, ModuleDef]() - + var defToModuleOrClass = Map[Definition, Definition]() + def defineClass(df: Definition) { val cName = defToJVMName(df) @@ -55,7 +55,8 @@ class CompilationUnit(val ctx: LeonContext, def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = { classes.get(cd) match { case Some(cf) => - val sig = "(" + cd.fields.map(f => typeToJVM(f.tpe)).mkString("") + ")V" + val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else "" + val sig = "(" + monitorType + cd.fields.map(f => typeToJVM(f.tpe)).mkString("") + ")V" Some((cf.className, sig)) case _ => None } @@ -64,13 +65,20 @@ class CompilationUnit(val ctx: LeonContext, // Returns className, methodName, methodSignature private[this] var funDefInfo = Map[FunDef, (String, String, String)]() + + /** + * Returns (cn, mn, sig) where + * cn is the module name + * mn is the safe method name + * sig is the method signature + */ def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = { funDefInfo.get(fd).orElse { val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else "" val sig = "(" + monitorType + fd.params.map(a => typeToJVM(a.tpe)).mkString("") + ")" + typeToJVM(fd.returnType) - defToModule.get(fd).flatMap(m => classes.get(m)) match { + defToModuleOrClass.get(fd).flatMap(m => classes.get(m)) match { case Some(cf) => val res = (cf.className, idToSafeJVMName(fd.id), sig) funDefInfo += fd -> res @@ -232,7 +240,7 @@ class CompilationUnit(val ctx: LeonContext, val exprToCompile = purescala.TreeOps.matchToIfThenElse(e) - mkExpr(e, ch)(Locals(newMapping)) + mkExpr(e, ch)(Locals(newMapping, true)) e.getType match { case Int32Type | BooleanType => @@ -254,64 +262,105 @@ class CompilationUnit(val ctx: LeonContext, def compileModule(module: ModuleDef) { val cf = classes(module) - - cf.addDefaultConstructor - cf.setFlags(( CLASS_ACC_SUPER | CLASS_ACC_PUBLIC | CLASS_ACC_FINAL ).asInstanceOf[U2]) - - for(funDef <- module.definedFunctions; - (_,mn,_) <- leonFunDefToJVMInfo(funDef)) { - - val paramsTypes = funDef.params.map(a => typeToJVM(a.tpe)) - - val realParams = if (params.requireMonitor) { - ("L" + MonitorClass + ";") +: paramsTypes - } else { - paramsTypes + + /*if (false) { + // currently we do not handle object fields + // this treats all fields as functions + for (fun <- module.definedFunctions) { + compileFunDef(fun, module) } - - val m = cf.addMethod( - typeToJVM(funDef.returnType), - mn, - realParams : _* - ) - m.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL | - METHOD_ACC_STATIC + } else {*/ + + val (fields, functions) = module.definedFunctions partition { _.canBeField } + val (strictFields, lazyFields) = fields partition { _.canBeStrictField } + + // Compile methods + for (function <- functions) { + compileFunDef(function,module) + } + + // Compile lazy fields + for (lzy <- lazyFields) { + compileLazyField(lzy, module) + } + + // Compile strict fields + for (field <- strictFields) { + compileStrictField(field, module) + } + + // Constructor + cf.addDefaultConstructor + + val cName = defToJVMName(module) + + // Add class initializer method + locally{ + val mh = cf.addMethod("V", "<clinit>") + mh.setFlags(( + METHOD_ACC_STATIC | + METHOD_ACC_PUBLIC ).asInstanceOf[U2]) - - compileFunDef(funDef, m.codeHandler) + + val ch = mh.codeHandler + /* + * FIXME : + * Dirty hack to make this compatible with monitoring of method invocations. + * Because we don't have access to the monitor object here, we initialize a new one + * that will get lost when this method returns, so we can't hope to count + * method invocations here :( + */ + ch << New(MonitorClass) << DUP + ch << Ldc(Int.MaxValue) // Allow "infinite" method calls + ch << InvokeSpecial(MonitorClass, cafebabe.Defaults.constructorName, "(I)V") + ch << AStore(ch.getFreshVar) // position 0 + for (lzy <- lazyFields) { initLazyField(ch, cName, lzy, true)} + for (field <- strictFields) { initStrictField(ch, cName , field, true)} + ch << RETURN + ch.freeze } - } + + + + } + /** Traverses the program to find all definitions, and stores those in global variables */ def init() { - // First define all classes + // First define all classes/ methods/ functions for (m <- program.modules) { - for ((parent, children) <- m.algebraicDataTypes) { - defineClass(parent) - - for (c <- children) { - defineClass(c) + for ( (parent, children) <- m.algebraicDataTypes; + cls <- Seq(parent) ++ children) { + defineClass(cls) + for (meth <- cls.methods) { + defToModuleOrClass += meth -> cls } } - - for(single <- m.singleCaseClasses) { + + for ( single <- m.singleCaseClasses ) { defineClass(single) + for (meth <- single.methods) { + defToModuleOrClass += meth -> single + } + } + + for(funDef <- m.definedFunctions) { + defToModuleOrClass += funDef -> m } - defineClass(m) } } + /** Compiles the program. Uses information provided by $init */ def compile() { // Compile everything for (m <- program.modules) { + for ((parent, children) <- m.algebraicDataTypes) { compileAbstractClassDef(parent) @@ -324,9 +373,7 @@ class CompilationUnit(val ctx: LeonContext, compileCaseClassDef(single) } - for(funDef <- m.definedFunctions) { - defToModule += funDef -> m - } + } for (m <- program.modules) { diff --git a/src/test/scala/leon/test/codegen/CodeGenTests.scala b/src/test/scala/leon/test/codegen/CodeGenTests.scala new file mode 100644 index 0000000000000000000000000000000000000000..dbbab962365adf12a98465efcbb7f41bcd1435c5 --- /dev/null +++ b/src/test/scala/leon/test/codegen/CodeGenTests.scala @@ -0,0 +1,340 @@ +package leon.test.codegen + +import leon._ +import leon.codegen._ +import leon.purescala.Definitions._ +import leon.purescala.Trees._ +import leon.evaluators.{CodeGenEvaluator,EvaluationResults} +import EvaluationResults._ + +import java.io._ + +case class TestCase( + name : String, + content : String, + expected : Expr, + args : Seq[Expr] = Seq(), + functionToTest : String = "test" +) + +class CodeGenTests extends test.LeonTestSuite { + + val catchAll = true + + val pipeline = + utils.TemporaryInputPhase andThen + frontends.scalac.ExtractionPhase andThen + utils.ScopingPhase andThen + purescala.MethodLifting andThen + utils.TypingPhase andThen + purescala.CompleteAbstractDefinitions andThen + purescala.RestoreMethods + + def compileTestFun(p : Program, toTest : String, ctx : LeonContext, requireMonitor : Boolean, doInstrument : Boolean) : ( Seq[Expr] => EvaluationResults.Result) = { + // We want to produce code that checks contracts + val evaluator = new CodeGenEvaluator(ctx, p, CodeGenParams( + maxFunctionInvocations = if (requireMonitor) 1000 else -1, // Monitor calls and abort execution if more than X calls + checkContracts = true, // Generate calls that checks pre/postconditions + doInstrument = doInstrument // Instrument reads to case classes (mainly for vanuatoo) + )) + + + val testFun = p.definedFunctions.find(_.id.name == toTest).getOrElse { + ctx.reporter.fatalError("Test function not defined!") + } + val params = testFun.params map { _.id } + val body = testFun.body.get + // Will apply test a number of times with the help of compileRec + evaluator.compile(body, params).getOrElse { + ctx.reporter.fatalError("Failed to compile test function!") + } + + } + + + private def testCodeGen(prog : TestCase, requireMonitor : Boolean, doInstrument : Boolean) { test(prog.name) { + import prog._ + val settings = testContext.settings.copy(injectLibrary = false) + val ctx = testContext.copy( + // We want a reporter that actually prints some output + reporter = new DefaultReporter(settings), + settings = settings + ) + + val ast = pipeline.run(ctx)( (content, List()) ) + + //ctx.reporter.info(purescala.ScalaPrinter(ast)) + + val compiled = compileTestFun(ast, functionToTest, ctx, requireMonitor, doInstrument) + try { compiled(args) match { + case Successful(res) if res == expected => + // Success + case RuntimeError(_) | EvaluatorError(_) if expected.isInstanceOf[Error] => + // Success + case Successful(res) => + ctx.reporter.fatalError(s""" + Program $name produced wrong output. + Output was ${res.toString} + Expected was ${expected.toString} + """.stripMargin) + case RuntimeError(mes) => + ctx.reporter.fatalError(s"Program $name threw runtime error with message $mes") + case EvaluatorError(res) => + ctx.reporter.fatalError(s"Evaluator failed for program $name with message $res") + }} catch { + // Currently, this is what we would like to catch and still succeed, but there might be more + case _ : LeonFatalError | _ : StackOverflowError if expected.isInstanceOf[Error] => + // Success + case th : Throwable => + if (catchAll) { + // This is to be able to continue testing after an error + ctx.reporter.fatalError(s"Program $name failed\n${th.printStackTrace()}")// with message ${th.getMessage()}") + } else { throw th } + } + }} + + + val programs = Seq( + + TestCase("simple", """ + object simple { + abstract class Abs + case class Conc(x : Int) extends Abs + def test = { + val c = Conc(1) + c.x + + } + }""", + IntLiteral(1) + ), + + TestCase("eager", """ + object eager { + abstract class Abs() { + val foo = 42 + } + case class Conc(x : Int) extends Abs() + def foo = { + val c = Conc(1) + c.foo + c.x + } + def test = foo + }""", + IntLiteral(43) + ), + + TestCase("this", """ + object thiss { + + case class Bar() { + def boo = this + def toRet = 42 + } + + def test = Bar().boo.toRet + } + """, + IntLiteral(42) + ), + + TestCase("oldStuff", """ + object oldStuff { + def test = 1 + case class Bar() { + def boo = 2 + } + }""", + IntLiteral(1) + ), + + TestCase("methSimple", """ + object methSimple { + + sealed abstract class Ab { + def f2(x : Int) = x + 5 + } + case class Con() extends Ab { } + + def test = Con().f2(5) + }""", + IntLiteral(10) + ), + + TestCase("methods", """ + object methods { + def f1 = 4 + sealed abstract class Ab { + def f2(x : Int) = Cs().f3(1,2) + f1 + x + 5 + } + case class Con() extends Ab {} + case class Cs() { + def f3(x : Int, y : Int) = x + y + } + def test = Con().f2(3) + }""", + IntLiteral(15) + ), + + + TestCase("lazy", """ + object lazyFields { + def foo = 1 + sealed abstract class Ab { + lazy val x : Int = this match { + case Conc(t) => t + 1 + case Conc2(t) => t+2 + } + } + case class Conc(t : Int) extends Ab { } + case class Conc2(t : Int) extends Ab { } + def test = foo + Conc(5).x + Conc2(6).x + } + """, + IntLiteral(1 + 5 + 1 + 6 + 2) + ), + + TestCase("modules", """ + object modules { + def foo = 1 + val bar = 2 + lazy val baz = 0 + def test = foo + bar + baz + } + """, + IntLiteral(1 + 2 + 0) + ), + + TestCase("lazyISLazy" , """ + object lazyISLazy { + abstract class Ab { lazy val x : Int = foo; def foo : Int = foo } + case class Conc() extends Ab { } + def test = { val willNotLoop = Conc(); 42 } + }""", + IntLiteral(42) + ), + + TestCase("ListWithSize" , """ + object list { + abstract class List[T] { + val length : Int = this match { + case Nil() => 0 + case Cons (_, xs ) => 1 + xs.length + } + + } + case class Cons[T](hd : T, tl : List[T]) extends List[T] + case class Nil[T]() extends List[T] + + + val l = Cons(1, Cons(2, Cons(3, Nil()))) + + def test = l.length + Nil().length + }""", + IntLiteral(3 ) + ), + + TestCase("ListWithSumMono" , """ + object ListWithSumMono { + abstract class List + case class Cons(hd : Int, tl : List) extends List + case class Nil() extends List + + def sum (l : List) : Int = l match { + case Nil() => 0 + case Cons(x, xs) => x + sum(xs) + } + + val l = Cons(1, Cons(2, Cons(3, Nil()))) + + def test = sum(l) + }""", + IntLiteral(1 + 2 + 3) + ), + + TestCase("poly" , """ + object poly { + case class Poly[T](poly : T) + def ex = Poly(42) + def test = ex.poly + }""", + IntLiteral(42) + ), + + TestCase("ListHead" , """ + object ListHead { + abstract class List[T] + case class Cons[T](hd : T, tl : List[T]) extends List[T] + case class Nil[T]() extends List[T] + + def l = Cons(1, Cons(2, Cons(3, Nil()))) + + def test = l.hd + }""", + IntLiteral(1) + ), + TestCase("ListWithSum" , """ + object ListWithSum { + abstract class List[T] + case class Cons[T](hd : T, tl : List[T]) extends List[T] + case class Nil[T]() extends List[T] + + def sum (l : List[Int]) : Int = l match { + case Nil() => 0 + case Cons(x, xs) => x + sum(xs) + } + + val l = Cons(1, Cons(2, Cons(3, Nil()))) + + def test = sum(l) + }""", + IntLiteral(1 + 2 + 3) + ), + + // This one loops! + TestCase("lazyLoops" , """ + object lazyLoops { + abstract class Ab { lazy val x : Int = foo; def foo : Int = foo } + case class Conc() extends Ab { } + def test = Conc().x + }""", + Error("Looping") + ), + + TestCase("Lazier" , """ + import leon.lang._ + object Lazier { + abstract class List[T] { + lazy val tail = this match { + case Nil() => error[List[T]]("Nil.tail") + case Cons(_, tl) => tl + } + } + case class Cons[T](hd : T, tl : List[T]) extends List[T] + case class Nil[T]() extends List[T] + + def sum (l : List[Int]) : Int = l match { + case Nil() => 0 + case c : Cons[Int] => c.hd + sum(c.tail) + } + + val l = Cons(1, Cons(2, Cons(3, Nil()))) + + def test = sum(l) + }""", + IntLiteral(1 + 2 + 3) + ) + ) + + + + for ( prog <- programs ; + requireMonitor <- Seq(false ,true ); + doInstrument <- Seq(false,true ) + + ) { + testCodeGen( + prog.copy(name = prog.name + (if (requireMonitor)"_M_" else "" ) + (if (doInstrument)"_I_" else "" )), + requireMonitor, doInstrument + )} +}