/* Copyright 2009-2015 EPFL, Lausanne */ package leon package codegen import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ import purescala.TypeOps.typeParamsOf import purescala.Extractors._ import purescala.Constructors._ import utils.UniqueCounter import runtime.{Monitor, StdMonitor} import cafebabe._ import cafebabe.AbstractByteCodes._ import cafebabe.ByteCodes._ import cafebabe.ClassFileTypes._ import cafebabe.Flags._ import scala.collection.JavaConverters._ import java.lang.reflect.Constructor import synthesis.Problem class CompilationUnit(val ctx: LeonContext, val program: Program, val params: CodeGenParams = CodeGenParams.default) extends CodeGeneration { protected[codegen] val requireQuantification = program.definedFunctions.exists { fd => exists { case _: Forall => true case _ => false } (fd.fullBody) } val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader) var classes = Map[Definition, ClassFile]() var defToModuleOrClass = Map[Definition, Definition]() val abstractFunDefs = program.definedFunctions.filter(_.body.isEmpty).map(_.id).toSet val runtimeCounter = new UniqueCounter[Unit] var runtimeTypeToIdMap = Map[TypeTree, Int]() var runtimeIdToTypeMap = Map[Int, TypeTree]() def registerType(tpe: TypeTree): Int = runtimeTypeToIdMap.get(tpe) match { case Some(id) => id case None => val id = runtimeCounter.nextGlobal runtimeTypeToIdMap += tpe -> id runtimeIdToTypeMap += id -> tpe id } var runtimeProblemMap = Map[Int, (Seq[TypeParameter], Problem)]() def registerProblem(p: Problem, tps: Seq[TypeParameter]): Int = { val id = runtimeCounter.nextGlobal runtimeProblemMap += id -> (tps, p) id } var runtimeForallMap = Map[Int, (Seq[TypeParameter], Forall)]() def registerForall(f: Forall, tps: Seq[TypeParameter]): Int = { val id = runtimeCounter.nextGlobal runtimeForallMap += id -> (tps, f) id } var runtimeAbstractMap = Map[Int, FunDef]() def registerAbstractFD(fd: FunDef): Int = { val id = runtimeCounter.nextGlobal runtimeAbstractMap += id -> fd id } def defineClass(df: Definition) { val cName = defToJVMName(df) val cf = df match { case cd: ClassDef => val pName = cd.parent.map(parent => defToJVMName(parent.classDef)) new ClassFile(cName, pName) case ob: ModuleDef => new ClassFile(cName, None) case _ => sys.error("Unhandled definition type") } classes += df -> cf } def jvmClassToLeonClass(name: String): Option[Definition] = { classes.find(_._2.className == name).map(_._1) } def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = { classes.get(cd) match { case Some(cf) => val sig = "(L"+MonitorClass+";" + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" Some((cf.className, sig)) case _ => None } } // Returns className, methodName, methodSignature private[this] var funDefInfo = Map[FunDef, (String, String, String)]() /** * Returns (cn, mn, sig) where * cn is the module name * mn is the safe method name * sig is the method signature */ def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = { funDefInfo.get(fd).orElse { val sig = "(L"+MonitorClass+";" + (if (fd.tparams.nonEmpty) "[I" else "") + fd.params.map(a => typeToJVM(a.getType)).mkString("") + ")" + typeToJVM(fd.returnType) defToModuleOrClass.get(fd).flatMap(m => classes.get(m)) match { case Some(cf) => val res = (cf.className, idToSafeJVMName(fd.id), sig) funDefInfo += fd -> res Some(res) case None => None } } } // Get the Java constructor corresponding to the Case class private[this] var ccdConstructors = Map[CaseClassDef, Constructor[_]]() private[this] def caseClassConstructor(ccd: CaseClassDef): Option[Constructor[_]] = { ccdConstructors.get(ccd).orElse { classes.get(ccd) match { case Some(cf) => val klass = loader.loadClass(cf.className) // This is a hack: we pick the constructor with the most arguments. val conss = klass.getConstructors.sortBy(_.getParameterTypes.length) assert(conss.nonEmpty) val cons = conss.last ccdConstructors += ccd -> cons Some(cons) case None => None } } } private[this] lazy val tupleConstructor: Constructor[_] = { val tc = loader.loadClass("leon.codegen.runtime.Tuple") val conss = tc.getConstructors.sortBy(_.getParameterTypes.length) assert(conss.nonEmpty) conss.last } def getMonitor(model: solvers.Model, maxInvocations: Int): Monitor = { val bodies = model.toSeq.filter { case (id, v) => abstractFunDefs(id) }.toMap val domains = model match { case hm: solvers.HenkinModel => Some(hm.doms) case _ => None } new StdMonitor(this, maxInvocations, bodies, domains) } /** Translates Leon values (not generic expressions) to JVM compatible objects. * * Currently, this method is only used to prepare arguments to reflective calls. * This means it is safe to return AnyRef (as opposed to primitive types), because * reflection needs this anyway. */ def valueToJVM(e: Expr)(implicit monitor: Monitor): AnyRef = e match { case IntLiteral(v) => new java.lang.Integer(v) case BooleanLiteral(v) => new java.lang.Boolean(v) case UnitLiteral() => new java.lang.Boolean(true) case CharLiteral(c) => new Character(c) case InfiniteIntegerLiteral(v) => new runtime.BigInt(v.toString) case FractionalLiteral(n, d) => new runtime.Rational(n.toString, d.toString) case StringLiteral(v) => new java.lang.String(v) case GenericValue(tp, id) => e case Tuple(elems) => tupleConstructor.newInstance(elems.map(valueToJVM).toArray).asInstanceOf[AnyRef] case CaseClass(cct, args) => caseClassConstructor(cct.classDef) match { case Some(cons) => val jvmArgs = monitor +: args.map(valueToJVM) cons.newInstance(jvmArgs.toArray : _*).asInstanceOf[AnyRef] case None => ctx.reporter.fatalError("Case class constructor not found?!?") } // For now, we only treat boolean arrays separately. // We have a use for these, mind you. //case f @ FiniteArray(exprs) if f.getType == ArrayType(BooleanType) => // exprs.map(e => exprToJVM(e).asInstanceOf[java.lang.Boolean].booleanValue).toArray case s @ FiniteSet(els, _) => val s = new leon.codegen.runtime.Set() for (e <- els) { s.add(valueToJVM(e)) } s case m @ FiniteMap(els, _, _) => val m = new leon.codegen.runtime.Map() for ((k,v) <- els) { m.add(valueToJVM(k), valueToJVM(v)) } m case f @ FiniteLambda(mapping, dflt, _) => val l = new leon.codegen.runtime.FiniteLambda(valueToJVM(dflt)) for ((ks,v) <- mapping) { // Force tuple even with 1/0 elems. val kJvm = tupleConstructor.newInstance(ks.map(valueToJVM).toArray).asInstanceOf[leon.codegen.runtime.Tuple] val vJvm = valueToJVM(v) l.add(kJvm,vJvm) } l case l @ Lambda(args, body) => val (afName, closures, tparams, consSig) = compileLambda(l) val args = closures.map { case (id, _) => if (id == monitorID) monitor else if (id == tpsID) typeParamsOf(l).toSeq.sortBy(_.id.uniqueName).map(registerType).toArray else throw CompilationException(s"Unexpected closure $id in Lambda compilation") } val lc = loader.loadClass(afName) val conss = lc.getConstructors.sortBy(_.getParameterTypes.length) println(conss) assert(conss.nonEmpty) val lambdaConstructor = conss.last println(args.toArray) lambdaConstructor.newInstance(args.toArray : _*).asInstanceOf[AnyRef] case f @ IsTyped(FiniteArray(elems, default, IntLiteral(length)), ArrayType(underlying)) => if (length < 0) { throw LeonFatalError( s"Whoops! Array ${f.asString(ctx)} has length $length. " + default.map { df => s"default: ${df.asString(ctx)}" }.getOrElse("") ) } import scala.reflect.ClassTag def allocArray[A: ClassTag](f: Expr => A): Array[A] = { val arr = new Array[A](length) for { df <- default.toSeq v = f(df) i <- 0 until length } { arr(i) = v } for ((ind, v) <- elems) { arr(ind) = f(v) } arr } underlying match { case Int32Type => allocArray { case IntLiteral(v) => v } case BooleanType => allocArray { case BooleanLiteral(b) => b } case UnitType => allocArray { case UnitLiteral() => true } case CharType => allocArray { case CharLiteral(c) => c } case _ => allocArray(valueToJVM) } case _ => throw CompilationException(s"Unexpected expression $e in valueToJVM") } /** Translates JVM objects back to Leon values of the appropriate type */ def jvmToValue(e: AnyRef, tpe: TypeTree): Expr = (e, tpe) match { case (i: Integer, Int32Type) => IntLiteral(i.toInt) case (c: runtime.BigInt, IntegerType) => InfiniteIntegerLiteral(BigInt(c.underlying)) case (c: runtime.Rational, RealType) => val num = BigInt(c.numerator()) val denom = BigInt(c.denominator()) FractionalLiteral(num, denom) case (b: java.lang.Boolean, BooleanType) => BooleanLiteral(b.booleanValue) case (c: java.lang.Character, CharType) => CharLiteral(c.toChar) case (c: java.lang.String, StringType) => StringLiteral(c) case (cc: runtime.CaseClass, ct: ClassType) => val fields = cc.productElements() // identify case class type of ct val cct = ct match { case cc: CaseClassType => cc case _ => jvmClassToLeonClass(cc.getClass.getName) match { case Some(cc: CaseClassDef) => CaseClassType(cc, ct.tps) case _ => throw CompilationException("Unable to identify class "+cc.getClass.getName+" to descendant of "+ct) } } CaseClass(cct, (fields zip cct.fieldsTypes).map { case (e, tpe) => jvmToValue(e, tpe) }) case (tpl: runtime.Tuple, tpe) => val stpe = unwrapTupleType(tpe, tpl.getArity) val elems = stpe.zipWithIndex.map { case (tpe, i) => jvmToValue(tpl.get(i), tpe) } tupleWrap(elems) case (gv @ GenericValue(gtp, id), tp: TypeParameter) => if (gtp == tp) gv else GenericValue(tp, id).copiedFrom(gv) case (set: runtime.Set, SetType(b)) => FiniteSet(set.getElements.asScala.map(jvmToValue(_, b)).toSet, b) case (map: runtime.Map, MapType(from, to)) => val pairs = map.getElements.asScala.map { entry => val k = jvmToValue(entry.getKey, from) val v = jvmToValue(entry.getValue, to) (k, v) }.toMap FiniteMap(pairs, from, to) case (lambda: runtime.FiniteLambda, ft @ FunctionType(from, to)) => val mapping = lambda.mapping.asScala.map { entry => val k = jvmToValue(entry._1, tupleTypeWrap(from)) val v = jvmToValue(entry._2, to) unwrapTuple(k, from.size) -> v } val dflt = jvmToValue(lambda.dflt, to) FiniteLambda(mapping.toSeq, dflt, ft) case (lambda: runtime.Lambda, _: FunctionType) => val cls = lambda.getClass val l = classToLambda(cls.getName) val closures = purescala.ExprOps.variablesOf(l).toSeq.sortBy(_.uniqueName) val closureVals = closures.map { id => val fieldVal = lambda.getClass.getField(id.uniqueName).get(lambda) jvmToValue(fieldVal, id.getType) } purescala.ExprOps.replaceFromIDs((closures zip closureVals).toMap, l) case (_, UnitType) => UnitLiteral() case (ar: Array[_], ArrayType(base)) => if (ar.length == 0) { EmptyArray(base) } else { val elems = for ((e: AnyRef, i) <- ar.zipWithIndex) yield { i -> jvmToValue(e, base) } NonemptyArray(elems.toMap, None) } case _ => throw CompilationException("Unsupported return value : " + e.getClass +" while expecting "+tpe) } def compileExpression(e: Expr, args: Seq[Identifier])(implicit ctx: LeonContext): CompiledExpression = { if(e.getType == Untyped) { throw new Unsupported(e, s"Cannot compile untyped expression.") } val id = exprCounter.nextGlobal val cName = "Leon$CodeGen$Expr$"+id val cf = new ClassFile(cName, None) cf.setFlags(( CLASS_ACC_PUBLIC | CLASS_ACC_FINAL ).asInstanceOf[U2]) cf.addDefaultConstructor val argsTypes = args.map(a => typeToJVM(a.getType)) val realArgs = ("L" + MonitorClass + ";") +: argsTypes val m = cf.addMethod( typeToJVM(e.getType), "eval", realArgs : _* ) m.setFlags(( METHOD_ACC_PUBLIC | METHOD_ACC_FINAL | METHOD_ACC_STATIC ).asInstanceOf[U2]) val ch = m.codeHandler val newMapping = Map(monitorID -> 0) ++ args.zipWithIndex.toMap.mapValues(_ + 1) mkExpr(e, ch)(NoLocals.withVars(newMapping)) e.getType match { case ValueType() => ch << IRETURN case _ => ch << ARETURN } ch.freeze loader.register(cf) new CompiledExpression(this, cf, e, args) } def compileModule(module: ModuleDef) { val cf = classes(module) cf.setFlags(( CLASS_ACC_SUPER | CLASS_ACC_PUBLIC | CLASS_ACC_FINAL ).asInstanceOf[U2]) val (fields, functions) = module.definedFunctions partition { _.canBeField } val (strictFields, lazyFields) = fields partition { _.canBeStrictField } // Compile methods for (function <- functions) { compileFunDef(function,module) } // Compile lazy fields for (lzy <- lazyFields) { compileLazyField(lzy, module) } // Compile strict fields for (field <- strictFields) { compileStrictField(field, module) } // Constructor cf.addDefaultConstructor val cName = defToJVMName(module) // Add class initializer method locally{ val mh = cf.addMethod("V", "<clinit>") mh.setFlags(( METHOD_ACC_STATIC | METHOD_ACC_PUBLIC ).asInstanceOf[U2]) val ch = mh.codeHandler /* * FIXME : * Dirty hack to make this compatible with monitoring of method invocations. * Because we don't have access to the monitor object here, we initialize a new one * that will get lost when this method returns, so we can't hope to count * method invocations here :( */ val locals = NoLocals.withVar(monitorID -> ch.getFreshVar) ch << New(NoMonitorClass) << DUP ch << InvokeSpecial(NoMonitorClass, cafebabe.Defaults.constructorName, "()V") ch << AStore(locals.varToLocal(monitorID).get) // position 0 for (lzy <- lazyFields) { initLazyField(ch, cName, lzy, isStatic = true)(locals) } for (field <- strictFields) { initStrictField(ch, cName , field, isStatic = true)(locals) } ch << RETURN ch.freeze } } /** Traverses the program to find all definitions, and stores those in global variables */ def init() { // First define all classes/ methods/ functions for (u <- program.units) { for { ch <- u.classHierarchies cls <- ch } { defineClass(cls) for (meth <- cls.methods) { defToModuleOrClass += meth -> cls } } for (m <- u.modules) { defineClass(m) for(funDef <- m.definedFunctions) { defToModuleOrClass += funDef -> m } } } } /** Compiles the program. * * Uses information provided by [[init]]. */ def compile() { // Compile everything for (u <- program.units) { for { ch <- u.classHierarchies c <- ch } c match { case acd: AbstractClassDef => compileAbstractClassDef(acd) case ccd: CaseClassDef => compileCaseClassDef(ccd) } for (m <- u.modules) compileModule(m) } classes.values.foreach(loader.register) } def writeClassFiles(prefix: String) { for ((d, cl) <- classes) { cl.writeToFile(prefix+cl.className + ".class") } } init() compile() } private [codegen] object exprCounter extends UniqueCounter[Unit] private [codegen] object lambdaCounter extends UniqueCounter[Unit] private [codegen] object forallCounter extends UniqueCounter[Unit]