/* Copyright 2009-2013 EPFL, Lausanne */

package leon
package purescala

import leon.solvers.Solver

import scala.collection.concurrent.TrieMap

object TreeOps {
  import Common._
  import TypeTrees._
  import Definitions._
  import xlang.Trees.LetDef
  import Trees._
  import Extractors._

  def negate(expr: Expr) : Expr = expr match {
    case Let(i,b,e) => Let(i,b,negate(e))
    case Not(e) => e
    case Iff(e1,e2) => Iff(negate(e1),e2)
    case Implies(e1,e2) => And(e1, negate(e2))
    case Or(exs) => And(exs map negate)
    case And(exs) => Or(exs map negate)
    case LessThan(e1,e2) => GreaterEquals(e1,e2)
    case LessEquals(e1,e2) => GreaterThan(e1,e2)
    case GreaterThan(e1,e2) => LessEquals(e1,e2)
    case GreaterEquals(e1,e2) => LessThan(e1,e2)
    case i @ IfExpr(c,e1,e2) => IfExpr(c, negate(e1), negate(e2)).setType(i.getType)
    case BooleanLiteral(b) => BooleanLiteral(!b)
    case _ => Not(expr)
  }

  // Warning ! This may loop forever if the substitutions are not
  // well-formed!
  def replace(substs: Map[Expr,Expr], expr: Expr) : Expr = {
    searchAndReplaceDFS(substs.get)(expr)
  }

  // Can't just be overloading because of type erasure :'(
  def replaceFromIDs(substs: Map[Identifier,Expr], expr: Expr) : Expr = {
    replace(substs.map(p => (Variable(p._1) -> p._2)), expr)
  }

  def searchAndReplace(subst: Expr=>Option[Expr], recursive: Boolean=true)(expr: Expr) : Expr = {
    def rec(ex: Expr, skip: Expr = null) : Expr = (if (ex == skip) None else subst(ex)) match {
      case Some(newExpr) => {
        if(newExpr.getType == Untyped) {
          Settings.reporter.error("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr)
        }
        if(ex == newExpr)
          if(recursive) rec(ex, ex) else ex
        else
          if(recursive) rec(newExpr) else newExpr
      }
      case None => ex match {
        case l @ Let(i,e,b) => {
          val re = rec(e)
          val rb = rec(b)
          if(re != e || rb != b)
            Let(i, re, rb).setType(l.getType)
          else
            l
        }
        //case l @ LetDef(fd, b) => {
        //  //TODO, not sure, see comment for the next LetDef
        //  fd.body = fd.body.map(rec(_))
        //  fd.precondition = fd.precondition.map(rec(_))
        //  fd.postcondition = fd.postcondition.map(rec(_))
        //  LetDef(fd, rec(b)).setType(l.getType)
        //}

        case lt @ LetTuple(ids, expr, body) => {
          val re = rec(expr)
          val rb = rec(body)
          if (re != expr || rb != body) {
            LetTuple(ids, re, rb).setType(lt.getType)
          } else {
            lt
          }
        }
        case n @ NAryOperator(args, recons) => {
          var change = false
          val rargs = args.map(a => {
            val ra = rec(a)
            if(ra != a) {
              change = true  
              ra
            } else {
              a
            }            
          })
          if(change)
            recons(rargs).setType(n.getType)
          else
            n
        }
        case b @ BinaryOperator(t1,t2,recons) => {
          val r1 = rec(t1)
          val r2 = rec(t2)
          if(r1 != t1 || r2 != t2)
            recons(r1,r2).setType(b.getType)
          else
            b
        }
        case u @ UnaryOperator(t,recons) => {
          val r = rec(t)
          if(r != t)
            recons(r).setType(u.getType)
          else
            u
        }
        case i @ IfExpr(t1,t2,t3) => {
          val r1 = rec(t1)
          val r2 = rec(t2)
          val r3 = rec(t3)
          if(r1 != t1 || r2 != t2 || r3 != t3)
            IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType)
          else
            i
        }
        case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType).setPosInfo(m)

        case c @ Choose(args, body) =>
          val body2 = rec(body)

          if (body != body2) {
            Choose(args, body2).setType(c.getType)
          } else {
            c
          }

        case t if t.isInstanceOf[Terminal] => t
        case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled)
      }
    }

    def inCase(cse: MatchCase) : MatchCase = cse match {
      case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs))
      case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard), rec(rhs))
    }

    rec(expr)
  }

  def searchAndReplaceDFS(subst: Expr=>Option[Expr])(expr: Expr) : Expr = {
    val (res,_) = searchAndReplaceDFSandTrackChanges(subst)(expr)
    res
  }

  def searchAndReplaceDFSandTrackChanges(subst: Expr=>Option[Expr])(expr: Expr) : (Expr,Boolean) = {
    var somethingChanged: Boolean = false
    def applySubst(ex: Expr) : Expr = subst(ex) match {
      case None => ex
      case Some(newEx) => {
        somethingChanged = true
        if(newEx.getType == Untyped) {
          Settings.reporter.warning("REPLACING [" + ex + "] WITH AN UNTYPED EXPRESSION !")
          Settings.reporter.warning("Here's the new expression: " + newEx)
        }
        newEx
      }
    }

    def rec(ex: Expr) : Expr = ex match {
      case l @ Let(i,e,b) => {
        val re = rec(e)
        val rb = rec(b)
        applySubst(if(re != e || rb != b) {
          Let(i,re,rb).setType(l.getType)
        } else {
          l
        })
      }
      case l @ LetTuple(ids,e,b) => {
        val re = rec(e)
        val rb = rec(b)
        applySubst(if(re != e || rb != b) {
          LetTuple(ids,re,rb).setType(l.getType)
        } else {
          l
        })
      }
      //case l @ LetDef(fd,b) => {
      //  //TODO: Not sure: I actually need the replace to occurs even in the pre/post condition, hope this is correct
      //  fd.body = fd.body.map(rec(_))
      //  fd.precondition = fd.precondition.map(rec(_))
      //  fd.postcondition = fd.postcondition.map(rec(_))
      //  val rl = LetDef(fd, rec(b)).setType(l.getType)
      //  applySubst(rl)
      //}
      case n @ NAryOperator(args, recons) => {
        var change = false
        val rargs = args.map(a => {
          val ra = rec(a)
          if(ra != a) {
            change = true  
            ra
          } else {
            a
          }            
        })
        applySubst(if(change) {
          recons(rargs).setType(n.getType)
        } else {
          n
        })
      }
      case b @ BinaryOperator(t1,t2,recons) => {
        val r1 = rec(t1)
        val r2 = rec(t2)
        applySubst(if(r1 != t1 || r2 != t2) {
          recons(r1,r2).setType(b.getType)
        } else {
          b
        })
      }
      case u @ UnaryOperator(t,recons) => {
        val r = rec(t)
        applySubst(if(r != t) {
          recons(r).setType(u.getType)
        } else {
          u
        })
      }
      case i @ IfExpr(t1,t2,t3) => {
        val r1 = rec(t1)
        val r2 = rec(t2)
        val r3 = rec(t3)
        applySubst(if(r1 != t1 || r2 != t2 || r3 != t3) {
          IfExpr(r1,r2,r3).setType(i.getType)
        } else {
          i
        })
      }
      case m @ MatchExpr(scrut,cses) => {
        val rscrut = rec(scrut)
        val (newCses,changes) = cses.map(inCase(_)).unzip
        applySubst(if(rscrut != scrut || changes.exists(res=>res)) {
          MatchExpr(rscrut, newCses).setType(m.getType).setPosInfo(m)
        } else {
          m
        })
      }

      case c @ Choose(args, body) =>
        val body2 = rec(body)

        applySubst(if (body != body2) {
          Choose(args, body2).setType(c.getType).setPosInfo(c)
        } else {
          c
        })

      case t if t.isInstanceOf[Terminal] => applySubst(t)
      case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplaceDFS: " + unhandled)
    }

    def inCase(cse: MatchCase) : (MatchCase,Boolean) = cse match {
      case s @ SimpleCase(pat, rhs) => {
        val rrhs = rec(rhs)
        if(rrhs != rhs) {
          (SimpleCase(pat, rrhs), true)
        } else {
          (s, false)
        }
      }
      case g @ GuardedCase(pat, guard, rhs) => {
        val rguard = rec(guard)
        val rrhs = rec(rhs)
        if(rguard != guard || rrhs != rhs) {
          (GuardedCase(pat, rguard, rrhs), true)
        } else {
          (g, false)
        }
      }
    }

    val res = rec(expr)
    (res, somethingChanged)
  }

  // rewrites pattern-matching expressions to use fresh variables for the binders
  def freshenLocals(expr: Expr) : Expr = {
    def rewritePattern(p: Pattern, sm: Map[Identifier,Identifier]) : Pattern = p match {
      case InstanceOfPattern(Some(b), ctd) => InstanceOfPattern(Some(sm(b)), ctd)
      case WildcardPattern(Some(b)) => WildcardPattern(Some(sm(b)))
      case CaseClassPattern(ob, ccd, sps) => CaseClassPattern(ob.map(sm(_)), ccd, sps.map(rewritePattern(_, sm)))
      case other => other
    }

    def freshenCase(cse: MatchCase) : MatchCase = {
      val allBinders: Set[Identifier] = cse.pattern.binders
      val subMap: Map[Identifier,Identifier] = Map(allBinders.map(i => (i, FreshIdentifier(i.name, true).setType(i.getType))).toSeq : _*)
      val subVarMap: Map[Expr,Expr] = subMap.map(kv => (Variable(kv._1) -> Variable(kv._2)))
      
      cse match {
        case SimpleCase(pattern, rhs) => SimpleCase(rewritePattern(pattern, subMap), replace(subVarMap, rhs))
        case GuardedCase(pattern, guard, rhs) => GuardedCase(rewritePattern(pattern, subMap), replace(subVarMap, guard), replace(subVarMap, rhs))
      }
    }

    def applyToTree(e : Expr) : Option[Expr] = e match {
      case m @ MatchExpr(s, cses) => Some(MatchExpr(s, cses.map(freshenCase(_))).setType(m.getType).setPosInfo(m))
      case l @ Let(i,e,b) => {
        val newID = FreshIdentifier(i.name, true).setType(i.getType)
        Some(Let(newID, e, replace(Map(Variable(i) -> Variable(newID)), b)))
      }
      case _ => None
    }

    searchAndReplaceDFS(applyToTree)(expr)
  }

  // convert describes how to compute a value for the leaves (that includes
  // functions with no args.)
  // combine descriess how to combine two values
  def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, expression: Expr) : A = {
    treeCatamorphism(convert, combine, (e:Expr,a:A)=>a, expression)
  }
  // compute allows the catamorphism to change the combined value depending on the tree
  def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, compute: (Expr,A)=>A, expression: Expr) : A = {
    def rec(expr: Expr) : A = expr match {
      case l @ Let(_, e, b) => compute(l, combine(rec(e), rec(b)))
      //case l @ LetDef(fd, b) => {//TODO, still not sure about the semantic
      //  val exprs: Seq[Expr] = fd.precondition.toSeq ++ fd.body.toSeq ++ fd.postcondition.toSeq ++ Seq(b)
      //  compute(l, exprs.map(rec(_)).reduceLeft(combine))
      //}
      case n @ NAryOperator(args, _) => {
        if(args.size == 0)
          compute(n, convert(n))
        else
          compute(n, args.map(rec(_)).reduceLeft(combine))
      }
      case b @ BinaryOperator(a1,a2,_) => compute(b, combine(rec(a1),rec(a2)))
      case u @ UnaryOperator(a,_) => compute(u, rec(a))
      case i @ IfExpr(a1,a2,a3) => compute(i, combine(combine(rec(a1), rec(a2)), rec(a3)))
      case m @ MatchExpr(scrut, cses) => compute(m, (scrut +: cses.flatMap(_.expressions)).map(rec(_)).reduceLeft(combine))
      case c @ Choose(args, body) => compute(c, rec(body))
      case t: Terminal => compute(t, convert(t))
      case unhandled => scala.sys.error("Non-terminal case should be handled in treeCatamorphism: " + unhandled)
    }

    rec(expression)
  }

  def containsIfExpr(expr: Expr): Boolean = {
    def convert(t : Expr) : Boolean = t match {
      case (i: IfExpr) => true
      case _ => false
    }
    def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2
    def compute(t : Expr, c : Boolean) = t match {
      case (i: IfExpr) => true
      case _ => c
    }
    treeCatamorphism(convert, combine, compute, expr)
  }

  def variablesOf(expr: Expr) : Set[Identifier] = {
    def convert(t: Expr) : Set[Identifier] = t match {
      case Variable(i) => Set(i)
      case _ => Set.empty
    }
    def combine(s1: Set[Identifier], s2: Set[Identifier]) = s1 ++ s2
    def compute(t: Expr, s: Set[Identifier]) = t match {
      case Let(i,_,_) => s -- Set(i)
      case Choose(is,_) => s -- is
      case MatchExpr(_, cses) => s -- (cses.map(_.pattern.binders).foldLeft(Set[Identifier]())((a, b) => a ++ b))
      case _ => s
    }
    treeCatamorphism(convert, combine, compute, expr)
  }

  def containsFunctionCalls(expr : Expr) : Boolean = {
    def convert(t : Expr) : Boolean = t match {
      case f : FunctionInvocation => true
      case _ => false
    }
    def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2
    def compute(t : Expr, c : Boolean) = t match {
      case f : FunctionInvocation => true
      case _ => c
    }
    treeCatamorphism(convert, combine, compute, expr)
  }

  def topLevelFunctionCallsOf(expr: Expr, barring : Set[FunDef] = Set.empty) : Set[FunctionInvocation] = {
    def convert(t: Expr) : Set[FunctionInvocation] = t match {
      case f @ FunctionInvocation(fd, _) if(!barring(fd)) => Set(f)
      case _ => Set.empty
    }
    def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2
    def compute(t: Expr, s: Set[FunctionInvocation]) = t match {
      case f @ FunctionInvocation(fd,  _) if(!barring(fd)) => Set(f) // ++ s that's the difference with the one below
      case _ => s
    }
    treeCatamorphism(convert, combine, compute, expr)
  }

  def allNonRecursiveFunctionCallsOf(expr: Expr, program: Program) : Set[FunctionInvocation] = {
    def convert(t: Expr) : Set[FunctionInvocation] = t match {
      case f @ FunctionInvocation(fd, _) if program.isRecursive(fd) => Set(f)
      case _ => Set.empty
    }
    
    def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2

    def compute(t: Expr, s: Set[FunctionInvocation]) = t match {
      case f @ FunctionInvocation(fd,_) if program.isRecursive(fd) => Set(f) ++ s
      case _ => s
    }
    treeCatamorphism(convert, combine, compute, expr)
  }

  def functionCallsOf(expr: Expr) : Set[FunctionInvocation] = {
    def convert(t: Expr) : Set[FunctionInvocation] = t match {
      case f @ FunctionInvocation(_, _) => Set(f)
      case _ => Set.empty
    }
    def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2
    def compute(t: Expr, s: Set[FunctionInvocation]) = t match {
      case f @ FunctionInvocation(_, _) => Set(f) ++ s
      case _ => s
    }
    treeCatamorphism(convert, combine, compute, expr)
  }

  def contains(expr: Expr, matcher: Expr=>Boolean) : Boolean = {
    treeCatamorphism[Boolean](
      matcher,
      (b1: Boolean, b2: Boolean) => b1 || b2,
      (t: Expr, b: Boolean) => b || matcher(t),
      expr)
  }

  def allIdentifiers(expr: Expr) : Set[Identifier] = expr match {
    case l @ Let(binder, e, b) => allIdentifiers(e) ++ allIdentifiers(b) + binder
    //TODO: Cannot have LetVar nor LetDef here, should not be visible at this point
    //case l @ LetVar(binder, e, b) => allIdentifiers(e) ++ allIdentifiers(b) + binder
    //case l @ LetDef(fd, b) => allIdentifiers(fd.getBody) ++ allIdentifiers(b) + fd.id
    case n @ NAryOperator(args, _) =>
      (args map (TreeOps.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b)
    case b @ BinaryOperator(a1,a2,_) => allIdentifiers(a1) ++ allIdentifiers(a2)
    case u @ UnaryOperator(a,_) => allIdentifiers(a)
    case i @ IfExpr(a1,a2,a3) => allIdentifiers(a1) ++ allIdentifiers(a2) ++ allIdentifiers(a3)
    case m @ MatchExpr(scrut, cses) =>
      (cses map (_.allIdentifiers)).foldLeft(Set[Identifier]())((a, b) => a ++ b) ++ allIdentifiers(scrut)
    case Variable(id) => Set(id)
    case t: Terminal => Set.empty
  }

  def allDeBruijnIndices(expr: Expr) : Set[DeBruijnIndex] =  {
    def convert(t: Expr) : Set[DeBruijnIndex] = t match {
      case i @ DeBruijnIndex(idx) => Set(i)
      case _ => Set.empty
    }
    def combine(s1: Set[DeBruijnIndex], s2: Set[DeBruijnIndex]) = s1 ++ s2
    treeCatamorphism(convert, combine, expr)
  }

  /* Simplifies let expressions:
   *  - removes lets when expression never occurs
   *  - simplifies when expressions occurs exactly once
   *  - expands when expression is just a variable.
   * Note that the code is simple but far from optimal (many traversals...)
   */
  def simplifyLets(expr: Expr) : Expr = {
    def simplerLet(t: Expr) : Option[Expr] = t match {

      case letExpr @ Let(i, t: Terminal, b) if !containsChoose(b) => Some(replace(Map((Variable(i) -> t)), b))

      case letExpr @ Let(i,e,b) if !containsChoose(b) => {
        val occurences = treeCatamorphism[Int]((e:Expr) => e match {
          case Variable(x) if x == i => 1
          case _ => 0
        }, (x:Int,y:Int)=>x+y, b)
        if(occurences == 0) {
          Some(b)
        } else if(occurences == 1) {
          Some(replace(Map((Variable(i) -> e)), b))
        } else {
          None
        }
      }

      case letTuple @ LetTuple(ids, Tuple(exprs), body) if !containsChoose(body) =>

        var newBody = body

        val (remIds, remExprs) = (ids zip exprs).filter { 
          case (id, value: Terminal) =>
            newBody = replace(Map((Variable(id) -> value)), newBody)
            //we replace, so we drop old
            false
          case (id, value) =>
            val occurences = treeCatamorphism[Int]((e:Expr) => e match {
              case Variable(x) if x == id => 1
              case _ => 0
            }, (x:Int,y:Int)=>x+y, body)

            if(occurences == 0) {
              false
            } else if(occurences == 1) {
              newBody = replace(Map((Variable(id) -> value)), newBody)
              false
            } else {
              true
            }
        }.unzip


        if (remIds.isEmpty) {
          Some(newBody)
        } else if (remIds.tail.isEmpty) {
          Some(Let(remIds.head, remExprs.head, newBody))
        } else {
          Some(LetTuple(remIds, Tuple(remExprs), newBody))
        }

      case l @ LetTuple(ids, tExpr, body) if !containsChoose(body) =>
        val TupleType(types) = tExpr.getType
        val arity = ids.size
        // A map containing vectors of the form (0, ..., 1, ..., 0) where the one corresponds to the index of the identifier in the
        // LetTuple. The idea is that we can sum such vectors up to compute the occurences of all variables in one traversal of the
        // expression.
        val zeroVec = Seq.fill(arity)(0)
        val idMap   = ids.zipWithIndex.toMap.mapValues(i => zeroVec.updated(i, 1))

        val occurences : Seq[Int] = treeCatamorphism[Seq[Int]]((e : Expr) => e match {
          case Variable(x) => idMap.getOrElse(x, zeroVec)
          case _ => zeroVec
        }, (v1 : Seq[Int], v2 : Seq[Int]) => (v1 zip v2).map(p => p._1  + p._2), body)

        val total = occurences.sum

        if(total == 0) {
          Some(body)
        } else if(total == 1) {
          val substMap : Map[Expr,Expr] = ids.map(Variable(_) : Expr).zipWithIndex.toMap.map {
            case (v,i) => (v -> TupleSelect(tExpr, i + 1).setType(v.getType))
          }

          Some(replace(substMap, body))
        } else {
          None
        }

      case _ => None
    }
    searchAndReplaceDFS(simplerLet)(expr)
  }

  // Pulls out all let constructs to the top level, and makes sure they're
  // properly ordered.
  private type DefPair  = (Identifier,Expr) 
  private type DefPairs = List[DefPair] 
  private def allLetDefinitions(expr: Expr) : DefPairs = treeCatamorphism[DefPairs](
    (e: Expr) => Nil,
    (s1: DefPairs, s2: DefPairs) => s1 ::: s2,
    (e: Expr, dps: DefPairs) => e match {
      case Let(i, e, _) => (i,e) :: dps
      case _ => dps
    },
    expr)
  
  private def killAllLets(expr: Expr) : Expr = searchAndReplaceDFS((e: Expr) => e match {
    case Let(_,_,ex) => Some(ex)
    case _ => None
  })(expr)

  def liftLets(expr: Expr) : Expr = {
    val initialDefinitionPairs = allLetDefinitions(expr)
    val definitionPairs = initialDefinitionPairs.map(p => (p._1, killAllLets(p._2)))
    val occursLists : Map[Identifier,Set[Identifier]] = Map(definitionPairs.map((dp: DefPair) => (dp._1 -> variablesOf(dp._2).toSet.filter(_.isLetBinder))) : _*)
    var newList : DefPairs = Nil
    var placed  : Set[Identifier] = Set.empty
    val toPlace = definitionPairs.size
    var placedC = 0
    var traversals = 0

    while(placedC < toPlace) {
      if(traversals > toPlace + 1) {
        scala.sys.error("Cycle in let definitions or multiple definition for the same identifier in liftLets : " + definitionPairs.mkString("\n"))
      }
      for((id,ex) <- definitionPairs) if (!placed(id)) {
        if((occursLists(id) -- placed) == Set.empty) {
          placed = placed + id
          newList = (id,ex) :: newList
          placedC = placedC + 1
        }
      }
      traversals = traversals + 1
    }

    val noLets = killAllLets(expr)

    val res = (newList.foldLeft(noLets)((e,iap) => Let(iap._1, iap._2, e)))
    simplifyLets(res)
  }

  def wellOrderedLets(tree : Expr) : Boolean = {
    val pairs = allLetDefinitions(tree)
    val definitions: Set[Identifier] = Set(pairs.map(_._1) : _*)
    val vars: Set[Identifier] = variablesOf(tree)
    val intersection = vars intersect definitions
    if(!intersection.isEmpty) {
      intersection.foreach(id => {
        Settings.reporter.error("Variable with identifier '" + id + "' has escaped its let-definition !")
      })
      false
    } else {
      vars.forall(id => if(id.isLetBinder) {
        Settings.reporter.error("Variable with identifier '" + id + "' has lost its let-definition (it disappeared??)")
        false
      } else {
        true
      })
    }
  }

  /* Fully expands all let expressions. */
  def expandLets(expr: Expr) : Expr = {
    def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match {
      case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s)
      case l @ Let(i,e,b) => rec(b, s + (i -> rec(e, s)))
      case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)).setType(i.getType)
      case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut, s), cses.map(inCase(_, s))).setType(m.getType).setPosInfo(m)
      case n @ NAryOperator(args, recons) => {
        var change = false
        val rargs = args.map(a => {
          val ra = rec(a, s)
          if(ra != a) {
            change = true  
            ra
          } else {
            a
          }            
        })
        if(change)
          recons(rargs).setType(n.getType)
        else
          n
      }
      case b @ BinaryOperator(t1,t2,recons) => {
        val r1 = rec(t1, s)
        val r2 = rec(t2, s)
        if(r1 != t1 || r2 != t2)
          recons(r1,r2).setType(b.getType)
        else
          b
      }
      case u @ UnaryOperator(t,recons) => {
        val r = rec(t, s)
        if(r != t)
          recons(r).setType(u.getType)
        else
          u
      }
      case t if t.isInstanceOf[Terminal] => t
      case unhandled => scala.sys.error("Unhandled case in expandLets: " + unhandled)
    }

    def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = cse match {
      case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs, s))
      case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard, s), rec(rhs, s))
    }

    rec(expr, Map.empty)
  }

  /** Rewrites all pattern-matching expressions into if-then-else expressions,
   * with additional error conditions. Does not introduce additional variables.
   */
  val cacheMtITE = new TrieMap[Expr, Expr]()

  def matchToIfThenElse(expr: Expr) : Expr = {
    cacheMtITE.get(expr) match {
      case Some(res) =>
        res
      case None =>
        val r = convertMatchToIfThenElse(expr)
        cacheMtITE += expr -> r
        r
    }
  }

  def conditionForPattern(in: Expr, pattern: Pattern, includeBinders: Boolean = false) : Expr = {
    def bind(ob: Option[Identifier], to: Expr): Expr = {
      if (!includeBinders) {
        BooleanLiteral(true)
      } else {
        ob.map(id => Equals(Variable(id), to)).getOrElse(BooleanLiteral(true))
      }
    }

    def rec(in: Expr, pattern: Pattern): Expr = {
      pattern match {
        case WildcardPattern(ob) => bind(ob, in)
        case InstanceOfPattern(_,_) => scala.sys.error("InstanceOfPattern not yet supported.")
        case CaseClassPattern(ob, ccd, subps) => {
          assert(ccd.fields.size == subps.size)
          val pairs = ccd.fields.map(_.id).toList zip subps.toList
          val subTests = pairs.map(p => rec(CaseClassSelector(ccd, in, p._1), p._2))
          val together = And(bind(ob, in) +: subTests)
          And(CaseClassInstanceOf(ccd, in), together)
        }
        case TuplePattern(ob, subps) => {
          val TupleType(tpes) = in.getType
          assert(tpes.size == subps.size)
          val subTests = subps.zipWithIndex.map{case (p, i) => rec(TupleSelect(in, i+1).setType(tpes(i)), p)}
          And(bind(ob, in) +: subTests)
        }
      }
    }

    rec(in, pattern)
  }

  private def convertMatchToIfThenElse(expr: Expr) : Expr = {
    def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = pattern match {
      case WildcardPattern(None) => Map.empty
      case WildcardPattern(Some(id)) => Map(id -> in)
      case InstanceOfPattern(None, _) => Map.empty
      case InstanceOfPattern(Some(id), _) => Map(id -> in)
      case CaseClassPattern(b, ccd, subps) => {
        assert(ccd.fields.size == subps.size)
        val pairs = ccd.fields.map(_.id).toList zip subps.toList
        val subMaps = pairs.map(p => mapForPattern(CaseClassSelector(ccd, in, p._1), p._2))
        val together = subMaps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _)
        b match {
          case Some(id) => Map(id -> in) ++ together
          case None => together
        }
      }
      case TuplePattern(b, subps) => {
        val TupleType(tpes) = in.getType
        assert(tpes.size == subps.size)

        val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)}
        val map = maps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _)
        b match {
          case Some(id) => map + (id -> in)
          case None => map
        }
      }
    }

    def rewritePM(e: Expr) : Option[Expr] = e match {
      case m @ MatchExpr(scrut, cases) => {
        // println("Rewriting the following PM: " + e)

        val condsAndRhs = for(cse <- cases) yield {
          val map = mapForPattern(scrut, cse.pattern)
          val patCond = conditionForPattern(scrut, cse.pattern, includeBinders = false)
          val realCond = cse.theGuard match {
            case Some(g) => And(patCond, replaceFromIDs(map, g))
            case None => patCond
          }
          val newRhs = replaceFromIDs(map, cse.rhs)
          (realCond, newRhs)
        } 

        val optCondsAndRhs = if(SimplePatternMatching.isSimple(m)) {
          // this is a hackish optimization: because we know all cases are covered, we replace the last condition by true (and that drops the check)
          val lastExpr = condsAndRhs.last._2

          condsAndRhs.dropRight(1) ++ Seq((BooleanLiteral(true),lastExpr))
        } else {
          condsAndRhs
        }

        val bigIte = optCondsAndRhs.foldRight[Expr](Error("non-exhaustive match").setType(bestRealType(m.getType)).setPosInfo(m))((p1, ex) => {
          if(p1._1 == BooleanLiteral(true)) {
            p1._2
          } else {
            IfExpr(p1._1, p1._2, ex).setType(m.getType)
          }
        })

        Some(bigIte)
      }
      case _ => None
    }
    
    searchAndReplaceDFS(rewritePM)(expr)
  }

  /** Rewrites all map accesses with additional error conditions. */
  val cacheMGWC = new TrieMap[Expr, Expr]()

  def mapGetWithChecks(expr: Expr) : Expr = {
    cacheMGWC.get(expr) match {
      case Some(res) =>
        res
      case None =>
        val r = convertMapGet(expr)
        cacheMGWC += expr -> r
        r
    }
  }

  private def convertMapGet(expr: Expr) : Expr = {
    def rewriteMapGet(e: Expr) : Option[Expr] = e match {
      case mg @ MapGet(m,k) => 
        val ida = MapIsDefinedAt(m, k)
        Some(IfExpr(ida, mg, Error("key not found for map access").setType(mg.getType).setPosInfo(mg)).setType(mg.getType))
      case _ => None
    }

    searchAndReplaceDFS(rewriteMapGet)(expr)
  }

  // prec: expression does not contain match expressions
  def measureADTChildrenDepth(expression: Expr) : Int = {
    import scala.math.max

    def rec(ex: Expr, lm: Map[Identifier,Int]) : Int = ex match {
      case Let(i,e,b) => rec(b,lm + (i -> rec(e,lm)))
      case Variable(id) => lm.getOrElse(id, 0)
      case CaseClassSelector(_, e, _) => rec(e,lm) + 1
      case NAryOperator(args, _) => if(args.isEmpty) 0 else args.map(rec(_,lm)).max
      case BinaryOperator(e1,e2,_) => max(rec(e1,lm), rec(e2,lm))
      case UnaryOperator(e,_) => rec(e,lm)
      case IfExpr(c,t,e) => max(max(rec(c,lm),rec(t,lm)),rec(e,lm))
      case t: Terminal => 0
      case _ => scala.sys.error("Not handled in measureChildrenDepth : " + ex)
    }
    
    rec(expression,Map.empty)
  }

  private val random = new scala.util.Random()

  def randomValue(v: Variable) : Expr = randomValue(v.getType)
  def simplestValue(v: Variable) : Expr = simplestValue(v.getType)

  def randomValue(tpe: TypeTree) : Expr = tpe match {
    case Int32Type => IntLiteral(random.nextInt(42))
    case BooleanType => BooleanLiteral(random.nextBoolean())
    case AbstractClassType(acd) =>
      val children = acd.knownChildren
      randomValue(classDefToClassType(children(random.nextInt(children.size))))
    case CaseClassType(cd) =>
      val fields = cd.fields
      CaseClass(cd, fields.map(f => randomValue(f.getType)))
    case _ => throw new Exception("I can't choose random value for type " + tpe)
  }

  def simplestValue(tpe: TypeTree) : Expr = tpe match {
    case Int32Type => IntLiteral(0)
    case BooleanType => BooleanLiteral(false)
    case AbstractClassType(acd) => {
      val children = acd.knownChildren
      val simplerChildren = children.filter{
        case ccd @ CaseClassDef(id, Some(parent), fields) =>
          !fields.exists(vd => vd.getType match {
            case AbstractClassType(fieldAcd) => acd == fieldAcd
            case CaseClassType(fieldCcd) => ccd == fieldCcd
            case _ => false
          })
        case _ => false
      }
      def orderByNumberOfFields(fst: ClassTypeDef, snd: ClassTypeDef) : Boolean = (fst, snd) match {
        case (CaseClassDef(_, _, flds1), CaseClassDef(_, _, flds2)) => flds1.size <= flds2.size
        case _ => true
      }
      val orderedChildren = simplerChildren.sortWith(orderByNumberOfFields)
      simplestValue(classDefToClassType(orderedChildren.head))
    }
    case CaseClassType(ccd) =>
      val fields = ccd.fields
      CaseClass(ccd, fields.map(f => simplestValue(f.getType)))
    case SetType(baseType) => FiniteSet(Seq()).setType(tpe)
    case MapType(fromType, toType) => FiniteMap(Seq()).setType(tpe)
    case TupleType(tpes) => Tuple(tpes.map(simplestValue))
    case ArrayType(tpe) => ArrayFill(IntLiteral(0), simplestValue(tpe))
    case _ => throw new Exception("I can't choose simplest value for type " + tpe)
  }

  //guarentee that all IfExpr will be at the top level and as soon as you encounter a non-IfExpr, then no more IfExpr can be found in the sub-expressions
  //require no-match, no-ets and only pure code
  def hoistIte(expr: Expr): Expr = {
    def transform(expr: Expr): Option[Expr] = expr match {
      case uop@UnaryOperator(IfExpr(c, t, e), op) => Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType))
      case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType))
      case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType))
      case nop@NAryOperator(ts, op) => {
        val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false }
        if(iteIndex == -1) None else {
          val (beforeIte, startIte) = ts.splitAt(iteIndex)
          val afterIte = startIte.tail
          val IfExpr(c, t, e) = startIte.head
          Some(IfExpr(c,
            op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType),
            op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType)
          ).setType(nop.getType))
        }
      }
      case _ => None
    }

    def fix[A](f: (A) => A, a: A): A = {
      val na = f(a)
      if(a == na) a else fix(f, na)
    }
    fix(searchAndReplaceDFS(transform), expr)
  }

  def genericTransform[C](pre:  (Expr, C) => (Expr, C),
                          post: (Expr, C) => (Expr, C),
                          combiner: (Seq[C]) => C)(init: C)(expr: Expr) = {

    def rec(eIn: Expr, cIn: C): (Expr, C) = {

      val (expr, ctx) = pre(eIn, cIn)

      val (newExpr, newC) = expr match {
        case t: Terminal =>
          (expr, cIn)

        case UnaryOperator(e, builder) =>
          val (e1, c) = rec(e, ctx)
          val newE = builder(e1)

          (newE, combiner(Seq(c)))

        case BinaryOperator(e1, e2, builder) =>
          val (ne1, c1) = rec(e1, ctx)
          val (ne2, c2) = rec(e2, ctx)
          val newE = builder(ne1, ne2)

          (newE, combiner(Seq(c1, c2)))

        case NAryOperator(es, builder) =>
          val (nes, cs) = es.map{ rec(_, ctx)}.unzip
          val newE = builder(nes)

          (newE, combiner(cs))

        case e =>
          sys.error("Expression "+e+" ["+e.getClass+"] is not extractable")
      }

      post(newExpr, newC)
    }

    rec(expr, init)
  }

  private def noCombiner(subCs: Seq[Unit]) = ()

  def simpleTransform(pre: Expr => Expr, post: Expr => Expr)(expr: Expr) = {
    val newPre  = (e: Expr, c: Unit) => (pre(e), ())
    val newPost = (e: Expr, c: Unit) => (post(e), ())

    genericTransform[Unit](newPre, newPost, noCombiner)(())(expr)._1
  }

  def simplePreTransform(pre: Expr => Expr)(expr: Expr) = {
    val newPre  = (e: Expr, c: Unit) => (pre(e), ())

    genericTransform[Unit](newPre, (_, _), noCombiner)(())(expr)._1
  }

  def simplePostTransform(post: Expr => Expr)(expr: Expr) = {
    val newPost = (e: Expr, c: Unit) => (post(e), ())

    genericTransform[Unit]((e,c) => (e, None), newPost, noCombiner)(())(expr)._1
  }

  def toCNF(e: Expr): Expr = {
    def pre(e: Expr) = e match {
      case Or(Seq(l, And(Seq(r1, r2)))) =>
        And(Or(l, r1), Or(l, r2))
      case Or(Seq(And(Seq(l1, l2)), r)) =>
        And(Or(l1, r), Or(l2, r))
      case _ =>
        e
    }

    simplePreTransform(pre)(e)
  }

  /*
   * Transforms complicated Ifs into multiple nested if blocks
   * It will decompose every OR clauses, and it will group AND clauses checking
   * isInstanceOf toghether.
   *
   *  if (a.isInstanceof[T1] && a.tail.isInstanceof[T2] && a.head == a2 || C) {
   *     T
   *  } else {
   *     E
   *  }
   *
   * Becomes:
   *
   *  if (a.isInstanceof[T1] && a.tail.isInstanceof[T2]) {
   *    if (a.head == a2) {
   *      T
   *    } else {
   *      if(C) {
   *        T
   *      } else {
   *        E
   *      }
   *    }
   *  } else {
   *    if(C) {
   *      T
   *    } else {
   *      E
   *    }
   *  }
   * 
   * This transformation runs immediately before patternMatchReconstruction.
   */
  def decomposeIfs(e: Expr): Expr = {
    def pre(e: Expr): Expr = e match {
      case IfExpr(cond, thenn, elze) =>
        val TopLevelOrs(orcases) = cond

        if (orcases.exists{ case TopLevelAnds(ands) => ands.exists(_.isInstanceOf[CaseClassInstanceOf]) } ) {
          if (!orcases.tail.isEmpty) {
            pre(IfExpr(orcases.head, thenn, IfExpr(Or(orcases.tail), thenn, elze)))
          } else {
            val TopLevelAnds(andcases) = orcases.head

            val (andis, andnotis) = andcases.partition(_.isInstanceOf[CaseClassInstanceOf])

            if (andis.isEmpty || andnotis.isEmpty) {
              e
            } else {
              IfExpr(And(andis), IfExpr(And(andnotis), thenn, elze), elze)
            }
          }
        } else {
          e
        }
      case _ =>
        e
    }

    simplePreTransform(pre)(e)
  }

  // This transformation assumes IfExpr of the form generated by decomposeIfs
  def patternMatchReconstruction(e: Expr): Expr = {
    def pre(e: Expr): Expr = e match {
      case IfExpr(cond, thenn, elze) =>
        val TopLevelAnds(cases) = cond

        if (cases.forall(_.isInstanceOf[CaseClassInstanceOf])) {
          // matchingOn might initially be: a : T1, a.tail : T2, b: T2
          def selectorDepth(e: Expr): Int = e match {
            case cd: CaseClassSelector =>
              1+selectorDepth(cd.caseClass)
            case _ =>
              0
          }

          var scrutSet = Set[Expr]()
          var conditions = Map[Expr, CaseClassDef]()

          var matchingOn = cases.collect { case cc : CaseClassInstanceOf => cc } sortBy(cc => selectorDepth(cc.expr))
          for (CaseClassInstanceOf(cd, expr) <- matchingOn) {
            conditions += expr -> cd

            expr match {
              case cd: CaseClassSelector =>
                if (!scrutSet.contains(cd.caseClass)) {
                  // we found a test looking like "a.foo.isInstanceof[..]"
                  // without a check on "a".
                  scrutSet += cd
                }
              case e =>
                scrutSet += e
            }
          }

          var substMap = Map[Expr, Expr]()


          def computePatternFor(cd: CaseClassDef, prefix: Expr): Pattern = {

            val name = prefix match {
              case CaseClassSelector(_, _, id) => id.name
              case Variable(id) => id.name
              case _ => "tmp"
            }

            val binder = FreshIdentifier(name, true).setType(prefix.getType) // Is it full of women though?

            // prefix becomes binder
            substMap += prefix -> Variable(binder)
            substMap += CaseClassInstanceOf(cd, prefix) -> BooleanLiteral(true)

            val subconds = for (id <- cd.fieldsIds) yield {
              val fieldSel = CaseClassSelector(cd, prefix, id)
              if (conditions contains fieldSel) {
                computePatternFor(conditions(fieldSel), fieldSel)
              } else {
                val b = FreshIdentifier(id.name, true).setType(id.getType)
                substMap += fieldSel -> Variable(b)
                WildcardPattern(Some(b))
              }
            }

            CaseClassPattern(Some(binder), cd, subconds)
          }

          val (scrutinees, patterns) = scrutSet.toSeq.map(s => (s, computePatternFor(conditions(s), s))).unzip

          val (scrutinee, pattern) = if (scrutinees.size > 1) {
            (Tuple(scrutinees), TuplePattern(None, patterns))
          } else {
            (scrutinees.head, patterns.head)
          }

          // We use searchAndReplace to replace the biggest match first
          // (topdown).
          // So replaceing using Map(a => b, CC(a) => d) will replace
          // "CC(a)" by "d" and not by "CC(b)"
          val newThen = searchAndReplace(substMap.get)(thenn)

          // Remove unused binders
          val vars = variablesOf(newThen)

          def simplerBinder(oid: Option[Identifier]) = oid.filter(vars(_))

          def simplifyPattern(p: Pattern): Pattern = p match {
            case CaseClassPattern(ob, cd, subpatterns) =>
              CaseClassPattern(simplerBinder(ob), cd, subpatterns map simplifyPattern)
            case WildcardPattern(ob) =>
              WildcardPattern(simplerBinder(ob))
            case TuplePattern(ob, patterns) =>
              TuplePattern(simplerBinder(ob), patterns map simplifyPattern)
            case _ =>
              p
          }

          MatchExpr(scrutinee, Seq(SimpleCase(simplifyPattern(pattern), newThen), SimpleCase(WildcardPattern(None), elze))).setType(e.getType)
        } else {
          e
        }
      case _ =>
        e
    }

    simplePreTransform(pre)(e)
  }

  def simplifyTautologies(solver : Solver)(expr : Expr) : Expr = {
    def pre(e : Expr) = e match {

      case LetDef(fd, expr) if fd.hasPrecondition =>
       val pre = fd.precondition.get 

        solver.solve(pre) match {
          case Some(true)  =>
            fd.precondition = None
            
          case Some(false) => solver.solve(Not(pre)) match {
            case Some(true) =>
              fd.precondition = Some(BooleanLiteral(false))
            case _ =>
          }
          case None =>
        }

        e

      case IfExpr(cond, thenn, elze) => 
        try {
          solver.solve(cond) match {
            case Some(true)  => thenn
            case Some(false) => solver.solve(Not(cond)) match {
              case Some(true) => elze
              case _ => e
            }
            case None => e
          }
        } catch {
          // let's give up when the solver crashes
          case _ : Exception => e 
        }

      case _ => e
    }

    simplePreTransform(pre)(expr)
  }

  def simplifyPaths(solver : Solver): Expr => Expr = {
    new SimplifierWithPaths(solver).transform _
  }

  trait Transformer {
    def transform(e: Expr): Expr
  }

  trait Traverser[T] {
    def traverse(e: Expr): T
  }

  abstract class TransformerWithPC extends Transformer {
    type C

    protected val initC: C

    protected def register(cond: Expr, path: C): C

    protected def rec(e: Expr, path: C): Expr = e match {
      case Let(i, e, b) =>
        val se = rec(e, path)
        val sb = rec(b, register(Equals(Variable(i), se), path))
        Let(i, se, sb)

      case MatchExpr(scrut, cases) =>
        val rs = rec(scrut, path)

        var soFar = path

        MatchExpr(rs, cases.map { c =>
          val patternExpr = conditionForPattern(rs, c.pattern, includeBinders = true)

          val subPath = register(patternExpr, soFar)
          soFar = register(Not(patternExpr), soFar)

          c match {
            case SimpleCase(p, rhs) =>
              SimpleCase(p, rec(rhs, subPath))
            case GuardedCase(p, g, rhs) =>
              GuardedCase(p, g, rec(rhs, subPath))
          }
        })

      case LetTuple(is, e, b) =>
        val se = rec(e, path)
        val sb = rec(b, register(Equals(Tuple(is.map(Variable(_))), se), path))
        LetTuple(is, se, sb)

      case IfExpr(cond, thenn, elze) =>
        val rc = rec(cond, path)

        IfExpr(rc, rec(thenn, register(rc, path)), rec(elze, register(Not(rc), path)))

      case And(es) => {
        var soFar = path
        And(for(e <- es) yield {
          val se = rec(e, soFar)
          soFar = register(se, soFar)
          se
        })
      }

      case Or(es) => {
        var soFar = path
        Or(for(e <- es) yield {
          val se = rec(e, soFar)
          soFar = register(Not(se), soFar)
          se
        })
      }


      case UnaryOperator(e, builder) =>
        builder(rec(e, path))

      case BinaryOperator(e1, e2, builder) =>
        builder(rec(e1, path), rec(e2, path))

      case NAryOperator(es, builder) =>
        builder(es.map(rec(_, path)))

      case t : Terminal => t

      case _ =>
        sys.error("Expression "+e+" ["+e.getClass+"] is not extractable")
    }

    def transform(e: Expr): Expr = {
      rec(e, initC)
    }
  }

  class SimplifierWithPaths(solver: Solver) extends TransformerWithPC {
    type C = List[Expr]

    val initC = Nil

    protected def register(e: Expr, c: C) = e :: c

    def impliedBy(e : Expr, path : Seq[Expr]) : Boolean = try {
      solver.solve(Implies(And(path), e)) match {
        case Some(true) => true
        case _ => false
      }
    } catch {
      case _ : Exception => false
    }

    def contradictedBy(e : Expr, path : Seq[Expr]) : Boolean = try {
      solver.solve(Implies(And(path), Not(e))) match {
        case Some(true) => true
        case _ => false
      }
    } catch {
      case _ : Exception => false
    }

    protected override def rec(e: Expr, path: C) = e match {
      case IfExpr(cond, thenn, elze) =>
        super.rec(e, path) match {
          case IfExpr(BooleanLiteral(true) , t, _) => t
          case IfExpr(BooleanLiteral(false), _, e) => e
          case ite => ite
        }

      case And(es) => {
        var soFar = path
        var continue = true
        var r = And(for(e <- es if continue) yield {
          val se = rec(e, soFar)
          if(se == BooleanLiteral(false)) continue = false
          soFar = register(se, soFar)
          se
        })

        if (continue) {
          r
        } else {
          BooleanLiteral(false)
        }
      }

      case MatchExpr(scrut, cases) =>
        val rs = rec(scrut, path)

        var stillPossible = true

        if (cases.exists(_.hasGuard)) {
          // unsupported for now
          e
        } else {
          MatchExpr(rs, cases.flatMap { c =>
            val patternExpr = conditionForPattern(rs, c.pattern, includeBinders = true)

            if (stillPossible && !contradictedBy(patternExpr, path)) {

              if (impliedBy(patternExpr, path)) {
                stillPossible = false
              }

              c match {
                case SimpleCase(p, rhs) =>
                  Some(SimpleCase(p, rec(rhs, patternExpr +: path)))
                case GuardedCase(_, _, _) =>
                  sys.error("woot.")
              }
            } else {
              None
            }
          })
        }

      case Or(es) => {
        var soFar = path
        var continue = true
        var r = Or(for(e <- es if continue) yield {
          val se = rec(e, soFar)
          if(se == BooleanLiteral(true)) continue = false
          soFar = register(Not(se), soFar)
          se
        })

        if (continue) {
          r
        } else {
          BooleanLiteral(true)
        }
      }

      case b if b.getType == BooleanType && impliedBy(b, path) =>
        BooleanLiteral(true)

      case b if b.getType == BooleanType && contradictedBy(b, path) =>
        BooleanLiteral(false)

      case _ =>
        super.rec(e, path)
    }
  }

  class ChooseCollectorWithPaths extends TransformerWithPC with Traverser[Seq[(Choose, Expr)]] {
    type C = Seq[Expr]
    val initC = Nil
    def register(e: Expr, path: C) = path :+ e

    var results: Seq[(Choose, Expr)] = Nil

    override def rec(e: Expr, path: C) = e match {
      case c : Choose =>
        results = results :+ (c, And(path))
        c
      case _ =>
        super.rec(e, path)
    }

    def traverse(e: Expr) = {
      results = Nil
      rec(e, initC)
      results
    }
  }

  class ScopeSimplifier extends Transformer {

    case class Scope(inScope: Set[Identifier] = Set(), oldToNew: Map[Identifier, Identifier] = Map(), funDefs: Map[FunDef, FunDef] = Map()) {

      def register(oldNew: (Identifier, Identifier)): Scope = {
        val (oldId, newId) = oldNew
        copy(inScope = inScope + newId, oldToNew = oldToNew + oldNew)
      }

      def registerFunDef(oldNew: (FunDef, FunDef)): Scope = {
        copy(funDefs = funDefs + oldNew)
      }
    }

    protected def genId(id: Identifier, scope: Scope): Identifier = {
      val existCount = scope.inScope.count(_.name == id.name)

      FreshIdentifier(id.name, existCount).setType(id.getType)
    }

    protected def rec(e: Expr, scope: Scope): Expr = e match {
      case Let(i, e, b) =>
        val si = genId(i, scope)
        val se = rec(e, scope)
        val sb = rec(b, scope.register(i -> si))
        Let(si, se, sb)

      case LetTuple(is, e, b) =>
        var newScope = scope
        val sis = for (i <- is) yield {
          val si = genId(i, newScope)
          newScope = newScope.register(i -> si)
          si
        }

        val se = rec(e, scope)
        val sb = rec(b, newScope)
        LetTuple(sis, se, sb)

      case MatchExpr(scrut, cases) =>
        val rs = rec(scrut, scope)

        def trPattern(p: Pattern, scope: Scope): (Pattern, Scope) = {
          val (newBinder, newScope) = p.binder match {
            case Some(id) =>
              val newId = genId(id, scope)
              val newScope = scope.register(id -> newId)
              (Some(newId), newScope)
            case None =>
              (None, scope)
          }

          var curScope = newScope
          var newSubPatterns = for (sp <- p.subPatterns) yield {
            val (subPattern, subScope) = trPattern(sp, curScope)
            curScope = subScope
            subPattern
          }

          val newPattern = p match {
            case InstanceOfPattern(b, ctd) =>
              InstanceOfPattern(newBinder, ctd)
            case WildcardPattern(b) =>
              WildcardPattern(newBinder)
            case CaseClassPattern(b, ccd, sub) =>
              CaseClassPattern(newBinder, ccd, newSubPatterns)
            case TuplePattern(b, sub) =>
              TuplePattern(newBinder, newSubPatterns)
          }


          (newPattern, curScope)
        }

        MatchExpr(rs, cases.map { c =>
          val (newP, newScope) = trPattern(c.pattern, scope)

          c match {
            case SimpleCase(p, rhs) =>
              SimpleCase(newP, rec(rhs, newScope))
            case GuardedCase(p, g, rhs) =>
              GuardedCase(newP, rec(g, newScope), rec(rhs, newScope))
          }
        })

      case Variable(id) =>
        Variable(scope.oldToNew.getOrElse(id, id))

      case FunctionInvocation(fd, args) =>
        val newFd = scope.funDefs.getOrElse(fd, fd)
        val newArgs = args.map(rec(_, scope))

        FunctionInvocation(newFd, newArgs)

      case UnaryOperator(e, builder) =>
        builder(rec(e, scope))

      case BinaryOperator(e1, e2, builder) =>
        builder(rec(e1, scope), rec(e2, scope))

      case NAryOperator(es, builder) =>
        builder(es.map(rec(_, scope)))

      case t : Terminal => t

      case _ =>
        sys.error("Expression "+e+" ["+e.getClass+"] is not extractable")
    }

    def transform(e: Expr): Expr = {
      rec(e, Scope())
    }
  }

  // Eliminates tuples of arity 0 and 1. This function also affects types!
  // Only rewrites local fundefs (i.e. LetDef's).
  def rewriteTuples(expr: Expr) : Expr = {
    def mapType(tt : TypeTree) : Option[TypeTree] = tt match {
      case TupleType(ts) => ts.size match {
        case 0 => Some(UnitType)
        case 1 => Some(ts(0))
        case _ =>
          val tss = ts.map(mapType)
          if(tss.exists(_.isDefined)) {
            Some(TupleType((tss zip ts).map(p => p._1.getOrElse(p._2))))
          } else {
            None
          }
      }
      case ListType(t)           => mapType(t).map(ListType(_))
      case SetType(t)            => mapType(t).map(SetType(_))
      case MultisetType(t)       => mapType(t).map(MultisetType(_))
      case ArrayType(t)          => mapType(t).map(ArrayType(_))
      case MapType(f,t)          => 
        val (f2,t2) = (mapType(f),mapType(t))
        if(f2.isDefined || t2.isDefined) {
          Some(MapType(f2.getOrElse(f), t2.getOrElse(t)))
        } else {
          None
        }
      case a : AbstractClassType => None
      case c : CaseClassType     =>
        // This is really just one big assertion. We don't rewrite class defs.
        val ccd = c.classDef
        val fieldTypes = ccd.fields.map(_.tpe)
        if(fieldTypes.exists(t => t match {
          case TupleType(ts) if ts.size <= 1 => true
          case _ => false
        })) {
          scala.sys.error("Cannot rewrite case class def that contains degenerate tuple types.")
        } else {
          None
        }
      case Untyped | AnyType | BottomType | BooleanType | Int32Type | UnitType => None  
    }

    var funDefMap = Map.empty[FunDef,FunDef]

    def fd2fd(funDef : FunDef) : FunDef = funDefMap.get(funDef) match {
      case Some(fd) => fd
      case None =>
        if(funDef.args.map(vd => mapType(vd.tpe)).exists(_.isDefined)) {
          scala.sys.error("Cannot rewrite function def that takes degenerate tuple arguments,")
        }
        val newFD = mapType(funDef.returnType) match {
          case None => funDef
          case Some(rt) =>
            val fd = new FunDef(FreshIdentifier(funDef.id.name, true), rt, funDef.args)
            // These will be taken care of in the recursive traversal.
            fd.body = funDef.body
            fd.precondition = funDef.precondition
            fd.postcondition = funDef.postcondition
            fd
        }
        funDefMap = funDefMap.updated(funDef, newFD)
        newFD
    }

    def pre(e : Expr) : Expr = e match {
      case Tuple(Seq()) => UnitLiteral

      case Tuple(Seq(s)) => pre(s)

      case ts @ TupleSelect(t, 1) => t.getType match {
        case TupleOneType(_) => pre(t)
        case _ => ts
      }

      case LetTuple(bs, v, bdy) if bs.size == 1 =>
        Let(bs(0), v, bdy)

      case l @ LetDef(fd, bdy) =>
        LetDef(fd2fd(fd), bdy)

      case r @ ResultVariable() =>
        mapType(r.getType).map { newType =>
          ResultVariable().setType(newType)
        } getOrElse {
          r
        }

      case FunctionInvocation(fd, args) =>
        FunctionInvocation(fd2fd(fd), args)

      case _ => e
    }

    simplePreTransform(pre)(expr)
  }

  def formulaSize(e: Expr): Int = e match {
    case t: Terminal =>
      1

    case UnaryOperator(e, builder) =>
      formulaSize(e)+1

    case BinaryOperator(e1, e2, builder) =>
      formulaSize(e1)+formulaSize(e2)+1

    case NAryOperator(es, _) =>
      es.map(formulaSize).foldRight(0)(_ + _)+1
  }

  def collect[C](f: PartialFunction[Expr, C])(e: Expr): List[C] = {
    def post(e: Expr, cs: List[C]) = {
      if (f.isDefinedAt(e)) {
        (e, f(e) :: cs)
      } else {
        (e, cs)
      }
    }

    def combiner(cs: Seq[List[C]]) = {
      cs.foldLeft(List[C]())(_ ::: _)
    }

    genericTransform[List[C]]((_, _), post, combiner)(List())(e)._2
  }

  def collectChooses(e: Expr): List[Choose] = {
    new ChooseCollectorWithPaths().traverse(e).map(_._1).toList
  }

  def containsChoose(e: Expr): Boolean = {
    simplePreTransform{
      case Choose(_, _) => return true
      case e => e
    }(e)
    false
  }

  def valuateWithModel(model: Map[Identifier, Expr])(id: Identifier): Expr = {
    model.getOrElse(id, simplestValue(id.getType))
  }

  def valuateWithModelIn(expr: Expr, vars: Set[Identifier], model: Map[Identifier, Expr]): Expr = {
    val valuator = valuateWithModel(model) _
    replace(vars.map(id => Variable(id) -> valuator(id)).toMap, expr)
  }

  //simple, local simplifications on arithmetic
  //you should not assume anything smarter than some constant folding and simple cancelation
  //to avoid infinite cycle we only apply simplification that reduce the size of the tree
  //The only guarentee from this function is to not augment the size of the expression and to be sound 
  //(note that an identity function would meet this specification)
  def simplifyArithmetic(expr: Expr): Expr = {
    def simplify0(expr: Expr): Expr = expr match {
      case Plus(IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 + i2)
      case Plus(IntLiteral(0), e) => e
      case Plus(e, IntLiteral(0)) => e
      case Plus(e1, UMinus(e2)) => Minus(e1, e2)
      case Plus(Plus(e, IntLiteral(i1)), IntLiteral(i2)) => Plus(e, IntLiteral(i1+i2))
      case Plus(Plus(IntLiteral(i1), e), IntLiteral(i2)) => Plus(IntLiteral(i1+i2), e)

      case Minus(e, IntLiteral(0)) => e
      case Minus(IntLiteral(0), e) => UMinus(e)
      case Minus(IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 - i2)
      case Minus(e1, UMinus(e2)) => Plus(e1, e2)
      case Minus(e1, Minus(UMinus(e2), e3)) => Plus(e1, Plus(e2, e3))

      case UMinus(IntLiteral(x)) => IntLiteral(-x)
      case UMinus(UMinus(x)) => x
      case UMinus(Plus(UMinus(e1), e2)) => Plus(e1, UMinus(e2))
      case UMinus(Minus(e1, e2)) => Minus(e2, e1)

      case Times(IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 * i2)
      case Times(IntLiteral(1), e) => e
      case Times(IntLiteral(-1), e) => UMinus(e)
      case Times(e, IntLiteral(1)) => e
      case Times(IntLiteral(0), _) => IntLiteral(0)
      case Times(_, IntLiteral(0)) => IntLiteral(0)
      case Times(IntLiteral(i1), Times(IntLiteral(i2), t)) => Times(IntLiteral(i1*i2), t)
      case Times(IntLiteral(i1), Times(t, IntLiteral(i2))) => Times(IntLiteral(i1*i2), t)
      case Times(IntLiteral(i), UMinus(e)) => Times(IntLiteral(-i), e)
      case Times(UMinus(e), IntLiteral(i)) => Times(e, IntLiteral(-i))
      case Times(IntLiteral(i1), Division(e, IntLiteral(i2))) if i2 != 0 && i1 % i2 == 0 => Times(IntLiteral(i1/i2), e)

      case Division(IntLiteral(i1), IntLiteral(i2)) if i2 != 0 => IntLiteral(i1 / i2)
      case Division(e, IntLiteral(1)) => e

      //here we put more expensive rules
      //btw, I know those are not the most general rules, but they lead to good optimizations :)
      case Plus(UMinus(Plus(e1, e2)), e3) if e1 == e3 => UMinus(e2)
      case Plus(UMinus(Plus(e1, e2)), e3) if e2 == e3 => UMinus(e1)
      case Minus(e1, e2) if e1 == e2 => IntLiteral(0) 
      case Minus(Plus(e1, e2), Plus(e3, e4)) if e1 == e4 && e2 == e3 => IntLiteral(0)
      case Minus(Plus(e1, e2), Plus(Plus(e3, e4), e5)) if e1 == e4 && e2 == e3 => UMinus(e5)

      //default
      case e => e
    }
    def fix[A](f: (A) => A)(a: A): A = {
      val na = f(a)
      if(a == na) a else fix(f)(na)
    }
      

    val res = fix(simplePostTransform(simplify0))(expr)
    res
  }

  def expandAndSimplifyArithmetic(expr: Expr): Expr = {
    val expr0 = try {
      val freeVars: Array[Identifier] = variablesOf(expr).toArray
      val coefs: Array[Expr] = TreeNormalizations.linearArithmeticForm(expr, freeVars)
      coefs.toList.zip(IntLiteral(1) :: freeVars.toList.map(Variable(_))).foldLeft[Expr](IntLiteral(0))((acc, t) => {
        if(t._1 == IntLiteral(0)) acc else Plus(acc, Times(t._1, t._2))
      })
    } catch {
      case _: Throwable =>
        expr
    }
    simplifyArithmetic(expr0)
  }

  //If the formula consist of some top level AND, find a top level
  //Equals and extract it, return the remaining formula as well
  def extractEquals(expr: Expr): (Option[Equals], Expr) = expr match {
    case And(es) =>
      // OK now I'm just messing with you.
      val (r, nes) = es.foldLeft[(Option[Equals],Seq[Expr])]((None, Seq())) {
        case ((None, nes), eq @ Equals(_,_)) => (Some(eq), nes)
        case ((o, nes), e) => (o, e +: nes)
      }
      (r, And(nes.reverse))

    case e => (None, e)
  }

  def isInductiveOn(solver: Solver)(expr: Expr, on: Identifier): Boolean = on match {
    case IsTyped(origId, AbstractClassType(cd)) =>
      def isAlternativeRecursive(cd: CaseClassDef): Boolean = {
        cd.fieldsIds.exists(_.getType == origId.getType)
      }

      val toCheck = cd.knownDescendents.collect {
        case ccd: CaseClassDef =>
          val isType = CaseClassInstanceOf(ccd, Variable(on))

            val recSelectors = ccd.fieldsIds.filter(_.getType == on.getType)

            if (recSelectors.isEmpty) {
              Seq()
            } else {
              val v = Variable(on)

              recSelectors.map{ s =>
                And(And(isType, expr), Not(replace(Map(v -> CaseClassSelector(ccd, v, s)), expr)))
              }
            }
      }.flatten

      toCheck.forall { cond =>
        solver.solveSAT(cond) match {
            case (Some(false), _)  =>
              true
            case (Some(true), model)  =>
              false
            case (None, _)  =>
              // Should we be optimistic here?
              false
        }
      }
    case _ =>
      false
  }

}