/* Copyright 2009-2014 EPFL, Lausanne */

package leon
package solvers.z3

import purescala.Common._
import purescala.Trees._
import purescala.Extractors._
import purescala.TreeOps._
import purescala.TypeTrees._
import purescala.Definitions._

import evaluators._

import z3.scala._

import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap}

case class Z3FunctionInvocation(tfd: TypedFunDef, args: Seq[Z3AST]) {
  override def toString = tfd.signature + args.mkString("(", ",", ")")
}

class FunctionTemplate private(
  solver: FairZ3Solver,
  val tfd : TypedFunDef,
  activatingBool : Identifier,
  condVars : Set[Identifier],
  exprVars : Set[Identifier],
  guardedExprs : Map[Identifier,Seq[Expr]],
  isRealFunDef : Boolean) {

  private def isTerminatingForAllInputs : Boolean = (
       isRealFunDef
    && !tfd.hasPrecondition
    && solver.getTerminator.terminates(tfd.fd).isGuaranteed
  )

  private val z3 = solver.z3

  private val asClauses : Seq[Expr] = {
    (for((b,es) <- guardedExprs; e <- es) yield {
      Implies(Variable(b), e)
    }).toSeq
  }

  val z3ActivatingBool = solver.idToFreshZ3Id(activatingBool)

  private val z3FunDefArgs     = tfd.params.map( ad => solver.idToFreshZ3Id(ad.id))

  private val zippedCondVars   = condVars.map(id => (id, solver.idToFreshZ3Id(id)))
  private val zippedExprVars   = exprVars.map(id => (id, solver.idToFreshZ3Id(id)))
  private val zippedFunDefArgs = tfd.params.map(_.id) zip z3FunDefArgs

  val idToZ3Ids: Map[Identifier, Z3AST] = {
    Map(activatingBool -> z3ActivatingBool) ++
    zippedCondVars ++
    zippedExprVars ++
    zippedFunDefArgs
  }

  val asZ3Clauses: Seq[Z3AST] = asClauses.map {
    cl => solver.toZ3Formula(cl, idToZ3Ids).getOrElse(sys.error("Could not translate to z3. Did you forget --xlang? @"+cl.getPos))
  }

  private val blockers : Map[Identifier,Set[FunctionInvocation]] = {
    val idCall = FunctionInvocation(tfd, tfd.params.map(_.toVariable))

    Map((for((b, es) <- guardedExprs) yield {
      val calls = es.foldLeft(Set.empty[FunctionInvocation])((s,e) => s ++ functionCallsOf(e)) - idCall
      if(calls.isEmpty) {
        None
      } else {
        Some((b, calls))
      }
    }).flatten.toSeq : _*)
  }

  val z3Blockers: Map[Z3AST,Set[Z3FunctionInvocation]] = blockers.map {
    case (b, funs) =>
      (idToZ3Ids(b) -> funs.map(fi => Z3FunctionInvocation(fi.tfd, fi.args.map(solver.toZ3Formula(_, idToZ3Ids).get))))
  }

  // We use a cache to create the same boolean variables.
  private val cache : MutableMap[Seq[Z3AST],Map[Z3AST,Z3AST]] = MutableMap.empty

  def instantiate(aVar : Z3AST, args : Seq[Z3AST]) : (Seq[Z3AST], Map[Z3AST,Set[Z3FunctionInvocation]]) = {
    assert(args.size == tfd.params.size)

    // The "isRealFunDef" part is to prevent evaluation of "fake"
    // function templates, as generated from FairZ3Solver.
    if(solver.evalGroundApps && isRealFunDef) {
      val ga = args.view.map(solver.asGround)
      if(ga.forall(_.isDefined)) {
        val leonArgs = ga.map(_.get).force
        val invocation = FunctionInvocation(tfd, leonArgs)
        solver.getEvaluator.eval(invocation) match {
          case EvaluationResults.Successful(result) =>
            val z3Invocation = z3.mkApp(solver.functionDefToDecl(tfd), args: _*)
            val z3Value      = solver.toZ3Formula(result).get
            val asZ3         = z3.mkEq(z3Invocation, z3Value)
            return (Seq(asZ3), Map.empty)

          case _ => throw new Exception("Evaluation of ground term should have succeeded.")
        }
      }
    }
    // ...end of ground evaluation part.

    val (wasHit,baseIDSubstMap) = cache.get(args) match {
      case Some(m) => (true,m)
      case None =>
        val newMap : Map[Z3AST,Z3AST] = 
          (zippedExprVars ++ zippedCondVars).map(p => p._2 -> solver.idToFreshZ3Id(p._1)).toMap ++
          (z3FunDefArgs zip args)
        cache(args) = newMap
        (false,newMap)
    }

    val idSubstMap : Map[Z3AST,Z3AST] = baseIDSubstMap + (z3ActivatingBool -> aVar)

    val (from, to) = idSubstMap.unzip
    val (fromArray, toArray) = (from.toArray, to.toArray)

    val newClauses  = asZ3Clauses.map(z3.substitute(_, fromArray, toArray))
    val newBlockers = z3Blockers.map { case (b, funs) =>
      val bp = if (b == z3ActivatingBool) {
        aVar
      } else {
        idSubstMap(b)
      }

      val newFuns = funs.map(fi => fi.copy(args = fi.args.map(z3.substitute(_, fromArray, toArray))))

      bp -> newFuns
    }

    (newClauses, newBlockers)
  }

  override def toString : String = {
    "Template for def " + tfd.id + "(" + tfd.params.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + tfd.returnType + " is :\n" +
    " * Activating boolean : " + activatingBool + "\n" + 
    " * Control booleans   : " + condVars.toSeq.map(_.toString).mkString(", ") + "\n" +
    " * Expression vars    : " + exprVars.toSeq.map(_.toString).mkString(", ") + "\n" +
    " * \"Clauses\"          : " + "\n    " + asClauses.mkString("\n    ") + "\n" +
    " * Block-map          : " + blockers.toString
  }
}

object FunctionTemplate {
  def mkTemplate(solver: FairZ3Solver, tfd: TypedFunDef, isRealFunDef : Boolean = true) : FunctionTemplate = {
    val condVars : MutableSet[Identifier] = MutableSet.empty
    val exprVars : MutableSet[Identifier] = MutableSet.empty

    // Represents clauses of the form:
    //    id => expr && ... && expr
    val guardedExprs : MutableMap[Identifier,Seq[Expr]] = MutableMap.empty

    def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = {
      assert(expr.getType == BooleanType)
      if(guardedExprs.isDefinedAt(guardVar)) {
        val prev : Seq[Expr] = guardedExprs(guardVar)
        guardedExprs(guardVar) = expr +: prev
      } else {
        guardedExprs(guardVar) = Seq(expr)
      }
    }

    // Group elements that satisfy p toghether
    // List(a, a, a, b, c, a, a), with p = _ == a will produce:
    // List(List(a,a,a), List(b), List(c), List(a, a))
    def groupWhile[T](p: T => Boolean, l: Seq[T]): Seq[Seq[T]] = {
      var res: Seq[Seq[T]] = Nil

      var c = l

      while(!c.isEmpty) {
        val (span, rest) = c.span(p)

        if (span.isEmpty) {
          res = res :+ Seq(rest.head)
          c   = rest.tail
        } else {
          res = res :+ span
          c  = rest
        }
      }

      res
    }

    def requireDecomposition(e: Expr) = {
      exists{
        case (_: FunctionInvocation) | (_: Assert) | (_: Ensuring) | (_: Choose) => true
        case _ => false
      }(e)
    }

    def rec(pathVar : Identifier, expr : Expr) : Expr = {
      expr match {
        case a @ Assert(cond, _, body) =>
          storeGuarded(pathVar, rec(pathVar, cond))
          rec(pathVar, body)

        case e @ Ensuring(body, id, post) =>
          rec(pathVar, Let(id, body, Assert(post, None, Variable(id))))

        case l @ Let(i, e, b) =>
          val newExpr : Identifier = FreshIdentifier("lt", true).setType(i.getType)
          exprVars += newExpr
          val re = rec(pathVar, e)
          storeGuarded(pathVar, Equals(Variable(newExpr), re))
          val rb = rec(pathVar, replace(Map(Variable(i) -> Variable(newExpr)), b))
          rb

        case l @ LetTuple(is, e, b) =>
          val tuple : Identifier = FreshIdentifier("t", true).setType(TupleType(is.map(_.getType)))
          exprVars += tuple
          val re = rec(pathVar, e)
          storeGuarded(pathVar, Equals(Variable(tuple), re))

          val mapping = for ((id, i) <- is.zipWithIndex) yield {
            val newId = FreshIdentifier("ti", true).setType(id.getType)
            exprVars += newId
            storeGuarded(pathVar, Equals(Variable(newId), TupleSelect(Variable(tuple), i+1)))

            (Variable(id) -> Variable(newId))
          }

          val rb = rec(pathVar, replace(mapping.toMap, b))
          rb

        case m : MatchExpr => sys.error("MatchExpr's should have been eliminated.")

        case i @ Implies(lhs, rhs) =>
          Implies(rec(pathVar, lhs), rec(pathVar, rhs))

        case a @ And(parts) =>
          And(parts.map(rec(pathVar, _)))

        case o @ Or(parts) =>
          Or(parts.map(rec(pathVar, _)))

        case i @ IfExpr(cond, thenn, elze) => {
          if(!requireDecomposition(i)) {
            i
          } else {
            val newBool1 : Identifier = FreshIdentifier("b", true).setType(BooleanType)
            val newBool2 : Identifier = FreshIdentifier("b", true).setType(BooleanType)
            val newExpr : Identifier = FreshIdentifier("e", true).setType(i.getType)

            condVars += newBool1
            condVars += newBool2

            exprVars += newExpr

            val crec = rec(pathVar, cond)
            val trec = rec(newBool1, thenn)
            val erec = rec(newBool2, elze)

            storeGuarded(pathVar, Or(Variable(newBool1), Variable(newBool2)))
            storeGuarded(pathVar, Or(Not(Variable(newBool1)), Not(Variable(newBool2))))
            // TODO can we improve this? i.e. make it more symmetrical?
            // Probably it's symmetrical enough to Z3.
            storeGuarded(pathVar, Iff(Variable(newBool1), crec)) 
            storeGuarded(newBool1, Equals(Variable(newExpr), trec))
            storeGuarded(newBool2, Equals(Variable(newExpr), erec))
            Variable(newExpr)
          }
        }

        case c @ Choose(ids, cond) =>
          val cid = FreshIdentifier("choose", true).setType(c.getType)
          exprVars += cid

          val m: Map[Expr, Expr] = if (ids.size == 1) {
            Map(Variable(ids.head) -> Variable(cid))
          } else {
            ids.zipWithIndex.map{ case (id, i) => Variable(id) -> TupleSelect(Variable(cid), i+1) }.toMap
          }

          storeGuarded(pathVar, replace(m, cond))
          Variable(cid)

        case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, a))).setType(n.getType)
        case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, a1), rec(pathVar, a2)).setType(b.getType)
        case u @ UnaryOperator(a, r) => r(rec(pathVar, a)).setType(u.getType)
        case t : Terminal => t
      }
    }

    // The precondition if it exists.
    val prec : Option[Expr] = tfd.precondition.map(p => matchToIfThenElse(p))

    val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b))

    val invocation : Expr = FunctionInvocation(tfd, tfd.params.map(_.toVariable))

    val invocationEqualsBody : Option[Expr] = newBody match {
      case Some(body) if isRealFunDef =>
        val b : Expr = Equals(invocation, body)

        Some(if(prec.isDefined) {
          Implies(prec.get, b)
        } else {
          b
        })

      case _ =>
        None
    }

    val activatingBool : Identifier = FreshIdentifier("start", true).setType(BooleanType)

    if (isRealFunDef) {
      val finalPred : Option[Expr] = invocationEqualsBody.map(expr => rec(activatingBool, expr))
      finalPred.foreach(p => storeGuarded(activatingBool, p))
    } else {
       val newFormula = rec(activatingBool, newBody.get)
       storeGuarded(activatingBool, newFormula)
    }

    // Now the postcondition.
    tfd.postcondition match {
      case Some((id, post)) =>
        val newPost : Expr = replace(Map(Variable(id) -> invocation), matchToIfThenElse(post))

        val postHolds : Expr =
          if(tfd.hasPrecondition) {
            Implies(prec.get, newPost)
          } else {
            newPost
          }

        val finalPred2 : Expr = rec(activatingBool,  postHolds)
        storeGuarded(activatingBool, finalPred2)
      case None =>

    }

    new FunctionTemplate(solver, tfd, activatingBool, Set(condVars.toSeq : _*), Set(exprVars.toSeq : _*), Map(guardedExprs.toSeq : _*),
isRealFunDef)
  }
}