Skip to content
Snippets Groups Projects
CompilationUnit.scala 14.02 KiB
/* Copyright 2009-2015 EPFL, Lausanne */

package leon
package codegen

import purescala.Common._
import purescala.Definitions._
import purescala.Expressions._
import purescala.ExprOps._
import purescala.Types._
import purescala.Extractors._
import purescala.Constructors._
import codegen.runtime.LeonCodeGenRuntimeMonitor
import codegen.runtime.LeonCodeGenRuntimeHenkinMonitor
import utils.UniqueCounter

import cafebabe._
import cafebabe.AbstractByteCodes._
import cafebabe.ByteCodes._
import cafebabe.ClassFileTypes._
import cafebabe.Flags._

import scala.collection.JavaConverters._

import java.lang.reflect.Constructor


class CompilationUnit(val ctx: LeonContext,
                      val program: Program,
                      val params: CodeGenParams = CodeGenParams.default) extends CodeGeneration {

  protected[codegen] val requireQuantification = program.definedFunctions.exists { fd =>
    exists { case _: Forall => true case _ => false } (fd.fullBody)
  }

  protected[codegen] val requireMonitor = params.requireMonitor || requireQuantification

  val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader)

  var classes = Map[Definition, ClassFile]()
  var defToModuleOrClass = Map[Definition, Definition]()

  def defineClass(df: Definition) {
    val cName = defToJVMName(df)

    val cf = df match {
      case cd: ClassDef =>
        val pName = cd.parent.map(parent => defToJVMName(parent.classDef))
        new ClassFile(cName, pName)

      case ob: ModuleDef =>
        new ClassFile(cName, None)

      case _ =>
        sys.error("Unhandled definition type")
    }

    classes += df -> cf
  }

  def jvmClassToLeonClass(name: String): Option[Definition] = {
    classes.find(_._2.className == name).map(_._1)
  }

  def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = {
    classes.get(cd) match {
      case Some(cf) =>
        val monitorType = if (requireMonitor) "L"+MonitorClass+";" else ""
        val sig = "(" + monitorType + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V"
        Some((cf.className, sig))
      case _ => None
    }
  }

  // Returns className, methodName, methodSignature
  private[this] var funDefInfo = Map[FunDef, (String, String, String)]()

  
  /**
   * Returns (cn, mn, sig) where 
   *  cn is the module name
   *  mn is the safe method name
   *  sig is the method signature
   */
  def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = {
    funDefInfo.get(fd).orElse {
      val monitorType = if (requireMonitor) "L"+MonitorClass+";" else ""

      val sig = "(" + monitorType + fd.params.map(a => typeToJVM(a.getType)).mkString("") + ")" + typeToJVM(fd.returnType)

      defToModuleOrClass.get(fd).flatMap(m => classes.get(m)) match {
        case Some(cf) =>
          val res = (cf.className, idToSafeJVMName(fd.id), sig)
          funDefInfo += fd -> res
          Some(res)
        case None =>
          None
      }
    }
  }

  // Get the Java constructor corresponding to the Case class
  private[this] var ccdConstructors = Map[CaseClassDef, Constructor[_]]()

  private[this] def caseClassConstructor(ccd: CaseClassDef): Option[Constructor[_]] = {
    ccdConstructors.get(ccd).orElse {
      classes.get(ccd) match {
        case Some(cf) =>
          val klass = loader.loadClass(cf.className)
          // This is a hack: we pick the constructor with the most arguments.
          val conss = klass.getConstructors.sortBy(_.getParameterTypes.length)
          assert(conss.nonEmpty)
          val cons = conss.last

          ccdConstructors += ccd -> cons
          Some(cons)
        case None =>
          None
      }
    }
  }

  private[this] lazy val tupleConstructor: Constructor[_] = {
    val tc = loader.loadClass("leon.codegen.runtime.Tuple")
    val conss = tc.getConstructors.sortBy(_.getParameterTypes.length)
    assert(conss.nonEmpty)
    conss.last
  }

  def modelToJVM(model: solvers.Model, maxInvocations: Int): LeonCodeGenRuntimeMonitor = model match {
    case hModel: solvers.HenkinModel =>
      val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations)
      for ((tpe, domain) <- hModel.domains; args <- domain) {
        val tpeId = typeId(tpe)
        // note here that it doesn't matter that `lhm` doesn't yet have its domains
        // filled since all values in `args` should be grounded
        val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple]
        lhm.add(tpeId, inputJvm)
      }
      lhm
    case _ =>
      new LeonCodeGenRuntimeMonitor(maxInvocations)
  }

  /** Translates Leon values (not generic expressions) to JVM compatible objects.
    *
    * Currently, this method is only used to prepare arguments to reflective calls.
    * This means it is safe to return AnyRef (as opposed to primitive types), because
    * reflection needs this anyway.
    */
  def valueToJVM(e: Expr)(implicit monitor: LeonCodeGenRuntimeMonitor): AnyRef = e match {
    case IntLiteral(v) =>
      new java.lang.Integer(v)

    case BooleanLiteral(v) =>
      new java.lang.Boolean(v)

    case UnitLiteral() =>
      new java.lang.Boolean(true)

    case CharLiteral(c) =>
      new Character(c)

    case InfiniteIntegerLiteral(v) =>
      new runtime.BigInt(v.toString)

    case RealLiteral(v) =>
      new runtime.Real(v.toString)

    case GenericValue(tp, id) =>
      e

    case Tuple(elems) =>
      tupleConstructor.newInstance(elems.map(valueToJVM).toArray).asInstanceOf[AnyRef]

    case CaseClass(cct, args) =>
      caseClassConstructor(cct.classDef) match {
        case Some(cons) =>
          val realArgs = if (requireMonitor) monitor +: args.map(valueToJVM) else  args.map(valueToJVM)
          cons.newInstance(realArgs.toArray : _*).asInstanceOf[AnyRef]
        case None =>
          ctx.reporter.fatalError("Case class constructor not found?!?")
      }

    // For now, we only treat boolean arrays separately.
    // We have a use for these, mind you.
    //case f @ FiniteArray(exprs) if f.getType == ArrayType(BooleanType) =>
    //  exprs.map(e => exprToJVM(e).asInstanceOf[java.lang.Boolean].booleanValue).toArray

    case s @ FiniteSet(els, _) =>
      val s = new leon.codegen.runtime.Set()
      for (e <- els) {
        s.add(valueToJVM(e))
      }
      s

    case m @ FiniteMap(els, _, _) =>
      val m = new leon.codegen.runtime.Map()
      for ((k,v) <- els) {
        m.add(valueToJVM(k), valueToJVM(v))
      }
      m

    case f @ PartialLambda(mapping, _) =>
      val l = new leon.codegen.runtime.PartialLambda()
      for ((ks,v) <- mapping) {
        // Force tuple even with 1/0 elems.
        val kJvm = tupleConstructor.newInstance(ks.map(valueToJVM).toArray).asInstanceOf[leon.codegen.runtime.Tuple]
        val vJvm = valueToJVM(v)
        l.add(kJvm,vJvm)
      }
      l

    case _ =>
      throw CompilationException(s"Unexpected expression $e in valueToJVM")

    // Just slightly overkill...
    //case _ =>
    //  compileExpression(e, Seq()).evalToJVM(Seq(),monitor)
  }
  
  /** Translates JVM objects back to Leon values of the appropriate type */
  def jvmToValue(e: AnyRef, tpe: TypeTree): Expr = (e, tpe) match {
    case (i: Integer, Int32Type) =>
      IntLiteral(i.toInt)

    case (c: runtime.BigInt, IntegerType) =>
      InfiniteIntegerLiteral(BigInt(c.underlying))

    case (c: runtime.Real, RealType) =>
      RealLiteral(BigDecimal(c.underlying))

    case (b: java.lang.Boolean, BooleanType) =>
      BooleanLiteral(b.booleanValue)

    case (c: java.lang.Character, CharType) =>
      CharLiteral(c.toChar)

    case (cc: runtime.CaseClass, ct: ClassType) =>
      val fields = cc.productElements()

      // identify case class type of ct
      val cct = ct match {
        case cc: CaseClassType =>
          cc

        case _ =>
          jvmClassToLeonClass(cc.getClass.getName) match {
            case Some(cc: CaseClassDef) =>
              CaseClassType(cc, ct.tps)
            case _ =>
              throw CompilationException("Unable to identify class "+cc.getClass.getName+" to descendant of "+ct)
        }
      }

      CaseClass(cct, (fields zip cct.fieldsTypes).map { case (e, tpe) => jvmToValue(e, tpe) })

    case (tpl: runtime.Tuple, tpe) =>
      val stpe = unwrapTupleType(tpe, tpl.getArity)
      val elems = stpe.zipWithIndex.map { case (tpe, i) => 
        jvmToValue(tpl.get(i), tpe)
      }
      tupleWrap(elems)

    case (gv @ GenericValue(gtp, id), tp: TypeParameter) =>
      if (gtp == tp) gv
      else GenericValue(tp, id).copiedFrom(gv)

    case (set: runtime.Set, SetType(b)) =>
      FiniteSet(set.getElements.asScala.map(jvmToValue(_, b)).toSet, b)

    case (map: runtime.Map, MapType(from, to)) =>
      val pairs = map.getElements.asScala.map { entry =>
        val k = jvmToValue(entry.getKey, from)
        val v = jvmToValue(entry.getValue, to)
        (k, v)
      }
      FiniteMap(pairs.toSeq, from, to)

    case (lambda: runtime.Lambda, _: FunctionType) =>
      val cls = lambda.getClass

      val l = classToLambda(cls.getName)
      val closures = purescala.ExprOps.variablesOf(l).toSeq.sortBy(_.uniqueName)
      val closureVals = closures.map { id =>
        val fieldVal = lambda.getClass.getField(id.name).get(lambda)
        jvmToValue(fieldVal, id.getType)
      }

      purescala.ExprOps.replaceFromIDs((closures zip closureVals).toMap, l)

    case (_, UnitType) =>
      UnitLiteral()

    case (ar: Array[_], ArrayType(base)) =>
      if (ar.length == 0) {
        EmptyArray(base)
      } else {
        val elems = for ((e: AnyRef, i) <- ar.zipWithIndex) yield {
          i -> jvmToValue(e, base)
        }

        NonemptyArray(elems.toMap, None)
      }

    case _ =>
      throw CompilationException("Unsupported return value : " + e.getClass +" while expecting "+tpe)
  }


  def compileExpression(e: Expr, args: Seq[Identifier])(implicit ctx: LeonContext): CompiledExpression = {
    if(e.getType == Untyped) {
      throw new Unsupported(e, s"Cannot compile untyped expression.")
    }

    val id = exprCounter.nextGlobal

    val cName = "Leon$CodeGen$Expr$"+id

    val cf = new ClassFile(cName, None)
    cf.setFlags((
      CLASS_ACC_PUBLIC |
      CLASS_ACC_FINAL
    ).asInstanceOf[U2])

    cf.addDefaultConstructor

    val argsTypes = args.map(a => typeToJVM(a.getType))

    val realArgs = if (requireMonitor) {
      ("L" + MonitorClass + ";") +: argsTypes
    } else {
      argsTypes
    }

    val m = cf.addMethod(
      typeToJVM(e.getType),
      "eval",
      realArgs : _*
    )

    m.setFlags((
      METHOD_ACC_PUBLIC |
      METHOD_ACC_FINAL |
      METHOD_ACC_STATIC
    ).asInstanceOf[U2])

    val ch = m.codeHandler

    val newMapping = if (requireMonitor) {
      args.zipWithIndex.toMap.mapValues(_ + 1) + (monitorID -> 0)
    } else {
      args.zipWithIndex.toMap
    }

    mkExpr(e, ch)(Locals(newMapping, Map.empty, Map.empty, isStatic = true))

    e.getType match {
      case ValueType() =>
        ch << IRETURN
      case _ =>
        ch << ARETURN
    }

    ch.freeze

    loader.register(cf)

    new CompiledExpression(this, cf, e, args)
  }

  def compileModule(module: ModuleDef) {
    val cf = classes(module)
    cf.setFlags((
      CLASS_ACC_SUPER |
      CLASS_ACC_PUBLIC |
      CLASS_ACC_FINAL
    ).asInstanceOf[U2])

    val (fields, functions) = module.definedFunctions partition { _.canBeField }
    val (strictFields, lazyFields) = fields partition { _.canBeStrictField }
    
    // Compile methods
    for (function <- functions) {
      compileFunDef(function,module)
    }
    
    // Compile lazy fields
    for (lzy <- lazyFields) {
      compileLazyField(lzy, module)
    }
    
    // Compile strict fields
    for (field <- strictFields) {
      compileStrictField(field, module)
    }
 
    // Constructor
    cf.addDefaultConstructor
    
    val cName = defToJVMName(module)
    
    // Add class initializer method
    locally{ 
      val mh = cf.addMethod("V", "<clinit>")
      mh.setFlags((
        METHOD_ACC_STATIC | 
        METHOD_ACC_PUBLIC
      ).asInstanceOf[U2])
      
      val ch = mh.codeHandler
      /*
       * FIXME :
       * Dirty hack to make this compatible with monitoring of method invocations.
       * Because we don't have access to the monitor object here, we initialize a new one 
       * that will get lost when this method returns, so we can't hope to count 
       * method invocations here :( 
       */
      ch << New(MonitorClass) << DUP
      ch << Ldc(Int.MaxValue) // Allow "infinite" method calls
      ch << InvokeSpecial(MonitorClass, cafebabe.Defaults.constructorName, "(I)V")
      ch << AStore(ch.getFreshVar) // position 0
      for (lzy <- lazyFields) { initLazyField(ch, cName, lzy, isStatic = true)}
      for (field <- strictFields) { initStrictField(ch, cName , field, isStatic = true)}
      ch  << RETURN
      ch.freeze
    }
  
  }

  /** Traverses the program to find all definitions, and stores those in global variables */
  def init() {
    // First define all classes/ methods/ functions
    for (u <- program.units) {

      for {
        ch  <- u.classHierarchies
        cls <- ch
      } {
        defineClass(cls)
        for (meth <- cls.methods) {
          defToModuleOrClass += meth -> cls
        }
      }
     
      for (m <- u.modules) {
        defineClass(m)
        for(funDef <- m.definedFunctions) {
          defToModuleOrClass += funDef -> m
        }
      }
    }
  }

  /** Compiles the program. 
    *
    * Uses information provided by [[init]].
    */
  def compile() {
    // Compile everything
    for (u <- program.units) {
      
      for {
        ch <- u.classHierarchies
        c  <- ch
      } {
        c match {
          case acd: AbstractClassDef =>
            compileAbstractClassDef(acd)
          case ccd: CaseClassDef =>
            compileCaseClassDef(ccd)
        }
      }

      for (m <- u.modules) compileModule(m)
    }

    classes.values.foreach(loader.register)
  }

  def writeClassFiles(prefix: String) {
    for ((d, cl) <- classes) {
      cl.writeToFile(prefix+cl.className + ".class")
    }
  }

  init()
  compile()
}
private [codegen] object exprCounter extends UniqueCounter[Unit]
private [codegen] object lambdaCounter extends UniqueCounter[Unit]