/* Copyright 2009-2016 EPFL, Lausanne */

package leon
package solvers
package unrolling

import purescala.Common._
import purescala.Expressions._
import purescala.Extractors._
import purescala.ExprOps._
import purescala.Types._
import purescala.TypeOps.bestRealType
import purescala.Definitions._
import purescala.Constructors._
import purescala.Quantification._

import theories._
import utils.SeqUtils._
import Instantiation._

class TemplateGenerator[T](val theories: TheoryEncoder,
                           val encoder: TemplateEncoder[T],
                           val assumePreHolds: Boolean) {
  private var cache     = Map[TypedFunDef, FunctionTemplate[T]]()
  private var cacheExpr = Map[Expr, (FunctionTemplate[T], Map[Identifier, Identifier])]()

  private type Clauses = (
    Map[Identifier,T],
    Map[Identifier,T],
    Map[Identifier, Set[Identifier]],
    Map[Identifier, Seq[Expr]],
    Seq[LambdaTemplate[T]],
    Seq[QuantificationTemplate[T]]
  )

  private def emptyClauses: Clauses = (Map.empty, Map.empty, Map.empty, Map.empty, Seq.empty, Seq.empty)

  private implicit class ClausesWrapper(clauses: Clauses) {
    def ++(that: Clauses): Clauses = {
      val (thisConds, thisExprs, thisTree, thisGuarded, thisLambdas, thisQuants) = clauses
      val (thatConds, thatExprs, thatTree, thatGuarded, thatLambdas, thatQuants) = that

      (thisConds ++ thatConds, thisExprs ++ thatExprs, thisTree merge thatTree,
        thisGuarded merge thatGuarded, thisLambdas ++ thatLambdas, thisQuants ++ thatQuants)
    }
  }

  val manager = new QuantificationManager[T](encoder)

  def mkTemplate(raw: Expr): (FunctionTemplate[T], Map[Identifier, Identifier]) = {
    if (cacheExpr contains raw) {
      return cacheExpr(raw)
    }

    val mapping = variablesOf(raw).map(id => id -> theories.encode(id)).toMap
    val body = theories.encode(raw)(mapping)

    val arguments = mapping.values.toSeq.map(ValDef(_))
    val fakeFunDef = new FunDef(FreshIdentifier("fake", alwaysShowUniqueID = true), Nil, arguments, body.getType)

    fakeFunDef.precondition = Some(andJoin(arguments.map(vd => manager.typeUnroller(vd.toVariable))))
    fakeFunDef.body = Some(body)

    val res = mkTemplate(fakeFunDef.typed, false)
    val p = (res, mapping)
    cacheExpr += raw -> p
    p
  }

  def mkTemplate(tfd: TypedFunDef, isRealFunDef: Boolean = true): FunctionTemplate[T] = {
    if (cache contains tfd) {
      return cache(tfd)
    }

    // The precondition if it exists.
    val prec : Option[Expr] = tfd.precondition.map(p => simplifyHOFunctions(matchToIfThenElse(p)))

    val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b))
    val lambdaBody : Option[Expr] = newBody.map(b => simplifyHOFunctions(b))

    val funDefArgs: Seq[Identifier] = tfd.paramIds
    val lambdaArguments: Seq[Identifier] = lambdaBody.map(lambdaArgs).toSeq.flatten
    val invocation : Expr = FunctionInvocation(tfd, funDefArgs.map(_.toVariable))

    val invocationEqualsBody : Seq[Expr] = lambdaBody match {
      case Some(body) if isRealFunDef =>
        val bs = liftedEquals(invocation, body, lambdaArguments) :+ Equals(invocation, body)

        if(prec.isDefined) {
          bs.map(Implies(prec.get, _))
        } else {
          bs
        }

      case _ =>
        Seq.empty
    }

    val start : Identifier = FreshIdentifier("start", BooleanType, true)
    val pathVar : (Identifier, T) = start -> encoder.encodeId(start)

    val allArguments : Seq[Identifier] = funDefArgs ++ lambdaArguments
    val arguments : Seq[(Identifier, T)] = allArguments.map(id => id -> encoder.encodeId(id))

    val substMap : Map[Identifier, T] = arguments.toMap + pathVar

    val (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) {
      invocationEqualsBody.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap))
    } else {
      (prec.toSeq :+ lambdaBody.get).foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap))
    }

    // Now the postcondition.
    val (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications) = tfd.postcondition match {
      case Some(post) =>
        val newPost : Expr = simplifyHOFunctions(application(matchToIfThenElse(post), Seq(invocation)))

        val postHolds : Expr =
          if(tfd.hasPrecondition) {
            if (assumePreHolds) {
              And(prec.get, newPost)
            } else {
              Implies(prec.get, newPost)
            }
          } else {
            newPost
          }

        val (postConds, postExprs, postTree, postGuarded, postLambdas, postQuantifications) = mkClauses(start, postHolds, substMap)
        (bodyConds ++ postConds, bodyExprs ++ postExprs, bodyTree merge postTree, bodyGuarded merge postGuarded, bodyLambdas ++ postLambdas, bodyQuantifications ++ postQuantifications)

      case None =>
        (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications)
    }

    val template = FunctionTemplate(tfd, encoder, manager,
      pathVar, arguments, condVars, exprVars, condTree, guardedExprs, quantifications, lambdas, isRealFunDef)
    cache += tfd -> template
    template
  }

  private def lambdaArgs(expr: Expr): Seq[Identifier] = expr match {
    case Lambda(args, body) => args.map(_.id.freshen) ++ lambdaArgs(body)
    case IsTyped(_, _: FunctionType) => sys.error("Only applicable on lambda chains")
    case _ => Seq.empty
  }

  private def liftedEquals(invocation: Expr, body: Expr, args: Seq[Identifier], inlineFirst: Boolean = false): Seq[Expr] = {
    def rec(i: Expr, b: Expr, args: Seq[Identifier], inline: Boolean): Seq[Expr] = i.getType match {
      case FunctionType(from, to) =>
        val (currArgs, nextArgs) = args.splitAt(from.size)
        val arguments = currArgs.map(_.toVariable)
        val apply = if (inline) application _ else Application
        val (appliedInv, appliedBody) = (apply(i, arguments), apply(b, arguments))
        rec(appliedInv, appliedBody, nextArgs, false) :+ Equals(appliedInv, appliedBody)
      case _ =>
        assert(args.isEmpty, "liftedEquals should consume all provided arguments")
        Seq.empty
    }

    rec(invocation, body, args, inlineFirst)
  }

  private def minimalFlattening(inits: Set[Identifier], conj: Expr): (Set[Identifier], Expr) = {
    var mapping: Map[Expr, Expr] = Map.empty
    var quantified: Set[Identifier] = inits
    var quantifierEqualities: Seq[(Expr, Identifier)] = Seq.empty

    val newConj = postMap {
      case expr if mapping.isDefinedAt(expr) =>
        Some(mapping(expr))

      case expr @ QuantificationMatcher(c, args) =>
        val isMatcher = args.exists { case Variable(id) => quantified(id) case _ => false }
        val isRelevant = (variablesOf(expr) & quantified).nonEmpty
        if (!isMatcher && isRelevant) {
          val newArgs = args.map {
            case arg @ QuantificationMatcher(_, _) if (variablesOf(arg) & quantified).nonEmpty =>
              val id = FreshIdentifier("flat", arg.getType)
              quantifierEqualities :+= (arg -> id)
              quantified += id
              Variable(id)
            case arg => arg
          }

          val newExpr = replace((args zip newArgs).toMap, expr)
          mapping += expr -> newExpr
          Some(newExpr)
        } else {
          None
        }

      case _ => None
    } (conj)

    val flatConj = implies(andJoin(quantifierEqualities.map {
      case (arg, id) => Equals(arg, Variable(id))
    }), newConj)

    (quantified, flatConj)
  }

  def mkClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): Clauses = {
    val (p, (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications)) = mkExprClauses(pathVar, expr, substMap)
    val allGuarded = guardedExprs + (pathVar -> (p +: guardedExprs.getOrElse(pathVar, Seq.empty)))
    (condVars, exprVars, condTree, allGuarded, lambdas, quantifications)
  }

  private def mkExprClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): (Expr, Clauses) = {

    var condVars = Map[Identifier, T]()
    var condTree = Map[Identifier, Set[Identifier]](pathVar -> Set.empty).withDefaultValue(Set.empty)
    def storeCond(pathVar: Identifier, id: Identifier) : Unit = {
      condVars += id -> encoder.encodeId(id)
      condTree += pathVar -> (condTree(pathVar) + id)
    }

    @inline def encodedCond(id: Identifier) : T = substMap.getOrElse(id, condVars(id))

    var exprVars = Map[Identifier, T]()
    @inline def storeExpr(id: Identifier) : Unit = exprVars += id -> encoder.encodeId(id)

    // Represents clauses of the form:
    //    id => expr && ... && expr
    var guardedExprs = Map[Identifier, Seq[Expr]]()
    def storeGuarded(guardVar: Identifier, expr: Expr) : Unit = {
      assert(expr.getType == BooleanType, expr.asString(Program.empty)(LeonContext.empty) + " is not of type Boolean. " + purescala.ExprOps.explainTyping(expr))

      val prev = guardedExprs.getOrElse(guardVar, Nil)
      guardedExprs += guardVar -> (expr +: prev)
    }

    var lambdaVars = Map[Identifier, T]()
    @inline def storeLambda(id: Identifier) : T = {
      val idT = encoder.encodeId(id)
      lambdaVars += id -> idT
      idT
    }

    var quantifications = Seq[QuantificationTemplate[T]]()
    @inline def registerQuantification(quantification: QuantificationTemplate[T]): Unit =
      quantifications :+= quantification

    var lambdas = Seq[LambdaTemplate[T]]()
    @inline def registerLambda(lambda: LambdaTemplate[T]) : Unit = lambdas :+= lambda

    def rec(pathVar: Identifier, expr: Expr): Expr = {
      expr match {
        case a @ Assert(cond, err, body) =>
          rec(pathVar, IfExpr(cond, body, Error(body.getType, err getOrElse "assertion failed")))

        case e @ Ensuring(_, _) =>
          rec(pathVar, e.toAssert)

        case l @ Let(i, e : Lambda, b) =>
          val re = rec(pathVar, e) // guaranteed variable!
          val rb = rec(pathVar, replace(Map(Variable(i) -> re), b))
          rb

        case l @ Let(i, e, b) =>
          val newExpr : Identifier = FreshIdentifier("lt", i.getType, true)
          storeExpr(newExpr)
          val re = rec(pathVar, e)
          storeGuarded(pathVar, Equals(Variable(newExpr), re))
          val rb = rec(pathVar, replace(Map(Variable(i) -> Variable(newExpr)), b))
          rb

        /* TODO: maybe we want this specialization?
        case l @ LetTuple(is, e, b) =>
          val tuple : Identifier = FreshIdentifier("t", TupleType(is.map(_.getType)), true)
          storeExpr(tuple)
          val re = rec(pathVar, e)
          storeGuarded(pathVar, Equals(Variable(tuple), re))

          val mapping = for ((id, i) <- is.zipWithIndex) yield {
            val newId = FreshIdentifier("ti", id.getType, true)
            storeExpr(newId)
            storeGuarded(pathVar, Equals(Variable(newId), TupleSelect(Variable(tuple), i+1)))

            (Variable(id) -> Variable(newId))
          }

          val rb = rec(pathVar, replace(mapping.toMap, b))
          rb
        */
        case m : MatchExpr => sys.error("'MatchExpr's should have been eliminated before generating templates.")
        case p : Passes    => sys.error("'Passes's should have been eliminated before generating templates.")

        case i @ Implies(lhs, rhs) =>
          if (!isSimple(i)) {
            rec(pathVar, Or(Not(lhs), rhs))
          } else {
            implies(rec(pathVar, lhs), rec(pathVar, rhs))
          }

        case a @ And(parts) =>
          val partitions = groupWhile(parts)(isSimple)
          partitions.map(andJoin) match {
            case Seq(e) => e
            case seq =>
              val newExpr : Identifier = FreshIdentifier("e", BooleanType, true)
              storeExpr(newExpr)

              def recAnd(pathVar: Identifier, partitions: Seq[Expr]): Unit = partitions match {
                case x :: Nil =>
                  storeGuarded(pathVar, Equals(Variable(newExpr), rec(pathVar, x)))

                case x :: xs =>
                  val newBool : Identifier = FreshIdentifier("b", BooleanType, true)
                  storeCond(pathVar, newBool)

                  val xrec = rec(pathVar, x)
                  storeGuarded(pathVar, Equals(Variable(newBool), xrec))
                  storeGuarded(pathVar, Implies(Not(Variable(newBool)), Not(Variable(newExpr))))

                  recAnd(newBool, xs)

                case Nil => scala.sys.error("Should never happen!")
              }

              recAnd(pathVar, seq)
              Variable(newExpr)
          }

        case o @ Or(parts) =>
          val partitions = groupWhile(parts)(isSimple)
          partitions.map(orJoin) match {
            case Seq(e) => e
            case seq =>
              val newExpr : Identifier = FreshIdentifier("e", BooleanType, true)
              storeExpr(newExpr)

              def recOr(pathVar: Identifier, partitions: Seq[Expr]): Unit = partitions match {
                case x :: Nil =>
                  storeGuarded(pathVar, Equals(Variable(newExpr), rec(pathVar, x)))

                case x :: xs =>
                  val newBool : Identifier = FreshIdentifier("b", BooleanType, true)
                  storeCond(pathVar, newBool)

                  val xrec = rec(pathVar, x)
                  storeGuarded(pathVar, Equals(Not(Variable(newBool)), xrec))
                  storeGuarded(pathVar, Implies(Not(Variable(newBool)), Variable(newExpr)))

                  recOr(newBool, xs)

                case Nil => scala.sys.error("Should never happen!")
              }

              recOr(pathVar, seq)
              Variable(newExpr)
          }

        case i @ IfExpr(cond, thenn, elze) => {
          if(isSimple(i)) {
            i
          } else {
            val newBool1 : Identifier = FreshIdentifier("b", BooleanType, true)
            val newBool2 : Identifier = FreshIdentifier("b", BooleanType, true)
            val newExpr : Identifier = FreshIdentifier("e", i.getType, true)

            storeCond(pathVar, newBool1)
            storeCond(pathVar, newBool2)

            storeExpr(newExpr)

            val crec = rec(pathVar, cond)
            val trec = rec(newBool1, thenn)
            val erec = rec(newBool2, elze)

            storeGuarded(pathVar, or(Variable(newBool1), Variable(newBool2)))
            storeGuarded(pathVar, or(not(Variable(newBool1)), not(Variable(newBool2))))
            // TODO can we improve this? i.e. make it more symmetrical?
            // Probably it's symmetrical enough to Z3.
            storeGuarded(pathVar, Equals(Variable(newBool1), crec))
            storeGuarded(newBool1, Equals(Variable(newExpr), trec))
            storeGuarded(newBool2, Equals(Variable(newExpr), erec))
            Variable(newExpr)
          }
        }

        case c @ Choose(Lambda(params, cond)) =>
          val cs = params.map(_.id.freshen.toVariable)

          for (c <- cs) {
            storeExpr(c.id)
          }

          val freshMap = (params.map(_.id) zip cs).toMap

          storeGuarded(pathVar, replaceFromIDs(freshMap, cond))

          tupleWrap(cs)

        case FiniteLambda(mapping, dflt, FunctionType(from, to)) =>
          val args = from.map(tpe => FreshIdentifier("x", tpe))
          val body = mapping.toSeq.foldLeft(dflt) { case (elze, (exprs, res)) =>
            IfExpr(andJoin((args zip exprs).map(p => Equals(Variable(p._1), p._2))), res, elze)
          }

          rec(pathVar, Lambda(args.map(ValDef(_)), body))

        case l @ Lambda(args, body) =>
          val idArgs : Seq[Identifier] = lambdaArgs(l)
          val trArgs : Seq[T] = idArgs.map(id => substMap.getOrElse(id, encoder.encodeId(id)))

          val lid = FreshIdentifier("lambda", bestRealType(l.getType), true)
          val clauses = liftedEquals(Variable(lid), l, idArgs, inlineFirst = true)

          val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars
          val clauseSubst: Map[Identifier, T] = localSubst ++ (idArgs zip trArgs)
          val (lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaTemplates, lambdaQuants) =
            clauses.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(pathVar, cls, clauseSubst))

          val ids: (Identifier, T) = lid -> storeLambda(lid)
          val dependencies: Map[Identifier, T] = variablesOf(l).map(id => id -> localSubst(id)).toMap
          val template = LambdaTemplate(ids, encoder, manager, pathVar -> encodedCond(pathVar),
            idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaQuants, lambdaTemplates, localSubst, dependencies, l)
          registerLambda(template)

          Variable(lid)

        case f @ Forall(args, body) =>
          val TopLevelAnds(conjuncts) = body

          val conjunctQs = conjuncts.map { conjunct =>
            val vars = variablesOf(conjunct)
            val inits = args.map(_.id).filter(vars).toSet
            val (quantifiers, flatConj) = minimalFlattening(inits, conjunct)

            val idQuantifiers : Seq[Identifier] = quantifiers.toSeq
            val trQuantifiers : Seq[T] = idQuantifiers.map(encoder.encodeId)

            val q: Identifier = FreshIdentifier("q", BooleanType, true)
            val q2: Identifier = FreshIdentifier("qo", BooleanType, true)
            val inst: Identifier = FreshIdentifier("inst", BooleanType, true)
            val guard: Identifier = FreshIdentifier("guard", BooleanType, true)

            val clause = Equals(Variable(inst), Implies(Variable(guard), flatConj))

            val qs: (Identifier, T) = q -> encoder.encodeId(q)
            val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars
            val clauseSubst: Map[Identifier, T] = localSubst ++ (idQuantifiers zip trQuantifiers)
            val (p, (qConds, qExprs, qTree, qGuarded, qTemplates, qQuants)) = mkExprClauses(pathVar, flatConj, clauseSubst)
            assert(qQuants.isEmpty, "Unhandled nested quantification in "+clause)

            val allGuarded = qGuarded + (pathVar -> (Seq(
              Equals(Variable(inst), Implies(Variable(guard), p)),
              Equals(Variable(q), And(Variable(q2), Variable(inst)))
            ) ++ qGuarded.getOrElse(pathVar, Seq.empty)))

            val dependencies: Map[Identifier, T] = vars.filterNot(quantifiers).map(id => id -> localSubst(id)).toMap
            val template = QuantificationTemplate[T](encoder, manager, pathVar -> encodedCond(pathVar),
              qs, q2, inst, guard, idQuantifiers zip trQuantifiers, qConds, qExprs, qTree, allGuarded, qTemplates, localSubst,
              dependencies, Forall(quantifiers.toSeq.sortBy(_.uniqueName).map(ValDef(_)), flatConj))
            registerQuantification(template)
            Variable(q)
          }

          andJoin(conjunctQs)

        case Operator(as, r) => r(as.map(a => rec(pathVar, a)))
      }
    }

    val p = rec(pathVar, expr)
    (p, (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications))
  }

}