package leon
package codegen

import purescala.Common._
import purescala.Definitions._
import purescala.Trees._
import purescala.TypeTrees._

import cafebabe._
import cafebabe.AbstractByteCodes._
import cafebabe.ByteCodes._
import cafebabe.ClassFileTypes._
import cafebabe.Defaults.constructorName
import cafebabe.Flags._

object CodeGeneration {
  private val BoxedIntClass  = "java/lang/Integer"
  private val BoxedBoolClass = "java/lang/Boolean"

  private val TupleClass     = "leon/codegen/runtime/Tuple"
  private val SetClass       = "leon/codegen/runtime/Set"
  private val MapClass       = "leon/codegen/runtime/Map"
  private val CaseClassClass = "leon/codegen/runtime/CaseClass"
  private val ErrorClass     = "leon/codegen/runtime/LeonCodeGenRuntimeException"
  private val ImpossibleEvaluationClass = "leon/codegen/runtime/LeonCodeGenEvaluationException"

  def defToJVMName(p : Program, d : Definition) : String = "Leon$CodeGen$" + d.id.uniqueName

  def typeToJVM(tpe : TypeTree)(implicit env : CompilationEnvironment) : String = tpe match {
    case Int32Type => "I"

    case BooleanType => "Z"

    case UnitType => "Z"

    case c : ClassType =>
      env.classDefToClass(c.classDef).map(n => "L" + n + ";").getOrElse("Unsupported class " + c.id)

    case _ : TupleType =>
      "L" + TupleClass + ";"

    case _ : SetType =>
      "L" + SetClass + ";"

    case _ : MapType =>
      "L" + MapClass + ";"

    case ArrayType(base) =>
      "[" + typeToJVM(base)

    case _ => throw CompilationException("Unsupported type : " + tpe)
  }

  // Assumes the CodeHandler has never received any bytecode.
  // Generates method body, and freezes the handler at the end.
  def compileFunDef(funDef : FunDef, ch : CodeHandler)(implicit env : CompilationEnvironment) {
    val newMapping = funDef.args.map(_.id).zipWithIndex.toMap

    val bodyWithPre = if(funDef.hasPrecondition) {
      IfExpr(funDef.precondition.get, funDef.getBody, Error("Precondition failed"))
    } else {
      funDef.getBody
    }

    val bodyWithPost = if(funDef.hasPostcondition) {
      val freshResID = FreshIdentifier("result").setType(funDef.returnType)
      val post = purescala.TreeOps.replace(Map(ResultVariable() -> Variable(freshResID)), funDef.postcondition.get)
      Let(freshResID, bodyWithPre, IfExpr(post, Variable(freshResID), Error("Postcondition failed")) )
    } else {
      bodyWithPre
    }

    val exprToCompile = purescala.TreeOps.matchToIfThenElse(bodyWithPost)

    mkExpr(exprToCompile, ch)(env.withVars(newMapping))

    funDef.returnType match {
      case Int32Type | BooleanType | UnitType =>
        ch << IRETURN

      case _ : ClassType | _ : TupleType | _ : SetType | _ : MapType | _ : ArrayType =>
        ch << ARETURN

      case other =>
        throw CompilationException("Unsupported return type : " + other.getClass)
    }

    ch.freeze
  }

  private[codegen] def mkExpr(e : Expr, ch : CodeHandler, canDelegateToMkBranch : Boolean = true)(implicit env : CompilationEnvironment) {
    e match {
      case Variable(id) =>
        val slot = slotFor(id)
        val instr = id.getType match {
          case Int32Type | BooleanType | UnitType => ILoad(slot)
          case _ => ALoad(slot)
        }
        ch << instr

      case Let(i,d,b) =>
        mkExpr(d, ch)
        val slot = ch.getFreshVar
        val instr = i.getType match {
          case Int32Type | BooleanType | UnitType => IStore(slot)
          case _ => AStore(slot)
        }
        ch << instr
        mkExpr(b, ch)(env.withVars(Map(i -> slot)))

      case LetTuple(is,d,b) =>
        mkExpr(d, ch) // the tuple
        var count = 0
        val withSlots = is.map(i => (i, ch.getFreshVar))
        for((i,s) <- withSlots) {
          ch << DUP
          ch << Ldc(count)
          ch << InvokeVirtual(TupleClass, "get", "(I)Ljava/lang/Object;")
          mkUnbox(i.getType, ch)
          val instr = i.getType match {
            case Int32Type | BooleanType | UnitType => IStore(s)
            case _ => AStore(s)
          }
          ch << instr
          count += 1
        }
        mkExpr(b, ch)(env.withVars(withSlots.toMap))

      case IntLiteral(v) =>
        ch << Ldc(v)

      case BooleanLiteral(v) =>
        ch << Ldc(if(v) 1 else 0)

      case UnitLiteral =>
        ch << Ldc(1)

      // Case classes
      case CaseClass(ccd, as) =>
        val ccName = env.classDefToClass(ccd).getOrElse {
          throw CompilationException("Unknown class : " + ccd.id)
        }
        // TODO FIXME It's a little ugly that we do it each time. Could be in env.
        val consSig = "(" + ccd.fields.map(f => typeToJVM(f.tpe)).mkString("") + ")V"
        ch << New(ccName) << DUP
        for(a <- as) {
          mkExpr(a, ch)
        }
        ch << InvokeSpecial(ccName, constructorName, consSig)

      case CaseClassInstanceOf(ccd, e) =>
        val ccName = env.classDefToClass(ccd).getOrElse {
          throw CompilationException("Unknown class : " + ccd.id)
        }
        mkExpr(e, ch)
        ch << InstanceOf(ccName)

      case CaseClassSelector(ccd, e, sid) =>
        mkExpr(e, ch)
        val ccName = env.classDefToClass(ccd).getOrElse {
          throw CompilationException("Unknown class : " + ccd.id)
        }
        ch << CheckCast(ccName)
        ch << GetField(ccName, sid.name, typeToJVM(sid.getType))

      // Tuples (note that instanceOf checks are in mkBranch)
      case Tuple(es) =>
        ch << New(TupleClass) << DUP
        ch << Ldc(es.size)
        ch << NewArray("java/lang/Object")
        for((e,i) <- es.zipWithIndex) {
          ch << DUP
          ch << Ldc(i)
          mkBoxedExpr(e, ch)
          ch << AASTORE
        }
        ch << InvokeSpecial(TupleClass, constructorName, "([Ljava/lang/Object;)V")

      case TupleSelect(t, i) =>
        val TupleType(bs) = t.getType
        mkExpr(t,ch)
        ch << Ldc(i - 1)
        ch << InvokeVirtual(TupleClass, "get", "(I)Ljava/lang/Object;")
        mkUnbox(bs(i - 1), ch)

      // Sets
      case FiniteSet(es) =>
        ch << DefaultNew(SetClass)
        for(e <- es) {
          ch << DUP
          mkBoxedExpr(e, ch)
          ch << InvokeVirtual(SetClass, "add", "(Ljava/lang/Object;)V")
        }

      case ElementOfSet(e, s) =>
        mkExpr(s, ch)
        mkBoxedExpr(e, ch)
        ch << InvokeVirtual(SetClass, "contains", "(Ljava/lang/Object;)Z")

      case SetCardinality(s) =>
        mkExpr(s, ch)
        ch << InvokeVirtual(SetClass, "size", "()I")

      case SubsetOf(s1, s2) =>
        mkExpr(s1, ch)
        mkExpr(s2, ch)
        ch << InvokeVirtual(SetClass, "subsetOf", "(L%s;)Z".format(SetClass))

      case SetIntersection(s1, s2) =>
        mkExpr(s1, ch)
        mkExpr(s2, ch)
        ch << InvokeVirtual(SetClass, "intersect", "(L%s;)L%s;".format(SetClass,SetClass))

      case SetUnion(s1, s2) =>
        mkExpr(s1, ch)
        mkExpr(s2, ch)
        ch << InvokeVirtual(SetClass, "union", "(L%s;)L%s;".format(SetClass,SetClass))

      case SetDifference(s1, s2) =>
        mkExpr(s1, ch)
        mkExpr(s2, ch)
        ch << InvokeVirtual(SetClass, "minus", "(L%s;)L%s;".format(SetClass,SetClass))

      // Maps
      case FiniteMap(ss) =>
        ch << DefaultNew(MapClass)
        for((f,t) <- ss) {
          ch << DUP
          mkBoxedExpr(f, ch)
          mkBoxedExpr(t, ch)
          ch << InvokeVirtual(MapClass, "add", "(Ljava/lang/Object;Ljava/lang/Object;)V")
        }

      case MapGet(m, k) =>
        val MapType(_, tt) = m.getType
        mkExpr(m, ch)
        mkBoxedExpr(k, ch)
        ch << InvokeVirtual(MapClass, "get", "(Ljava/lang/Object;)Ljava/lang/Object;")
        mkUnbox(tt, ch)

      case MapIsDefinedAt(m, k) =>
        mkExpr(m, ch)
        mkBoxedExpr(k, ch)
        ch << InvokeVirtual(MapClass, "isDefinedAt", "(Ljava/lang/Object;)Z")

      case MapUnion(m1, m2) =>
        mkExpr(m1, ch)
        mkExpr(m2, ch)
        ch << InvokeVirtual(MapClass, "union", "(L%s;)L%s;".format(MapClass,MapClass))

      // Branching
      case IfExpr(c, t, e) =>
        val tl = ch.getFreshLabel("then")
        val el = ch.getFreshLabel("else")
        val al = ch.getFreshLabel("after")
        mkBranch(c, tl, el, ch)
        ch << Label(tl)
        mkExpr(t, ch)
        ch << Goto(al) << Label(el)
        mkExpr(e, ch)
        ch << Label(al)

      case FunctionInvocation(fd, as) =>
        val (cn, mn, ms) = env.funDefToMethod(fd).getOrElse {
          throw CompilationException("Unknown method : " + fd.id)
        }
        for(a <- as) {
          mkExpr(a, ch)
        }
        ch << InvokeStatic(cn, mn, ms)

      // Arithmetic
      case Plus(l, r) =>
        mkExpr(l, ch)
        mkExpr(r, ch)
        ch << IADD

      case Minus(l, r) =>
        mkExpr(l, ch)
        mkExpr(r, ch)
        ch << ISUB

      case Times(l, r) =>
        mkExpr(l, ch)
        mkExpr(r, ch)
        ch << IMUL

      case Division(l, r) =>
        mkExpr(l, ch)
        mkExpr(r, ch)
        ch << IDIV

      case Modulo(l, r) =>
        mkExpr(l, ch)
        mkExpr(r, ch)
        ch << IREM

      case UMinus(e) =>
        mkExpr(e, ch)
        ch << INEG

      case ArrayLength(a) =>
        mkExpr(a, ch)
        ch << ARRAYLENGTH

      case as @ ArraySelect(a,i) =>
        mkExpr(a, ch)
        mkExpr(i, ch)
        ch << (as.getType match {
          case Untyped => throw CompilationException("Cannot compile untyped array access.")
          case Int32Type => IALOAD
          case BooleanType => BALOAD
          case _ => AALOAD
        })

      case a @ FiniteArray(es) =>
        ch << Ldc(es.size)
        val storeInstr = a.getType match {
          case ArrayType(Int32Type) => ch << NewArray.primitive("T_INT"); IASTORE
          case ArrayType(BooleanType) => ch << NewArray.primitive("T_BOOLEAN"); BASTORE
          case ArrayType(other) => ch << NewArray(typeToJVM(other)); AASTORE
          case other => throw CompilationException("Cannot compile finite array expression whose type is %s.".format(other))
        }
        for((e,i) <- es.zipWithIndex) {
          ch << DUP << Ldc(i)
          mkExpr(e, ch) 
          ch << storeInstr
        }

      // Misc and boolean tests
      case Error(desc) =>
        ch << New(ErrorClass) << DUP
        ch << Ldc(desc)
        ch << InvokeSpecial(ErrorClass, constructorName, "(Ljava/lang/String;)V")
        ch << ATHROW

      case Choose(_, _) =>
        ch << New(ImpossibleEvaluationClass) << DUP
        ch << Ldc("Cannot execute choose.")
        ch << InvokeSpecial(ImpossibleEvaluationClass, constructorName, "(Ljava/lang/String;)V")
        ch << ATHROW

      case b if b.getType == BooleanType && canDelegateToMkBranch =>
        val fl = ch.getFreshLabel("boolfalse")
        val al = ch.getFreshLabel("boolafter")
        ch << Ldc(1)
        mkBranch(b, al, fl, ch, canDelegateToMkExpr = false)
        ch << Label(fl) << POP << Ldc(0) << Label(al)

      case _ => throw CompilationException("Unsupported expr. : " + e) 
    }
  }

  // Leaves on the stack a value equal to `e`, always of a type compatible with java.lang.Object.
  private[codegen] def mkBoxedExpr(e : Expr, ch : CodeHandler)(implicit env : CompilationEnvironment) {
    e.getType match {
      case Int32Type =>
        ch << New(BoxedIntClass) << DUP
        mkExpr(e, ch)
        ch << InvokeSpecial(BoxedIntClass, constructorName, "(I)V")

      case BooleanType | UnitType =>
        ch << New(BoxedBoolClass) << DUP
        mkExpr(e, ch)
        ch << InvokeSpecial(BoxedBoolClass, constructorName, "(Z)V")

      case _ =>
        mkExpr(e, ch)
    }
  }

  // Assumes the top of the stack contains of value of the right type, and makes it
  // compatible with java.lang.Object.
  private[codegen] def mkBox(tpe : TypeTree, ch : CodeHandler)(implicit env : CompilationEnvironment) {
    tpe match {
      case Int32Type =>
        ch << New(BoxedIntClass) << DUP_X1 << SWAP << InvokeSpecial(BoxedIntClass, constructorName, "(I)V")

      case BooleanType | UnitType =>
        ch << New(BoxedBoolClass) << DUP_X1 << SWAP << InvokeSpecial(BoxedBoolClass, constructorName, "(Z)V")

      case _ => 
    }
  }

  // Assumes that the top of the stack contains a value that should be of type `tpe`, and unboxes it to the right (JVM) type.
  private[codegen] def mkUnbox(tpe : TypeTree, ch : CodeHandler)(implicit env : CompilationEnvironment) {
    tpe match {
      case Int32Type =>
        ch << CheckCast(BoxedIntClass) << InvokeVirtual(BoxedIntClass, "intValue", "()I")

      case BooleanType | UnitType =>
        ch << CheckCast(BoxedBoolClass) << InvokeVirtual(BoxedBoolClass, "booleanValue", "()Z")

      case ct : ClassType =>
        val cn = env.classDefToClass(ct.classDef).getOrElse {
          throw new CompilationException("Unsupported class type : " + ct)
        }
        ch << CheckCast(cn)

      case tt : TupleType =>
        ch << CheckCast(TupleClass)

      case st : SetType =>
        ch << CheckCast(SetClass)

      case mt : MapType =>
        ch << CheckCast(MapClass)

      case _ =>
        throw new CompilationException("Unsupported type in unboxing : " + tpe)
    }
  }

  private[codegen] def mkBranch(cond : Expr, then : String, elze : String, ch : CodeHandler, canDelegateToMkExpr : Boolean = true)(implicit env : CompilationEnvironment) {
    cond match {
      case BooleanLiteral(true) =>
        ch << Goto(then)

      case BooleanLiteral(false) =>
        ch << Goto(elze)

      case And(es) =>
        val fl = ch.getFreshLabel("andnext")
        mkBranch(es.head, fl, elze, ch)
        ch << Label(fl)
        mkBranch(And(es.tail), then, elze, ch)

      case Or(es) =>
        val fl = ch.getFreshLabel("ornext")
        mkBranch(es.head, then, fl, ch)
        ch << Label(fl)
        mkBranch(Or(es.tail), then, elze, ch) 

      case Implies(l, r) =>
        mkBranch(Or(Not(l), r), then, elze, ch)

      case Not(c) =>
        mkBranch(c, elze, then, ch)

      case Variable(b) =>
        ch << ILoad(slotFor(b)) << IfEq(elze) << Goto(then)

      case Equals(l,r) =>
        mkExpr(l, ch)
        mkExpr(r, ch)
        l.getType match {
          case Int32Type | BooleanType | UnitType =>
            ch << If_ICmpEq(then) << Goto(elze)

          case _ =>
            ch << InvokeVirtual("java/lang/Object", "equals", "(Ljava/lang/Object;)Z")
            ch << IfEq(elze) << Goto(then)
        }

      case Iff(l,r) =>
        mkExpr(l, ch)
        mkExpr(r, ch)
        ch << If_ICmpEq(then) << Goto(elze)

      case LessThan(l,r) =>
        mkExpr(l, ch)
        mkExpr(r, ch)
        ch << If_ICmpLt(then) << Goto(elze) 

      case GreaterThan(l,r) =>
        mkExpr(l, ch)
        mkExpr(r, ch)
        ch << If_ICmpGt(then) << Goto(elze) 

      case LessEquals(l,r) =>
        mkExpr(l, ch)
        mkExpr(r, ch)
        ch << If_ICmpLe(then) << Goto(elze) 

      case GreaterEquals(l,r) =>
        mkExpr(l, ch)
        mkExpr(r, ch)
        ch << If_ICmpGe(then) << Goto(elze) 

      case other if canDelegateToMkExpr =>
        mkExpr(other, ch, canDelegateToMkBranch = false)
        ch << IfEq(elze) << Goto(then)

      case other => throw CompilationException("Unsupported branching expr. : " + other) 
    }
  }

  private[codegen] def slotFor(id : Identifier)(implicit env : CompilationEnvironment) : Int = {
    env.varToLocal(id).getOrElse {
      throw CompilationException("Unknown variable : " + id)
    }
  }

  def compileAbstractClassDef(p : Program, acd : AbstractClassDef)(implicit env : CompilationEnvironment) : ClassFile = {
    val cName = defToJVMName(p, acd)

    val cf  = new ClassFile(cName, None)
    cf.setFlags((
      CLASS_ACC_SUPER |
      CLASS_ACC_PUBLIC |
      CLASS_ACC_ABSTRACT
    ).asInstanceOf[U2])

    cf.addInterface(CaseClassClass)

    cf.addDefaultConstructor

    cf
  }

  def compileCaseClassDef(p : Program, ccd : CaseClassDef)(implicit env : CompilationEnvironment) : ClassFile = {

    val cName = defToJVMName(p, ccd)
    val pName = ccd.parent.map(parent => defToJVMName(p, parent))

    val cf = new ClassFile(cName, pName)
    cf.setFlags((
      CLASS_ACC_SUPER |
      CLASS_ACC_PUBLIC |
      CLASS_ACC_FINAL
    ).asInstanceOf[U2])

    if(ccd.parent.isEmpty) {
      cf.addInterface(CaseClassClass)
    }

    // definition of the constructor
    if(ccd.fields.isEmpty) {
      cf.addDefaultConstructor
    } else {
      val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.tpe)) }

      for((nme, jvmt) <- namesTypes) {
        val fh = cf.addField(jvmt, nme)
        fh.setFlags((
          FIELD_ACC_PUBLIC |
          FIELD_ACC_FINAL
        ).asInstanceOf[U2])
      }

      val cch = cf.addConstructor(namesTypes.map(_._2).toList).codeHandler

      cch << ALoad(0)
      cch << InvokeSpecial(pName.getOrElse("java/lang/Object"), constructorName, "()V")

      var c = 1
      for((nme, jvmt) <- namesTypes) {
        cch << ALoad(0)
        cch << (jvmt match {
          case "I" | "Z" => ILoad(c)
          case _ => ALoad(c)
        })
        cch << PutField(cName, nme, jvmt)
        c += 1
      }
      cch << RETURN
      cch.freeze
    }

    locally {
      val pnm = cf.addMethod("Ljava/lang/String;", "productName")
      pnm.setFlags((
        METHOD_ACC_PUBLIC |
        METHOD_ACC_FINAL
      ).asInstanceOf[U2])

      val pnch = pnm.codeHandler

      pnch << Ldc(cName) << ARETURN

      pnch.freeze
    }

    locally {
      val pem = cf.addMethod("[Ljava/lang/Object;", "productElements")
      pem.setFlags((
        METHOD_ACC_PUBLIC |
        METHOD_ACC_FINAL
      ).asInstanceOf[U2])

      val pech = pem.codeHandler

      pech << Ldc(ccd.fields.size)
      pech << NewArray("java/lang/Object")

      for ((f, i) <- ccd.fields.zipWithIndex) {
        pech << DUP
        pech << Ldc(i)
        pech << ALoad(0)
        pech << GetField(cName, f.id.name, typeToJVM(f.tpe))
        mkBox(f.tpe, pech)
        pech << AASTORE
      }

      pech << ARETURN
      pech.freeze
    }

    // definition of equals
    locally {
      val emh = cf.addMethod("Z", "equals", "Ljava/lang/Object;")
      emh.setFlags((
        METHOD_ACC_PUBLIC |
        METHOD_ACC_FINAL
      ).asInstanceOf[U2])

      val ech = emh.codeHandler

      val notRefEq = ech.getFreshLabel("notrefeq")
      val notEq = ech.getFreshLabel("noteq")
      val castSlot = ech.getFreshVar

      // If references are equal, trees are equal.
      ech << ALoad(0) << ALoad(1) << If_ACmpNe(notRefEq) << Ldc(1) << IRETURN << Label(notRefEq)

      // We check the type (this also checks against null)....
      ech << ALoad(1) << InstanceOf(cName) << IfEq(notEq)

      // ...finally, we compare fields one by one, shortcircuiting on disequalities.
      if(!ccd.fields.isEmpty) {
        ech << ALoad(1) << CheckCast(cName) << AStore(castSlot)

        val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.tpe)) }
        
        for((nme, jvmt) <- namesTypes) {
          ech << ALoad(0) << GetField(cName, nme, jvmt)
          ech << ALoad(castSlot) << GetField(cName, nme, jvmt)

          jvmt match {
            case "I" | "Z" =>
              ech << If_ICmpNe(notEq)

            case ot =>
              ech << InvokeVirtual("java/lang/Object", "equals", "(Ljava/lang/Object;)Z") << IfEq(notEq)
          }
        }
      } 

      ech << Ldc(1) << IRETURN << Label(notEq) << Ldc(0) << IRETURN
      ech.freeze
    }

    // definition of hashcode
    locally {
      val hmh = cf.addMethod("I", "hashCode", "")
      hmh.setFlags((
        METHOD_ACC_PUBLIC |
        METHOD_ACC_FINAL
      ).asInstanceOf[U2])

      val hch = hmh.codeHandler
      // TODO FIXME. Look at Scala for inspiration.
      hch << Ldc(42) << IRETURN
      hch.freeze
    }

    cf
  }
}