/* Copyright 2009-2015 EPFL, Lausanne */

package leon
package codegen

import purescala.Common._
import purescala.Definitions._
import purescala.Expressions._
import purescala.Types._
import purescala.Extractors._
import purescala.Constructors._
import codegen.runtime.LeonCodeGenRuntimeMonitor

import cafebabe._
import cafebabe.AbstractByteCodes._
import cafebabe.ByteCodes._
import cafebabe.ClassFileTypes._
import cafebabe.Flags._

import scala.collection.JavaConverters._

import java.lang.reflect.Constructor


class CompilationUnit(val ctx: LeonContext,
                      val program: Program,
                      val params: CodeGenParams = CodeGenParams.default) extends CodeGeneration {

  val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader)

  var classes     = Map[Definition, ClassFile]()
  var defToModuleOrClass = Map[Definition, Definition]()
  
  def defineClass(df: Definition) {
    val cName = defToJVMName(df)

    val cf = df match {
      case ccd: CaseClassDef =>
        val pName = ccd.parent.map(parent => defToJVMName(parent.classDef))
        new ClassFile(cName, pName)

      case acd: AbstractClassDef =>
        new ClassFile(cName, None)

      case ob: ModuleDef =>
        new ClassFile(cName, None)

      case _ =>
        sys.error("Unhandled definition type")
    }

    classes += df -> cf
  }

  def jvmClassToLeonClass(name: String): Option[Definition] = {
    classes.find(_._2.className == name).map(_._1)
  }

  def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = {
    classes.get(cd) match {
      case Some(cf) =>
        val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else ""
        val sig = "(" + monitorType + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V"
        Some((cf.className, sig))
      case _ => None
    }
  }

  // Returns className, methodName, methodSignature
  private[this] var funDefInfo = Map[FunDef, (String, String, String)]()

  
  /**
   * Returns (cn, mn, sig) where 
   *  cn is the module name
   *  mn is the safe method name
   *  sig is the method signature
   */
  def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = {
    funDefInfo.get(fd).orElse {
      val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else ""

      val sig = "(" + monitorType + fd.params.map(a => typeToJVM(a.getType)).mkString("") + ")" + typeToJVM(fd.returnType)

      defToModuleOrClass.get(fd).flatMap(m => classes.get(m)) match {
        case Some(cf) =>
          val res = (cf.className, idToSafeJVMName(fd.id), sig)
          funDefInfo += fd -> res
          Some(res)
        case None =>
          None
      }
    }
  }

  // Get the Java constructor corresponding to the Case class
  private[this] var ccdConstructors = Map[CaseClassDef, Constructor[_]]()

  private[this] def caseClassConstructor(ccd: CaseClassDef): Option[Constructor[_]] = {
    ccdConstructors.get(ccd).orElse {
      classes.get(ccd) match {
        case Some(cf) =>
          val klass = loader.loadClass(cf.className)
          // This is a hack: we pick the constructor with the most arguments.
          val conss = klass.getConstructors.sortBy(_.getParameterTypes.length)
          assert(conss.nonEmpty)
          val cons = conss.last

          ccdConstructors += ccd -> cons
          Some(cons)
        case None =>
          None
      }
    }
  }

  private[this] lazy val tupleConstructor: Constructor[_] = {
    val tc = loader.loadClass("leon.codegen.runtime.Tuple")
    val conss = tc.getConstructors.sortBy(_.getParameterTypes.length)
    assert(conss.nonEmpty)
    conss.last
  }

  // Currently, this method is only used to prepare arguments to reflective calls.
  // This means it is safe to return AnyRef (as opposed to primitive types), because
  // reflection needs this anyway.
  def exprToJVM(e: Expr)(implicit monitor : LeonCodeGenRuntimeMonitor): AnyRef = e match {
    case IntLiteral(v) =>
      new java.lang.Integer(v)

    case InfiniteIntegerLiteral(v) =>
      new leon.codegen.runtime.BigInt(v.toString)

    case BooleanLiteral(v) =>
      new java.lang.Boolean(v)

    case GenericValue(tp, id) =>
      e

    case Tuple(elems) =>
      tupleConstructor.newInstance(elems.map(exprToJVM).toArray).asInstanceOf[AnyRef]

    case CaseClass(cct, args) =>
      caseClassConstructor(cct.classDef) match {
        case Some(cons) =>
          val realArgs = if (params.requireMonitor) monitor +: args.map(exprToJVM) else  args.map(exprToJVM)
          cons.newInstance(realArgs.toArray : _*).asInstanceOf[AnyRef]
        case None =>
          ctx.reporter.fatalError("Case class constructor not found?!?")
      }

    // For now, we only treat boolean arrays separately.
    // We have a use for these, mind you.
    //case f @ FiniteArray(exprs) if f.getType == ArrayType(BooleanType) =>
    //  exprs.map(e => exprToJVM(e).asInstanceOf[java.lang.Boolean].booleanValue).toArray

    case s @ FiniteSet(els, _) =>
      val s = new leon.codegen.runtime.Set()
      for (e <- els) {
        s.add(exprToJVM(e))
      }
      s

    case m @ FiniteMap(els, _, _) =>
      val m = new leon.codegen.runtime.Map()
      for ((k,v) <- els) {
        m.add(exprToJVM(k), exprToJVM(v))
      }
      m

    case f @ purescala.Extractors.FiniteLambda(dflt, els) =>
      val l = new leon.codegen.runtime.FiniteLambda(exprToJVM(dflt))

      for ((k,v) <- els) {
        val ks = unwrapTuple(k, f.getType.asInstanceOf[FunctionType].from.size)
        // Force tuple even with 1/0 elems.
        val kJvm = tupleConstructor.newInstance(ks.map(exprToJVM).toArray).asInstanceOf[leon.codegen.runtime.Tuple]
        val vJvm = exprToJVM(v)
        l.add(kJvm,vJvm)
      }
      l

    // Just slightly overkill...
    case _ =>
      compileExpression(e, Seq()).evalToJVM(Seq(),monitor)
  }

  // Note that this may produce untyped expressions! (typically: sets, maps)
  def jvmToExpr(e: AnyRef, tpe: TypeTree): Expr = (e, tpe) match {
    case (i: Integer, Int32Type) =>
      IntLiteral(i.toInt)

    case (c: runtime.BigInt, IntegerType) =>
      InfiniteIntegerLiteral(BigInt(c.underlying))

    case (b: java.lang.Boolean, BooleanType) =>
      BooleanLiteral(b.booleanValue)

    case (c: java.lang.Character, CharType) =>
      CharLiteral(c.toChar)

    case (cc: runtime.CaseClass, ct: ClassType) =>
      val fields = cc.productElements()

      // identify case class type of ct
      val cct = ct match {
        case cc: CaseClassType =>
          cc

        case _ =>
          jvmClassToLeonClass(cc.getClass.getName) match {
            case Some(cc: CaseClassDef) =>
              CaseClassType(cc, ct.tps)
            case _ =>
              throw CompilationException("Unable to identify class "+cc.getClass.getName+" to descendent of "+ct)
        }
      }

      CaseClass(cct, (fields zip cct.fieldsTypes).map { case (e, tpe) => jvmToExpr(e, tpe) })

    case (tpl: runtime.Tuple, tpe) =>
      val stpe = unwrapTupleType(tpe, tpl.getArity)
      val elems = stpe.zipWithIndex.map { case (tpe, i) => 
        jvmToExpr(tpl.get(i), tpe)
      }
      tupleWrap(elems)

    case (gv @ GenericValue(gtp, id), tp: TypeParameter) =>
      if (gtp == tp) gv
      else GenericValue(tp, id).copiedFrom(gv)

    case (set : runtime.Set, SetType(b)) =>
      finiteSet(set.getElements.asScala.map(jvmToExpr(_, b)).toSet, b)

    case (map : runtime.Map, MapType(from, to)) =>
      val pairs = map.getElements.asScala.map { entry =>
        val k = jvmToExpr(entry.getKey, from)
        val v = jvmToExpr(entry.getValue, to)
        (k, v)
      }
      finiteMap(pairs.toSeq, from, to)

    case (_, UnitType) =>
      UnitLiteral()

    case _ =>
      throw CompilationException("Unsupported return value : " + e.getClass +" while expecting "+tpe)
  }

  var compiledN = 0

  def compileExpression(e: Expr, args: Seq[Identifier]): CompiledExpression = {
    if(e.getType == Untyped) {
      throw new IllegalArgumentException(s"Cannot compile untyped expression [$e].")
    }

    compiledN += 1

    val id = CompilationUnit.nextExprId

    val cName = "Leon$CodeGen$Expr$"+id

    val cf = new ClassFile(cName, None)
    cf.setFlags((
      CLASS_ACC_PUBLIC |
      CLASS_ACC_FINAL
    ).asInstanceOf[U2])

    cf.addDefaultConstructor

    val argsTypes = args.map(a => typeToJVM(a.getType))

    val realArgs = if (params.requireMonitor) {
      ("L" + MonitorClass + ";") +: argsTypes
    } else {
      argsTypes
    }

    val m = cf.addMethod(
      typeToJVM(e.getType),
      "eval",
      realArgs : _*
    )

    m.setFlags((
      METHOD_ACC_PUBLIC |
      METHOD_ACC_FINAL |
      METHOD_ACC_STATIC
    ).asInstanceOf[U2])

    val ch = m.codeHandler

    val newMapping = if (params.requireMonitor) {
        args.zipWithIndex.toMap.mapValues(_ + 1)
      } else {
        args.zipWithIndex.toMap
      }

    mkExpr(e, ch)(Locals(newMapping, Map.empty, Map.empty, true))

    e.getType match {
      case Int32Type | BooleanType | UnitType =>
        ch << IRETURN

      case IntegerType | _: TupleType  | _: SetType | _: MapType | _: AbstractClassType | _: CaseClassType | _: ArrayType | _: FunctionType | _: TypeParameter =>
        ch << ARETURN

      case other =>
        throw CompilationException("Unsupported return type : " + other)
    }

    ch.freeze

    loader.register(cf)

    new CompiledExpression(this, cf, e, args)
  }

  def compileModule(module: ModuleDef) {
    val cf = classes(module)
    cf.setFlags((
      CLASS_ACC_SUPER |
      CLASS_ACC_PUBLIC |
      CLASS_ACC_FINAL
    ).asInstanceOf[U2])
    
    /*if (false) {
      // currently we do not handle object fields 
      // this treats all fields as functions
      for (fun <- module.definedFunctions) {
        compileFunDef(fun, module)
      }
    } else {*/
    
    val (fields, functions) = module.definedFunctions partition { _.canBeField }
    val (strictFields, lazyFields) = fields partition { _.canBeStrictField }
    
    // Compile methods
    for (function <- functions) {
      compileFunDef(function,module)
    }
    
    // Compile lazy fields
    for (lzy <- lazyFields) {
      compileLazyField(lzy, module)
    }
    
    // Compile strict fields
    for (field <- strictFields) {
      compileStrictField(field, module)
    }
 
    // Constructor
    cf.addDefaultConstructor
    
    val cName = defToJVMName(module)
    
    // Add class initializer method
    locally{ 
      val mh = cf.addMethod("V", "<clinit>")
      mh.setFlags((
        METHOD_ACC_STATIC | 
        METHOD_ACC_PUBLIC
      ).asInstanceOf[U2])
      
      val ch = mh.codeHandler
      /*
       * FIXME :
       * Dirty hack to make this compatible with monitoring of method invocations.
       * Because we don't have access to the monitor object here, we initialize a new one 
       * that will get lost when this method returns, so we can't hope to count 
       * method invocations here :( 
       */
      ch << New(MonitorClass) << DUP
      ch << Ldc(Int.MaxValue) // Allow "infinite" method calls
      ch << InvokeSpecial(MonitorClass, cafebabe.Defaults.constructorName, "(I)V")
      ch << AStore(ch.getFreshVar) // position 0
      for (lzy <- lazyFields) { initLazyField(ch, cName, lzy, true)}  
      for (field <- strictFields) { initStrictField(ch, cName , field, true)}
      ch  << RETURN
      ch.freeze
    }
      
    

  
  }

  /** Traverses the program to find all definitions, and stores those in global variables */
  def init() {
    // First define all classes/ methods/ functions
    for (u <- program.units; m <- u.modules) {
      val (parents, children) = m.algebraicDataTypes.toSeq.unzip
      for ( cls <- parents ++ children.flatten ++ m.singleCaseClasses) {
        defineClass(cls)
        for (meth <- cls.methods) {
          defToModuleOrClass += meth -> cls
        }
      }
     
      defineClass(m)
      for(funDef <- m.definedFunctions) {
        defToModuleOrClass += funDef -> m
      }
      
    }
  }

  /** Compiles the program. Uses information provided by $init */
  def compile() {
    // Compile everything
    for (u <- program.units) {
      
      for ((parent, children) <- u.algebraicDataTypes) {
        compileAbstractClassDef(parent)

        for (c <- children) {
          compileCaseClassDef(c)
        }
      }

      for(single <- u.singleCaseClasses) {
        compileCaseClassDef(single)
      }

      for (m <- u.modules) compileModule(m)
      
    }

    classes.values.foreach(loader.register)
  }

  def writeClassFiles(prefix: String) {
    for ((d, cl) <- classes) {
      cl.writeToFile(prefix+cl.className + ".class")
    }
  }

  init()
  compile()
}

object CompilationUnit {
  private var _nextExprId = 0
  private[codegen] def nextExprId = synchronized {
    _nextExprId += 1
    _nextExprId
  }

  private var _nextLambdaId = 0
  private[codegen] def nextLambdaId = synchronized {
    _nextLambdaId += 1
    _nextLambdaId
  }
}