Skip to content
Snippets Groups Projects
Commit 45e6a8e6 authored by Emmanouil (Manos) Koukoutos's avatar Emmanouil (Manos) Koukoutos
Browse files

Code Generation for Fields

parent 0889f2f7
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ import purescala.Common._ ...@@ -7,6 +7,7 @@ import purescala.Common._
import purescala.Definitions._ import purescala.Definitions._
import purescala.Trees._ import purescala.Trees._
import purescala.TypeTrees._ import purescala.TypeTrees._
import purescala.TypeTreeOps.instantiateType
import utils._ import utils._
import cafebabe._ import cafebabe._
...@@ -19,19 +20,32 @@ import cafebabe.Flags._ ...@@ -19,19 +20,32 @@ import cafebabe.Flags._
trait CodeGeneration { trait CodeGeneration {
self: CompilationUnit => self: CompilationUnit =>
case class Locals(vars: Map[Identifier, Int]) { /** A class providing information about the status of parameters in the function that is being currently compiled.
* vars is a mapping from local variables/ parameters to the offset of the respective JVM local register
* isStatic signifies if the current method is static (a function, in Leon terms)
*/
case class Locals(vars: Map[Identifier, Int], private val isStatic : Boolean ) {
/** Fetches the offset of a local variable/ parameter from its identifier */
def varToLocal(v: Identifier): Option[Int] = vars.get(v) def varToLocal(v: Identifier): Option[Int] = vars.get(v)
/** Adds some extra variables to the mapping */
def withVars(newVars: Map[Identifier, Int]) = { def withVars(newVars: Map[Identifier, Int]) = {
Locals(vars ++ newVars) Locals(vars ++ newVars, isStatic)
} }
/** Adds an extra variable to the mapping */
def withVar(nv: (Identifier, Int)) = { def withVar(nv: (Identifier, Int)) = {
Locals(vars + nv) Locals(vars + nv, isStatic)
} }
/** The index of the monitor object in this function */
def monitorIndex = if (isStatic) 0 else 1
}
object NoLocals {
/** Make a $Locals object without any local variables */
def apply(isStatic : Boolean) = new Locals(Map(), isStatic)
} }
object NoLocals extends Locals(Map())
private[codegen] val BoxedIntClass = "java/lang/Integer" private[codegen] val BoxedIntClass = "java/lang/Integer"
private[codegen] val BoxedBoolClass = "java/lang/Boolean" private[codegen] val BoxedBoolClass = "java/lang/Boolean"
...@@ -50,6 +64,10 @@ trait CodeGeneration { ...@@ -50,6 +64,10 @@ trait CodeGeneration {
def idToSafeJVMName(id: Identifier) = id.uniqueName.replaceAll("\\.", "\\$") def idToSafeJVMName(id: Identifier) = id.uniqueName.replaceAll("\\.", "\\$")
def defToJVMName(d : Definition) : String = "Leon$CodeGen$" + idToSafeJVMName(d.id) def defToJVMName(d : Definition) : String = "Leon$CodeGen$" + idToSafeJVMName(d.id)
/** Retrieve the name of the underlying lazy field from a lazy field accessor method */
private[codegen] def underlyingField(lazyAccessor : String) = lazyAccessor + "$underlying"
/** Return the respective JVM type from a Leon type */
def typeToJVM(tpe : TypeTree) : String = tpe match { def typeToJVM(tpe : TypeTree) : String = tpe match {
case Int32Type => "I" case Int32Type => "I"
...@@ -78,15 +96,57 @@ trait CodeGeneration { ...@@ -78,15 +96,57 @@ trait CodeGeneration {
case _ => throw CompilationException("Unsupported type : " + tpe) case _ => throw CompilationException("Unsupported type : " + tpe)
} }
// Assumes the CodeHandler has never received any bytecode. /** Return the respective boxed JVM type from a Leon type */
// Generates method body, and freezes the handler at the end. def typeToJVMBoxed(tpe : TypeTree) : String = tpe match {
def compileFunDef(funDef : FunDef, ch : CodeHandler) { case Int32Type => s"L$BoxedIntClass;"
val newMapping = if (params.requireMonitor) { case BooleanType | UnitType => s"L$BoxedBoolClass;"
funDef.params.map(_.id).zipWithIndex.toMap.mapValues(_ + 1) case other => typeToJVM(other)
} else { }
funDef.params.map(_.id).zipWithIndex.toMap
} /**
* Compiles a function/method definition.
* @param funDef The function definition to be compiled
* @param owner The module/class that contains $funDef
*/
def compileFunDef(funDef : FunDef, owner : Definition) {
val isStatic = owner.isInstanceOf[ModuleDef]
val cf = classes(owner)
val (_,mn,_) = leonFunDefToJVMInfo(funDef).get
val paramsTypes = funDef.params.map(a => typeToJVM(a.tpe))
val realParams = if (params.requireMonitor) {
("L" + MonitorClass + ";") +: paramsTypes
} else {
paramsTypes
}
val m = cf.addMethod(
typeToJVM(funDef.returnType),
mn,
realParams : _*
)
m.setFlags((
if (isStatic)
METHOD_ACC_PUBLIC |
METHOD_ACC_FINAL |
METHOD_ACC_STATIC
else
METHOD_ACC_PUBLIC |
METHOD_ACC_FINAL
).asInstanceOf[U2])
val ch = m.codeHandler
// An offset we introduce to the parameters:
// 1 if this is a method, so we need "this" in position 0 of the stack
// 1 if we are monitoring // FIXME
val paramsOffset = Seq(!isStatic, params.requireMonitor).count(x => x)
val newMapping =
funDef.params.map(_.id).zipWithIndex.toMap.mapValues(_ + paramsOffset)
val body = funDef.body.getOrElse(throw CompilationException("Can't compile a FunDef without body: "+funDef.id.name)) val body = funDef.body.getOrElse(throw CompilationException("Can't compile a FunDef without body: "+funDef.id.name))
val bodyWithPre = if(funDef.hasPrecondition && params.checkContracts) { val bodyWithPre = if(funDef.hasPrecondition && params.checkContracts) {
...@@ -105,10 +165,11 @@ trait CodeGeneration { ...@@ -105,10 +165,11 @@ trait CodeGeneration {
val exprToCompile = purescala.TreeOps.matchToIfThenElse(bodyWithPost) val exprToCompile = purescala.TreeOps.matchToIfThenElse(bodyWithPost)
if (params.recordInvocations) { if (params.recordInvocations) {
ch << ALoad(0) << InvokeVirtual(MonitorClass, "onInvoke", "()V") // index of monitor object will be before the first Scala parameter
ch << ALoad(paramsOffset-1) << InvokeVirtual(MonitorClass, "onInvoke", "()V")
} }
mkExpr(exprToCompile, ch)(Locals(newMapping)) mkExpr(exprToCompile, ch)(Locals(newMapping, isStatic))
funDef.returnType match { funDef.returnType match {
case Int32Type | BooleanType | UnitType => case Int32Type | BooleanType | UnitType =>
...@@ -184,6 +245,8 @@ trait CodeGeneration { ...@@ -184,6 +245,8 @@ trait CodeGeneration {
throw CompilationException("Unknown class : " + cct.id) throw CompilationException("Unknown class : " + cct.id)
} }
ch << New(ccName) << DUP ch << New(ccName) << DUP
if (params.requireMonitor)
ch << ALoad(locals.monitorIndex)
for((a, vd) <- as zip cct.classDef.fields) { for((a, vd) <- as zip cct.classDef.fields) {
vd.tpe match { vd.tpe match {
case TypeParameter(_) => case TypeParameter(_) =>
...@@ -306,13 +369,35 @@ trait CodeGeneration { ...@@ -306,13 +369,35 @@ trait CodeGeneration {
mkExpr(e, ch) mkExpr(e, ch)
ch << Label(al) ch << Label(al)
// Strict static fields
case FunctionInvocation(tfd, as) if tfd.fd.canBeStrictField =>
val (className, fieldName, _) = leonFunDefToJVMInfo(tfd.fd).getOrElse {
throw CompilationException("Unknown method : " + tfd.id)
}
if (params.requireMonitor) {
// index of monitor object will be before the first Scala parameter
ch << ALoad(locals.monitorIndex) << InvokeVirtual(MonitorClass, "onInvoke", "()V")
}
// Get static field
ch << GetStatic(className, fieldName, typeToJVM(tfd.fd.returnType))
// unbox field
(tfd.fd.returnType, tfd.returnType) match {
case (TypeParameter(_), tpe) =>
mkUnbox(tpe, ch)
case _ =>
}
// Static lazy fields/ functions
case FunctionInvocation(tfd, as) => case FunctionInvocation(tfd, as) =>
val (cn, mn, ms) = leonFunDefToJVMInfo(tfd.fd).getOrElse { val (cn, mn, ms) = leonFunDefToJVMInfo(tfd.fd).getOrElse {
throw CompilationException("Unknown method : " + tfd.id) throw CompilationException("Unknown method : " + tfd.id)
} }
if (params.requireMonitor) { if (params.requireMonitor) {
ch << ALoad(0) ch << ALoad(locals.monitorIndex)
} }
for((a, vd) <- as zip tfd.fd.params) { for((a, vd) <- as zip tfd.fd.params) {
...@@ -331,7 +416,63 @@ trait CodeGeneration { ...@@ -331,7 +416,63 @@ trait CodeGeneration {
mkUnbox(tpe, ch) mkUnbox(tpe, ch)
case _ => case _ =>
} }
// Strict fields are handled as fields
case MethodInvocation(rec, _, tfd, _) if tfd.fd.canBeStrictField =>
val (className, fieldName, _) = leonFunDefToJVMInfo(tfd.fd).getOrElse {
throw CompilationException("Unknown method : " + tfd.id)
}
if (params.requireMonitor) {
// index of monitor object will be before the first Scala parameter
ch << ALoad(locals.monitorIndex) << InvokeVirtual(MonitorClass, "onInvoke", "()V")
}
// Load receiver
mkExpr(rec,ch)
// Get field
ch << GetField(className, fieldName, typeToJVM(tfd.fd.returnType))
// unbox field
(tfd.fd.returnType, tfd.returnType) match {
case (TypeParameter(_), tpe) =>
mkUnbox(tpe, ch)
case _ =>
}
// This is for lazy fields and real methods.
// To access a lazy field, we call its accessor function.
case MethodInvocation(rec, cd, tfd, as) =>
val (className, methodName, sig) = leonFunDefToJVMInfo(tfd.fd).getOrElse {
throw CompilationException("Unknown method : " + tfd.id)
}
// Receiver of the method call
mkExpr(rec,ch)
if (params.requireMonitor) {
ch << ALoad(locals.monitorIndex)
}
for((a, vd) <- as zip tfd.fd.params) {
vd.tpe match {
case TypeParameter(_) =>
mkBoxedExpr(a, ch)
case _ =>
mkExpr(a, ch)
}
}
// No dynamic dispatching/overriding in Leon,
// so no need to take care of own vs. "super" methods
ch << InvokeVirtual(className, methodName, sig)
(tfd.fd.returnType, tfd.returnType) match {
case (TypeParameter(_), tpe) =>
mkUnbox(tpe, ch)
case _ =>
}
// Arithmetic // Arithmetic
case Plus(l, r) => case Plus(l, r) =>
mkExpr(l, ch) mkExpr(l, ch)
...@@ -442,7 +583,10 @@ trait CodeGeneration { ...@@ -442,7 +583,10 @@ trait CodeGeneration {
ch << InvokeStatic(ChooseEntryPointClass, "invoke", "(I[Ljava/lang/Object;)Ljava/lang/Object;") ch << InvokeStatic(ChooseEntryPointClass, "invoke", "(I[Ljava/lang/Object;)Ljava/lang/Object;")
mkUnbox(choose.getType, ch) mkUnbox(choose.getType, ch)
case This(ct) =>
ch << ALoad(0) // FIXME what if doInstrument etc
case b if b.getType == BooleanType && canDelegateToMkBranch => case b if b.getType == BooleanType && canDelegateToMkBranch =>
val fl = ch.getFreshLabel("boolfalse") val fl = ch.getFreshLabel("boolfalse")
val al = ch.getFreshLabel("boolafter") val al = ch.getFreshLabel("boolafter")
...@@ -608,7 +752,162 @@ trait CodeGeneration { ...@@ -608,7 +752,162 @@ trait CodeGeneration {
} }
} }
def compileAbstractClassDef(acd : AbstractClassDef) {
/**
* Compiles a lazy field $lzy, owned by the module/ class $owner.
*
* To define a lazy field, we have to add an accessor method and an underlying field.
* The accessor method has the name of the original (Scala) lazy field and can be public.
* The underlying field has a different name, is private, and is of a boxed type
* to support null value (to signify uninitialized).
*
* @param lzy The lazy field to be compiled
* @param owner The module/class containing $lzy
*/
def compileLazyField(lzy : FunDef, owner : Definition) {
ctx.reporter.internalAssertion(lzy.canBeLazyField, s"Trying to compile non-lazy ${lzy.id.name} as a lazy field")
val (_, accessorName, _ ) = leonFunDefToJVMInfo(lzy).get
val cf = classes(owner)
val cName = defToJVMName(owner)
val isStatic = owner.isInstanceOf[ModuleDef]
// Name of the underlying field
val underlyingName = underlyingField(accessorName)
// Underlying field is of boxed type
val underlyingType = typeToJVMBoxed(lzy.returnType)
// Underlying field. It is of a boxed type
val fh = cf.addField(underlyingType,underlyingName)
fh.setFlags( if (isStatic) {(
FIELD_ACC_STATIC |
FIELD_ACC_PRIVATE
).asInstanceOf[U2] } else {
FIELD_ACC_PRIVATE
}) // FIXME private etc?
// accessor method
locally {
val parameters = if (params.requireMonitor) {
Seq("L" + MonitorClass + ";")
} else Seq()
val accM = cf.addMethod(typeToJVM(lzy.returnType), accessorName, parameters : _*)
accM.setFlags( if (isStatic) {(
METHOD_ACC_STATIC | // FIXME other flags? Not always public?
METHOD_ACC_PUBLIC
).asInstanceOf[U2] } else {
METHOD_ACC_PUBLIC
})
val ch = accM.codeHandler
val body = purescala.TreeOps.matchToIfThenElse(lzy.body.getOrElse(throw CompilationException("Lazy field without body?")))
val initLabel = ch.getFreshLabel("isInitialized")
if (params.requireMonitor) {
ch << ALoad(if (isStatic) 0 else 1) << InvokeVirtual(MonitorClass, "onInvoke", "()V")
}
if (isStatic) {
ch << GetStatic(cName, underlyingName, underlyingType)
} else {
ch << ALoad(0) << GetField(cName, underlyingName, underlyingType) // if (lzy == null)
}
// oldValue
ch << DUP << IfNonNull(initLabel)
// null
ch << POP
//
mkBoxedExpr(body,ch)(NoLocals(isStatic)) // lzy = <expr>
ch << DUP
// newValue, newValue
if (isStatic) {
ch << PutStatic(cName, underlyingName, underlyingType)
//newValue
}
else {
ch << ALoad(0) << SWAP
// newValue, object, newValue
ch << PutField (cName, underlyingName, underlyingType)
//newValue
}
ch << Label(initLabel) // return lzy
//newValue
lzy.returnType match {
case Int32Type | BooleanType | UnitType =>
// Since the underlying field only has boxed types, we have to unbox them to return them
mkUnbox(lzy.returnType, ch)(NoLocals(isStatic))
ch << IRETURN
case _ : ClassType | _ : TupleType | _ : SetType | _ : MapType | _ : ArrayType | _: TypeParameter =>
ch << ARETURN
case other => throw CompilationException("Unsupported return type : " + other.getClass)
}
ch.freeze
}
}
/** Compile the (strict) field $field which is owned by class $owner */
def compileStrictField(field : FunDef, owner : Definition) = {
ctx.reporter.internalAssertion(field.canBeStrictField,
s"Trying to compile ${field.id.name} as a strict field")
val (_, fieldName, _) = leonFunDefToJVMInfo(field).get
val cf = classes(owner)
val fh = cf.addField(typeToJVM(field.returnType),fieldName)
fh.setFlags( owner match {
case _ : ModuleDef => (
FIELD_ACC_STATIC |
FIELD_ACC_PUBLIC | // FIXME
FIELD_ACC_FINAL
).asInstanceOf[U2]
case _ => (
FIELD_ACC_PUBLIC | // FIXME
FIELD_ACC_FINAL
).asInstanceOf[U2]
})
}
/** Initializes a lazy field to null
* @param ch the codehandler to add the initializing code to
* @param className the name of the class in which the field is initialized
* @param lzy the lazy field to be initialized
* @param isStatic true if this is a static field
*/
def initLazyField(ch: CodeHandler, className : String, lzy : FunDef, isStatic: Boolean) = {
val (_, name, _) = leonFunDefToJVMInfo(lzy).get
val underlyingName = underlyingField(name)
val jvmType = typeToJVMBoxed(lzy.returnType)
if (isStatic){
ch << ACONST_NULL << PutStatic(className, underlyingName, jvmType)
} else {
ch << ALoad(0) << ACONST_NULL << PutField(className, underlyingName, jvmType)
}
}
/** Initializes a (strict) field
* @param ch the codehandler to add the initializing code to
* @param className the name of the class in which the field is initialized
* @param field the field to be initialized
* @param isStatic true if this is a static field
*/
def initStrictField(ch : CodeHandler, className : String, field: FunDef, isStatic: Boolean) {
val (_, name , _) = leonFunDefToJVMInfo(field).get
val body = field.body.getOrElse(throw CompilationException("No body for field?"))
val jvmType = typeToJVM(field.returnType)
mkExpr(purescala.TreeOps.matchToIfThenElse(body), ch)(NoLocals(isStatic)) // FIXME Locals?
if (isStatic){
ch << PutStatic(className, name, jvmType)
} else {
ch << ALoad(0) << SWAP << PutField (className, name, jvmType)
}
}
def compileAbstractClassDef(acd : AbstractClassDef) {
val cName = defToJVMName(acd) val cName = defToJVMName(acd)
val cf = classes(acd) val cf = classes(acd)
...@@ -621,7 +920,56 @@ trait CodeGeneration { ...@@ -621,7 +920,56 @@ trait CodeGeneration {
cf.addInterface(CaseClassClass) cf.addInterface(CaseClassClass)
cf.addDefaultConstructor // add special monitor for method invocations
if (params.doInstrument) {
val fh = cf.addField("I", instrumentedField)
fh.setFlags(FIELD_ACC_PUBLIC)
}
val (fields, methods) = acd.methods partition { _.canBeField }
val (strictFields, lazyFields) = fields partition { _.canBeStrictField }
// Compile methods
for (method <- methods) {
compileFunDef(method,acd)
}
// Compile lazy fields
for (lzy <- lazyFields) {
compileLazyField(lzy, acd)
}
// Compile strict fields
for (field <- strictFields) {
compileStrictField(field, acd)
}
// definition of the constructor
if (fields.isEmpty && !params.doInstrument && !params.requireMonitor) cf.addDefaultConstructor else {
val constrParams = if (params.requireMonitor) {
Seq("L" + MonitorClass + ";")
} else Seq()
val cch = cf.addConstructor(constrParams : _*).codeHandler
// Abstract classes are hierarchy roots, so call java.lang.Object constructor
cch << ALoad(0)
cch << InvokeSpecial("java/lang/Object", constructorName, "()V")
// Initialize special monitor field
if (params.doInstrument) {
cch << ALoad(0)
cch << Ldc(0)
cch << PutField(cName, instrumentedField, "I")
}
for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, false) }
for (field <- strictFields) { initStrictField(cch, cName, field, false)}
cch << RETURN
cch.freeze
}
} }
/** /**
...@@ -629,9 +977,8 @@ trait CodeGeneration { ...@@ -629,9 +977,8 @@ trait CodeGeneration {
*/ */
val instrumentedField = "__read" val instrumentedField = "__read"
def instrumentedGetField(ch: CodeHandler, cct: CaseClassType, id: Identifier)(implicit locals: Locals): Unit = { def instrumentedGetField(ch: CodeHandler, cct: ClassType, id: Identifier)(implicit locals: Locals): Unit = {
val ccd = cct.classDef val ccd = cct.classDef
ccd.fields.zipWithIndex.find(_._1.id == id) match { ccd.fields.zipWithIndex.find(_._1.id == id) match {
case Some((f, i)) => case Some((f, i)) =>
val expType = cct.fields(i).tpe val expType = cct.fields(i).tpe
...@@ -658,12 +1005,15 @@ trait CodeGeneration { ...@@ -658,12 +1005,15 @@ trait CodeGeneration {
} }
} }
def compileCaseClassDef(ccd: CaseClassDef) { def compileCaseClassDef(ccd: CaseClassDef) {
val cName = defToJVMName(ccd) val cName = defToJVMName(ccd)
val pName = ccd.parent.map(parent => defToJVMName(parent.classDef)) val pName = ccd.parent.map(parent => defToJVMName(parent.classDef))
// An instantiation of ccd with its own type parameters
val cct = CaseClassType(ccd, ccd.tparams.map(_.tp)) val cct = CaseClassType(ccd, ccd.tparams.map(_.tp))
val cf = classes(ccd) val cf = classes(ccd)
cf.setFlags(( cf.setFlags((
...@@ -676,48 +1026,94 @@ trait CodeGeneration { ...@@ -676,48 +1026,94 @@ trait CodeGeneration {
cf.addInterface(CaseClassClass) cf.addInterface(CaseClassClass)
} }
val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.tpe)) } locally {
// definition of the constructor val (fields, methods) = ccd.methods partition { _.canBeField }
if(!params.doInstrument && ccd.fields.isEmpty) { val (strictFields, lazyFields) = fields partition { _.canBeStrictField }
cf.addDefaultConstructor
} else { // Compile methods
for((nme, jvmt) <- namesTypes) { for (method <- methods) {
val fh = cf.addField(jvmt, nme) compileFunDef(method,ccd)
fh.setFlags((
FIELD_ACC_PUBLIC |
FIELD_ACC_FINAL
).asInstanceOf[U2])
} }
if (params.doInstrument) { // Compile lazy fields
val fh = cf.addField("I", instrumentedField) for (lzy <- lazyFields) {
fh.setFlags(FIELD_ACC_PUBLIC) compileLazyField(lzy, ccd)
} }
val cch = cf.addConstructor(namesTypes.map(_._2).toList).codeHandler // Compile strict fields
for (field <- strictFields) {
cch << ALoad(0) compileStrictField(field, ccd)
cch << InvokeSpecial(pName.getOrElse("java/lang/Object"), constructorName, "()V")
if (params.doInstrument) {
cch << ALoad(0)
cch << Ldc(0)
cch << PutField(cName, instrumentedField, "I")
} }
// Case class parameters
val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.tpe)) }
// definition of the constructor
if(!params.doInstrument && !params.requireMonitor && ccd.fields.isEmpty && ccd.methods.filter{ _.canBeField }.isEmpty) {
cf.addDefaultConstructor
} else {
for((nme, jvmt) <- namesTypes) {
val fh = cf.addField(jvmt, nme)
fh.setFlags((
FIELD_ACC_PUBLIC |
FIELD_ACC_FINAL
).asInstanceOf[U2])
}
if (params.doInstrument) {
val fh = cf.addField("I", instrumentedField)
fh.setFlags(FIELD_ACC_PUBLIC)
}
// If we are monitoring function calls, we have an extra argument on the constructor
val realArgs = if (params.requireMonitor) {
("L" + MonitorClass + ";") +: (namesTypes map (_._2))
} else (namesTypes map (_._2))
// Offset of the first Scala parameter of the constructor
val paramOffset = if (params.requireMonitor) 2 else 1
val cch = cf.addConstructor(realArgs.toList).codeHandler
if (params.doInstrument) {
cch << ALoad(0)
cch << Ldc(0)
cch << PutField(cName, instrumentedField, "I")
}
var c = paramOffset
for((nme, jvmt) <- namesTypes) {
cch << ALoad(0)
cch << (jvmt match {
case "I" | "Z" => ILoad(c)
case _ => ALoad(c)
})
cch << PutField(cName, nme, jvmt)
c += 1
}
// Call parent constructor AFTER initializing case class parameters
if (ccd.parent.isDefined) {
// Load this
cch << ALoad(0)
// Load monitor object
if (params.requireMonitor) cch << ALoad(1)
val constrSig = if (params.requireMonitor) "(L" + MonitorClass + ";)V" else "()V"
cch << InvokeSpecial(pName.get, constructorName, constrSig)
} else {
// Call constructor of java.lang.Object
cch << ALoad(0)
cch << InvokeSpecial("java/lang/Object", constructorName, "()V")
}
var c = 1
for((nme, jvmt) <- namesTypes) { // Now initialize fields
cch << ALoad(0) for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, false)}
cch << (jvmt match { for (field <- strictFields) { initStrictField(cch, cName , field, false)}
case "I" | "Z" => ILoad(c) cch << RETURN
case _ => ALoad(c) cch.freeze
})
cch << PutField(cName, nme, jvmt)
c += 1
} }
cch << RETURN
cch.freeze
} }
locally { locally {
...@@ -764,8 +1160,12 @@ trait CodeGeneration { ...@@ -764,8 +1160,12 @@ trait CodeGeneration {
pech << DUP pech << DUP
pech << Ldc(i) pech << Ldc(i)
pech << ALoad(0) pech << ALoad(0)
instrumentedGetField(pech, cct, f.id)(NoLocals) // WARNING: Passing NoLocals(false) is kind of a hack,
mkBox(f.tpe, pech)(NoLocals) // since there is no monitor object anywhere in this method.
// We are saved because it is not used anywhere,
// but beware if you decide to add any mkExpr and the like.
instrumentedGetField(pech, cct, f.id)(NoLocals(false))
mkBox(f.tpe, pech)(NoLocals(false))
pech << AASTORE pech << AASTORE
} }
...@@ -798,10 +1198,14 @@ trait CodeGeneration { ...@@ -798,10 +1198,14 @@ trait CodeGeneration {
ech << ALoad(1) << CheckCast(cName) << AStore(castSlot) ech << ALoad(1) << CheckCast(cName) << AStore(castSlot)
for(vd <- ccd.fields) { for(vd <- ccd.fields) {
// WARNING: Passing NoLocals(false) is kind of a hack,
// since there is no monitor object anywhere in this method.
// We are saved because it is not used anywhere,
// but beware if you decide to add any mkExpr and the like.
ech << ALoad(0) ech << ALoad(0)
instrumentedGetField(ech, cct, vd.id)(NoLocals) instrumentedGetField(ech, cct, vd.id)(NoLocals(false))
ech << ALoad(castSlot) ech << ALoad(castSlot)
instrumentedGetField(ech, cct, vd.id)(NoLocals) instrumentedGetField(ech, cct, vd.id)(NoLocals(false))
typeToJVM(vd.getType) match { typeToJVM(vd.getType) match {
case "I" | "Z" => case "I" | "Z" =>
......
...@@ -25,8 +25,8 @@ class CompilationUnit(val ctx: LeonContext, ...@@ -25,8 +25,8 @@ class CompilationUnit(val ctx: LeonContext,
val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader) val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader)
var classes = Map[Definition, ClassFile]() var classes = Map[Definition, ClassFile]()
var defToModule = Map[Definition, ModuleDef]() var defToModuleOrClass = Map[Definition, Definition]()
def defineClass(df: Definition) { def defineClass(df: Definition) {
val cName = defToJVMName(df) val cName = defToJVMName(df)
...@@ -55,7 +55,8 @@ class CompilationUnit(val ctx: LeonContext, ...@@ -55,7 +55,8 @@ class CompilationUnit(val ctx: LeonContext,
def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = { def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = {
classes.get(cd) match { classes.get(cd) match {
case Some(cf) => case Some(cf) =>
val sig = "(" + cd.fields.map(f => typeToJVM(f.tpe)).mkString("") + ")V" val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else ""
val sig = "(" + monitorType + cd.fields.map(f => typeToJVM(f.tpe)).mkString("") + ")V"
Some((cf.className, sig)) Some((cf.className, sig))
case _ => None case _ => None
} }
...@@ -64,13 +65,20 @@ class CompilationUnit(val ctx: LeonContext, ...@@ -64,13 +65,20 @@ class CompilationUnit(val ctx: LeonContext,
// Returns className, methodName, methodSignature // Returns className, methodName, methodSignature
private[this] var funDefInfo = Map[FunDef, (String, String, String)]() private[this] var funDefInfo = Map[FunDef, (String, String, String)]()
/**
* Returns (cn, mn, sig) where
* cn is the module name
* mn is the safe method name
* sig is the method signature
*/
def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = { def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = {
funDefInfo.get(fd).orElse { funDefInfo.get(fd).orElse {
val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else "" val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else ""
val sig = "(" + monitorType + fd.params.map(a => typeToJVM(a.tpe)).mkString("") + ")" + typeToJVM(fd.returnType) val sig = "(" + monitorType + fd.params.map(a => typeToJVM(a.tpe)).mkString("") + ")" + typeToJVM(fd.returnType)
defToModule.get(fd).flatMap(m => classes.get(m)) match { defToModuleOrClass.get(fd).flatMap(m => classes.get(m)) match {
case Some(cf) => case Some(cf) =>
val res = (cf.className, idToSafeJVMName(fd.id), sig) val res = (cf.className, idToSafeJVMName(fd.id), sig)
funDefInfo += fd -> res funDefInfo += fd -> res
...@@ -232,7 +240,7 @@ class CompilationUnit(val ctx: LeonContext, ...@@ -232,7 +240,7 @@ class CompilationUnit(val ctx: LeonContext,
val exprToCompile = purescala.TreeOps.matchToIfThenElse(e) val exprToCompile = purescala.TreeOps.matchToIfThenElse(e)
mkExpr(e, ch)(Locals(newMapping)) mkExpr(e, ch)(Locals(newMapping, true))
e.getType match { e.getType match {
case Int32Type | BooleanType => case Int32Type | BooleanType =>
...@@ -254,64 +262,105 @@ class CompilationUnit(val ctx: LeonContext, ...@@ -254,64 +262,105 @@ class CompilationUnit(val ctx: LeonContext,
def compileModule(module: ModuleDef) { def compileModule(module: ModuleDef) {
val cf = classes(module) val cf = classes(module)
cf.addDefaultConstructor
cf.setFlags(( cf.setFlags((
CLASS_ACC_SUPER | CLASS_ACC_SUPER |
CLASS_ACC_PUBLIC | CLASS_ACC_PUBLIC |
CLASS_ACC_FINAL CLASS_ACC_FINAL
).asInstanceOf[U2]) ).asInstanceOf[U2])
for(funDef <- module.definedFunctions; /*if (false) {
(_,mn,_) <- leonFunDefToJVMInfo(funDef)) { // currently we do not handle object fields
// this treats all fields as functions
val paramsTypes = funDef.params.map(a => typeToJVM(a.tpe)) for (fun <- module.definedFunctions) {
compileFunDef(fun, module)
val realParams = if (params.requireMonitor) {
("L" + MonitorClass + ";") +: paramsTypes
} else {
paramsTypes
} }
} else {*/
val m = cf.addMethod(
typeToJVM(funDef.returnType), val (fields, functions) = module.definedFunctions partition { _.canBeField }
mn, val (strictFields, lazyFields) = fields partition { _.canBeStrictField }
realParams : _*
) // Compile methods
m.setFlags(( for (function <- functions) {
METHOD_ACC_PUBLIC | compileFunDef(function,module)
METHOD_ACC_FINAL | }
METHOD_ACC_STATIC
// Compile lazy fields
for (lzy <- lazyFields) {
compileLazyField(lzy, module)
}
// Compile strict fields
for (field <- strictFields) {
compileStrictField(field, module)
}
// Constructor
cf.addDefaultConstructor
val cName = defToJVMName(module)
// Add class initializer method
locally{
val mh = cf.addMethod("V", "<clinit>")
mh.setFlags((
METHOD_ACC_STATIC |
METHOD_ACC_PUBLIC
).asInstanceOf[U2]) ).asInstanceOf[U2])
compileFunDef(funDef, m.codeHandler) val ch = mh.codeHandler
/*
* FIXME :
* Dirty hack to make this compatible with monitoring of method invocations.
* Because we don't have access to the monitor object here, we initialize a new one
* that will get lost when this method returns, so we can't hope to count
* method invocations here :(
*/
ch << New(MonitorClass) << DUP
ch << Ldc(Int.MaxValue) // Allow "infinite" method calls
ch << InvokeSpecial(MonitorClass, cafebabe.Defaults.constructorName, "(I)V")
ch << AStore(ch.getFreshVar) // position 0
for (lzy <- lazyFields) { initLazyField(ch, cName, lzy, true)}
for (field <- strictFields) { initStrictField(ch, cName , field, true)}
ch << RETURN
ch.freeze
} }
}
}
/** Traverses the program to find all definitions, and stores those in global variables */
def init() { def init() {
// First define all classes // First define all classes/ methods/ functions
for (m <- program.modules) { for (m <- program.modules) {
for ((parent, children) <- m.algebraicDataTypes) { for ( (parent, children) <- m.algebraicDataTypes;
defineClass(parent) cls <- Seq(parent) ++ children) {
defineClass(cls)
for (c <- children) { for (meth <- cls.methods) {
defineClass(c) defToModuleOrClass += meth -> cls
} }
} }
for(single <- m.singleCaseClasses) { for ( single <- m.singleCaseClasses ) {
defineClass(single) defineClass(single)
for (meth <- single.methods) {
defToModuleOrClass += meth -> single
}
}
for(funDef <- m.definedFunctions) {
defToModuleOrClass += funDef -> m
} }
defineClass(m) defineClass(m)
} }
} }
/** Compiles the program. Uses information provided by $init */
def compile() { def compile() {
// Compile everything // Compile everything
for (m <- program.modules) { for (m <- program.modules) {
for ((parent, children) <- m.algebraicDataTypes) { for ((parent, children) <- m.algebraicDataTypes) {
compileAbstractClassDef(parent) compileAbstractClassDef(parent)
...@@ -324,9 +373,7 @@ class CompilationUnit(val ctx: LeonContext, ...@@ -324,9 +373,7 @@ class CompilationUnit(val ctx: LeonContext,
compileCaseClassDef(single) compileCaseClassDef(single)
} }
for(funDef <- m.definedFunctions) {
defToModule += funDef -> m
}
} }
for (m <- program.modules) { for (m <- program.modules) {
......
package leon.test.codegen
import leon._
import leon.codegen._
import leon.purescala.Definitions._
import leon.purescala.Trees._
import leon.evaluators.{CodeGenEvaluator,EvaluationResults}
import EvaluationResults._
import java.io._
case class TestCase(
name : String,
content : String,
expected : Expr,
args : Seq[Expr] = Seq(),
functionToTest : String = "test"
)
class CodeGenTests extends test.LeonTestSuite {
val catchAll = true
val pipeline =
utils.TemporaryInputPhase andThen
frontends.scalac.ExtractionPhase andThen
utils.ScopingPhase andThen
purescala.MethodLifting andThen
utils.TypingPhase andThen
purescala.CompleteAbstractDefinitions andThen
purescala.RestoreMethods
def compileTestFun(p : Program, toTest : String, ctx : LeonContext, requireMonitor : Boolean, doInstrument : Boolean) : ( Seq[Expr] => EvaluationResults.Result) = {
// We want to produce code that checks contracts
val evaluator = new CodeGenEvaluator(ctx, p, CodeGenParams(
maxFunctionInvocations = if (requireMonitor) 1000 else -1, // Monitor calls and abort execution if more than X calls
checkContracts = true, // Generate calls that checks pre/postconditions
doInstrument = doInstrument // Instrument reads to case classes (mainly for vanuatoo)
))
val testFun = p.definedFunctions.find(_.id.name == toTest).getOrElse {
ctx.reporter.fatalError("Test function not defined!")
}
val params = testFun.params map { _.id }
val body = testFun.body.get
// Will apply test a number of times with the help of compileRec
evaluator.compile(body, params).getOrElse {
ctx.reporter.fatalError("Failed to compile test function!")
}
}
private def testCodeGen(prog : TestCase, requireMonitor : Boolean, doInstrument : Boolean) { test(prog.name) {
import prog._
val settings = testContext.settings.copy(injectLibrary = false)
val ctx = testContext.copy(
// We want a reporter that actually prints some output
reporter = new DefaultReporter(settings),
settings = settings
)
val ast = pipeline.run(ctx)( (content, List()) )
//ctx.reporter.info(purescala.ScalaPrinter(ast))
val compiled = compileTestFun(ast, functionToTest, ctx, requireMonitor, doInstrument)
try { compiled(args) match {
case Successful(res) if res == expected =>
// Success
case RuntimeError(_) | EvaluatorError(_) if expected.isInstanceOf[Error] =>
// Success
case Successful(res) =>
ctx.reporter.fatalError(s"""
Program $name produced wrong output.
Output was ${res.toString}
Expected was ${expected.toString}
""".stripMargin)
case RuntimeError(mes) =>
ctx.reporter.fatalError(s"Program $name threw runtime error with message $mes")
case EvaluatorError(res) =>
ctx.reporter.fatalError(s"Evaluator failed for program $name with message $res")
}} catch {
// Currently, this is what we would like to catch and still succeed, but there might be more
case _ : LeonFatalError | _ : StackOverflowError if expected.isInstanceOf[Error] =>
// Success
case th : Throwable =>
if (catchAll) {
// This is to be able to continue testing after an error
ctx.reporter.fatalError(s"Program $name failed\n${th.printStackTrace()}")// with message ${th.getMessage()}")
} else { throw th }
}
}}
val programs = Seq(
TestCase("simple", """
object simple {
abstract class Abs
case class Conc(x : Int) extends Abs
def test = {
val c = Conc(1)
c.x
}
}""",
IntLiteral(1)
),
TestCase("eager", """
object eager {
abstract class Abs() {
val foo = 42
}
case class Conc(x : Int) extends Abs()
def foo = {
val c = Conc(1)
c.foo + c.x
}
def test = foo
}""",
IntLiteral(43)
),
TestCase("this", """
object thiss {
case class Bar() {
def boo = this
def toRet = 42
}
def test = Bar().boo.toRet
}
""",
IntLiteral(42)
),
TestCase("oldStuff", """
object oldStuff {
def test = 1
case class Bar() {
def boo = 2
}
}""",
IntLiteral(1)
),
TestCase("methSimple", """
object methSimple {
sealed abstract class Ab {
def f2(x : Int) = x + 5
}
case class Con() extends Ab { }
def test = Con().f2(5)
}""",
IntLiteral(10)
),
TestCase("methods", """
object methods {
def f1 = 4
sealed abstract class Ab {
def f2(x : Int) = Cs().f3(1,2) + f1 + x + 5
}
case class Con() extends Ab {}
case class Cs() {
def f3(x : Int, y : Int) = x + y
}
def test = Con().f2(3)
}""",
IntLiteral(15)
),
TestCase("lazy", """
object lazyFields {
def foo = 1
sealed abstract class Ab {
lazy val x : Int = this match {
case Conc(t) => t + 1
case Conc2(t) => t+2
}
}
case class Conc(t : Int) extends Ab { }
case class Conc2(t : Int) extends Ab { }
def test = foo + Conc(5).x + Conc2(6).x
}
""",
IntLiteral(1 + 5 + 1 + 6 + 2)
),
TestCase("modules", """
object modules {
def foo = 1
val bar = 2
lazy val baz = 0
def test = foo + bar + baz
}
""",
IntLiteral(1 + 2 + 0)
),
TestCase("lazyISLazy" , """
object lazyISLazy {
abstract class Ab { lazy val x : Int = foo; def foo : Int = foo }
case class Conc() extends Ab { }
def test = { val willNotLoop = Conc(); 42 }
}""",
IntLiteral(42)
),
TestCase("ListWithSize" , """
object list {
abstract class List[T] {
val length : Int = this match {
case Nil() => 0
case Cons (_, xs ) => 1 + xs.length
}
}
case class Cons[T](hd : T, tl : List[T]) extends List[T]
case class Nil[T]() extends List[T]
val l = Cons(1, Cons(2, Cons(3, Nil())))
def test = l.length + Nil().length
}""",
IntLiteral(3 )
),
TestCase("ListWithSumMono" , """
object ListWithSumMono {
abstract class List
case class Cons(hd : Int, tl : List) extends List
case class Nil() extends List
def sum (l : List) : Int = l match {
case Nil() => 0
case Cons(x, xs) => x + sum(xs)
}
val l = Cons(1, Cons(2, Cons(3, Nil())))
def test = sum(l)
}""",
IntLiteral(1 + 2 + 3)
),
TestCase("poly" , """
object poly {
case class Poly[T](poly : T)
def ex = Poly(42)
def test = ex.poly
}""",
IntLiteral(42)
),
TestCase("ListHead" , """
object ListHead {
abstract class List[T]
case class Cons[T](hd : T, tl : List[T]) extends List[T]
case class Nil[T]() extends List[T]
def l = Cons(1, Cons(2, Cons(3, Nil())))
def test = l.hd
}""",
IntLiteral(1)
),
TestCase("ListWithSum" , """
object ListWithSum {
abstract class List[T]
case class Cons[T](hd : T, tl : List[T]) extends List[T]
case class Nil[T]() extends List[T]
def sum (l : List[Int]) : Int = l match {
case Nil() => 0
case Cons(x, xs) => x + sum(xs)
}
val l = Cons(1, Cons(2, Cons(3, Nil())))
def test = sum(l)
}""",
IntLiteral(1 + 2 + 3)
),
// This one loops!
TestCase("lazyLoops" , """
object lazyLoops {
abstract class Ab { lazy val x : Int = foo; def foo : Int = foo }
case class Conc() extends Ab { }
def test = Conc().x
}""",
Error("Looping")
),
TestCase("Lazier" , """
import leon.lang._
object Lazier {
abstract class List[T] {
lazy val tail = this match {
case Nil() => error[List[T]]("Nil.tail")
case Cons(_, tl) => tl
}
}
case class Cons[T](hd : T, tl : List[T]) extends List[T]
case class Nil[T]() extends List[T]
def sum (l : List[Int]) : Int = l match {
case Nil() => 0
case c : Cons[Int] => c.hd + sum(c.tail)
}
val l = Cons(1, Cons(2, Cons(3, Nil())))
def test = sum(l)
}""",
IntLiteral(1 + 2 + 3)
)
)
for ( prog <- programs ;
requireMonitor <- Seq(false ,true );
doInstrument <- Seq(false,true )
) {
testCodeGen(
prog.copy(name = prog.name + (if (requireMonitor)"_M_" else "" ) + (if (doInstrument)"_I_" else "" )),
requireMonitor, doInstrument
)}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment