/* Copyright 2009-2016 EPFL, Lausanne */

package leon
package purescala

import Common._
import Types._
import Definitions._
import Expressions._
import Extractors._
import Constructors._
import utils._
import solvers._
import scala.language.implicitConversions

/** Provides functions to manipulate [[purescala.Expressions]].
  *
  * This object provides a few generic operations on Leon expressions,
  * as well as some common operations.
  *
  * The generic operations lets you apply operations on a whole tree
  * expression. You can look at:
  *   - [[GenTreeOps.fold foldRight]]
  *   - [[GenTreeOps.preTraversal preTraversal]]
  *   - [[GenTreeOps.postTraversal postTraversal]]
  *   - [[GenTreeOps.preMap preMap]]
  *   - [[GenTreeOps.postMap postMap]]
  *   - [[GenTreeOps.genericTransform genericTransform]]
  *
  * These operations usually take a higher order function that gets applied to the
  * expression tree in some strategy. They provide an expressive way to build complex
  * operations on Leon expressions.
  *
  */
object ExprOps extends GenTreeOps[Expr] {

  val Deconstructor = Operator

  /** Replaces bottom-up sub-identifiers by looking up for them in a map */
  def replaceFromIDs(substs: Map[Identifier, Expr], expr: Expr) : Expr = {
    postMap({
      case Variable(i) => substs.get(i)
      case _ => None
    })(expr)
  }

  def preTransformWithBinders(f: (Expr, Set[Identifier]) => Expr, initBinders: Set[Identifier] = Set())(e: Expr) = {
    import xlang.Expressions.LetVar
    def rec(binders: Set[Identifier], e: Expr): Expr = f(e, binders) match {
      case ld@LetDef(fds, bd) =>
        fds.foreach(fd => {
          fd.fullBody = rec(binders ++ fd.paramIds, fd.fullBody)
        })
        LetDef(fds, rec(binders, bd)).copiedFrom(ld)
      case l@Let(i, v, b) =>
        Let(i, rec(binders + i, v), rec(binders + i, b)).copiedFrom(l)
      case lv@LetVar(i, v, b) =>
        LetVar(i, rec(binders + i, v), rec(binders + i, b)).copiedFrom(lv)
      case m@MatchExpr(scrut, cses) =>
        MatchExpr(rec(binders, scrut), cses map { case mc@MatchCase(pat, og, rhs) =>
          val newBs = binders ++ pat.binders
          MatchCase(pat, og map (rec(newBs, _)), rec(newBs, rhs)).copiedFrom(mc)
        }).copiedFrom(m)
      case p@Passes(in, out, cses) =>
        Passes(rec(binders, in), rec(binders, out), cses map { case mc@MatchCase(pat, og, rhs) =>
          val newBs = binders ++ pat.binders
          MatchCase(pat, og map (rec(newBs, _)), rec(newBs, rhs)).copiedFrom(mc)
        }).copiedFrom(p)
      case l@Lambda(args, bd) =>
        Lambda(args, rec(binders ++ args.map(_.id), bd)).copiedFrom(l)
      case f@Forall(args, bd) =>
        Forall(args, rec(binders ++ args.map(_.id), bd)).copiedFrom(f)
      case d@Deconstructor(subs, builder) =>
        builder(subs map (rec(binders, _))).copiedFrom(d)
    }

    rec(initBinders, e)
  }

  /** Returns the set of free variables in an expression */
  def variablesOf(expr: Expr): Set[Identifier] = {
    import leon.xlang.Expressions._
    fold[Set[Identifier]] {
      case (e, subs) =>
        val subvs = subs.flatten.toSet
        e match {
          case Variable(i) => subvs + i
          case Old(i) => subvs + i
          case LetDef(fds, _) => subvs -- fds.flatMap(_.params.map(_.id))
          case Let(i, _, _) => subvs - i
          case LetVar(i, _, _) => subvs - i
          case MatchExpr(_, cses) => subvs -- cses.flatMap(_.pattern.binders)
          case Passes(_, _, cses) => subvs -- cses.flatMap(_.pattern.binders)
          case Lambda(args, _) => subvs -- args.map(_.id)
          case Forall(args, _) => subvs -- args.map(_.id)
          case _ => subvs
        }
    }(expr)
  }

  /** Returns true if the expression contains a function call */
  def containsFunctionCalls(expr: Expr): Boolean = {
    exists{
      case _: FunctionInvocation => true
      case _ => false
    }(expr)
  }

  /** Returns all Function calls found in the expression */
  def functionCallsOf(expr: Expr): Set[FunctionInvocation] = {
    collect[FunctionInvocation] {
      case f: FunctionInvocation => Set(f)
      case _ => Set()
    }(expr)
  }
  
  def nestedFunDefsOf(expr: Expr): Set[FunDef] = {
    collect[FunDef] {
      case LetDef(fds, _) => fds.toSet
      case _ => Set()
    }(expr)
  }

  /** Returns functions in directly nested LetDefs */
  def directlyNestedFunDefs(e: Expr): Set[FunDef] = {
    fold[Set[FunDef]]{
      case (LetDef(fds,_), fromFdsFromBd) => fromFdsFromBd.last ++ fds
      case (_,             subs) => subs.flatten.toSet
    }(e)
  }

  /** Computes the negation of a boolean formula, with some simplifications. */
  def negate(expr: Expr) : Expr = {
    //require(expr.getType == BooleanType)
    (expr match {
      case Let(i,b,e) => Let(i,b,negate(e))
      case Not(e) => e
      case Implies(e1,e2) => and(e1, negate(e2))
      case Or(exs) => and(exs map negate: _*)
      case And(exs) => or(exs map negate: _*)
      case LessThan(e1,e2) => GreaterEquals(e1,e2)
      case LessEquals(e1,e2) => GreaterThan(e1,e2)
      case GreaterThan(e1,e2) => LessEquals(e1,e2)
      case GreaterEquals(e1,e2) => LessThan(e1,e2)
      case IfExpr(c,e1,e2) => IfExpr(c, negate(e1), negate(e2))
      case BooleanLiteral(b) => BooleanLiteral(!b)
      case e => Not(e)
    }).setPos(expr)
  }

  def replacePatternBinders(pat: Pattern, subst: Map[Identifier, Identifier]): Pattern = {
    def rec(p: Pattern): Pattern = p match {
      case InstanceOfPattern(ob, ctd) => InstanceOfPattern(ob map subst, ctd)
      case WildcardPattern(ob) => WildcardPattern(ob map subst)
      case TuplePattern(ob, sps) => TuplePattern(ob map subst, sps map rec)
      case CaseClassPattern(ob, ccd, sps) => CaseClassPattern(ob map subst, ccd, sps map rec)
      case UnapplyPattern(ob, obj, sps) => UnapplyPattern(ob map subst, obj, sps map rec)
      case LiteralPattern(ob, lit) => LiteralPattern(ob map subst, lit)
    }

    rec(pat)
  }


  /** Replace each node by its constructor
    *
    * Remap the expression by calling the corresponding constructor
    * for each node of the expression. The constructor will perfom
    * some local simplifications, resulting in a simplified expression.
    */
  def simplifyByConstructors(expr: Expr): Expr = {
    def step(e: Expr): Option[Expr] = e match {
      case Not(t) => Some(not(t))
      case UMinus(t) => Some(uminus(t))
      case BVUMinus(t) => Some(uminus(t))
      case RealUMinus(t) => Some(uminus(t))
      case CaseClassSelector(cd, e, sel) => Some(caseClassSelector(cd, e, sel))
      case AsInstanceOf(e, ct) => Some(asInstOf(e, ct))
      case Equals(t1, t2) => Some(equality(t1, t2))
      case Implies(t1, t2) => Some(implies(t1, t2))
      case Plus(t1, t2) => Some(plus(t1, t2))
      case Minus(t1, t2) => Some(minus(t1, t2))
      case Times(t1, t2) => Some(times(t1, t2))
      case BVPlus(t1, t2) => Some(plus(t1, t2))
      case BVMinus(t1, t2) => Some(minus(t1, t2))
      case BVTimes(t1, t2) => Some(times(t1, t2))
      case RealPlus(t1, t2) => Some(plus(t1, t2))
      case RealMinus(t1, t2) => Some(minus(t1, t2))
      case RealTimes(t1, t2) => Some(times(t1, t2))
      case And(args) => Some(andJoin(args))
      case Or(args) => Some(orJoin(args))
      case Tuple(args) => Some(tupleWrap(args))
      case MatchExpr(scrut, cases) => Some(matchExpr(scrut, cases))
      case Passes(in, out, cases) => Some(passes(in, out, cases))
      case _ => None
    }
    postMap(step)(expr)
  }

  /** ATTENTION: Unused, and untested
    * rewrites pattern-matching expressions to use fresh variables for the binders
    */
  def freshenLocals(expr: Expr) : Expr = {
    def freshenCase(cse: MatchCase) : MatchCase = {
      val allBinders: Set[Identifier] = cse.pattern.binders
      val subMap: Map[Identifier,Identifier] =
        Map(allBinders.map(i => (i, FreshIdentifier(i.name, i.getType, true))).toSeq : _*)
      val subVarMap: Map[Expr,Expr] = subMap.map(kv => Variable(kv._1) -> Variable(kv._2))

      MatchCase(
        replacePatternBinders(cse.pattern, subMap),
        cse.optGuard map { replace(subVarMap, _)},
        replace(subVarMap,cse.rhs)
      )
    }

    postMap{
      case m @ MatchExpr(s, cses) =>
        Some(matchExpr(s, cses.map(freshenCase)).copiedFrom(m))

      case p @ Passes(in, out, cses) =>
        Some(Passes(in, out, cses.map(freshenCase)).copiedFrom(p))

      case l @ Let(i,e,b) =>
        val newID = FreshIdentifier(i.name, i.getType, alwaysShowUniqueID = true).copiedFrom(i)
        Some(Let(newID, e, replaceFromIDs(Map(i -> Variable(newID)), b)).copiedFrom(l))

      case _ => None
    }(expr)
  }

  /** Applies the function to the I/O constraint and simplifies the resulting constraint */
  def applyAsMatches(p : Passes, f : Expr => Expr) = {
    f(p.asConstraint) match {
      case Equals(newOut, MatchExpr(newIn, newCases)) =>
        val filtered = newCases flatMap {
          case MatchCase(p, g, `newOut`) => None
          case other => Some(other)
        }
        Passes(newIn, newOut, filtered)
      case other =>
        other
    }
  }

  /** Normalizes the expression expr */
  def normalizeExpression(expr: Expr) : Expr = {
    def rec(e: Expr): Option[Expr] = e match {
      case TupleSelect(Let(id, v, b), ts) =>
        Some(Let(id, v, tupleSelect(b, ts, true)))

      case TupleSelect(LetTuple(ids, v, b), ts) =>
        Some(letTuple(ids, v, tupleSelect(b, ts, true)))

      case CaseClassSelector(cct, cc: CaseClass, id) =>
        Some(caseClassSelector(cct, cc, id).copiedFrom(e))

      case IfExpr(c, thenn, elze) if (thenn == elze) && isPurelyFunctional(c) =>
        Some(thenn)

      case IfExpr(c, BooleanLiteral(true), BooleanLiteral(false)) =>
        Some(c)

      case IfExpr(Not(c), thenn, elze) =>
        Some(IfExpr(c, elze, thenn).copiedFrom(e))

      case IfExpr(c, BooleanLiteral(false), BooleanLiteral(true)) =>
        Some(Not(c).copiedFrom(e))

      case FunctionInvocation(tfd, List(IfExpr(c, thenn, elze))) =>
        Some(IfExpr(c, FunctionInvocation(tfd, List(thenn)), FunctionInvocation(tfd, List(elze))).copiedFrom(e))

      case _ =>
        None
    }

    fixpoint(postMap(rec))(expr)
  }

  private val typedIds: scala.collection.mutable.Map[TypeTree, List[Identifier]] =
    scala.collection.mutable.Map.empty.withDefaultValue(List.empty)

  /** Normalizes identifiers in an expression to enable some notion of structural
    * equality between expressions on which usual equality doesn't make sense
    * (i.e. closures).
    *
    * This function relies on the static map `typedIds` to ensure identical
    * structures and must therefore be synchronized.
    */
  def normalizeStructure(expr: Expr): (Expr, Map[Identifier, Identifier]) = synchronized {
    val allVars : Seq[Identifier] = fold[Seq[Identifier]] {
      (expr, idSeqs) => idSeqs.foldLeft(expr match {
        case Lambda(args, _) => args.map(_.id)
        case Forall(args, _) => args.map(_.id)
        case LetDef(fds, _) => fds.flatMap(_.paramIds)
        case Let(i, _, _) => Seq(i)
        case MatchExpr(_, cses) => cses.flatMap(_.pattern.binders)
        case Passes(_, _, cses) => cses.flatMap(_.pattern.binders)
        case Variable(id) => Seq(id)
        case _ => Seq.empty[Identifier]
      })((acc, seq) => acc ++ seq)
    } (expr).distinct

    val grouped : Map[TypeTree, Seq[Identifier]] = allVars.groupBy(_.getType)
    val subst = grouped.foldLeft(Map.empty[Identifier, Identifier]) { case (subst, (tpe, ids)) =>
      val currentVars = typedIds(tpe)

      val freshCount = ids.size - currentVars.size
      val typedVars = if (freshCount > 0) {
        val allIds = currentVars ++ List.range(0, freshCount).map(_ => FreshIdentifier("x", tpe, true))
        typedIds += tpe -> allIds
        allIds
      } else {
        currentVars
      }

      subst ++ (ids zip typedVars)
    }

    val normalized = postMap {
      case Lambda(args, body) => Some(Lambda(args.map(vd => vd.copy(id = subst(vd.id))), body))
      case Forall(args, body) => Some(Forall(args.map(vd => vd.copy(id = subst(vd.id))), body))
      case Let(i, e, b)       => Some(Let(subst(i), e, b))
      case m@MatchExpr(scrut, cses) => Some(MatchExpr(scrut, cses.map { cse =>
        cse.copy(pattern = replacePatternBinders(cse.pattern, subst))
      }).copiedFrom(m))
      case Passes(in, out, cses) => Some(Passes(in, out, cses.map { cse =>
        cse.copy(pattern = replacePatternBinders(cse.pattern, subst))
      }))
      case Variable(id) => Some(Variable(subst(id)))
      case _ => None
    } (expr)

    (normalized, subst)
  }

  /** Returns '''true''' if the formula is Ground,
    * which means that it does not contain any variable ([[purescala.ExprOps#variablesOf]] e is empty)
    * and [[purescala.ExprOps#isDeterministic isDeterministic]]
    */
  def isGround(e: Expr): Boolean = {
    variablesOf(e).isEmpty && isDeterministic(e)
  }

  /** Returns '''true''' if the formula is simple,
    * which means that it requires no special encoding for an
    * unrolling solver. See implementation for what this means exactly.
    */
  def isSimple(e: Expr): Boolean = !exists {
    case (_: Choose) | (_: Hole) |
         (_: Assert) | (_: Ensuring) |
         (_: Forall) | (_: Lambda) | (_: FiniteLambda) |
         (_: FunctionInvocation) | (_: Application) => true
    case _ => false
  } (e)

  /** Returns a function which can simplify all ground expressions which appear in a program context.
    */
  def evalGround(ctx: LeonContext, program: Program): Expr => Expr = {
    import evaluators._

    val eval = new DefaultEvaluator(ctx, program)

    def rec(e: Expr): Option[Expr] = e match {
      case l: Terminal => None
      case e if isGround(e) => eval.eval(e).result // returns None if eval fails
      case _ => None
    }

    preMap(rec)
  }

  /** Simplifies let expressions
    *
    *  - removes lets when expression never occurs
    *  - simplifies when expressions occurs exactly once
    *  - expands when expression is just a variable.
    *
    * @note the code is simple but far from optimal (many traversals...)
    */
  def simplifyLets(expr: Expr) : Expr = {

    def freeComputable(e: Expr) = e match {
      case TupleSelect(Variable(_), _) => true
      case CaseClassSelector(_, Variable(_), _) => true
      case FiniteSet(els, _) => els.isEmpty
      case FiniteMap(els, _, _) => els.isEmpty
      case _: Terminal => true
      case _ => false
    }

    def simplerLet(t: Expr): Option[Expr] = t match {

      case Let(i, e, b) if freeComputable(e) && isPurelyFunctional(e) =>
        // computation is very quick and code easy to read, always inline
        Some(replaceFromIDs(Map(i -> e), b))

      case Let(i,e,b) if isPurelyFunctional(e) =>
        // computation may be slow, or code complicated to read, inline at most once
        val occurrences = count {
          case Variable(`i`) => 1
          case _ => 0
        }(b)

        if(occurrences == 0) {
          Some(b)
        } else if(occurrences == 1) {
          Some(replaceFromIDs(Map(i -> e), b))
        } else {
          None
        }

      /*case LetPattern(patt, e0, body) if isPurelyFunctional(e0) =>
        // Will turn the match-expression with a single case into a list of lets.
        // @mk it is not clear at all that we want this

        // Just extra safety...
        val e = (e0.getType, patt) match {
          case (_:AbstractClassType, CaseClassPattern(_, cct, _)) =>
            asInstOf(e0, cct)
          case (at: AbstractClassType, InstanceOfPattern(_, ct)) if at != ct =>
            asInstOf(e0, ct)
          case _ =>
            e0
        }

        // Sort lets in dependency order
        val lets = mapForPattern(e, patt).toSeq.sortWith {
          case ((id1, e1), (id2, e2)) => exists{ _ == Variable(id1) }(e2)
        }

        Some(lets.foldRight(body) {
          case ((id, e), bd) => Let(id, e, bd)
        })*/

      case MatchExpr(scrut, cases) =>
        // Merge match within match
        var changed = false
        val newCases = cases map {
          case MatchCase(patt, g, LetPattern(innerPatt, Variable(id), body)) if patt.binders contains id =>
            changed = true
            val newPatt = PatternOps.preMap {
              case WildcardPattern(Some(`id`)) => Some(innerPatt.withBinder(id))
              case _ => None
            }(patt)
            MatchCase(newPatt, g, body)
          case other =>
            other
        }
        if(changed) Some(MatchExpr(scrut, newCases)) else None

      case _ => None
    }

    postMap(simplerLet, applyRec = true)(expr)
  }

  /** Fully expands all let expressions. */
  def expandLets(expr: Expr) : Expr = {
    def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match {
      case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s)
      case l @ Let(i,e,b) => rec(b, s + (i -> rec(e, s)))
      case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s))
      case m @ MatchExpr(scrut, cses) => matchExpr(rec(scrut, s), cses.map(inCase(_, s))).setPos(m)
      case p @ Passes(in, out, cses) => Passes(rec(in, s), rec(out,s), cses.map(inCase(_, s))).setPos(p)
      case n @ Deconstructor(args, recons) =>
        var change = false
        val rargs = args.map(a => {
          val ra = rec(a, s)
          if(ra != a) {
            change = true
            ra
          } else {
            a
          }
        })
        if(change)
          recons(rargs)
        else
          n
      case unhandled => throw LeonFatalError("Unhandled case in expandLets: " + unhandled)
    }

    def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = {
      import cse._
      MatchCase(pattern, optGuard map { rec(_, s) }, rec(rhs,s))
    }

    rec(expr, Map.empty)
  }

  /** Lifts lets to top level.
    *
    * Does not push any used variable out of scope.
    * Assumes no match expressions (i.e. matchToIfThenElse has been called on e)
    */
  def liftLets(e: Expr): Expr = {

    type C = Seq[(Identifier, Expr)]

    def combiner(e: Expr, defs: Seq[C]): C = (e, defs) match {
      case (Let(i, ex, b), Seq(inDef, inBody)) =>
        inDef ++ ((i, ex) +: inBody)
      case _ =>
        defs.flatten
    }

    def noLet(e: Expr, defs: C) = e match {
      case Let(_, _, b) => (b, defs)
      case _ => (e, defs)
    }

    val (bd, defs) = genericTransform[C](noTransformer, noLet, combiner)(Seq())(e)

    defs.foldRight(bd){ case ((id, e), body) => Let(id, e, body) }
  }

  /** Generates substitutions necessary to transform scrutinee to equivalent
    * specialized cases
    *
    * {{{
    *    e match {
    *     case CaseClass((a, 42), c) => expr
    *    }
    * }}}
    * will return, for the first pattern:
    * {{{
    *    Map(
    *     e -> CaseClass(t, c),
    *     t -> (a, b2),
    *     b2 -> 42,
    *    )
    * }}}
    *
    * @note UNUSED, is not maintained
    */
  def patternSubstitutions(in: Expr, pattern: Pattern): Seq[(Expr, Expr)] ={
    def rec(in: Expr, pattern: Pattern): Seq[(Expr, Expr)] = pattern match {
      case InstanceOfPattern(ob, cct: CaseClassType) =>
        val pt = CaseClassPattern(ob, cct, cct.fields.map { f =>
          WildcardPattern(Some(FreshIdentifier(f.id.name, f.getType)))
        })
        rec(in, pt)

      case TuplePattern(_, subps) =>
        val TupleType(subts) = in.getType
        val subExprs = (subps zip subts).zipWithIndex map {
          case ((p, t), index) => p.binder.map(_.toVariable).getOrElse(tupleSelect(in, index+1, subps.size))
        }

        // Special case to get rid of (a,b) match { case (c,d) => .. }
        val subst0 = in match {
          case Tuple(ts) =>
            ts zip subExprs
          case _ =>
            Seq(in -> tupleWrap(subExprs))
        }

        subst0 ++ ((subExprs zip subps) flatMap {
          case (e, p) => recBinder(e, p)
        })

      case CaseClassPattern(_, cct, subps) =>
        val subExprs = (subps zip cct.classDef.fields) map {
          case (p, f) => p.binder.map(_.toVariable).getOrElse(caseClassSelector(cct, in, f.id))
        }

        // Special case to get rid of Cons(a,b) match { case Cons(c,d) => .. }
        val subst0 = in match {
          case CaseClass(`cct`, args) =>
            args zip subExprs
          case _ =>
            Seq(in -> CaseClass(cct, subExprs))
        }

        subst0 ++ ((subExprs zip subps) flatMap {
          case (e, p) => recBinder(e, p)
        })

      case LiteralPattern(_, v) =>
        Seq(in -> v)

      case _ =>
        Seq()
    }

    def recBinder(in: Expr, pattern: Pattern): Seq[(Expr, Expr)] = {
      (pattern, pattern.binder) match {
        case (_: WildcardPattern, Some(b)) =>
          Seq(in -> b.toVariable)
        case (p, Some(b)) =>
          val bv = b.toVariable
          Seq(in -> bv) ++ rec(bv, pattern)
        case _ =>
          rec(in, pattern)
      }
    }

    recBinder(in, pattern).filter{ case (a, b) => a != b }
  }

  /** Recursively transforms a pattern on a boolean formula expressing the conditions for the input expression, possibly including name binders
    *
    * For example, the following pattern on the input `i`
    * {{{
    * case m @ MyCaseClass(t: B, (_, 7)) =>
    * }}}
    * will yield the following condition before simplification (to give some flavour)
    *
    * {{{and(IsInstanceOf(MyCaseClass, i), and(Equals(m, i), InstanceOfClass(B, i.t), equals(i.k.arity, 2), equals(i.k._2, 7))) }}}
    *
    * Pretty-printed, this would be:
    * {{{
    * i.instanceOf[MyCaseClass] && m == i && i.t.instanceOf[B] && i.k.instanceOf[Tuple2] && i.k._2 == 7
    * }}}
    *
    * @see [[purescala.Expressions.Pattern]]
    */
  def conditionForPattern(in: Expr, pattern: Pattern, includeBinders: Boolean = false): Path = {
    def bind(ob: Option[Identifier], to: Expr): Path = {
      if (!includeBinders) {
        Path.empty
      } else {
        ob.map(id => Path.empty withBinding (id -> to)).getOrElse(Path.empty)
      }
    }

    def rec(in: Expr, pattern: Pattern): Path = {
      pattern match {
        case WildcardPattern(ob) =>
          bind(ob, in)

        case InstanceOfPattern(ob, ct) =>
          if (ct.parent.isEmpty) {
            bind(ob, in)
          } else {
            Path(IsInstanceOf(in, ct)) merge bind(ob, in)
          }

        case CaseClassPattern(ob, cct, subps) =>
          assert(cct.classDef.fields.size == subps.size)
          val pairs = cct.classDef.fields.map(_.id).toList zip subps.toList
          val subTests = pairs.map(p => rec(caseClassSelector(cct, in, p._1), p._2))
          val together = subTests.foldLeft(bind(ob, in))(_ merge _)
          Path(IsInstanceOf(in, cct)) merge together

        case TuplePattern(ob, subps) =>
          val TupleType(tpes) = in.getType
          assert(tpes.size == subps.size)
          val subTests = subps.zipWithIndex.map{case (p, i) => rec(tupleSelect(in, i+1, subps.size), p)}
          subTests.foldLeft(bind(ob, in))(_ merge _)

        case up @ UnapplyPattern(ob, fd, subps) =>
          def someCase(e: Expr) = {
            // In the case where unapply returns a Some, it is enough that the subpatterns match
            val subTests = unwrapTuple(e, subps.size) zip subps map { case (ex, p) => rec(ex, p) }
            subTests.foldLeft(Path.empty)(_ merge _).toClause
          }
          Path(up.patternMatch(in, BooleanLiteral(false), someCase).setPos(in)) merge bind(ob, in)

        case LiteralPattern(ob, lit) =>
          Path(Equals(in, lit)) merge bind(ob, in)
      }
    }

    rec(in, pattern)
  }

  /** Converts the pattern applied to an input to a map between identifiers and expressions */
  def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = {
    def bindIn(id: Option[Identifier], cast: Option[ClassType] = None): Map[Identifier,Expr] = id match {
      case None => Map()
      case Some(id) => Map(id -> cast.map(asInstOf(in, _)).getOrElse(in))
    }
    pattern match {
      case CaseClassPattern(b, cct, subps) =>
        assert(cct.fields.size == subps.size)
        val pairs = cct.classDef.fields.map(_.id).toList zip subps.toList
        val subMaps = pairs.map(p => mapForPattern(caseClassSelector(cct, asInstOf(in, cct), p._1), p._2))
        val together = subMaps.flatten.toMap
        bindIn(b, Some(cct)) ++ together

      case TuplePattern(b, subps) =>
        val TupleType(tpes) = in.getType
        assert(tpes.size == subps.size)

        val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(tupleSelect(in, i+1, subps.size), p)}
        val map = maps.flatten.toMap
        bindIn(b) ++ map

      case up@UnapplyPattern(b, _, subps) =>
        bindIn(b) ++ unwrapTuple(up.getUnsafe(in), subps.size).zip(subps).flatMap {
          case (e, p) => mapForPattern(e, p)
        }.toMap

      case InstanceOfPattern(b, ct) =>
        bindIn(b, Some(ct))

      case other =>
        bindIn(other.binder)
    }
  }

  /** Rewrites all pattern-matching expressions into if-then-else expressions
    * Introduces additional error conditions. Does not introduce additional variables.
    */
  def matchToIfThenElse(expr: Expr): Expr = {

    def rewritePM(e: Expr): Option[Expr] = e match {
      case m @ MatchExpr(scrut, cases) =>
        // println("Rewriting the following PM: " + e)

        val condsAndRhs = for (cse <- cases) yield {
          val map = mapForPattern(scrut, cse.pattern)
          val patCond = conditionForPattern(scrut, cse.pattern, includeBinders = false)
          val realCond = cse.optGuard match {
            case Some(g) => patCond withCond replaceFromIDs(map, g)
            case None => patCond
          }
          val newRhs = replaceFromIDs(map, cse.rhs)
          (realCond.toClause, newRhs)
        }

        val bigIte = condsAndRhs.foldRight[Expr](Error(m.getType, "Match is non-exhaustive").copiedFrom(m))((p1, ex) => {
          if(p1._1 == BooleanLiteral(true)) {
            p1._2
          } else {
            IfExpr(p1._1, p1._2, ex)
          }
        })

        Some(bigIte)

      case p: Passes =>
        // This introduces a MatchExpr
        Some(p.asConstraint)

      case _ => None
    }

    preMap(rewritePM)(expr)
  }

  /** For each case in the [[purescala.Expressions.MatchExpr MatchExpr]], concatenates the path condition with the newly induced conditions.
   *
   *  Each case holds the conditions on other previous cases as negative.
   *
    * @see [[purescala.ExprOps#conditionForPattern conditionForPattern]]
    * @see [[purescala.ExprOps#mapForPattern mapForPattern]]
    */
  def matchExprCaseConditions(m: MatchExpr, path: Path) : Seq[Path] = {
    val MatchExpr(scrut, cases) = m
    var pcSoFar = path

    for (c <- cases) yield {
      val g = c.optGuard getOrElse BooleanLiteral(true)
      val cond = conditionForPattern(scrut, c.pattern, includeBinders = true)
      val localCond = pcSoFar merge (cond withCond g)

      // These contain no binders defined in this MatchCase
      val condSafe = conditionForPattern(scrut, c.pattern)
      val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern), g)
      pcSoFar = pcSoFar merge (condSafe withCond gSafe).negate

      localCond
    }
  }

  /** Condition to pass this match case, expressed w.r.t scrut only */
  def matchCaseCondition(scrut: Expr, c: MatchCase): Path = {

    val patternC = conditionForPattern(scrut, c.pattern, includeBinders = false)

    c.optGuard match {
      case Some(g) =>
        // guard might refer to binders
        val map  = mapForPattern(scrut, c.pattern)
        patternC withCond replaceFromIDs(map, g)

      case None =>
        patternC
    }
  }

  /** Returns the path conditions for each of the case passes.
    *
    * Each case holds the conditions on other previous cases as negative.
    */
  def passesPathConditions(p: Passes, pathCond: Path) : Seq[Path] = {
    matchExprCaseConditions(MatchExpr(p.in, p.cases), pathCond)
  }

  /**
   * Returns a pattern from an expression, and a guard if any.
   */
  def expressionToPattern(e : Expr) : (Pattern, Expr) = {
    var guard : Expr = BooleanLiteral(true)
    def rec(e : Expr) : Pattern = e match {
      case CaseClass(cct, fields) => CaseClassPattern(None, cct, fields map rec)
      case Tuple(subs) => TuplePattern(None, subs map rec)
      case l : Literal[_] => LiteralPattern(None, l)
      case Variable(i) => WildcardPattern(Some(i))
      case other =>
        val id = FreshIdentifier("other", other.getType, true)
        guard = and(guard, Equals(Variable(id), other))
        WildcardPattern(Some(id))
    }
    (rec(e), guard)
  }

  /**
    * Takes a pattern and returns an expression that corresponds to it.
    * Also returns a sequence of `Identifier -> Expr` pairs which
    * represent the bindings for intermediate binders (from outermost to innermost)
    */
  def patternToExpression(p: Pattern, expectedType: TypeTree): (Expr, Seq[(Identifier, Expr)]) = {
    def fresh(tp : TypeTree) = FreshIdentifier("binder", tp, true)
    var ieMap = Seq[(Identifier, Expr)]()
    def addBinding(b : Option[Identifier], e : Expr) = b foreach { ieMap +:= (_, e) }
    def rec(p : Pattern, expectedType : TypeTree) : Expr = p match {
      case WildcardPattern(b) =>
        Variable(b getOrElse fresh(expectedType))
      case LiteralPattern(b, lit) =>
        addBinding(b,lit)
        lit
      case InstanceOfPattern(b, ct) => ct match {
        case act: AbstractClassType =>
          // @mk: This seems dubious, in the sense that it just binds the expression
          //      of the AbstractClassType to an id instead of going case-wise.
          //      I think this is sufficient for the use of this function though:
          //      it is only used to generate examples so it is followed by a type-aware enumerator.
          val e = Variable(fresh(act))
          addBinding(b, e)
          e

        case cct: CaseClassType =>
          val fields = cct.fields map { f => Variable(fresh(f.getType)) }
          val e = CaseClass(cct, fields)
          addBinding(b, e)
          e
      }
      case TuplePattern(b, subs) =>
        val TupleType(subTypes) = expectedType
        val e = Tuple(subs zip subTypes map {
          case (sub, subType) => rec(sub, subType)
        })
        addBinding(b, e)
        e
      case CaseClassPattern(b, cct, subs) =>
        val e = CaseClass(cct, subs zip cct.fieldsTypes map { case (sub,tp) => rec(sub,tp) })
        addBinding(b, e)
        e
      case up@UnapplyPattern(b, fd, subs) =>
        // TODO: Support this
        NoTree(expectedType)
    }

    (rec(p, expectedType), ieMap)

  }


  /** Rewrites all map accesses with additional error conditions. */
  def mapGetWithChecks(expr: Expr): Expr = {
    postMap({
      case mg @ MapApply(m,k) =>
        val ida = MapIsDefinedAt(m, k)
        Some(IfExpr(ida, mg, Error(mg.getType, "Key not found for map access").copiedFrom(mg)).copiedFrom(mg))

      case _=>
        None
    })(expr)
  }

  /** Returns simplest value of a given type */
  def simplestValue(tpe: TypeTree) : Expr = tpe match {
    case StringType                 => StringLiteral("")
    case Int32Type                  => IntLiteral(0)
    case RealType               	  => FractionalLiteral(0, 1)
    case IntegerType                => InfiniteIntegerLiteral(0)
    case CharType                   => CharLiteral('a')
    case BooleanType                => BooleanLiteral(false)
    case UnitType                   => UnitLiteral()
    case SetType(baseType)          => FiniteSet(Set(), baseType)
    case BagType(baseType)          => FiniteBag(Map(), baseType)
    case MapType(fromType, toType)  => FiniteMap(Map(), fromType, toType)
    case TupleType(tpes)            => Tuple(tpes.map(simplestValue))
    case ArrayType(tpe)             => EmptyArray(tpe)

    case act @ AbstractClassType(acd, tpe) =>
      val ccDesc = act.knownCCDescendants

      def isRecursive(cct: CaseClassType): Boolean = {
        cct.fieldsTypes.exists{
          case AbstractClassType(fieldAcd, _) => acd.root == fieldAcd.root
          case CaseClassType(fieldCcd, _) => acd.root == fieldCcd.root
          case _ => false
        }
      }

      val nonRecChildren = ccDesc.filterNot(isRecursive).sortBy(_.fields.size)

      nonRecChildren.headOption match {
        case Some(cct) =>
          simplestValue(cct)

        case None =>
          throw LeonFatalError(act +" does not seem to be well-founded")
      }

    case cct: CaseClassType =>
      CaseClass(cct, cct.fieldsTypes.map(t => simplestValue(t)))

    case tp: TypeParameter =>
      GenericValue(tp, 0)

    case ft @ FunctionType(from, to) =>
      FiniteLambda(Seq.empty, simplestValue(to), ft)

    case _ => throw LeonFatalError("I can't choose simplest value for type " + tpe)
  }

  /* Checks if a given expression is 'real' and does not contain generic
   * values. */
  def isRealExpr(v: Expr): Boolean = {
    !exists {
      case gv: GenericValue => true
      case _ => false
    }(v)
  }

  def valuesOf(tp: TypeTree): Stream[Expr] = {
    import utils.StreamUtils._
    tp match {
      case BooleanType =>
        Stream(BooleanLiteral(false), BooleanLiteral(true))
      case Int32Type =>
        Stream.iterate(0) { prev =>
          if (prev > 0) -prev else -prev + 1
        } map IntLiteral
      case IntegerType =>
        Stream.iterate(BigInt(0)) { prev =>
          if (prev > 0) -prev else -prev + 1
        } map InfiniteIntegerLiteral
      case UnitType =>
        Stream(UnitLiteral())
      case tp: TypeParameter =>
        Stream.from(0) map (GenericValue(tp, _))
      case TupleType(stps) =>
        cartesianProduct(stps map (tp => valuesOf(tp))) map Tuple
      case SetType(base) =>
        def elems = valuesOf(base)
        elems.scanLeft(Stream(FiniteSet(Set(), base): Expr)){ (prev, curr) =>
          prev flatMap {
            case fs@FiniteSet(elems, tp) =>
              Stream(fs, FiniteSet(elems + curr, tp))
          }
        }.flatten // FIXME Need cp οr is this fine?
      case cct: CaseClassType =>
        cartesianProduct(cct.fieldsTypes map valuesOf) map (CaseClass(cct, _))
      case act: AbstractClassType =>
        interleave(act.knownCCDescendants.map(cct => valuesOf(cct)))
    }
  }


  /** Hoists all IfExpr at top level.
    *
    * Guarantees that all IfExpr will be at the top level and as soon as you
    * encounter a non-IfExpr, then no more IfExpr can be found in the
    * sub-expressions
    *
    * Assumes no match expressions
    */
  def hoistIte(expr: Expr): Expr = {
    def transform(expr: Expr): Option[Expr] = expr match {
      case IfExpr(c, t, e) => None

      case nop@Deconstructor(ts, op) => {
        val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false }
        if(iteIndex == -1) None else {
          val (beforeIte, startIte) = ts.splitAt(iteIndex)
          val afterIte = startIte.tail
          val IfExpr(c, t, e) = startIte.head
          Some(IfExpr(c,
            op(beforeIte ++ Seq(t) ++ afterIte).copiedFrom(nop),
            op(beforeIte ++ Seq(e) ++ afterIte).copiedFrom(nop)
          ))
        }
      }
      case _ => None
    }

    postMap(transform, applyRec = true)(expr)
  }

  def simplifyPaths(sf: SolverFactory[Solver], initC: List[Expr] = Nil): Expr => Expr = {
    new SimplifierWithPaths(sf, initC).transform
  }

  trait Traverser[T] {
    def traverse(e: Expr): T
  }

  object CollectorWithPaths {
    def apply[T](p: PartialFunction[Expr,T]): CollectorWithPaths[(T, Path)] = new CollectorWithPaths[(T, Path)] {
      def collect(e: Expr, path: Path): Option[(T, Path)] = if (!p.isDefinedAt(e)) None else {
        Some(p(e) -> path)
      }
    }
  }

  trait CollectorWithPaths[T] extends TransformerWithPC with Traverser[Seq[T]] {
    protected val initPath: Seq[Expr] = Nil

    private var results: Seq[T] = Nil

    def collect(e: Expr, path: Path): Option[T]

    def walk(e: Expr, path: Path): Option[Expr] = None

    override def rec(e: Expr, path: Path) = {
      collect(e, path).foreach { results :+= _ }
      walk(e, path) match {
        case Some(r) => r
        case _ => super.rec(e, path)
      }
    }

    def traverse(funDef: FunDef): Seq[T] = traverse(funDef.fullBody)

    def traverse(e: Expr): Seq[T] = traverse(e, initPath)

    def traverse(e: Expr, init: Expr): Seq[T] = traverse(e, Seq(init))

    def traverse(e: Expr, init: Seq[Expr]): Seq[T] = {
      results = Nil
      rec(e, Path(init))
      results
    }
  }

  def collectWithPC[T](f: PartialFunction[Expr, T])(expr: Expr): Seq[(T, Path)] = {
    CollectorWithPaths(f).traverse(expr)
  }

  override def formulaSize(e: Expr): Int = e match {
    case ml: MatchExpr =>
      super.formulaSize(e) + ml.cases.map(cs => PatternOps.formulaSize(cs.pattern)).sum
    case _ =>
      super.formulaSize(e)
  }

  /** Returns true if the expression is deterministic /
    * does not contain any [[purescala.Expressions.Choose Choose]]
    * or [[purescala.Expressions.Hole Hole]] or [[purescala.Expressions.WithOracle WithOracle]]
    */
  def isDeterministic(e: Expr): Boolean = {
    exists {
      case _ : Choose | _: Hole | _: WithOracle => false
      case _ => true
    }(e)
  }

  /** Returns if this expression behaves as a purely functional construct,
    * i.e. always returns the same value (for the same environment) and has no side-effects
    */
  def isPurelyFunctional(e: Expr): Boolean = {
    exists {
      case _ : Error | _ : Choose | _: Hole | _: WithOracle => false
      case _ => true
    }(e)
  }

  /** Returns the value for an identifier given a model. */
  def valuateWithModel(model: Model)(id: Identifier): Expr = {
    model.getOrElse(id, simplestValue(id.getType))
  }

  /** Substitute (free) variables in an expression with values form a model.
    *
    * Complete with simplest values in case of incomplete model.
    */
  def valuateWithModelIn(expr: Expr, vars: Set[Identifier], model: Model): Expr = {
    val valuator = valuateWithModel(model) _
    replace(vars.map(id => Variable(id) -> valuator(id)).toMap, expr)
  }
  
  /** Simple, local optimization on string */
  def simplifyString(expr: Expr): Expr = {
    def simplify0(expr: Expr): Expr = (expr match {
      case StringConcat(StringLiteral(""), b) => b
      case StringConcat(b, StringLiteral("")) => b
      case StringConcat(StringLiteral(a), StringLiteral(b)) => StringLiteral(a + b)
      case StringLength(StringLiteral(a)) => InfiniteIntegerLiteral(a.length)
      case SubString(StringLiteral(a), InfiniteIntegerLiteral(start), InfiniteIntegerLiteral(end)) => StringLiteral(a.substring(start.toInt, end.toInt))
      case _ => expr
    }).copiedFrom(expr)
    simplify0(expr)
    fixpoint(simplePostTransform(simplify0))(expr)
  }

  /** Simple, local simplification on arithmetic
    *
    * You should not assume anything smarter than some constant folding and
    * simple cancellation. To avoid infinite cycle we only apply simplification
    * that reduce the size of the tree. The only guarantee from this function is
    * to not augment the size of the expression and to be sound.
    */
  def simplifyArithmetic(expr: Expr): Expr = {
    def simplify0(expr: Expr): Expr = (expr match {
      case Plus(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => InfiniteIntegerLiteral(i1 + i2)
      case Plus(InfiniteIntegerLiteral(zero), e) if zero == BigInt(0) => e
      case Plus(e, InfiniteIntegerLiteral(zero)) if zero == BigInt(0) => e
      case Plus(e1, UMinus(e2)) => Minus(e1, e2)
      case Plus(Plus(e, InfiniteIntegerLiteral(i1)), InfiniteIntegerLiteral(i2)) => Plus(e, InfiniteIntegerLiteral(i1+i2))
      case Plus(Plus(InfiniteIntegerLiteral(i1), e), InfiniteIntegerLiteral(i2)) => Plus(InfiniteIntegerLiteral(i1+i2), e)

      case Minus(e, InfiniteIntegerLiteral(zero)) if zero == BigInt(0) => e
      case Minus(InfiniteIntegerLiteral(zero), e) if zero == BigInt(0) => UMinus(e)
      case Minus(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => InfiniteIntegerLiteral(i1 - i2)
      case Minus(e1, UMinus(e2)) => Plus(e1, e2)
      case Minus(e1, Minus(UMinus(e2), e3)) => Plus(e1, Plus(e2, e3))

      case UMinus(InfiniteIntegerLiteral(x)) => InfiniteIntegerLiteral(-x)
      case UMinus(UMinus(x)) => x
      case UMinus(Plus(UMinus(e1), e2)) => Plus(e1, UMinus(e2))
      case UMinus(Minus(e1, e2)) => Minus(e2, e1)

      case Times(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => InfiniteIntegerLiteral(i1 * i2)
      case Times(InfiniteIntegerLiteral(one), e) if one == BigInt(1) => e
      case Times(InfiniteIntegerLiteral(mone), e) if mone == BigInt(-1) => UMinus(e)
      case Times(e, InfiniteIntegerLiteral(one)) if one == BigInt(1) => e
      case Times(InfiniteIntegerLiteral(zero), _) if zero == BigInt(0) => InfiniteIntegerLiteral(0)
      case Times(_, InfiniteIntegerLiteral(zero)) if zero == BigInt(0) => InfiniteIntegerLiteral(0)
      case Times(InfiniteIntegerLiteral(i1), Times(InfiniteIntegerLiteral(i2), t)) => Times(InfiniteIntegerLiteral(i1*i2), t)
      case Times(InfiniteIntegerLiteral(i1), Times(t, InfiniteIntegerLiteral(i2))) => Times(InfiniteIntegerLiteral(i1*i2), t)
      case Times(InfiniteIntegerLiteral(i), UMinus(e)) => Times(InfiniteIntegerLiteral(-i), e)
      case Times(UMinus(e), InfiniteIntegerLiteral(i)) => Times(e, InfiniteIntegerLiteral(-i))
      case Times(InfiniteIntegerLiteral(i1), Division(e, InfiniteIntegerLiteral(i2))) if i2 != BigInt(0) && i1 % i2 == BigInt(0) => Times(InfiniteIntegerLiteral(i1/i2), e)

      case Division(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) if i2 != BigInt(0) => InfiniteIntegerLiteral(i1 / i2)
      case Division(e, InfiniteIntegerLiteral(one)) if one == BigInt(1) => e

      //here we put more expensive rules
      //btw, I know those are not the most general rules, but they lead to good optimizations :)
      case Plus(UMinus(Plus(e1, e2)), e3) if e1 == e3 => UMinus(e2)
      case Plus(UMinus(Plus(e1, e2)), e3) if e2 == e3 => UMinus(e1)
      case Minus(e1, e2) if e1 == e2 => InfiniteIntegerLiteral(0)
      case Minus(Plus(e1, e2), Plus(e3, e4)) if e1 == e4 && e2 == e3 => InfiniteIntegerLiteral(0)
      case Minus(Plus(e1, e2), Plus(Plus(e3, e4), e5)) if e1 == e4 && e2 == e3 => UMinus(e5)

      case StringConcat(StringLiteral(""), a) => a
      case StringConcat(a, StringLiteral("")) => a
      case StringConcat(StringLiteral(a), StringLiteral(b)) => StringLiteral(a+b)
      case StringConcat(StringLiteral(a), StringConcat(StringLiteral(b), c)) => StringConcat(StringLiteral(a+b), c)
      case StringConcat(StringConcat(c, StringLiteral(a)), StringLiteral(b)) => StringConcat(c, StringLiteral(a+b))
      case StringConcat(a, StringConcat(b, c)) => StringConcat(StringConcat(a, b), c)
      //default
      case e => e
    }).copiedFrom(expr)

    fixpoint(simplePostTransform(simplify0))(expr)
  }

  /**
   * Some helper methods for FractionalLiterals
   */
  def normalizeFraction(fl: FractionalLiteral) = {
    val FractionalLiteral(num, denom) = fl
    val modNum = if (num < 0) -num else num
    val modDenom = if (denom < 0) -denom else denom
    val divisor = modNum.gcd(modDenom)
    val simpNum = num / divisor
    val simpDenom = denom / divisor
    if (simpDenom < 0)
      FractionalLiteral(-simpNum, -simpDenom)
    else
      FractionalLiteral(simpNum, simpDenom)
  }

  val realzero = FractionalLiteral(0, 1)
  def floor(fl: FractionalLiteral): FractionalLiteral = {
    val FractionalLiteral(n, d) = normalizeFraction(fl)
    if (d == 0) throw new IllegalStateException("denominator zero")
    if (n == 0) realzero
    else if (n > 0) {
      //perform integer division
      FractionalLiteral(n / d, 1)
    } else {
      //here the number is negative
      if (n % d == 0)
        FractionalLiteral(n / d, 1)
      else {
        //perform integer division and subtract 1
        FractionalLiteral(n / d - 1, 1)
      }
    }
  }

  /** Checks whether a predicate is inductive on a certain identifier.
    *
    * isInductive(foo(a, b), a) where a: List will check whether
    *    foo(Nil, b) and
    *    foo(t, b) => foo(Cons(h,t), b)
    */
  def isInductiveOn(sf: SolverFactory[Solver])(path: Path, on: Identifier): Boolean = on match {
    case IsTyped(origId, AbstractClassType(cd, tps)) =>

      val toCheck = cd.knownDescendants.collect {
        case ccd: CaseClassDef =>
          val cct = CaseClassType(ccd, tps)

          val isType = IsInstanceOf(Variable(on), cct)

          val recSelectors = (cct.classDef.fields zip cct.fieldsTypes).collect { 
            case (vd, tpe) if tpe == on.getType => vd.id
          }

          if (recSelectors.isEmpty) {
            Seq()
          } else {
            val v = Variable(on)

            recSelectors.map { s =>
              and(path and isType, not(replace(Map(v -> caseClassSelector(cct, v, s)), path.toClause)))
            }
          }
      }.flatten

      val solver = SimpleSolverAPI(sf)

      toCheck.forall { cond =>
        solver.solveSAT(cond)._1 match {
          case Some(false) =>
            true
          case Some(true)  =>
            false
          case None =>
            // Should we be optimistic here?
            false
        }
      }
    case _ =>
      false
  }
  
  type Apriori = Map[Identifier, Identifier]
  
  /** Checks whether two expressions can be homomorphic and returns the corresponding mapping */
  def canBeHomomorphic(t1: Expr, t2: Expr): Option[Map[Identifier, Identifier]] = {
    val freeT1Variables = ExprOps.variablesOf(t1)
    val freeT2Variables = ExprOps.variablesOf(t2)
    
    def mergeContexts(
        a: Option[Apriori],
        b: Apriori => Option[Apriori]):
        Option[Apriori] = a.flatMap(b)

    object Same {
      def unapply(tt: (Expr, Expr)): Option[(Expr, Expr)] = {
        if (tt._1.getClass == tt._2.getClass) {
          Some(tt)
        } else {
          None
        }
      }
    }
    implicit class AugmentedContext(c: Option[Apriori]) {
      def &&(other: Apriori => Option[Apriori]): Option[Apriori] = mergeContexts(c, other)
      def --(other: Seq[Identifier]) =
        c.map(_ -- other)
    }
    implicit class AugmentedBoolean(c: Boolean) {
      def &&(other:  => Option[Apriori]) = if(c) other else None
    }
    implicit class AugmentedFilter(c: Apriori => Option[Apriori]) {
      def &&(other: Apriori => Option[Apriori]):
        Apriori => Option[Apriori]
      = (m: Apriori) => c(m).flatMap(mp => other(mp))
    }
    implicit class AugmentedSeq[T](c: Seq[T]) {
      def mergeall(p: T => Apriori => Option[Apriori])(apriori: Apriori) =
        (Option(apriori) /: c) {
          case (s, c) => s.flatMap(apriori => p(c)(apriori))
        }
    }
    implicit def noneToContextTaker(c: None.type) = {
      (m: Apriori) => None
    }


    def idHomo(i1: Identifier, i2: Identifier)(apriori: Apriori): Option[Apriori] = {
      if(!(freeT1Variables(i1) || freeT2Variables(i2)) || i1 == i2 || apriori.get(i1) == Some(i2)) Some(Map(i1 -> i2)) else None
    }
    def idOptionHomo(i1: Option[Identifier], i2: Option[Identifier])(apriori: Apriori): Option[Apriori] = {
      (i1.size == i2.size) && (i1 zip i2).headOption.flatMap(i => idHomo(i._1, i._2)(apriori))
    }

    def fdHomo(fd1: FunDef, fd2: FunDef)(apriori: Apriori): Option[Apriori] = {
      if(fd1.params.size == fd2.params.size) {
         val newMap = Map((
           (fd1.id -> fd2.id) +:
           (fd1.paramIds zip fd2.paramIds)): _*)
         Option(newMap) && isHomo(fd1.fullBody, fd2.fullBody)
      } else None
    }

    def isHomo(t1: Expr, t2: Expr)(apriori: Apriori): Option[Apriori] = {
      def casesMatch(cs1 : Seq[MatchCase], cs2 : Seq[MatchCase])(apriori: Apriori) : Option[Apriori] = {
        def patternHomo(p1: Pattern, p2: Pattern)(apriori: Apriori): Option[Apriori] = (p1, p2) match {
          case (InstanceOfPattern(ob1, cd1), InstanceOfPattern(ob2, cd2)) =>
            cd1 == cd2 && idOptionHomo(ob1, ob2)(apriori)

          case (WildcardPattern(ob1), WildcardPattern(ob2)) =>
            idOptionHomo(ob1, ob2)(apriori)

          case (CaseClassPattern(ob1, ccd1, subs1), CaseClassPattern(ob2, ccd2, subs2)) =>
            val m = idOptionHomo(ob1, ob2)(apriori)

            (ccd1 == ccd2 && subs1.size == subs2.size) && m &&
              ((subs1 zip subs2) mergeall { case (p1, p2) => patternHomo(p1, p2) })

          case (UnapplyPattern(ob1, TypedFunDef(fd1, ts1), subs1), UnapplyPattern(ob2, TypedFunDef(fd2, ts2), subs2)) =>
            val m = idOptionHomo(ob1, ob2)(apriori)

            (subs1.size == subs2.size && ts1 == ts2) && m && fdHomo(fd1, fd2) && (
              (subs1 zip subs2) mergeall { case (p1, p2) => patternHomo(p1, p2) })

          case (TuplePattern(ob1, subs1), TuplePattern(ob2, subs2)) =>
            val m = idOptionHomo(ob1, ob2)(apriori)

            (ob1.size == ob2.size && subs1.size == subs2.size) && m && (
              (subs1 zip subs2) mergeall { case (p1, p2) => patternHomo(p1, p2) })

          case (LiteralPattern(ob1, lit1), LiteralPattern(ob2,lit2)) =>
            lit1 == lit2 && idOptionHomo(ob1, ob2)(apriori)

          case _ =>
            None
        }

        (cs1 zip cs2).mergeall {
          case (MatchCase(p1, g1, e1), MatchCase(p2, g2, e2)) =>
            val h = patternHomo(p1, p2) _
            val g: Apriori => Option[Apriori] = (g1, g2) match {
              case (Some(g1), Some(g2)) => isHomo(g1, g2)(_)
              case (None, None) => (m: Apriori) => Some(m)
              case _ => None
            }
            val e = isHomo(e1, e2) _

            h && g && e
        }(apriori)
      }

      val res: Option[Apriori] = (t1, t2) match {
        case (Variable(i1), Variable(i2)) =>
          idHomo(i1, i2)(apriori)

        case (Let(id1, v1, e1), Let(id2, v2, e2)) =>
          
          isHomo(v1, v2)(apriori + (id1 -> id2)) &&
          isHomo(e1, e2)
          
        case (Hole(_, _), Hole(_, _)) =>
          None

        case (LetDef(fds1, e1), LetDef(fds2, e2)) =>
          fds1.size == fds2.size &&
          {
            val zipped = fds1.zip(fds2)
            (zipped mergeall (fds => fdHomo(fds._1, fds._2)))(apriori) &&
            isHomo(e1, e2)
          }

        case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) =>
          cs1.size == cs2.size && casesMatch(cs1,cs2)(apriori) && isHomo(s1, s2)

        case (Passes(in1, out1, cs1), Passes(in2, out2, cs2)) =>
          (cs1.size == cs2.size && casesMatch(cs1,cs2)(apriori)) && isHomo(in1,in2) && isHomo(out1,out2)

        case (FunctionInvocation(tfd1, args1), FunctionInvocation(tfd2, args2)) =>
          (if(tfd1 == tfd2) Some(apriori) else (apriori.get(tfd1.fd.id) match {
            case None =>
              isHomo(tfd1.fd.fullBody, tfd2.fd.fullBody)(apriori + (tfd1.fd.id -> tfd2.fd.id))
            case Some(fdid2) =>
              if(fdid2 == tfd2.fd.id) Some(apriori) else None
          })) &&
          tfd1.tps.zip(tfd2.tps).mergeall{
            case (t1, t2) => if(t1 == t2)
              (m: Apriori) => Option(m)
              else (m: Apriori) => None} &&
          (args1 zip args2).mergeall{ case (a1, a2) => isHomo(a1, a2) }

        case (Lambda(defs, body), Lambda(defs2, body2)) =>
          // We remove variables introduced by lambdas.
          ((defs zip defs2).mergeall{ case (ValDef(a1), ValDef(a2)) =>
            (m: Apriori) =>
              Some(m + (a1 -> a2)) }(apriori)
           && isHomo(body, body2)
          ) -- (defs.map(_.id))
          
        case (v1, v2) if isValue(v1) && isValue(v2) =>
          v1 == v2 && Some(apriori)

        case Same(Operator(es1, _), Operator(es2, _)) =>
          (es1.size == es2.size) &&
          (es1 zip es2).mergeall{ case (e1, e2) => isHomo(e1, e2) }(apriori)

        case _ =>
          None
      }

      res
    }

    isHomo(t1,t2)(Map())
  } // ensuring (res => res.isEmpty || isHomomorphic(t1, t2)(res.get))

  /** Checks whether two trees are homomoprhic modulo an identifier map.
    *
    * Used for transformation tests.
    */
  def isHomomorphic(t1: Expr, t2: Expr)(implicit map: Map[Identifier, Identifier]): Boolean = {
    object Same {
      def unapply(tt: (Expr, Expr)): Option[(Expr, Expr)] = {
        if (tt._1.getClass == tt._2.getClass) {
          Some(tt)
        } else {
          None
        }
      }
    }

    def idHomo(i1: Identifier, i2: Identifier)(implicit map: Map[Identifier, Identifier]) = {
      i1 == i2 || map.get(i1).contains(i2)
    }

    def fdHomo(fd1: FunDef, fd2: FunDef)(implicit map: Map[Identifier, Identifier]) = {
      (fd1.params.size == fd2.params.size) && {
         val newMap = map +
           (fd1.id -> fd2.id) ++
           (fd1.paramIds zip fd2.paramIds)
         isHomo(fd1.fullBody, fd2.fullBody)(newMap)
      }
    }

    def isHomo(t1: Expr, t2: Expr)(implicit map: Map[Identifier,Identifier]): Boolean = {

      def casesMatch(cs1 : Seq[MatchCase], cs2 : Seq[MatchCase]) : Boolean = {
        def patternHomo(p1: Pattern, p2: Pattern): (Boolean, Map[Identifier, Identifier]) = (p1, p2) match {
          case (InstanceOfPattern(ob1, cd1), InstanceOfPattern(ob2, cd2)) =>
            (ob1.size == ob2.size && cd1 == cd2, Map((ob1 zip ob2).toSeq : _*))

          case (WildcardPattern(ob1), WildcardPattern(ob2)) =>
            (ob1.size == ob2.size, Map((ob1 zip ob2).toSeq : _*))

          case (CaseClassPattern(ob1, ccd1, subs1), CaseClassPattern(ob2, ccd2, subs2)) =>
            val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2)

            if (ob1.size == ob2.size && ccd1 == ccd2 && subs1.size == subs2.size) {
              (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) {
                case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2)
              }
            } else {
              (false, Map())
            }

          case (UnapplyPattern(ob1, fd1, subs1), UnapplyPattern(ob2, fd2, subs2)) =>
            val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2)

            if (ob1.size == ob2.size && fd1 == fd2 && subs1.size == subs2.size) {
              (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) {
                case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2)
              }
            } else {
              (false, Map())
            }

          case (TuplePattern(ob1, subs1), TuplePattern(ob2, subs2)) =>
            val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2)

            if (ob1.size == ob2.size && subs1.size == subs2.size) {
              (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) {
                case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2)
              }
            } else {
              (false, Map())
            }

          case (LiteralPattern(ob1, lit1), LiteralPattern(ob2,lit2)) =>
            (ob1.size == ob2.size && lit1 == lit2, (ob1 zip ob2).toMap)

          case _ =>
            (false, Map())
        }

        (cs1 zip cs2).forall {
          case (MatchCase(p1, g1, e1), MatchCase(p2, g2, e2)) =>
            val (h, nm) = patternHomo(p1, p2)
            val g = (g1, g2) match {
              case (Some(g1), Some(g2)) => isHomo(g1,g2)(map ++ nm)
              case (None, None) => true
              case _ => false
            }
            val e = isHomo(e1, e2)(map ++ nm)

            g && e && h
        }

      }

      val res = (t1, t2) match {
        case (Variable(i1), Variable(i2)) =>
          idHomo(i1, i2)

        case (Let(id1, v1, e1), Let(id2, v2, e2)) =>
          isHomo(v1, v2) &&
          isHomo(e1, e2)(map + (id1 -> id2))

        case (LetDef(fds1, e1), LetDef(fds2, e2)) =>
          fds1.size == fds2.size &&
          {
            val zipped = fds1.zip(fds2)
            zipped.forall( fds =>
            fdHomo(fds._1, fds._2)
            ) &&
            isHomo(e1, e2)(map ++ zipped.map(fds => fds._1.id -> fds._2.id))
          }

        case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) =>
          cs1.size == cs2.size && isHomo(s1, s2) && casesMatch(cs1,cs2)

        case (Passes(in1, out1, cs1), Passes(in2, out2, cs2)) =>
          cs1.size == cs2.size && isHomo(in1,in2) && isHomo(out1,out2) && casesMatch(cs1,cs2)

        case (FunctionInvocation(tfd1, args1), FunctionInvocation(tfd2, args2)) =>
          // TODO: Check type params
          fdHomo(tfd1.fd, tfd2.fd) &&
          (args1 zip args2).forall{ case (a1, a2) => isHomo(a1, a2) }

        case Same(Deconstructor(es1, _), Deconstructor(es2, _)) =>
          (es1.size == es2.size) &&
          (es1 zip es2).forall{ case (e1, e2) => isHomo(e1, e2) }

        case _ =>
          false
      }

      res
    }

    isHomo(t1,t2)
  }

  /** Checks whether the match cases cover all possible inputs.
    *
    * Used when reconstructing pattern matching from ITE.
    *
    * e.g. The following:
    * {{{
    * list match {
    *   case Cons(_, Cons(_, a)) =>
    *
    *   case Cons(_, Nil) =>
    *
    *   case Nil =>
    *
    * }
    * }}}
    * is exaustive.
    *
    * @note Unused and unmaintained
    */
  def isMatchExhaustive(m: MatchExpr): Boolean = {

    /*
     * Takes the matrix of the cases per position/types:
     * e.g.
     * e match {   // Where e: (T1, T2, T3)
     *  case (P1, P2, P3) =>
     *  case (P4, P5, P6) =>
     *
     * becomes checked as:
     *   Seq( (T1, Seq(P1, P4)), (T2, Seq(P2, P5)), (T3, Seq(p3, p6)))
     *
     * We then check that P1+P4 covers every T1, etc..
     *
     * TODO: We ignore type parameters here, we might want to make sure it's
     * valid. What's Leon's semantics w.r.t. erasure?
     */
    def areExhaustive(pss: Seq[(TypeTree, Seq[Pattern])]): Boolean = pss.forall { case (tpe, ps) =>

      tpe match {
        case TupleType(tpes) =>
          val subs = ps.collect {
            case TuplePattern(_, bs) =>
              bs
          }

          areExhaustive(tpes zip subs.transpose)

        case _: ClassType =>

          def typesOf(tpe: TypeTree): Set[CaseClassDef] = tpe match {
            case AbstractClassType(ctp, _) =>
              ctp.knownDescendants.collect { case c: CaseClassDef => c }.toSet

            case CaseClassType(ctd, _) =>
              Set(ctd)

            case _ =>
              Set()
          }

          var subChecks = typesOf(tpe).map(_ -> Seq[Seq[Pattern]]()).toMap

          for (p <- ps) p match {
            case w: WildcardPattern =>
              // (a) Wildcard covers everything, no type left to check
              subChecks = Map.empty

            case InstanceOfPattern(_, cct) =>
              // (a: B) covers all Bs
              subChecks --= typesOf(cct)

            case CaseClassPattern(_, cct, subs) =>
              val ccd = cct.classDef
              // We record the patterns per types, if they still need to be checked
              if (subChecks contains ccd) {
                subChecks += (ccd -> (subChecks(ccd) :+ subs))
              }

            case _ =>
              sys.error("Unexpected case: "+p)
          }

          subChecks.forall { case (ccd, subs) =>
            val tpes = ccd.fields.map(_.getType)

            if (subs.isEmpty) {
              false
            } else {
              areExhaustive(tpes zip subs.transpose)
            }
          }

        case BooleanType =>
          // make sure ps contains either
          // - Wildcard or
          // - both true and false
          (ps exists { _.isInstanceOf[WildcardPattern] }) || {
            var found = Set[Boolean]()
            ps foreach {
              case LiteralPattern(_, BooleanLiteral(b)) => found += b
              case _ => ()
            }
            (found contains true) && (found contains false)
          }

        case UnitType =>
          // Anything matches ()
          ps.nonEmpty
        
        case StringType =>
          // Can't possibly pattern match against all Strings one by one
          ps exists (_.isInstanceOf[WildcardPattern])

        case Int32Type =>
          // Can't possibly pattern match against all Ints one by one
          ps exists (_.isInstanceOf[WildcardPattern])

        case _ =>
          true
      }
    }

    val patterns = m.cases.map {
      case SimpleCase(p, _) =>
        p
      case GuardedCase(p, _,  _) =>
        return false
    }

    areExhaustive(Seq((m.scrutinee.getType, patterns)))
  }

  /** Flattens a function that contains a LetDef with a direct call to it
    *
    * Used for merging synthesis results.
    *
    * {{{
    * def foo(a, b) {
    *   def bar(c, d) {
    *      if (..) { bar(c, d) } else { .. }
    *   }
    *   bar(b, a)
    * }
    * }}}
    * becomes
    * {{{
    * def foo(a, b) {
    *   if (..) { foo(b, a) } else { .. }
    * }
    * }}}
    */
  def flattenFunctions(fdOuter: FunDef, ctx: LeonContext, p: Program): FunDef = {
    fdOuter.body match {
      case Some(LetDef(fdsInner, FunctionInvocation(tfdInner2, args))) if fdsInner.size == 1 && fdsInner.head == tfdInner2.fd =>
        val fdInner = fdsInner.head
        val argsDef  = fdOuter.paramIds
        val argsCall = args.collect { case Variable(id) => id }

        if (argsDef.toSet == argsCall.toSet) {
          val defMap = argsDef.zipWithIndex.toMap
          val rewriteMap = argsCall.map(defMap)

          val innerIdsToOuterIds = (fdInner.paramIds zip argsCall).toMap

          def pre(e: Expr) = e match {
            case FunctionInvocation(tfd, args) if tfd.fd == fdInner =>
              val newArgs = (args zip rewriteMap).sortBy(_._2)
              FunctionInvocation(fdOuter.typed(tfd.tps), newArgs.map(_._1))
            case Variable(id) =>
              Variable(innerIdsToOuterIds.getOrElse(id, id))
            case _ =>
              e
          }

          def mergePre(outer: Option[Expr], inner: Option[Expr]): Option[Expr] = (outer, inner) match {
            case (None, Some(ie)) =>
              Some(simplePreTransform(pre)(ie))
            case (Some(oe), None) =>
              Some(oe)
            case (None, None) =>
              None
            case (Some(oe), Some(ie)) =>
              Some(and(oe, simplePreTransform(pre)(ie)))
          }

          def mergePost(outer: Option[Expr], inner: Option[Expr]): Option[Expr] = (outer, inner) match {
            case (None, Some(ie)) =>
              Some(simplePreTransform(pre)(ie))
            case (Some(oe), None) =>
              Some(oe)
            case (None, None) =>
              None
            case (Some(oe), Some(ie)) =>
              val res = FreshIdentifier("res", fdOuter.returnType, true)
              Some(Lambda(Seq(ValDef(res)), and(
                application(oe, Seq(Variable(res))),
                application(simplePreTransform(pre)(ie), Seq(Variable(res)))
              )))
          }

          val newFd = fdOuter.duplicate()

          val simp = Simplifiers.bestEffort(ctx, p) _

          newFd.body          = fdInner.body.map(b => simplePreTransform(pre)(b))
          newFd.precondition  = mergePre(fdOuter.precondition, fdInner.precondition).map(simp)
          newFd.postcondition = mergePost(fdOuter.postcondition, fdInner.postcondition).map(simp)

          newFd
        } else {
          fdOuter
        }
      case _ =>
        fdOuter
    }
  }

  def expandAndSimplifyArithmetic(expr: Expr): Expr = {
    val expr0 = try {
      val freeVars: Array[Identifier] = variablesOf(expr).toArray
      val coefs: Array[Expr] = TreeNormalizations.linearArithmeticForm(expr, freeVars)
      coefs.toList.zip(InfiniteIntegerLiteral(1) :: freeVars.toList.map(Variable)).foldLeft[Expr](InfiniteIntegerLiteral(0))((acc, t) => {
        if(t._1 == InfiniteIntegerLiteral(0)) acc else Plus(acc, Times(t._1, t._2))
      })
    } catch {
      case _: Throwable =>
        expr
    }
    simplifyArithmetic(expr0)
  }

  /* =================
   * Body manipulation
   * =================
   */

  /** Returns whether a particular [[Expressions.Expr]] contains specification
    * constructs, namely [[Expressions.Require]] and [[Expressions.Ensuring]].
    */
  def hasSpec(e: Expr): Boolean = exists {
    case Require(_, _) => true
    case Ensuring(_, _) => true
    case _ => false
  } (e)

  /** Merges the given [[Path]] into the provided [[Expressions.Expr]].
    *
    * This method expects to run on a [[Definitions.FunDef.fullBody]] and merges into
    * existing pre- and postconditions.
    *
    * @param expr The current body
    * @param path The path that should be wrapped around the given body
    * @see [[Expressions.Ensuring]]
    * @see [[Expressions.Require]]
    */
  def withPath(expr: Expr, path: Path): Expr = expr match {
    case Let(i, e, b) => withPath(b, path withBinding (i -> e))
    case Require(pre, b) => path specs (b, pre)
    case Ensuring(Require(pre, b), post) => path specs (b, pre, post)
    case Ensuring(b, post) => path specs (b, post = post)
    case b => path specs b
  }

  /** Replaces the precondition of an existing [[Expressions.Expr]] with a new one.
    *
    * If no precondition is provided, removes any existing precondition.
    * Else, wraps the expression with a [[Expressions.Require]] clause referring to the new precondition.
    *
    * @param expr The current expression
    * @param pred An optional precondition. Setting it to None removes any precondition.
    * @see [[Expressions.Ensuring]]
    * @see [[Expressions.Require]]
    */
  def withPrecondition(expr: Expr, pred: Option[Expr]): Expr = (pred, expr) match {
    case (Some(newPre), Require(pre, b))              => req(newPre, b)
    case (Some(newPre), Ensuring(Require(pre, b), p)) => Ensuring(req(newPre, b), p)
    case (Some(newPre), Ensuring(b, p))               => Ensuring(req(newPre, b), p)
    case (Some(newPre), Let(i, e, b)) if hasSpec(b)   => Let(i, e, withPrecondition(b, pred))
    case (Some(newPre), b)                            => req(newPre, b)
    case (None, Require(pre, b))                      => b
    case (None, Ensuring(Require(pre, b), p))         => Ensuring(b, p)
    case (None, Let(i, e, b)) if hasSpec(b)           => Let(i, e, withPrecondition(b, pred))
    case (None, b)                                    => b
  }

  /** Replaces the postcondition of an existing [[Expressions.Expr]] with a new one.
    *
    * If no postcondition is provided, removes any existing postcondition.
    * Else, wraps the expression with a [[Expressions.Ensuring]] clause referring to the new postcondition.
    *
    * @param expr The current expression
    * @param oie An optional postcondition. Setting it to None removes any postcondition.
    * @see [[Expressions.Ensuring]]
    * @see [[Expressions.Require]]
    */
  def withPostcondition(expr: Expr, oie: Option[Expr]): Expr = (oie, expr) match {
    case (Some(npost), Ensuring(b, post))          => ensur(b, npost)
    case (Some(npost), Let(i, e, b)) if hasSpec(b) => Let(i, e, withPostcondition(b, oie))
    case (Some(npost), b)                          => ensur(b, npost)
    case (None, Ensuring(b, p))                    => b
    case (None, Let(i, e, b)) if hasSpec(b)        => Let(i, e, withPostcondition(b, oie))
    case (None, b)                                 => b
  }

  /** Adds a body to a specification
    *
    * @param expr The specification expression [[Expressions.Ensuring]] or [[Expressions.Require]]. If none of these, the argument is discarded.
    * @param body An option of [[Expressions.Expr]] possibly containing an expression body.
    * @return The post/pre condition with the body. If no body is provided, returns [[Expressions.NoTree]]
    * @see [[Expressions.Ensuring]]
    * @see [[Expressions.Require]]
    */
  def withBody(expr: Expr, body: Option[Expr]): Expr = expr match {
    case Let(i, e, b) if hasSpec(b)      => Let(i, e, withBody(b, body))
    case Require(pre, _)                 => Require(pre, body.getOrElse(NoTree(expr.getType)))
    case Ensuring(Require(pre, _), post) => Ensuring(Require(pre, body.getOrElse(NoTree(expr.getType))), post)
    case Ensuring(_, post)               => Ensuring(body.getOrElse(NoTree(expr.getType)), post)
    case _                               => body.getOrElse(NoTree(expr.getType))
  }

  /** Extracts the body without its specification
    *
    * [[Expressions.Expr]] trees contain its specifications as part of certain nodes.
    * This function helps extracting only the body part of an expression
    *
    * @return An option type with the resulting expression if not [[Expressions.NoTree]]
    * @see [[Expressions.Ensuring]]
    * @see [[Expressions.Require]]
    */
  def withoutSpec(expr: Expr): Option[Expr] = expr match {
    case Let(i, e, b)                    => withoutSpec(b).map(Let(i, e, _))
    case Require(pre, b)                 => Option(b).filterNot(_.isInstanceOf[NoTree])
    case Ensuring(Require(pre, b), post) => Option(b).filterNot(_.isInstanceOf[NoTree])
    case Ensuring(b, post)               => Option(b).filterNot(_.isInstanceOf[NoTree])
    case b                               => Option(b).filterNot(_.isInstanceOf[NoTree])
  }

  /** Returns the precondition of an expression wrapped in Option */
  def preconditionOf(expr: Expr): Option[Expr] = expr match {
    case Let(i, e, b)                 => preconditionOf(b).map(Let(i, e, _))
    case Require(pre, _)              => Some(pre)
    case Ensuring(Require(pre, _), _) => Some(pre)
    case b                            => None
  }

  /** Returns the postcondition of an expression wrapped in Option */
  def postconditionOf(expr: Expr): Option[Expr] = expr match {
    case Let(i, e, b)      => postconditionOf(b).map(Let(i, e, _))
    case Ensuring(_, post) => Some(post)
    case _                 => None
  }

  /** Returns a tuple of precondition, the raw body and the postcondition of an expression */
  def breakDownSpecs(e : Expr) = (preconditionOf(e), withoutSpec(e), postconditionOf(e))

  def preTraversalWithParent(f: (Expr, Option[Tree]) => Unit, initParent: Option[Tree] = None)(e: Expr): Unit = {
    val rec = preTraversalWithParent(f, Some(e)) _

    f(e, initParent)

    val Deconstructor(es, _) = e
    es foreach rec
  }

  object InvocationExtractor {
    private def flatInvocation(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match {
      case fi @ FunctionInvocation(tfd, args) => Some((tfd, args))
      case Application(caller, args) => flatInvocation(caller) match {
        case Some((tfd, prevArgs)) => Some((tfd, prevArgs ++ args))
        case None => None
      }
        case _ => None
    }

    def unapply(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match {
      case IsTyped(f: FunctionInvocation, ft: FunctionType) => None
      case IsTyped(f: Application, ft: FunctionType) => None
      case FunctionInvocation(tfd, args) => Some(tfd -> args)
      case f: Application => flatInvocation(f)
      case _ => None
    }
  }

  def firstOrderCallsOf(expr: Expr): Set[(TypedFunDef, Seq[Expr])] =
    collect[(TypedFunDef, Seq[Expr])] {
      case InvocationExtractor(tfd, args) => Set(tfd -> args)
      case _ => Set.empty
    }(expr)

  object ApplicationExtractor {
    private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match {
      case Application(fi: FunctionInvocation, _) => None
      case Application(caller: Application, args) => flatApplication(caller) match {
        case Some((c, prevArgs)) => Some((c, prevArgs ++ args))
        case None => None
  }
        case Application(caller, args) => Some((caller, args))
        case _ => None
    }

    def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match {
      case IsTyped(f: Application, ft: FunctionType) => None
      case f: Application => flatApplication(f)
      case _ => None
    }
  }

  def firstOrderAppsOf(expr: Expr): Set[(Expr, Seq[Expr])] =
    collect[(Expr, Seq[Expr])] {
      case ApplicationExtractor(caller, args) => Set(caller -> args)
      case _ => Set.empty
    } (expr)

  def simplifyHOFunctions(expr: Expr) : Expr = {

    def liftToLambdas(expr: Expr) = {
      def apply(expr: Expr, args: Seq[Expr]): Expr = expr match {
        case IfExpr(cond, thenn, elze) =>
          IfExpr(cond, apply(thenn, args), apply(elze, args))
        case Let(i, e, b) =>
          Let(i, e, apply(b, args))
        case LetTuple(is, es, b) =>
          letTuple(is, es, apply(b, args))
        //case l @ Lambda(params, body) =>
        //  l.withParamSubst(args, body)
        case _ => Application(expr, args)
      }

      def lift(expr: Expr): Expr = expr.getType match {
        case FunctionType(from, to) => expr match {
          case _ : Lambda => expr
          case _ : Variable => expr
          case e =>
            val args = from.map(tpe => ValDef(FreshIdentifier("x", tpe, true)))
            val application = apply(expr, args.map(_.toVariable))
            Lambda(args, lift(application))
        }
        case _ => expr
      }

      def extract(expr: Expr, build: Boolean) = if (build) lift(expr) else expr

      def rec(expr: Expr, build: Boolean): Expr = expr match {
        case Application(caller, args) =>
          val newArgs = args.map(rec(_, true))
          val newCaller = rec(caller, false)
          extract(Application(newCaller, newArgs), build)
        case FunctionInvocation(fd, args) =>
          val newArgs = args.map(rec(_, true))
          extract(FunctionInvocation(fd, newArgs), build)
        case l @ Lambda(args, body) =>
          val newBody = rec(body, true)
          extract(Lambda(args, newBody), build)
        case Deconstructor(es, recons) => recons(es.map(rec(_, build)))
      }

      rec(lift(expr), true)
    }

    liftToLambdas(
      matchToIfThenElse(
        expr
      )
    )
  }

  /** lift closures introduced by synthesis.
    *
    * Closures already define all
    * the necessary information as arguments, no need to close them.
    */
  def liftClosures(e: Expr): (Set[FunDef], Expr) = {
    var fds: Map[FunDef, FunDef] = Map()

    val res1 = preMap({
      case LetDef(lfds, b) =>
        val nfds = lfds.map(fd => fd -> fd.duplicate())

        fds ++= nfds

        Some(letDef(nfds.map(_._2), b))

      case FunctionInvocation(tfd, args) =>
        if (fds contains tfd.fd) {
          Some(FunctionInvocation(fds(tfd.fd).typed(tfd.tps), args))
        } else {
          None
        }

      case _ =>
        None
    })(e)

    // we now remove LetDefs
    val res2 = preMap({
      case LetDef(fds, b) =>
        Some(b)
      case _ =>
        None
    }, applyRec = true)(res1)

    (fds.values.toSet, res2)
  }

  def isListLiteral(e: Expr)(implicit pgm: Program): Option[(TypeTree, List[Expr])] = e match {
    case CaseClass(CaseClassType(classDef, Seq(tp)), Nil) =>
      for {
        leonNil <- pgm.library.Nil
        if classDef == leonNil
      } yield {
        (tp, Nil)
      }
    case CaseClass(CaseClassType(classDef, Seq(tp)), Seq(hd, tl)) =>
      for {
        leonCons <- pgm.library.Cons
        if classDef == leonCons
        (_, tlElems) <- isListLiteral(tl)
      } yield {
        (tp, hd :: tlElems)
      }
    case _ =>
      None
  }


  /** Collects from within an expression all conditions under which the evaluation of the expression
    * will not fail (e.g. by violating a function precondition or evaluating to an error).
    *
    * Collection of preconditions of function invocations can be disabled
    * (mainly for [[leon.verification.Tactic]]).
    *
    * @param e The expression for which correctness conditions are calculated.
    * @param collectFIs Whether we also want to collect preconditions for function invocations
    * @return A sequence of pairs (expression, condition)
    */
  def collectCorrectnessConditions(e: Expr, collectFIs: Boolean = false): Seq[(Expr, Expr)] = {
    val conds = collectWithPC {

      case m @ MatchExpr(scrut, cases) =>
        (m, orJoin(cases map (matchCaseCondition(scrut, _).toClause)))

      case e @ Error(_, _) =>
        (e, BooleanLiteral(false))

      case a @ Assert(cond, _, _) =>
        (a, cond)

      /*case e @ Ensuring(body, post) =>
        (e, application(post, Seq(body)))

      case r @ Require(pred, e) =>
        (r, pred)*/

      case fi @ FunctionInvocation(tfd, args) if tfd.hasPrecondition && collectFIs =>
        (fi, tfd.withParamSubst(args, tfd.precondition.get))
    }(e)

    conds map {
      case ((e, cond), path) =>
        (e, path implies cond)
    }
  }


  def simpleCorrectnessCond(e: Expr, path: List[Expr], sf: SolverFactory[Solver]): Expr = {
    simplifyPaths(sf, path)(
      andJoin( collectCorrectnessConditions(e) map { _._2 } )
    )
  }

  def tupleWrapArg(fun: Expr) = fun.getType match {
    case FunctionType(args, res) if args.size > 1 =>
      val newArgs = fun match {
        case Lambda(args, _) => args map (_.id)
        case _ => args map (arg => FreshIdentifier("x", arg.getType, alwaysShowUniqueID = true))
      }
      val res = FreshIdentifier("res", TupleType(args map (_.getType)), alwaysShowUniqueID = true)
      val patt = TuplePattern(None, newArgs map (arg => WildcardPattern(Some(arg))))
      Lambda(Seq(ValDef(res)), MatchExpr(res.toVariable, Seq(SimpleCase(patt, application(fun, newArgs map (_.toVariable))))))
    case _ =>
      fun
  }
  var msgs: String = ""
  implicit class BooleanAdder(b: Boolean) {
    def <(msg: String) = {if(!b) msgs += msg; b}
  }

  /** Returns true if expr is a value of type t */
  def isValueOfType(e: Expr, t: TypeTree): Boolean = {
    def unWrapSome(s: Expr) = s match {
      case CaseClass(_, Seq(a)) => a
      case _ => s
    }
    (e, t) match {
      case (StringLiteral(_), StringType) => true
      case (IntLiteral(_), Int32Type) => true
      case (InfiniteIntegerLiteral(_), IntegerType) => true
      case (CharLiteral(_), CharType) => true
      case (FractionalLiteral(_, _), RealType) => true
      case (BooleanLiteral(_), BooleanType) => true
      case (UnitLiteral(), UnitType) => true
      case (GenericValue(t, _), tp) => t == tp
      case (Tuple(elems), TupleType(bases)) =>
        elems zip bases forall (eb => isValueOfType(eb._1, eb._2))
      case (FiniteSet(elems, tbase), SetType(base)) =>
        tbase == base &&
        (elems forall isValue)
      case (FiniteMap(elems, tk, tv), MapType(from, to)) =>
        (tk == from) < s"$tk not equal to $from" && (tv == to) < s"$tv not equal to $to" &&
        (elems forall (kv => isValueOfType(kv._1, from) < s"${kv._1} not a value of type ${from}" && isValueOfType(unWrapSome(kv._2), to) < s"${unWrapSome(kv._2)} not a value of type ${to}" ))
      case (NonemptyArray(elems, defaultValues), ArrayType(base)) =>
        elems.values forall (x => isValueOfType(x, base))
      case (EmptyArray(tpe), ArrayType(base)) =>
        tpe == base
      case (CaseClass(ct, args), ct2@AbstractClassType(classDef, tps)) => 
        TypeOps.isSubtypeOf(ct, ct2) < s"$ct not a subtype of $ct2" &&
        ((args zip ct.fieldsTypes) forall (argstyped => isValueOfType(argstyped._1, argstyped._2) < s"${argstyped._1} not a value of type ${argstyped._2}" ))
      case (CaseClass(ct, args), ct2@CaseClassType(classDef, tps)) => 
        (ct == ct2) <  s"$ct not equal to $ct2" &&
        ((args zip ct.fieldsTypes) forall (argstyped => isValueOfType(argstyped._1, argstyped._2)))
      case (FiniteLambda(mapping, default, tpe), exTpe@FunctionType(ins, out)) =>
        tpe == exTpe
      case (Lambda(valdefs, body), FunctionType(ins, out)) =>
        (valdefs zip ins forall (vdin => (vdin._1.getType == vdin._2) < s"${vdin._1.getType} is not equal to ${vdin._2}")) &&
        (body.getType == out) < s"${body.getType} is not equal to ${out}"
      case (FiniteBag(elements, fbtpe), BagType(tpe)) =>
        fbtpe == tpe && elements.forall{ case (key, value) => isValueOfType(key, tpe) && isValueOfType(value, IntegerType) }
      case _ => false
    }
  }
    
  /** Returns true if expr is a value. Stronger than isGround */
  val isValue = (e: Expr) => isValueOfType(e, e.getType)
  
  /** Returns a nested string explaining why this expression is typed the way it is.*/
  def explainTyping(e: Expr): String = {
    leon.purescala.ExprOps.fold[String]{ (e, se) => 
      e match {
        case FunctionInvocation(tfd, args) =>
          s"$e is of type ${e.getType}" + se.map(child => "\n  " + "\n".r.replaceAllIn(child, "\n  ")).mkString + s" because ${tfd.fd.id.name} was instantiated with ${tfd.fd.tparams.zip(args).map(k => k._1 +":="+k._2).mkString(",")} with type ${tfd.fd.params.map(_.getType).mkString(",")} => ${tfd.fd.returnType}"
        case e =>
          s"$e is of type ${e.getType}" + se.map(child => "\n  " + "\n".r.replaceAllIn(child, "\n  ")).mkString
      }
    }(e)
  }
}