diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala
index bd8dc3252ec37aba3d3102564bcea482652288c7..0533b568d8cde5e844e44261471d0d044c6be39b 100644
--- a/src/purescala/Analysis.scala
+++ b/src/purescala/Analysis.scala
@@ -141,8 +141,7 @@ object Analysis {
   def inlineFunctionsAndContracts(program: Program, expr: Expr) : Expr = {
     var extras : List[Expr] = Nil
 
-    val isFunCall: Function[Expr,Boolean] = _.isInstanceOf[FunctionInvocation]
-    def applyToCall(e: Expr) : Expr = e match {
+    def applyToCall(e: Expr) : Option[Expr] = e match {
       case f @ FunctionInvocation(fd, args) => {
         val fArgsAsVars: List[Variable] = fd.args.map(_.toVariable).toList
         val fParamsAsLetVars: List[Identifier] = fd.args.map(a => FreshIdentifier("arg", true).setType(a.tpe)).toList
@@ -159,17 +158,17 @@ object Analysis {
             replace(substMap + (ResultVariable() -> newVar), fd.postcondition.get),
             Equals(newVar, FunctionInvocation(fd, fParamsAsLetVarVars).setType(fd.returnType))
           )) :: extras
-          newVar
+          Some(newVar)
         } else if(fd.hasImplementation && !program.isRecursive(fd)) { // means we can inline at least one level...
-          mkBigLet(replace(substMap, fd.body.get))
+          Some(mkBigLet(replace(substMap, fd.body.get)))
         } else { // we can't do much for calls to recursive functions or to functions with no bodies
-          f 
+          None
         }
       }
-      case o => o
+      case o => None
     }
 
-    val finalE = searchAndApply(isFunCall, applyToCall, expr)
+    val finalE = searchAndReplace(applyToCall)(expr)
     pulloutLets(Implies(And(extras.reverse), finalE))
   }
 
@@ -181,11 +180,7 @@ object Analysis {
       var extras : List[Expr] = Nil
 
       def urf(expr: Expr, left: Int) : Expr = {
-        def isRecursiveCall(e: Expr) = e match {
-          case f @ FunctionInvocation(fd, _) if fd.hasImplementation && program.isRecursive(fd) => true
-          case _ => false
-        }
-        def unrollCall(t: Int)(e: Expr) = e match {
+        def unrollCall(t: Int)(e: Expr) : Option[Expr] = e match {
           case f @ FunctionInvocation(fd, args) if fd.hasImplementation && program.isRecursive(fd) => {
             val newLetIDs = fd.args.map(a => FreshIdentifier(a.id.name, true).setType(a.tpe))
             val newLetVars = newLetIDs.map(Variable(_))
@@ -205,17 +200,17 @@ object Analysis {
               // println(" --- newVar is ------------------")
               // println(newVar)
               // println("*********************************")
-              newVar
+              Some(newVar)
             } else {
               val bigLet = (newLetIDs zip args).foldLeft(bodyWithLetVars)((e,p) => Let(p._1, p._2, e))
-              urf(bigLet, t-1)
+              Some(urf(bigLet, t-1))
             }
           }
-          case o => o
+          case o => None
         }
 
         if(left > 0)
-          searchAndApply(isRecursiveCall, unrollCall(left), expr, false)
+          searchAndReplace(unrollCall(left), false)(expr)
         else
           expr
       }
@@ -238,12 +233,8 @@ object Analysis {
     def rspm(expr: Expr) : (Expr,Seq[Expr]) = {
       var extras : List[Expr] = Nil
 
-      def isPMExpr(e: Expr) : Boolean = {
-        e.isInstanceOf[MatchExpr]
-      }
-
-      def rewritePM(e: Expr) : Expr = e.asInstanceOf[MatchExpr] match {
-        case SimplePatternMatching(scrutinee, classType, casesInfo) => {
+      def rewritePM(e: Expr) : Option[Expr] = e match {
+        case SimplePatternMatching(scrutinee, classType, casesInfo) => Some({
           val newVar = Variable(FreshIdentifier("pm", true)).setType(e.getType)
           val scrutAsLetID = FreshIdentifier("scrut", true).setType(scrutinee.getType)
           val lle : List[(Variable,List[Expr])] = casesInfo.map(cseInfo => {
@@ -256,11 +247,11 @@ object Analysis {
           val (newPVars, newExtras) = lle.unzip
           extras = Let(scrutAsLetID, scrutinee, And(Or(newPVars.map(Equals(Variable(scrutAsLetID), _))), And(newExtras.flatten))) :: extras
           newVar
-        }
-        case _ => e
+        })
+        case _ => None
       }
       
-      val cleanerTree = searchAndApply(isPMExpr, rewritePM, expr) 
+      val cleanerTree = searchAndReplace(rewritePM)(expr) 
       (cleanerTree, extras.reverse)
     }
     val (savedLets, naked) = pulloutAndKeepLets(expression)
diff --git a/src/purescala/Definitions.scala b/src/purescala/Definitions.scala
index 32b2d23d5310813ebac43b8cfa275afd299cb6ef..90d418bb7604291ea0b235c55bb427805821d50d 100644
--- a/src/purescala/Definitions.scala
+++ b/src/purescala/Definitions.scala
@@ -51,16 +51,15 @@ object Definitions {
       var resSet: Set[(FunDef,FunDef)] =
         new scala.collection.immutable.HashSet[(FunDef,FunDef)]()
 
-      def isFunCall(e: Expr) : Boolean = e.isInstanceOf[FunctionInvocation]
-      def applyToFunCall(f1: FunDef)(e: Expr) : Expr = e match {
-        case f @ FunctionInvocation(f2, _) => { resSet = resSet + ((f1,f2)); f }
-        case o => o
+      def applyToFunCall(f1: FunDef)(e: Expr) : Option[Expr] = e match {
+        case f @ FunctionInvocation(f2, _) => { resSet = resSet + ((f1,f2)); Some(f) }
+        case _ => None
       }
 
       for(funDef <- definedFunctions) {
-        funDef.precondition.map(searchAndApply(isFunCall, applyToFunCall(funDef), _))
-        funDef.body.map(searchAndApply(isFunCall, applyToFunCall(funDef), _))
-        funDef.postcondition.map(searchAndApply(isFunCall, applyToFunCall(funDef), _))
+        funDef.precondition.map(searchAndReplace(applyToFunCall(funDef))(_))
+        funDef.body.map(searchAndReplace(applyToFunCall(funDef))(_))
+        funDef.postcondition.map(searchAndReplace(applyToFunCall(funDef))(_))
       }
 
       var callers: Map[FunDef,Set[FunDef]] =
diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala
index 91d0c15dea293ca881cc419bdfeeec500360bcb7..96418d14c5822672daec8205cf2a1765e427ace3 100644
--- a/src/purescala/Trees.scala
+++ b/src/purescala/Trees.scala
@@ -400,6 +400,80 @@ object Trees {
     rec(expr)
   }
 
+  def searchAndReplace(subst: Expr=>Option[Expr], recursive: Boolean=true)(expr: Expr) : Expr = {
+    def rec(ex: Expr, skip: Expr = null) : Expr = (if (ex == skip) None else subst(ex)) match {
+      case Some(newExpr) => {
+        if(newExpr.getType == NoType) {
+          Settings.reporter.warning("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr)
+        }
+        if(ex == newExpr)
+          if(recursive) rec(ex, ex) else ex
+        else
+          if(recursive) rec(newExpr) else newExpr
+      }
+      case None => ex match {
+        case l @ Let(i,e,b) => {
+          val re = rec(e)
+          val rb = rec(b)
+          if(re != e || rb != b)
+            Let(i, re, rb).setType(l.getType)
+          else
+            l
+        }
+        case n @ NAryOperator(args, recons) => {
+          var change = false
+          val rargs = args.map(a => {
+            val ra = rec(a)
+            if(ra != a) {
+              change = true  
+              ra
+            } else {
+              a
+            }            
+          })
+          if(change)
+            recons(rargs).setType(n.getType)
+          else
+            n
+        }
+        case b @ BinaryOperator(t1,t2,recons) => {
+          val r1 = rec(t1)
+          val r2 = rec(t2)
+          if(r1 != t1 || r2 != t2)
+            recons(r1,r2).setType(b.getType)
+          else
+            b
+        }
+        case u @ UnaryOperator(t,recons) => {
+          val r = rec(t)
+          if(r != t)
+            recons(r).setType(u.getType)
+          else
+            u
+        }
+        case i @ IfExpr(t1,t2,t3) => {
+          val r1 = rec(t1)
+          val r2 = rec(t2)
+          val r3 = rec(t3)
+          if(r1 != t1 || r2 != t2 || r3 != t3)
+            IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType)
+          else
+            i
+        }
+        case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType)
+        case t if t.isInstanceOf[Terminal] => t
+        case unhandled => scala.Predef.error("Non-terminal case should be handled in searchAndApply: " + unhandled)
+      }
+    }
+
+    def inCase(cse: MatchCase) : MatchCase = cse match {
+      case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs))
+      case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard), rec(rhs))
+    }
+
+    rec(expr)
+  }
+
   /* Simplifies let expressions:
    *  - removes lets when expression never occurs
    *  - simplifies when expressions occurs exactly once