diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
index f457bb4b39c266b1f45e530495187b5ee92fad5d..fa5ebb6c3c6458e14897341a7a582a5994ea2813 100644
--- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
@@ -47,13 +47,13 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
   }
 
   def initRC(mappings: Map[Identifier, Expr]): RC
-  def initGC: GC
+  def initGC(): GC
 
   private[this] var clpCache = Map[(Choose, Seq[Expr]), Expr]()
 
   def eval(ex: Expr, mappings: Map[Identifier, Expr]) = {
     try {
-      lastGC = Some(initGC)
+      lastGC = Some(initGC())
       ctx.timers.evaluators.recursive.runtime.start()
       EvaluationResults.Successful(e(ex)(initRC(mappings), lastGC.get))
     } catch {
@@ -78,7 +78,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
         case Some(v) =>
           v
         case None =>
-          throw EvalError("No value for identifier " + id.name + " in mapping.")
+          throw EvalError("No value for identifier " + id.asString(ctx) + " in mapping.")
       }
 
     case Application(caller, args) =>
@@ -145,8 +145,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
       val evArgs = args map e
 
       // build a mapping for the function...
-      val frame = rctx.newVars(tfd.paramSubst(evArgs))
-      
+      val frame = rctx.withNewVars(tfd.paramSubst(evArgs))
+
       if(tfd.hasPrecondition) {
         e(tfd.precondition.get)(frame, gctx) match {
           case BooleanLiteral(true) =>
diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala
index e7acc9a42f34f3b014ded101a86484d8747cac83..64a0dca50c0b66ceb685be331588fa703f2f62e3 100644
--- a/src/main/scala/leon/purescala/DefOps.scala
+++ b/src/main/scala/leon/purescala/DefOps.scala
@@ -283,14 +283,6 @@ object DefOps {
       fdMapCache(fd).getOrElse(fd)
     }
 
-    def replaceCalls(e: Expr): Expr = {
-      preMap {
-        case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) =>
-          fiMapF(fi, fdMap(fd)).map(_.setPos(fi))
-        case _ =>
-          None
-      }(e)
-    }
 
     val newP = p.copy(units = for (u <- p.units) yield {
       u.copy(
@@ -300,7 +292,7 @@ object DefOps {
               df match {
                 case f : FunDef =>
                   val newF = fdMap(f)
-                  newF.fullBody = replaceCalls(newF.fullBody)
+                  newF.fullBody = replaceFunCalls(newF.fullBody, fdMap, fiMapF)
                   newF
                 case d =>
                   d
@@ -319,6 +311,15 @@ object DefOps {
     (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd })
   }
 
+  def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = {
+    preMap {
+      case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) =>
+        fiMapF(fi, fdMapF(fd)).map(_.setPos(fi))
+      case _ =>
+        None
+    }(e)
+  }
+
   def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = {
     var found = false
     val res = p.copy(units = for (u <- p.units) yield {
diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala
index 20b95a30bb868e93db3311e14c64740ce016afac..0449334d3702fd4164274b7b494642c647db994f 100644
--- a/src/main/scala/leon/repair/Repairman.scala
+++ b/src/main/scala/leon/repair/Repairman.scala
@@ -83,19 +83,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout
             for ((sol, i) <- solutions.zipWithIndex) {
               reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":"))
               val expr = sol.toSimplifiedExpr(ctx, synth.program)
-              reporter.info(ScalaPrinter(expr))
-            }
-            reporter.info(ASCIIHelpers.title("In context:"))
-
-
-            for ((sol, i) <- solutions.zipWithIndex) {
-              reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":"))
-              val expr = sol.toSimplifiedExpr(ctx, synth.program)
-              val nfd = fd.duplicate
-
-              nfd.body = Some(expr)
-
-              reporter.info(ScalaPrinter(nfd))
+              reporter.info(expr.asString(ctx))
             }
           }
         } finally {
diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala
index c0f44d23d2c02b3d2a26f95af835ffadf4063f1b..ca53ba5ae5f045bb1a451ae94508f16ecd6e42a3 100644
--- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala
+++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala
@@ -23,13 +23,15 @@ import datagen._
 import codegen.CodeGenParams
 
 import utils._
+import utils.ExpressionGrammars.{SizeBoundedGrammar, SizedLabel}
+import bonsai.Generator
 
 abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
 
   case class CegisParams(
     grammar: ExpressionGrammar[T],
     rootLabel: TypeTree => T,
-    maxUnfoldings: Int = 3
+    maxUnfoldings: Int = 5
   )
 
   def getParams(sctx: SynthesisContext, p: Problem): CegisParams
@@ -42,7 +44,8 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
     val nProgramsLimit = 100000
 
     val sctx = hctx.sctx
-    val ctx  = sctx.context
+    implicit val ctx  = sctx.context
+
 
     // CEGIS Flags to activate or deactivate features
     val useOptTimeout         = sctx.settings.cegisUseOptTimeout.getOrElse(true)
@@ -63,9 +66,31 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
       return Nil
     }
 
-    class NonDeterministicProgram(val p: Problem) {
+    class NonDeterministicProgram(val p: Problem, initTermSize: Int = 1) {
+
+      private var termSize = 0;
+
+      val grammar = ExpressionGrammars.SizeBoundedGrammar(params.grammar)
+
+      def rootLabel(tpe: TypeTree) = SizedLabel(params.rootLabel(tpe), termSize)
+
+      def xLabels = p.xs.map(x => rootLabel(x.getType))
+
+      var nAltsCache = Map[SizedLabel[T], Int]()
+
+      def countAlternatives(l: SizedLabel[T]): Int = {
+        if (!(nAltsCache contains l)) {
+          val count = grammar.getProductions(l).map {
+            case Generator(subTrees, _) => subTrees.map(countAlternatives).product
+          }.sum
+          nAltsCache += l -> count
+        }
+        nAltsCache(l)
+      }
 
-      private val grammar = params.grammar
+      def allProgramsCount(): Int = {
+        xLabels.map(countAlternatives).product
+      }
 
       /**
        * Different view of the tree of expressions:
@@ -84,60 +109,113 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
        *         (b3, H(c7, c8), Set(c7, c8))
        *       )
        */
-      private var cTree: Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]] = Map()
+      private var cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]] = Map()
 
-      /**
-       * Computes dependencies of c's
-       *
-       * c1 -> Set(c2, c3, c4, c5)
-       */
-      private var cDeps: Map[Identifier, Set[Identifier]] = Map()
-
-      /**
-       * Keeps track of blocked Bs and which C are affected, assuming cs are undefined:
-       *
-       * b2 -> Set(c4)
-       * b3 -> Set(c4)
-       */
-      private var closedBs: Map[Identifier, Set[Identifier]] = Map()
 
-      /**
-       * Maps c identifiers to grammar labels
-       *
-       * Labels allows us to use grammars that are not only type-based
-       */
-      private var labels: Map[Identifier, T] = Map() ++ p.xs.map(x => x -> params.rootLabel(x.getType))
+      // C identifiers corresponding to p.xs
+      private var rootCs: Seq[Identifier]    = Seq()
 
       private var bs: Set[Identifier]        = Set()
+
       private var bsOrdered: Seq[Identifier] = Seq()
 
-      /**
-       * Checks if 'b' is closed (meaning it depends on uninterpreted terms)
-       */
-      def isBActive(b: Identifier) = !closedBs.contains(b)
 
 
-      def allProgramsCount(): Int = {
-        var nAltsCache = Map[Identifier, Int]()
-
-        def nAltsFor(c: Identifier): Int = {
-          if (!(nAltsCache contains c)) {
-            val subs = for ((b, _, subcs) <- cTree(c) if isBActive(b)) yield {
-              if (subcs.isEmpty) {
-                1
-              } else {
-                subcs.toSeq.map(nAltsFor).product
+
+      class CGenerator {
+        private var buffers = Map[SizedLabel[T], Stream[Identifier]]()
+
+        private var slots = Map[SizedLabel[T], Int]().withDefaultValue(0)
+
+        private def streamOf(t: SizedLabel[T]): Stream[Identifier] = {
+          FreshIdentifier(t.toString, t.getType, true) #:: streamOf(t)
+        }
+
+        def rewind(): Unit = {
+          slots = Map[SizedLabel[T], Int]().withDefaultValue(0)
+        }
+
+        def getNext(t: SizedLabel[T]) = {
+          if (!(buffers contains t)) {
+            buffers += t -> streamOf(t)
+          }
+
+          val n = slots(t)
+          slots += t -> (n+1)
+
+          buffers(t)(n)
+        }
+      }
+
+      def init(): Unit = {
+        updateCTree()
+      }
+
+
+      def updateCTree(): Unit = {
+        def freshB() = {
+          val id = FreshIdentifier("B", BooleanType, true)
+          bs += id
+          id
+        }
+
+        def defineCTreeFor(l: SizedLabel[T], c: Identifier): Unit = {
+          if (!(cTree contains c)) {
+            val cGen = new CGenerator()
+
+            var alts = grammar.getProductions(l)
+
+            val cTreeData = for (gen <- alts) yield {
+              val b = freshB()
+
+              // Optimize labels
+              cGen.rewind()
+
+              val subCs = for (sl <- gen.subTrees) yield {
+                val subC = cGen.getNext(sl)
+                defineCTreeFor(sl, subC)
+                subC
               }
+
+              (b, gen.builder, subCs)
             }
 
-            nAltsCache += c -> subs.sum
+            cTree += c -> cTreeData
           }
-          nAltsCache(c)
         }
 
-        p.xs.map(nAltsFor).product
+        val cGen = new CGenerator()
+
+        rootCs = for (l <- xLabels) yield {
+          val c = cGen.getNext(l)
+          defineCTreeFor(l, c)
+          c
+        }
+
+        sctx.reporter.ifDebug { printer =>
+          printer("Grammar so far:")
+          grammar.printProductions(printer)
+        }
+
+        bsOrdered    = bs.toSeq.sortBy(_.id)
+
+        setCExpr(computeCExpr())
       }
 
+      /**
+       * Keeps track of blocked Bs and which C are affected, assuming cs are undefined:
+       *
+       * b2 -> Set(c4)
+       * b3 -> Set(c4)
+       */
+      private var closedBs: Map[Identifier, Set[Identifier]] = Map()
+
+      /**
+       * Checks if 'b' is closed (meaning it depends on uninterpreted terms)
+       */
+      def isBActive(b: Identifier) = !closedBs.contains(b)
+
+
       /**
        * Returns all possible assignments to Bs in order to enumerate all possible programs
        */
@@ -151,7 +229,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
 
         var cache = Map[Identifier, Seq[Set[Identifier]]]()
 
-        def allProgramsFor(cs: Set[Identifier]): Seq[Set[Identifier]] = {
+        def allProgramsFor(cs: Seq[Identifier]): Seq[Set[Identifier]] = {
           val seqs = for (c <- cs.toSeq) yield {
             if (!(cache contains c)) {
               val subs = for ((b, _, subcs) <- cTree(c) if isBActive(b)) yield {
@@ -173,100 +251,97 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
           }
         }
 
-        allProgramsFor(p.xs.toSet)
+        allProgramsFor(rootCs)
       }
 
-      private def debugCExpr(cTree: Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]],
+      private def debugCTree(cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]],
                              markedBs: Set[Identifier] = Set()): Unit = {
         println(" -- -- -- -- -- ")
         for ((c, alts) <- cTree) {
           println
           println(f"$c%-4s :=")
-          for ((b, ex, cs) <- alts ) {
+          for ((b, builder, cs) <- alts ) {
             val active = if (isBActive(b)) " " else "тип"
             val markS   = if (markedBs(b)) Console.GREEN else ""
             val markE   = if (markedBs(b)) Console.RESET else ""
 
-            println(f"      $markS$active  $b%-4s => $ex%-40s [$cs]$markE")
+            val ex = builder(cs.map(_.toVariable)).asString
+
+            println(f"      $markS$active  ${b.asString}%-4s => $ex%-40s [${cs.map(_.asString).mkString(", ")}]$markE")
           }
         }
       }
 
-      private def computeCExpr(): Expr = {
+      private def computeCExpr(): (Expr, Seq[FunDef]) = {
+        var cToFd = Map[Identifier, FunDef]()
 
-        val lets = for ((c, alts) <- cTree) yield {
-          val activeAlts = alts.filter(a => isBActive(a._1))
+        def exprOf(alt: (Identifier, Seq[Expr] => Expr, Seq[Identifier])): Expr = {
+          val (_, builder, cs) = alt
 
-          val expr = activeAlts.foldLeft(simplestValue(c.getType): Expr) {
-            case (e, (b, ex, _)) => IfExpr(b.toVariable, ex, e)
-          }
+          val e = builder(cs.map { c =>
+            val fd = cToFd(c)
+            FunctionInvocation(fd.typed, fd.params.map(_.toVariable))
+          })
 
-          (c, expr)
+          outerExprToInnerExpr(e)
         }
 
-        // We order the lets base don dependencies
-        def defFor(c: Identifier): Expr = {
-          cDeps(c).filter(lets.contains).foldLeft(lets(c)) {
-            case (e, c) => Let(c, defFor(c), e)
-          }
+        // Define all C-def
+        for ((c, alts) <- cTree) yield {
+          cToFd += c -> new FunDef(FreshIdentifier(c.toString, alwaysShowUniqueID = true),
+                                   Seq(),
+                                   c.getType,
+                                   p.as.map(id => ValDef(id)))
         }
 
-        val res = tupleWrap(p.xs.map(defFor))
-
-        val substMap : Map[Expr,Expr] = bsOrdered.zipWithIndex.map {
-          case (b, i) => Variable(b) -> ArraySelect(bArrayId.toVariable, IntLiteral(i))
-        }.toMap
-
-        val simplerRes = simplifyLets(res)
-
-        replace(substMap, simplerRes)
-      }
+        // Fill C-def bodies
+        for ((c, alts) <- cTree) {
+          val activeAlts = alts.filter(a => isBActive(a._1))
 
+          val body = if (activeAlts.nonEmpty) {
+            activeAlts.init.foldLeft(exprOf(activeAlts.last)) {
+              case (e, alt) => IfExpr(alt._1.toVariable, exprOf(alt), e)
+            }
+          } else {
+            Error(c.getType, "Impossibru")
+          }
 
-      /**
-       * Information about the final Program representing CEGIS solutions at
-       * the current unfolding level
-       */
-      private val outerSolution = {
-        val part = new PartialSolution(hctx.search.g, true)
-        e : Expr => part.solutionAround(hctx.currentNode)(e).getOrElse {
-          sctx.reporter.fatalError("Unable to create outer solution")
+          cToFd(c).fullBody = body
         }
-      }
 
-      private val bArrayId = FreshIdentifier("bArray", ArrayType(BooleanType), true)
+        // Top-level expression for rootCs
+        val expr = tupleWrap(rootCs.map { c =>
+          val fd = cToFd(c)
+          FunctionInvocation(fd.typed, fd.params.map(_.toVariable))
+        })
 
-      private var cTreeFd = new FunDef(
-        FreshIdentifier("cTree", alwaysShowUniqueID = true),
-        Seq(),
-        p.outType,
-        p.as.map(id => ValDef(id))
-      )
+        (expr, cToFd.values.toSeq)
+      }
 
-      private var phiFd = new FunDef(
-        FreshIdentifier("phiFd", alwaysShowUniqueID = true),
-        Seq(),
-        BooleanType,
-        p.as.map(id => ValDef(id))
-      )
 
-      private var programCTree: Program = _
 
-      // Map functions from original program to cTree program
-      private var fdMapCTree: Map[FunDef, FunDef] = _
+      private val cTreeFd = new FunDef(FreshIdentifier("cTree", alwaysShowUniqueID = true),
+                               Seq(),
+                               p.outType,
+                               p.as.map(id => ValDef(id))
+                             )
 
-      private var tester: (Seq[Expr], Set[Identifier]) => EvaluationResults.Result = _
+      private val phiFd   = new FunDef(FreshIdentifier("phiFd", alwaysShowUniqueID = true),
+                               Seq(),
+                               BooleanType,
+                               p.as.map(id => ValDef(id))
+                             )
 
-      private def initializeCTreeProgram(): Unit = {
 
-        // CEGIS is solved by called cTree function (without bs yet)
-        val fullSol = outerSolution(FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)))
+      private val (innerProgram, origFdMap) = {
 
+        val outerSolution = {
+          new PartialSolution(hctx.search.g, true)
+            .solutionAround(hctx.currentNode)(FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)))
+            .getOrElse(ctx.reporter.fatalError("Unable to get outer solution"))
+        }
 
-        val chFd = hctx.ci.fd
-        val prog0 = hctx.program
-
-        val affected = prog0.callGraph.transitiveCallers(chFd) ++ Set(chFd, cTreeFd, phiFd) ++ fullSol.defs
+        val program0 = addFunDefs(sctx.program, Seq(cTreeFd, phiFd) ++ outerSolution.defs, hctx.ci.fd)
 
         cTreeFd.body = None
         phiFd.body   = Some(
@@ -275,83 +350,75 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
                    p.phi)
         )
 
-        val prog1 = addFunDefs(prog0, Seq(cTreeFd, phiFd) ++ fullSol.defs, chFd)
-
-        val (prog2, fdMap2) = replaceFunDefs(prog1)({
-          case fd if affected(fd) =>
-            // Add the b array argument to all affected functions
-            val nfd = new FunDef(
-              fd.id.freshen,
-              fd.tparams,
-              fd.returnType,
-              fd.params :+ ValDef(bArrayId)
-            )
-            nfd.copyContentFrom(fd)
-            nfd.copiedFrom(fd)
-
-            if (fd == chFd) {
-              nfd.fullBody = replace(Map(hctx.ci.ch -> fullSol.guardedTerm), nfd.fullBody)
-            }
+        replaceFunDefs(program0){
+          case fd if fd == hctx.ci.fd =>
+            val nfd = fd.duplicate
+
+            nfd.fullBody = postMap {
+              case ch if ch eq hctx.ci.ch =>
+                Some(outerSolution.term)
+
+              case _ => None
+            }(nfd.fullBody)
 
             Some(nfd)
 
-          case _ =>
-            None
-        }, {
-          case (FunctionInvocation(old, args), newfd) if old.fd != newfd =>
-            Some(FunctionInvocation(newfd.typed(old.tps), args :+ bArrayId.toVariable))
-          case _ =>
+          case `cTreeFd` | `phiFd` =>
             None
-        })
 
-        programCTree = prog2
-        cTreeFd      = fdMap2(cTreeFd)
-        phiFd        = fdMap2(phiFd)
-        fdMapCTree   = fdMap2
+          case fd =>
+            Some(fd.duplicate)
+        }
+
       }
 
-      private def setCExpr(cTree: Expr): Unit = {
+      /**
+       * Since CEGIS works with a copy of the outer program,
+       * it needs to map outer function calls to inner function calls
+       * and vice-versa. 'inner' refers to the CEGIS-specific program,
+       * 'outer' refers to the actual program on which we do synthesis.
+       */
+      private def outerExprToInnerExpr(e: Expr): Expr = {
+        replaceFunCalls(e, {fd => origFdMap.getOrElse(fd, fd) })
+      }
 
-        cTreeFd.body = Some(preMap{
-          case FunctionInvocation(TypedFunDef(fd, tps), args) if fdMapCTree contains fd =>
-            Some(FunctionInvocation(fdMapCTree(fd).typed(tps), args :+ bArrayId.toVariable))
-          case _ =>
-            None
-        }(cTree))
+      private val innerPc  = outerExprToInnerExpr(p.pc)
+      private val innerPhi = outerExprToInnerExpr(p.phi)
+
+      private var programCTree: Program = _
+      private var tester: (Seq[Expr], Set[Identifier]) => EvaluationResults.Result = _
+
+      private def setCExpr(cTreeInfo: (Expr, Seq[FunDef])): Unit = {
+        val (cTree, newFds) = cTreeInfo
+
+        cTreeFd.body = Some(cTree)
+        programCTree = addFunDefs(innerProgram, newFds, cTreeFd)
 
         //println("-- "*30)
-        //println(programCTree)
+        //println(programCTree.asString)
         //println(".. "*30)
 
-        val evaluator  = new DualEvaluator(sctx.context, programCTree, CodeGenParams.default)
-
+//        val evaluator  = new DualEvaluator(sctx.context, programCTree, CodeGenParams.default)
+        val evaluator  = new DefaultEvaluator(sctx.context, programCTree)
 
         tester =
           { (ins: Seq[Expr], bValues: Set[Identifier]) =>
-            val bsValue = finiteArray(bsOrdered.map(b => BooleanLiteral(bValues(b))), None, BooleanType)
-            val args = ins :+ bsValue
+            val envMap = bs.map(b => b -> BooleanLiteral(bValues(b))).toMap
 
-            val fi = FunctionInvocation(phiFd.typed, args)
+            val fi = FunctionInvocation(phiFd.typed, ins)
 
-            evaluator.eval(fi, Map())
+            evaluator.eval(fi, envMap)
           }
       }
 
 
-      private def updateCTree() {
-        if (programCTree eq null) {
-          initializeCTreeProgram()
-        }
-
-        setCExpr(computeCExpr())
-      }
-
       def testForProgram(bValues: Set[Identifier])(ins: Seq[Expr]): Boolean = {
         tester(ins, bValues) match {
           case EvaluationResults.Successful(res) =>
             res == BooleanLiteral(true)
 
           case EvaluationResults.RuntimeError(err) =>
+            sctx.reporter.warning("RE testing CE: "+err)
             false
 
           case EvaluationResults.EvaluatorError(err) =>
@@ -362,79 +429,77 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
 
 
 
+      // Returns the outer expression corresponding to a B-valuation
       def getExpr(bValues: Set[Identifier]): Expr = {
         def getCValue(c: Identifier): Expr = {
           cTree(c).find(i => bValues(i._1)).map {
-            case (b, ex, cs) =>
-              val map = for (c <- cs) yield {
-                c -> getCValue(c)
-              }
-
-              substAll(map.toMap, ex)
+            case (b, builder, cs) =>
+              builder(cs.map(getCValue))
           }.getOrElse {
             simplestValue(c.getType)
           }
         }
 
-        tupleWrap(p.xs.map(c => getCValue(c)))
+        tupleWrap(rootCs.map(c => getCValue(c)))
       }
 
+      /**
+       * Here we check the validity of a given program in isolation, we compute
+       * the corresponding expr and replace it in place of the C-tree
+       */
       def validatePrograms(bss: Set[Set[Identifier]]): Either[Stream[Solution], Seq[Seq[Expr]]] = {
-        try {
-          val cexs = for (bs <- bss.toSeq) yield {
-            val sol = getExpr(bs)
+        val origImpl = cTreeFd.fullBody
 
-            val fullSol = outerSolution(sol)
+        val cexs = for (bs <- bss.toSeq) yield {
+          val outerSol = getExpr(bs)
+          val innerSol = outerExprToInnerExpr(outerSol)
 
-            val prog = addFunDefs(hctx.program, fullSol.defs, hctx.ci.fd)
+          cTreeFd.fullBody = innerSol
 
-            hctx.ci.ch.impl = Some(fullSol.guardedTerm)
+          val cnstr = and(innerPc, letTuple(p.xs, innerSol, Not(innerPhi)))
 
-            val cnstr = and(p.pc, letTuple(p.xs, sol, Not(p.phi)))
-            //println("Solving for: "+cnstr)
+          //println("Solving for: "+cnstr.asString)
 
-            val solverf = SolverFactory.default(ctx, prog).withTimeout(cexSolverTo)
-            val solver  = solverf.getNewSolver()
-            try {
-              solver.assertCnstr(cnstr)
-              solver.check match {
-                case Some(true) =>
-                  excludeProgram(bs)
-                  val model = solver.getModel
-                  //println("Found counter example: ")
-                  //for ((s, v) <- model) {
-                  //  println(" "+s.asString+" -> "+v.asString)
-                  //}
+          val solverf = SolverFactory.default(ctx, innerProgram).withTimeout(cexSolverTo)
+          val solver  = solverf.getNewSolver()
+          try {
+            solver.assertCnstr(cnstr)
+            solver.check match {
+              case Some(true) =>
+                excludeProgram(bs)
+                val model = solver.getModel
+                //println("Found counter example: ")
+                //for ((s, v) <- model) {
+                //  println(" "+s.asString+" -> "+v.asString)
+                //}
 
-                  //val evaluator  = new DefaultEvaluator(ctx, prog)
-                  //println(evaluator.eval(cnstr, model))
+                //val evaluator  = new DefaultEvaluator(ctx, prog)
+                //println(evaluator.eval(cnstr, model))
 
-                  Some(p.as.map(a => model.getOrElse(a, simplestValue(a.getType))))
+                Some(p.as.map(a => model.getOrElse(a, simplestValue(a.getType))))
 
-                case Some(false) =>
-                  // UNSAT, valid program
-                  return Left(Stream(Solution(BooleanLiteral(true), Set(), sol, true)))
+              case Some(false) =>
+                // UNSAT, valid program
+                return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true)))
 
-                case None =>
-                  if (useOptTimeout) {
-                    // Interpret timeout in CE search as "the candidate is valid"
-                    sctx.reporter.info("CEGIS could not prove the validity of the resulting expression")
-                    // Optimistic valid solution
-                    return Left(Stream(Solution(BooleanLiteral(true), Set(), sol, false)))
-                  } else {
-                    None
-                  }
-              }
-            } finally {
-              solver.free()
-              solverf.shutdown()
+              case None =>
+                if (useOptTimeout) {
+                  // Interpret timeout in CE search as "the candidate is valid"
+                  sctx.reporter.info("CEGIS could not prove the validity of the resulting expression")
+                  // Optimistic valid solution
+                  return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false)))
+                } else {
+                  None
+                }
             }
+          } finally {
+            solver.free()
+            solverf.shutdown()
+            cTreeFd.fullBody = origImpl
           }
-
-          Right(cexs.flatten)
-        } finally {
-          hctx.ci.ch.impl = None
         }
+
+        Right(cexs.flatten)
       }
 
       var excludedPrograms = ArrayBuffer[Set[Identifier]]()
@@ -442,224 +507,18 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
       // Explicitly remove program computed by bValues from the search space
       def excludeProgram(bValues: Set[Identifier]): Unit = {
         val bvs = bValues.filter(isBActive)
-        //println(f" (-)  ${bvs.mkString(", ")}%-40s  ("+getExpr(bvs)+")")
         excludedPrograms += bvs
       }
 
-      /**
-       * Shrinks the non-deterministic program to the provided set of
-       * alternatives only
-       */
-      def shrinkTo(remainingBs: Set[Identifier], finalUnfolding: Boolean): Unit = {
-        //println("Shrinking!")
-
-        val initialBs = remainingBs ++ (if (finalUnfolding) Set() else closedBs.keySet)
-
-        var cParent = Map[Identifier, Identifier]()
-        var cOfB    = Map[Identifier, Identifier]()
-        var underBs = Map[Identifier, Set[Identifier]]()
-
-        for ((cparent, alts) <- cTree;
-             (b, _, cs) <- alts) {
-
-          cOfB += b -> cparent
-
-          for (cchild <- cs) {
-            underBs += cchild -> (underBs.getOrElse(cchild, Set()) + b)
-            cParent += cchild -> cparent
-          }
-        }
-
-        def bParents(b: Identifier): Set[Identifier] = {
-          val parentBs = underBs.getOrElse(cOfB(b), Set())
-
-          Set(b) ++ parentBs.flatMap(bParents)
-        }
-
-        // include parents
-        val keptBs = initialBs.flatMap(bParents)
-
-        //println("Initial Bs: "+initialBs)
-        //println("Keeping Bs: "+keptBs)
-
-        //debugCExpr(cTree, keptBs)
-
-        var newCTree = Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]]()
-
-        for ((c, alts) <- cTree) yield {
-          newCTree += c -> alts.filter(a => keptBs(a._1))
-        }
-
-        def removeDeadAlts(c: Identifier, deadC: Identifier) {
-          if (newCTree contains c) {
-            val alts = newCTree(c)
-            val newAlts = alts.filterNot(a => a._3 contains deadC)
-
-            if (newAlts.isEmpty) {
-              for (cp <- cParent.get(c)) {
-                removeDeadAlts(cp, c)
-              }
-              newCTree -= c
-            } else {
-              newCTree += c -> newAlts
-            }
-          }
-        }
-
-        //println("BETWEEN")
-        //debugCExpr(newCTree, keptBs)
-
-        for ((c, alts) <- newCTree if alts.isEmpty) {
-          for (cp <- cParent.get(c)) {
-            removeDeadAlts(cp, c)
-          }
-          newCTree -= c
-        }
-
-        var newCDeps = Map[Identifier, Set[Identifier]]()
-
-        for ((c, alts) <- cTree) yield {
-          newCDeps += c -> alts.map(_._3).toSet.flatten
-        }
-
-        cTree        = newCTree
-        cDeps        = newCDeps
-        closedBs     = closedBs.filterKeys(keptBs)
-
-        bs           = cTree.map(_._2.map(_._1)).flatten.toSet
-        bsOrdered    = bs.toSeq.sortBy(_.id)
-
-        excludedPrograms = excludedPrograms.filter(_.forall(bs))
-
-        //debugCExpr(cTree)
-        updateCTree()
-      }
-
-      class CGenerator {
-        private var buffers = Map[T, Stream[Identifier]]()
-        
-        private var slots = Map[T, Int]().withDefaultValue(0)
-
-        private def streamOf(t: T): Stream[Identifier] = {
-          FreshIdentifier("c", t.getType, true) #:: streamOf(t)
-        }
-
-        def reset(): Unit = {
-          slots = Map[T, Int]().withDefaultValue(0)
-        }
-
-        def getNext(t: T) = {
-          if (!(buffers contains t)) {
-            buffers += t -> streamOf(t)
-          }
-
-          val n = slots(t)
-          slots += t -> (n+1)
-
-          buffers(t)(n)
-        }
-      }
-
       def unfold(finalUnfolding: Boolean): Boolean = {
-        var newBs = Set[Identifier]()
-        var unfoldedSomething = false
-
-        def freshB() = {
-          val id = FreshIdentifier("B", BooleanType, true)
-          newBs += id
-          id
-        }
-
-        val unfoldBehind = if (cTree.isEmpty) {
-          p.xs
-        } else {
-          closedBs.flatMap(_._2).toSet
-        }
-
-        closedBs = Map[Identifier, Set[Identifier]]()
-
-        for (c <- unfoldBehind) {
-          var alts = grammar.getProductions(labels(c))
-
-          if (finalUnfolding) {
-            alts = alts.filter(_.subTrees.isEmpty)
-          }
-
-          val cGen = new CGenerator()
-
-          val cTreeInfos = if (alts.nonEmpty) {
-            for (gen <- alts) yield {
-              val b = freshB()
-
-              // Optimize labels
-              cGen.reset()
-
-              val cToLabel = for (t <- gen.subTrees) yield {
-                cGen.getNext(t) -> t
-              }
-
-
-              labels ++= cToLabel
-
-              val cs = cToLabel.map(_._1)
-              val ex = gen.builder(cs.map(_.toVariable))
-
-              if (cs.nonEmpty) {
-                closedBs += b -> cs.toSet
-              }
-
-              //println(" + "+b+" => "+c+" = "+ex)
-
-              unfoldedSomething = true
-
-              (b, ex, cs.toSet)
-            }
-          } else {
-            // Happens in final unfolding when no alts have ground terms
-            val b = freshB()
-            closedBs += b -> Set()
-
-            Seq((b, simplestValue(c.getType), Set[Identifier]()))
-          }
-
-          cTree += c -> cTreeInfos
-          cDeps += c -> cTreeInfos.map(_._3).toSet.flatten
-        }
-
-        sctx.reporter.ifDebug { printer =>
-          printer("Grammar so far:")
-          grammar.printProductions(printer)
-        }
-
-        bs           = bs ++ newBs
-        bsOrdered    = bs.toSeq.sortBy(_.id)
-
-        /**
-         * Close dead-ends
-         *
-         * Find 'c' that have no active alternatives, then close all 'b's that
-         * depend on such "dead" 'c's
-         */
-        var deadCs = Set[Identifier]()
-
-        for ((c, alts) <- cTree) {
-          if (alts.forall{ case (b, _, _) => !isBActive(b) }) {
-            deadCs += c
-          }
-        }
-
-        for ((_, alts) <- cTree; (b, _, cs) <- alts) {
-          if ((cs & deadCs).nonEmpty) {
-            closedBs += (b -> closedBs.getOrElse(b, Set()))
-          }
-        }
-
-        //debugCExpr(cTree)
+        termSize += 1
         updateCTree()
-
-        unfoldedSomething
+        true
       }
 
+      /**
+       * First phase of CEGIS: solve for potential programs (that work on at least one input)
+       */
       def solveForTentativeProgram(): Option[Option[Set[Identifier]]] = {
         val solverf = SolverFactory.default(ctx, programCTree).withTimeout(exSolverTo)
         val solver  = solverf.getNewSolver()
@@ -670,16 +529,14 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
         //println(phiFd.fullBody.asString(ctx))
 
         val fixedBs = finiteArray(bsOrdered.map(_.toVariable), None, BooleanType)
-        val cnstrFixed = replaceFromIDs(Map(bArrayId -> fixedBs), cnstr)
 
+        val toFind = and(innerPc, cnstr)
+        //println(" --- Constraints ---")
+        //println(" - "+toFind)
         try {
-          val toFind = and(p.pc, cnstrFixed)
-          //println(" --- Constraints ---")
-          //println(" - "+toFind)
+          solver.assertCnstr(andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable))))
           solver.assertCnstr(toFind)
 
-          // oneOfBs
-          //println(" -- OneOf:")
           for ((c, alts) <- cTree) {
             val activeBs = alts.map(_._1).filter(isBActive)
 
@@ -739,17 +596,19 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
         }
       }
 
+      /**
+       * Second phase of CEGIS: verify a given program by looking for CEX inputs
+       */
       def solveForCounterExample(bs: Set[Identifier]): Option[Option[Seq[Expr]]] = {
         val solverf = SolverFactory.default(ctx, programCTree).withTimeout(cexSolverTo)
         val solver  = solverf.getNewSolver()
         val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable))
 
-        val fixedBs = finiteArray(bsOrdered.map(b => BooleanLiteral(bs(b))), None, BooleanType)
-        val cnstrFixed = replaceFromIDs(Map(bArrayId -> fixedBs), cnstr)
 
         try {
-          solver.assertCnstr(p.pc)
-          solver.assertCnstr(Not(cnstrFixed))
+          solver.assertCnstr(andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable))))
+          solver.assertCnstr(innerPc)
+          solver.assertCnstr(Not(cnstr))
 
           solver.check match {
             case Some(true) =>
@@ -781,6 +640,8 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
         val sctx = hctx.sctx
 
         val ndProgram = new NonDeterministicProgram(p)
+        ndProgram.init()
+
         var unfolding = 1
         val maxUnfoldings = params.maxUnfoldings
 
@@ -788,19 +649,21 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
 
         var baseExampleInputs: ArrayBuffer[Seq[Expr]] = new ArrayBuffer[Seq[Expr]]()
 
+        sctx.reporter.ifDebug { printer =>
+          ndProgram.grammar.printProductions(printer)
+        }
+
         // We populate the list of examples with a predefined one
         sctx.reporter.debug("Acquiring initial list of examples")
 
         baseExampleInputs ++= p.tb.examples.map(_.ins).toSet
 
-        val pc = p.pc
-
-        if (pc == BooleanLiteral(true)) {
+        if (p.pc == BooleanLiteral(true)) {
           baseExampleInputs += p.as.map(a => simplestValue(a.getType))
         } else {
           val solver = sctx.newSolver.setTimeout(exSolverTo)
 
-          solver.assertCnstr(pc)
+          solver.assertCnstr(p.pc)
 
           try {
             solver.check match {
@@ -823,18 +686,19 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
 
         sctx.reporter.ifDebug { debug =>
           baseExampleInputs.foreach { in =>
-            debug("  - "+in.mkString(", "))
+            debug("  - "+in.map(_.asString).mkString(", "))
           }
         }
 
+
         /**
          * We generate tests for discarding potential programs
          */
         val inputIterator: Iterator[Seq[Expr]] = if (useVanuatoo) {
-          new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, pc, 20, 3000)
+          new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, 20, 3000)
         } else {
           val evaluator  = new DualEvaluator(sctx.context, sctx.program, CodeGenParams.default)
-          new GrammarDataGen(evaluator, ExpressionGrammars.ValueGrammar).generateFor(p.as, pc, 20, 1000)
+          new GrammarDataGen(evaluator, ExpressionGrammars.ValueGrammar).generateFor(p.as, p.pc, 20, 1000)
         }
 
         val cachedInputIterator = new Iterator[Seq[Expr]] {
@@ -901,7 +765,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
                   val e = examples.next()
                   if (!ndProgram.testForProgram(bs)(e)) {
                     failedTestsStats(e) += 1
-                    sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs)}%-80s failed on: ${e.mkString(", ")}")
+                    sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.map(_.asString).mkString(", ")}")
                     wrongPrograms += bs
                     prunedPrograms -= bs
 
@@ -964,7 +828,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
                 if (nPassing < nInitial * shrinkThreshold && useShrink) {
                   // We shrink the program to only use the bs mentionned
                   val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _)
-                  ndProgram.shrinkTo(bssToKeep, unfolding == maxUnfoldings)
+                  //ndProgram.shrinkTo(bssToKeep, unfolding == maxUnfoldings)
                 } else {
                   wrongPrograms.foreach {
                     ndProgram.excludeProgram
diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
index 8ad0dc7f7e63e71b014358457116cd9cabbd890f..3ac83225d73b7dfacaccd183eed03bb37825e818 100644
--- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
+++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
@@ -8,6 +8,8 @@ import bonsai._
 
 import Helpers._
 
+import leon.utils.SeqUtils.sumTo
+
 import purescala.Expressions.{Or => LeonOr, _}
 import purescala.Common._
 import purescala.Definitions._
@@ -413,6 +415,36 @@ object ExpressionGrammars {
    }
   }
 
+  case class SizedLabel[T <% Typed](underlying: T, size: Int) extends Typed {
+    val getType = underlying.getType
+
+    override def toString = underlying.toString+"|"+size+"|"
+  }
+
+  case class SizeBoundedGrammar[T <% Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[SizedLabel[T]] {
+    def computeProductions(sl: SizedLabel[T]): Seq[Gen] = {
+      if (sl.size <= 0) {
+        Nil
+      } else if (sl.size == 1) {
+        g.getProductions(sl.underlying).filter(_.subTrees.isEmpty).map {
+          case Generator(subTrees, builder) =>
+            Generator[SizedLabel[T], Expr](Nil, builder)
+        }
+      } else {
+        g.getProductions(sl.underlying).filter(_.subTrees.nonEmpty).flatMap {
+          case Generator(subTrees, builder) =>
+            val sizes = sumTo(sl.size-1, subTrees.size)
+
+            for (ss <- sizes) yield {
+              val subSizedLabels = (subTrees zip ss) map (s => SizedLabel(s._1, s._2))
+
+              Generator[SizedLabel[T], Expr](subSizedLabels, builder)
+            }
+        }
+      }
+    }
+  }
+
   case class BoundedGrammar[T](g: ExpressionGrammar[Label[T]], bound: Int) extends ExpressionGrammar[Label[T]] {
     def computeProductions(l: Label[T]): Seq[Gen] = g.computeProductions(l).flatMap {
       case g: Generator[Label[T], Expr] =>
diff --git a/src/main/scala/leon/utils/SeqUtils.scala b/src/main/scala/leon/utils/SeqUtils.scala
index ff1f55d43d696fe5a5f9e301517716f5bac69641..5a5e2dff3088da991ac99a6b8f1e46f759837629 100644
--- a/src/main/scala/leon/utils/SeqUtils.scala
+++ b/src/main/scala/leon/utils/SeqUtils.scala
@@ -31,4 +31,14 @@ object SeqUtils {
 
     result
   }
+
+  def sumTo(sum: Int, arity: Int): Seq[Seq[Int]] = {
+    if (arity == 1) {
+      Seq(Seq(sum))
+    } else {
+      (1 until sum).flatMap{ n => 
+        sumTo(sum-n, arity-1).map( r => n +: r)
+      }
+    }
+  }
 }