diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index 23dcdd43cf0adb469e65a86925a17ac20fb9613c..7b0a88a4e4fe08d113543121c4040ffaceed69cd 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -32,9 +32,12 @@ trait CodeExtraction extends ASTExtractors {
     if (p.isRange) {
       val start = p.focusStart
       val end   = p.focusEnd
-      LeonRangePosition(start.line, start.column, end.line, end.column, p.source.file.file)
+      LeonRangePosition(start.line, start.column, start.point,
+                        end.line, end.column, end.point,
+                        p.source.file.file)
     } else {
-      LeonOffsetPosition(p.line, p.column, p.source.file.file)
+      LeonOffsetPosition(p.line, p.column, p.point,
+                         p.source.file.file)
     }
   }
 
diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala
index bc186b293d25f5e0615f916cc3638aee6d6626b7..eaa7bca9a3011abb9c684af763ea174c78bf4f62 100644
--- a/src/main/scala/leon/purescala/ScalaPrinter.scala
+++ b/src/main/scala/leon/purescala/ScalaPrinter.scala
@@ -341,7 +341,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb
 
         fd.args.foreach(arg => {
           sb.append(idToString(arg.id))
-          sb.append(" : ")
+          sb.append(": ")
           pp(arg.tpe, p)
 
           if(c < sz - 1) {
@@ -350,7 +350,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb
           c = c + 1
         })
 
-        sb.append(") : ")
+        sb.append("): ")
         pp(fd.returnType, p)
         sb.append(" = {\n")
         ind(lvl+1)
diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index 9d7a569750af2896e5afc4a46485ebee3486ad5d..02b771757bba5f0d4866165ea67e4bfd59a539df 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -2142,4 +2142,62 @@ object TreeOps {
     areExaustive(Seq((m.scrutinee.getType, patterns)))
   }
 
+  def flattenFunctions(fdOuter: FunDef): FunDef = {
+    fdOuter.body match {
+      case Some(LetDef(fdInner, FunctionInvocation(fdInner2, args))) if fdInner == fdInner2 =>
+        val argsDef  = fdOuter.args.map(_.id)
+        val argsCall = args.collect { case Variable(id) => id }
+
+        if (argsDef.toSet == argsCall.toSet) {
+          val defMap = argsDef.zipWithIndex.toMap
+          val rewriteMap = argsCall.map(defMap)
+
+          val innerIdsToOuterIds = (fdInner.args.map(_.id) zip argsCall).toMap
+
+          def pre(e: Expr) = e match {
+            case FunctionInvocation(fd, args) if fd == fdInner =>
+              val newArgs = (args zip rewriteMap).sortBy(_._2)
+              FunctionInvocation(fdOuter, newArgs.map(_._1))
+            case Variable(id) =>
+              Variable(innerIdsToOuterIds.getOrElse(id, id))
+            case _ =>
+              e
+          }
+
+          def mergePre(outer: Option[Expr], inner: Option[Expr]): Option[Expr] = (outer, inner) match {
+            case (None, Some(ie)) =>
+              Some(simplePreTransform(pre)(ie))
+            case (Some(oe), None) =>
+              Some(oe)
+            case (None, None) =>
+              None
+            case (Some(oe), Some(ie)) =>
+              Some(And(oe, simplePreTransform(pre)(ie)))
+          }
+
+          def mergePost(outer: Option[(Identifier, Expr)], inner: Option[(Identifier, Expr)]): Option[(Identifier, Expr)] = (outer, inner) match {
+            case (None, Some((iid, ie))) =>
+              Some((iid, simplePreTransform(pre)(ie)))
+            case (Some(oe), None) =>
+              Some(oe)
+            case (None, None) =>
+              None
+            case (Some((oid, oe)), Some((iid, ie))) =>
+              Some((oid, And(oe, replaceFromIDs(Map(iid -> Variable(oid)), simplePreTransform(pre)(ie)))))
+          }
+
+          val newFd = fdOuter.duplicate
+
+          newFd.body          = fdInner.body.map(b => simplePreTransform(pre)(b))
+          newFd.precondition  = mergePre(fdOuter.precondition, fdInner.precondition)
+          newFd.postcondition = mergePost(fdOuter.postcondition, fdInner.postcondition)
+
+          newFd
+        } else {
+          fdOuter
+        }
+      case _ =>
+        fdOuter
+    }
+  }
 }
diff --git a/src/main/scala/leon/synthesis/FileInterface.scala b/src/main/scala/leon/synthesis/FileInterface.scala
index 145b5b132adc2cb318f0d4ccc103b3c28bf6163b..49ff882540e8c27798bf7e50c96cb50427d44cdd 100644
--- a/src/main/scala/leon/synthesis/FileInterface.scala
+++ b/src/main/scala/leon/synthesis/FileInterface.scala
@@ -8,7 +8,7 @@ import purescala.Common.Tree
 import purescala.Definitions.FunDef
 import purescala.ScalaPrinter
 
-import leon.utils.Position
+import leon.utils.RangePosition
 
 import java.io.File
 class FileInterface(reporter: Reporter) {
@@ -33,7 +33,7 @@ class FileInterface(reporter: Reporter) {
 
         var newCode = origCode
         for ( (ci, e) <- solutions) {
-          newCode = substitute(newCode, CodePattern.forChoose(ci), e)
+          newCode = substitute(newCode, ci.ch, e)
         }
 
         val out = new BufferedWriter(new FileWriter(newFile))
@@ -44,102 +44,23 @@ class FileInterface(reporter: Reporter) {
     }
   }
 
-  case class CodePattern(startWith: String, pos: Position, blocks: Int)
+  def substitute(str: String, fromTree: Tree, toTree: Tree): String = {
 
-  object CodePattern {
-    def forChoose(ci: ChooseInfo) = CodePattern("choose", ci.ch.getPos, 1)
-    def forFunDef(fd: FunDef) = CodePattern("def", fd.getPos, 2)
-  }
-
-  def substitute(str: String, pattern: CodePattern, subst: Tree): String = {
-    var lines = List[Int]()
-
-    // Compute line positions
-    var lastFound = -1
-    do {
-      lastFound = str.indexOf('\n', lastFound+1)
-
-      if (lastFound > -1) {
-        lines = lastFound :: lines
-      }
-    } while(lastFound> 0)
-    lines = lines.reverse;
-
-    def lineOf(offset: Int): (Int, Int) = {
-      lines.zipWithIndex.find(_._1 > offset) match {
-        case Some((off, no)) =>
-          (no+1, if (no > 0) lines(no-1) else 0)
-        case None =>
-          (lines.size+1, lines.lastOption.getOrElse(0))
-      }
-    }
+    fromTree.getPos match {
+      case rp: RangePosition =>
+        val from = rp.pointFrom
+        val to   = rp.pointTo
 
-    def getLineIndentation(offset: Int): Int = {
-      var i = str.lastIndexOf('\n', offset)+1
+        val before = str.substring(0, from)
+        val after  = str.substring(to, str.length)
 
-      var res = 0;
+        val newCode = ScalaPrinter(toTree, fromTree.getPos.col/2)
 
-      while (i < str.length) {
-        val c = str.charAt(i)
-        i += 1
-
-        if (c == ' ') {
-          res += 1
-        } else if (c == '\t') {
-          res += 4
-        } else {
-          i = str.length
-        }
-      }
+        before + newCode + after
 
-      res
+      case _ =>
+        sys.error("Substitution requires RangePos on the input tree: "+fromTree)
     }
-
-    lastFound = -1
-
-    var newStr = str
-    var newStrOffset = 0
-
-    do {
-      lastFound = str.indexOf(pattern.startWith, lastFound+1)
-
-      if (lastFound > -1) {
-        val (lineno, lineoffset) = lineOf(lastFound)
-        // compute scala equivalent of the position:
-        val scalaOffset = str.substring(lineoffset, lastFound).replaceAll("\t", " "*8).length
-
-        val indent = getLineIndentation(lastFound)
-
-        if (pattern.pos.line == lineno && pattern.pos.col == scalaOffset) {
-          var lvl      = 0;
-          var i        = lastFound + 6;
-          var continue = true;
-          do {
-            var blocksRemaining = pattern.blocks
-            val c = str.charAt(i)
-            if (c == '(' || c == '{') {
-              lvl += 1
-            } else if (c == ')' || c == '}') {
-              lvl -= 1
-              if (lvl == 0) {
-                blocksRemaining -= 1
-                if (blocksRemaining == 0) {
-                  continue = false
-                }
-              }
-            }
-            i += 1
-          } while(continue)
-
-          val newCode = ScalaPrinter(subst, indent/2)
-          newStr = (newStr.substring(0, lastFound+newStrOffset))+newCode+(newStr.substring(i+newStrOffset, newStr.length))
-
-          newStrOffset += -(i-lastFound)+newCode.length
-        }
-      }
-    } while(lastFound> 0)
-
-    newStr
   }
 
   def readFile(file: File): String = {
diff --git a/src/main/scala/leon/utils/Positions.scala b/src/main/scala/leon/utils/Positions.scala
index b44dc9a36c7924972b114d9b9e4d463b4f59952e..1ca2f9faf279537f9bff8965d85f1c2a7d5b716a 100644
--- a/src/main/scala/leon/utils/Positions.scala
+++ b/src/main/scala/leon/utils/Positions.scala
@@ -16,9 +16,11 @@ abstract class Position {
   def isDefined = true
 }
 
-case class OffsetPosition(line: Int, col: Int, file: File) extends Position
+case class OffsetPosition(line: Int, col: Int, point: Int, file: File) extends Position
 
-case class RangePosition(lineFrom: Int, colFrom: Int, lineTo: Int, colTo: Int, file: File) extends Position {
+case class RangePosition(lineFrom: Int, colFrom: Int, pointFrom: Int,
+                         lineTo: Int, colTo: Int, pointTo: Int,
+                         file: File) extends Position {
   val line = lineFrom
   val col  = colFrom