diff --git a/src/main/scala/inox/ast/Constructors.scala b/src/main/scala/inox/ast/Constructors.scala
index 16363edb3d064ca90b9b1b61010f6155da5bdaa9..0852b113235ba5cb4aed1783ff811ffeba14d4b9 100644
--- a/src/main/scala/inox/ast/Constructors.scala
+++ b/src/main/scala/inox/ast/Constructors.scala
@@ -205,7 +205,7 @@ trait Constructors {
     * @see [[Expressions.Lambda Lambda]]
     * @see [[Expressions.Application Application]]
     */
-  def application(fn: Expr, realArgs: Seq[Expr]) = fn match {
+  def application(fn: Expr, realArgs: Seq[Expr]): Expr = fn match {
      case Lambda(formalArgs, body) =>
       assert(realArgs.size == formalArgs.size, "Invoking lambda with incorrect number of arguments")
 
@@ -224,6 +224,9 @@ trait Constructors {
         case ((vd, bd), body) => let(vd, bd, body)
       }
 
+    case Assume(pred, l: Lambda) =>
+      assume(pred, application(l, realArgs))
+
     case _ =>
       Application(fn, realArgs)
    }
diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala
index 583bf1cae2b63ef9ca3738b2d299796fe1ab6112..cb072f84ce2d8d6d602168c6f84b0cc25daaffc4 100644
--- a/src/main/scala/inox/ast/SymbolOps.scala
+++ b/src/main/scala/inox/ast/SymbolOps.scala
@@ -258,7 +258,7 @@ trait SymbolOps { self: TypeOps =>
 
         forall(args ++ allArgs.flatten, recons(allBodies))
       }
-      
+
       postMap {
         case Forall(args1, Forall(args2, body)) =>
           Some(forall(args1 ++ args2, body))
@@ -746,32 +746,35 @@ trait SymbolOps { self: TypeOps =>
     })
   }
 
-  def simplifyAssumptions(expr: Expr): Expr = {
-    def lift(expr: Expr): Expr = {
-      val vars = variablesOf(expr)
-      var assumptions: Seq[Expr] = Seq.empty
-
-      object transformer extends transformers.TransformerWithPC {
-        val trees: self.trees.type = self.trees
-        val symbols: self.symbols.type = self.symbols
-        val initEnv = Path.empty
-
-        override protected def rec(e: Expr, path: Path): Expr = e match {
-          case Assume(pred, body) if (variablesOf(pred) ++ path.variables) subsetOf vars =>
-            assumptions :+= path implies pred
-            rec(body, path withCond pred)
-          case _ => super.rec(e, path)
-        }
-      }
+  def liftAssumptions(expr: Expr): (Seq[Expr], Expr) = {
+    val vars = variablesOf(expr)
+    var assumptions: Seq[Expr] = Seq.empty
+
+    object transformer extends transformers.TransformerWithPC {
+      val trees: self.trees.type = self.trees
+      val symbols: self.symbols.type = self.symbols
+      val initEnv = Path.empty
 
-      val (vs, es, tps, recons) = deconstructor.deconstruct(expr)
-      val newEs = es.map(transformer.transform)
-      assume(andJoin(assumptions.toSeq), recons(vs, newEs, tps))
+      override protected def rec(e: Expr, path: Path): Expr = e match {
+        case Assume(pred, body) if (variablesOf(pred) ++ path.variables) subsetOf vars =>
+          assumptions :+= path implies pred
+          rec(body, path withCond pred)
+        case _ => super.rec(e, path)
+      }
     }
 
-    postMap(e => Some(lift(e)))(expr)
+    val newExpr = transformer.transform(expr)
+    (assumptions, newExpr)
   }
 
+  def simplifyAssumptions(expr: Expr): Expr = postMap {
+    case Assume(pred, body) =>
+      val (predAssumptions, newPred) = liftAssumptions(pred)
+      val (bodyAssumptions, newBody) = liftAssumptions(body)
+      Some(assume(andJoin(predAssumptions ++ (newPred +: bodyAssumptions)), newBody))
+    case _ => None
+  } (expr)
+
   def simplifyFormula(e: Expr, simplify: Boolean = true): Expr = {
     if (simplify) {
       val simp: Expr => Expr =
diff --git a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala
index 1cbd408fb4ef53303e7de860bfed62f908e8a9bf..2a12e422bdf1cd1ebdcbcb13108b7f0fa14f3ccb 100644
--- a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala
+++ b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala
@@ -301,7 +301,13 @@ trait TemplateGenerator { self: Templates =>
         val idArgs : Seq[Variable] = lambdaArgs(l)
         val trArgs : Seq[Encoded] = idArgs.map(id => substMap.getOrElse(id, encodeSymbol(id)))
 
-        val (struct, deps) = normalizeStructure(l)
+        val (assumptions, without) = liftAssumptions(l)
+
+        for (a <- assumptions) {
+          rec(pathVar, a, Some(true))
+        }
+
+        val (struct, deps) = normalizeStructure(without.asInstanceOf[Lambda])
         val sortedDeps = exprOps.variablesOf(struct).map(v => v -> deps(v)).toSeq.sortBy(_._1.id.uniqueName)
 
         val isNormalForm: Boolean = {