package leon
package solvers
package smtlib

import utils.Interruptible
import purescala._
import Common._
import Trees.{Assert => _, _}
import Extractors._
import TreeOps._
import TypeTrees._
import Definitions._
import utils.Bijection

import _root_.smtlib.common._
import _root_.smtlib.printer.{RecursivePrinter => SMTPrinter}
import _root_.smtlib.parser.Commands.{Constructor => SMTConstructor, _}
import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Let => SMTLet, _}
import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _}
import _root_.smtlib.theories._
import _root_.smtlib.{Interpreter => SMTInterpreter}

trait SMTLIBTarget {
  this: SMTLIBSolver =>

  val reporter = context.reporter

  def targetName: String
  def getNewInterpreter(): SMTInterpreter

  val interpreter = getNewInterpreter()

  var out: java.io.FileWriter = _

  reporter.ifDebug { debug =>
    val file = context.files.headOption.map(_.getName).getOrElse("NA")  
    val n    = VCNumbers.getNext(targetName+file)

    val dir = new java.io.File("vcs");

    if (!dir.isDirectory) {
      dir.mkdir
    }

    out = new java.io.FileWriter(s"vcs/$targetName-$file-$n.smt2", true)
  }

  def id2sym(id: Identifier): SSymbol = SSymbol(id.name+"!"+id.globalId)

  // metadata for CC, and variables
  val constructors = new Bijection[TypeTree, SSymbol]()
  val selectors    = new Bijection[(TypeTree, Int), SSymbol]()
  val testers      = new Bijection[TypeTree, SSymbol]()
  val variables    = new Bijection[Identifier, SSymbol]()
  val sorts        = new Bijection[TypeTree, Sort]()
  val functions    = new Bijection[TypedFunDef, SSymbol]()

  def normalizeType(t: TypeTree): TypeTree = t match {
    case ct: ClassType if ct.parent.isDefined => ct.parent.get
    case tt: TupleType => TupleType(tt.bases.map(normalizeType))
    case _ =>   t
  }

  // Corresponds to a smt map, not a leon/scala array
  // Should NEVER escape past SMT-world
  case class RawArrayType(from: TypeTree, to: TypeTree) extends TypeTree

  // Corresponds to a raw array value, which is coerced to a Leon expr depending on target type (set/array)
  // Should NEVER escape past SMT-world
  case class RawArrayValue(keyTpe: TypeTree, elems: Map[Expr, Expr], default: Expr) extends Expr

  def fromRawArray(r: RawArrayValue, tpe: TypeTree): Expr = tpe match {
    case SetType(base) =>
      assert(r.default == BooleanLiteral(false) && r.keyTpe == base)

      FiniteSet(r.elems.keySet).setType(tpe)

    case RawArrayType(from, to) =>
      r

    case _ =>
      unsupported("Unable to extract from raw array for "+tpe)
  }

  def unsupported(str: Any) = reporter.fatalError(s"Unsupported in smt-$targetName: $str")

  def declareSort(t: TypeTree): Sort = {
    val tpe = normalizeType(t)
    sorts.cachedB(tpe) {
      tpe match {
        case BooleanType => Core.BoolSort()
        case Int32Type => Ints.IntSort()
        case CharType => FixedSizeBitVectors.BitVectorSort(32)

        case RawArrayType(from, to) =>
          Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(from), declareSort(to)))

        case MapType(from, to) =>
          declareMapSort(from, to)

        case TypeParameter(id) =>
          val s = id2sym(id)
          val cmd = DeclareSort(s, 0)
          sendCommand(cmd)
          Sort(SMTIdentifier(s))

        case _: ClassType | _: TupleType | _: ArrayType | UnitType =>
          declareStructuralSort(tpe)

        case _ =>
          unsupported("Sort "+t)
      }
    }
  }

  var mapSort: Option[SSymbol] = None
  var optionSort: Option[SSymbol] = None

  def declareOptionSort(of: TypeTree): Sort = {
    optionSort match {
      case None =>
        val t      = SSymbol("T")

        val s      = SSymbol("Option")
        val some   = SSymbol("Some")
        val some_v = SSymbol("Some_v")
        val none   = SSymbol("None")

        val caseSome = SList(some, SList(some_v, t))
        val caseNone = SList(none)

        val cmd = NonStandardCommand(SList(SSymbol("declare-datatypes"), SList(t), SList(SList(s, caseSome, caseNone))))
        sendCommand(cmd)

        optionSort = Some(s)
      case _ =>
    }

    Sort(SMTIdentifier(optionSort.get), Seq(declareSort(of)))
  }

  def declareMapSort(from: TypeTree, to: TypeTree): Sort = {
    mapSort match {
      case None =>
        val m = SSymbol("Map")
        val a = SSymbol("A")
        val b = SSymbol("B")
        mapSort = Some(m)

        val optSort = declareOptionSort(to)

        val arraySort = Sort(SMTIdentifier(SSymbol("Array")),
                             Seq(Sort(SMTIdentifier(a)), optSort))

        val cmd = DefineSort(m, Seq(a, b), arraySort)
        sendCommand(cmd)
      case _ =>
    }

    Sort(SMTIdentifier(mapSort.get), Seq(declareSort(from), declareSort(to)))
  }



  def freshSym(id: Identifier): SSymbol = freshSym(id.name)
  def freshSym(name: String): SSymbol = id2sym(FreshIdentifier(name))

  def getHierarchy(ct: ClassType): (ClassType, Seq[CaseClassType]) = ct match {
    case act: AbstractClassType =>
      (act, act.knownCCDescendents)
    case cct: CaseClassType =>
      cct.parent match {
        case Some(p) =>
          getHierarchy(p)
        case None =>
          (cct, List(cct))
      }
  }


  case class DataType(sym: SSymbol, cases: Seq[Constructor])
  case class Constructor(sym: SSymbol, tpe: TypeTree, fields: Seq[(SSymbol, TypeTree)])

  def findDependencies(t: TypeTree, dts: Map[TypeTree, DataType] = Map()): Map[TypeTree, DataType] = t match {
    case ct: ClassType =>
      val (root, sub) = getHierarchy(ct)

      if (!(dts contains root) && !(sorts containsA root)) {
        val sym = freshSym(ct.id)

        val conss = sub.map { case cct =>
          Constructor(freshSym(cct.id), cct, cct.fields.map(vd => (freshSym(vd.id), vd.tpe)))
        }

        var cdts = dts + (root -> DataType(sym, conss))

        // look for dependencies
        for (ct <- root +: sub; f <- ct.fields) {
          cdts ++= findDependencies(f.tpe, cdts)
        }

        cdts
      } else {
        dts
      }

    case tt @ TupleType(bases) =>
      if (!(dts contains t) && !(sorts containsA t)) {
        val sym = freshSym("tuple"+bases.size)

        val c = Constructor(freshSym(sym.name), tt, bases.zipWithIndex.map {
          case (tpe, i) => (freshSym("_"+(i+1)), tpe)
        })

        var cdts = dts + (tt -> DataType(sym, Seq(c)))

        for (b <- bases) {
          cdts ++= findDependencies(b, cdts)
        }
        cdts
      } else {
        dts
      }

    case UnitType =>
      if (!(dts contains t) && !(sorts containsA t)) {

        val sym = freshSym("Unit")

        dts + (t -> DataType(sym, Seq(Constructor(freshSym(sym.name), t, Nil))))
      } else {
        dts
      }

    case at @ ArrayType(base) =>
      if (!(dts contains t) && !(sorts containsA t)) {
        val sym = freshSym("array")

        val c = Constructor(freshSym(sym.name), at, List(
          (freshSym("size"), Int32Type),
          (freshSym("content"), RawArrayType(Int32Type, base))
        ))

        var cdts = dts + (at -> DataType(sym, Seq(c)))

        findDependencies(base, cdts)
      } else {
        dts
      }

    case _ =>
      dts
  }

  def declareDatatypes(datatypes: Map[TypeTree, DataType]): Unit = {
    // We pre-declare ADTs
    for ((tpe, DataType(sym, _)) <- datatypes) {
      sorts += tpe -> Sort(SMTIdentifier(sym))
    }

    def toDecl(c: Constructor): SMTConstructor = {
      val s = c.sym

      testers += c.tpe -> SSymbol("is-"+s.name)
      constructors += c.tpe -> s

      SMTConstructor(s, c.fields.zipWithIndex.map {
        case ((cs, t), i) =>
          selectors += (c.tpe, i) -> cs
          (cs, declareSort(t))
      })
    }

    val adts = for ((tpe, DataType(sym, cases)) <- datatypes.toList) yield {
      (sym, cases.map(toDecl))
    }


    val cmd = DeclareDatatypes(adts)
    sendCommand(cmd)
  }

  def declareStructuralSort(t: TypeTree): Sort = {
    // Populates the dependencies of the structural type to define.
    val datatypes = findDependencies(t)

    declareDatatypes(datatypes)

    sorts.toB(t)
  }

  def declareVariable(id: Identifier): SSymbol = {
    variables.cachedB(id) {
      val s = id2sym(id)
      val cmd = DeclareFun(s, List(), declareSort(id.getType))
      sendCommand(cmd)
      s
    }
  }

  def declareFunction(tfd: TypedFunDef): SSymbol = {
    functions.cachedB(tfd) {
      val id = if (tfd.tps.isEmpty) {
        tfd.id
      } else {
        FreshIdentifier(tfd.id.name)
      }
      val s = id2sym(id)
      sendCommand(DeclareFun(s, tfd.params.map(p => declareSort(p.tpe)), declareSort(tfd.returnType)))
      s
    }
  }

  def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = {
    e match {
      case Variable(id) =>
        declareSort(e.getType)
        bindings.getOrElse(id, variables.toB(id))

      case UnitLiteral() =>
        declareSort(UnitType)

        declareVariable(FreshIdentifier("Unit").setType(UnitType))

      case IntLiteral(i) => if (i > 0) Ints.NumeralLit(i) else Ints.Neg(Ints.NumeralLit(-i))
      case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt))
      case BooleanLiteral(v) => Core.BoolConst(v)
      case StringLiteral(s) => SString(s)
      case Let(b,d,e) =>
        val id      = id2sym(b)
        val value   = toSMT(d)
        val newBody = toSMT(e)(bindings + (b -> id))

        SMTLet(
          VarBinding(id, value),
          Seq(),
          newBody
        )

      case er @ Error(_) =>
        val s = declareVariable(FreshIdentifier("error_value").setType(er.getType))
        s

      case s @ CaseClassSelector(cct, e, id) =>
        declareSort(cct)
        val selector = selectors.toB((cct, s.selectorIndex))
        FunctionApplication(selector, Seq(toSMT(e)))

      case CaseClassInstanceOf(cct, e) =>
        declareSort(cct)
        val tester = testers.toB(cct)
        FunctionApplication(tester, Seq(toSMT(e)))

      case CaseClass(cct, es) =>
        declareSort(cct)
        val constructor = constructors.toB(cct)
        if (es.isEmpty) {
          constructor
        } else {
          FunctionApplication(constructor, es.map(toSMT))
        }

      case t @ Tuple(es) =>
        val tpe = normalizeType(t.getType)
        declareSort(tpe)
        val constructor = constructors.toB(tpe)
        FunctionApplication(constructor, es.map(toSMT))

      case ts @ TupleSelect(t, i) =>
        val tpe = normalizeType(t.getType)
        declareSort(tpe)
        val selector = selectors.toB((tpe, i-1))
        FunctionApplication(selector, Seq(toSMT(t)))

      case al @ ArrayLength(a) =>
        val tpe = normalizeType(a.getType)
        val selector = selectors.toB((tpe, 0))

        FunctionApplication(selector, Seq(toSMT(a)))

      case al @ ArraySelect(a, i) =>
        val tpe = normalizeType(a.getType)

        val scontent = FunctionApplication(selectors.toB((tpe, 1)), Seq(toSMT(a)))

        ArraysEx.Select(scontent, toSMT(i))

      case al @ ArrayUpdated(a, i, e) =>
        val tpe = normalizeType(a.getType)

        val sa = toSMT(a)
        val ssize    = FunctionApplication(selectors.toB((tpe, 0)), Seq(sa))
        val scontent = FunctionApplication(selectors.toB((tpe, 1)), Seq(sa))

        val newcontent = ArraysEx.Store(scontent, toSMT(i), toSMT(e))

        val constructor = constructors.toB(tpe)
        FunctionApplication(constructor, Seq(ssize, newcontent))

      /**
       * ===== Map operations =====
       */
      case m @ FiniteMap(elems) =>
        val mt @ MapType(from, to) = m.getType
        val ms = declareSort(mt)

        val opt = declareOptionSort(to)

        var res: Term = FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(ms)), List(QualifiedIdentifier(SMTIdentifier(SSymbol("None")), Some(opt))))
        for ((k, v) <- elems) {
          res = ArraysEx.Store(res, toSMT(k), FunctionApplication(SSymbol("Some"), List(toSMT(v))))
        }

        res

      case MapGet(m, k) =>
        declareSort(m.getType)
        FunctionApplication(SSymbol("Some_v"), List(ArraysEx.Select(toSMT(m), toSMT(k))))

      case MapIsDefinedAt(m, k) =>
        declareSort(m.getType)
        FunctionApplication(SSymbol("is-Some"), List(ArraysEx.Select(toSMT(m), toSMT(k))))

      /**
       * ===== Everything else =====
       */

      case e @ UnaryOperator(u, _) =>
        e match {
          case (_: Not) => Core.Not(toSMT(u))
          case (_: UMinus) => Ints.Neg(toSMT(u))
          case _ => reporter.fatalError("Unhandled unary "+e)
        }

      case e @ BinaryOperator(a, b, _) =>
        e match {
          case (_: Equals) => Core.Equals(toSMT(a), toSMT(b))
          case (_: Implies) => Core.Implies(toSMT(a), toSMT(b))
          case (_: Iff) => Core.Equals(toSMT(a), toSMT(b))
          case (_: Plus) => Ints.Add(toSMT(a), toSMT(b))
          case (_: Minus) => Ints.Sub(toSMT(a), toSMT(b))
          case (_: Times) => Ints.Mul(toSMT(a), toSMT(b))
          case (_: Division) => Ints.Div(toSMT(a), toSMT(b))
          case (_: Modulo) => Ints.Mod(toSMT(a), toSMT(b))
          case (_: LessThan) => Ints.LessThan(toSMT(a), toSMT(b))
          case (_: LessEquals) => Ints.LessEquals(toSMT(a), toSMT(b))
          case (_: GreaterThan) => Ints.GreaterThan(toSMT(a), toSMT(b))
          case (_: GreaterEquals) => Ints.GreaterEquals(toSMT(a), toSMT(b))
          case _ => reporter.fatalError("Unhandled binary "+e)
        }

      case e @ NAryOperator(sub, _) =>
        e match {
          case (_: And) => Core.And(sub.map(toSMT): _*)
          case (_: Or) => Core.Or(sub.map(toSMT): _*)
          case (_: IfExpr) => Core.ITE(toSMT(sub(0)), toSMT(sub(1)), toSMT(sub(2))) 
          case (f: FunctionInvocation) => 
            FunctionApplication(
              declareFunction(f.tfd),
              sub.map(toSMT)
            )
          case _ => reporter.fatalError("Unhandled nary "+e)
        }


      case o => unsupported("Tree: " + o)
    }
  }

  def fromSMT(pair: (Term, TypeTree))(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = {
    fromSMT(pair._1, pair._2)
  }

  def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match {
    case (_, UnitType) =>
      UnitLiteral()

    case (SHexadecimal(h), CharType) =>
      CharLiteral(h.toInt.toChar)

    case (SNumeral(n), Int32Type) =>
      IntLiteral(n.toInt)

    case (Core.True(), BooleanType)  => BooleanLiteral(true)
    case (Core.False(), BooleanType)  => BooleanLiteral(false)

    case (SHexadecimal(hexa), Int32Type) => IntLiteral(hexa.toInt)

    case (SimpleSymbol(s), _: ClassType) if constructors.containsB(s) =>
      constructors.toA(s) match {
        case cct: CaseClassType =>
          CaseClass(cct, Nil)
        case t =>
          unsupported("woot? for a single constructor for non-case-object: "+t)
      }

    case (SimpleSymbol(s), tpe) if lets contains s =>
      fromSMT(lets(s), tpe)

    case (SimpleSymbol(s), _) =>
      variables.getA(s).map(_.toVariable).getOrElse {
        unsupported("Unknown symbol: "+s)
      }

    case (FunctionApplication(SimpleSymbol(s), args), tpe) if constructors.containsB(s) =>
      constructors.toA(s) match {
        case cct: CaseClassType =>
          val rargs = args.zip(cct.fields.map(_.tpe)).map(fromSMT)
          CaseClass(cct, rargs)
        case tt: TupleType =>
          val rargs = args.zip(tt.bases).map(fromSMT)
          Tuple(rargs)

        case at: ArrayType =>
          val IntLiteral(size)                 = fromSMT(args(0), Int32Type)
          val RawArrayValue(_, elems, default) = fromSMT(args(1), RawArrayType(Int32Type, at.base))

          val entries = for (i <- 0 to size-1) yield elems.getOrElse(IntLiteral(i), default)

          FiniteArray(entries).setType(at)

        case t =>
          unsupported("Woot? structural type that is non-structural: "+t)
      }

    // EK: Since we have no type information, we cannot do type-directed
    // extraction of defs, instead, we expand them in smt-world
    case (SMTLet(binding, bindings, body), tpe) =>
      val defsMap: Map[SSymbol, Term] = (binding +: bindings).map {
        case VarBinding(s, value) => (s, value)
      }.toMap

      fromSMT(body, tpe)(lets ++ defsMap, letDefs)

    case (FunctionApplication(SimpleSymbol(SSymbol(app)), args), tpe) => {
      app match {
        case "-" =>
          args match {
            case List(a) => UMinus(fromSMT(a, Int32Type))
            case List(a, b) => Minus(fromSMT(a, Int32Type), fromSMT(b, Int32Type))
          }

        case _ =>
          unsupported("Function "+app+" not handled in fromSMT: "+s)
      }
    }
    case (QualifiedIdentifier(id, sort), tpe) =>
      unsupported("Unhandled case in fromSMT: " + id +": "+sort +" ("+tpe+")")
    case _ =>
      unsupported("Unhandled case in fromSMT: " + (s, tpe))
  }

  def sendCommand(cmd: Command): CommandResponse = {
    reporter.ifDebug { debug =>
      SMTPrinter.printCommand(cmd, out)
      out.write("\n")
      out.flush
    }

    interpreter.eval(cmd) match {
      case err@ErrorResponse(msg) if !interrupted =>
        reporter.fatalError("Unnexpected error from smt-"+targetName+" solver: "+msg)
      case res => res
    }
  }

  override def assertCnstr(expr: Expr): Unit = {
    variablesOf(expr).foreach(declareVariable)
    val term = toSMT(expr)(Map())
    sendCommand(Assert(term))
  }

  override def check: Option[Boolean] = sendCommand(CheckSat()) match {
    case CheckSatResponse(SatStatus)     => Some(true)
    case CheckSatResponse(UnsatStatus)   => Some(false)
    case CheckSatResponse(UnknownStatus) => None
    case _                               => None
  }

  override def getModel: Map[Identifier, Expr] = {
    val syms = variables.bSet.toList
    val cmd: Command = 
      GetValue(syms.head, 
               syms.tail.map(s => QualifiedIdentifier(SMTIdentifier(s)))
              )


    val GetValueResponse(valuationPairs) = sendCommand(cmd)

    valuationPairs.collect {
      case (SimpleSymbol(sym), value) if variables.containsB(sym) =>
        val id = variables.toA(sym)

        (id, fromSMT(value, id.getType)(Map(), Map()))
    }.toMap
  }

  override def push(): Unit = {
    sendCommand(Push(1))
  }
  override def pop(lvl: Int = 1): Unit = {
    sendCommand(Pop(1))
  }

  protected object SimpleSymbol {

    def unapply(term: Term): Option[SSymbol] = term match {
      case QualifiedIdentifier(SMTIdentifier(sym, Seq()), None) => Some(sym)
      case _ => None
    }

  }

  import scala.language.implicitConversions
  implicit def symbolToQualifiedId(s: SSymbol): QualifiedIdentifier = {
    QualifiedIdentifier(SMTIdentifier(s))
  }

}

object VCNumbers {
  private var nexts = Map[String, Int]().withDefaultValue(0)
  def getNext(id: String) = {
    val n = nexts(id)+1
    nexts += id -> n
    n
  }
}