diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala
index d109c73fb987417b4694094a27f1bb92c99470f6..98ca2a4b21ac6b13ef0c867f66d22fd9dfee7e66 100644
--- a/src/main/scala/leon/synthesis/rules/Cegis.scala
+++ b/src/main/scala/leon/synthesis/rules/Cegis.scala
@@ -25,10 +25,18 @@ import evaluators._
 import datagen._
 import codegen.CodeGenParams
 
-import utils.ExpressionGrammar
+import utils._
 
+case object CEGIS extends CEGISLike("CEGIS") {
+  def getGrammar(sctx: SynthesisContext, p: Problem) = {
+    ExpressionGrammars.default(sctx, p)
+  }
+}
+
+
+abstract class CEGISLike(name: String) extends Rule(name) {
+  def getGrammar(sctx: SynthesisContext, p: Problem): ExpressionGrammar
 
-case object CEGIS extends Rule("CEGIS") {
   def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
 
     // CEGIS Flags to actiave or de-activate features
@@ -51,7 +59,7 @@ case object CEGIS extends Rule("CEGIS") {
     class NonDeterministicProgram(val p: Problem,
                                   val initGuard: Identifier) {
 
-      val grammar = new ExpressionGrammar(sctx, p)
+      val grammar = getGrammar(sctx, p)
 
       // b -> (c, ex) means the clause b => c == ex
       var mappings: Map[Identifier, (Identifier, Expr)] = Map()
@@ -309,7 +317,7 @@ case object CEGIS extends Rule("CEGIS") {
 
         for ((parentGuard, recIds) <- guardedTerms; recId <- recIds) {
 
-          var alts = grammar.getGenerators(recId.getType)
+          var alts = grammar.getProductions(recId.getType)
           if (finalUnrolling) {
             alts = alts.filter(_.subTrees.isEmpty)
           }
@@ -356,7 +364,7 @@ case object CEGIS extends Rule("CEGIS") {
 
         sctx.reporter.ifDebug { printer =>
           printer("Grammar so far:");
-          grammar.printGrammar(printer)
+          grammar.printProductions(printer)
         }
 
         //program  = And(program :: newClauses)
diff --git a/src/main/scala/leon/synthesis/rules/Tegis.scala b/src/main/scala/leon/synthesis/rules/Tegis.scala
index c6f57ed0f9bb7b5383e2d45bd0f8edb4e81347b8..ada8bce10dba8afb7a1c28c8191306c56eda7f70 100644
--- a/src/main/scala/leon/synthesis/rules/Tegis.scala
+++ b/src/main/scala/leon/synthesis/rules/Tegis.scala
@@ -24,7 +24,7 @@ import evaluators._
 import datagen._
 import codegen.CodeGenParams
 
-import utils.ExpressionGrammar
+import utils._
 
 import bonsai._
 import bonsai.enumerators._
@@ -32,7 +32,7 @@ import bonsai.enumerators._
 case object TEGIS extends Rule("TEGIS") {
 
   def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
-    val grammar = new ExpressionGrammar(sctx, p)
+    val grammar = ExpressionGrammars.default(sctx, p)
 
     var tests = p.getTests(sctx).map(_.ins).distinct
     if (tests.nonEmpty) {
@@ -46,7 +46,7 @@ case object TEGIS extends Rule("TEGIS") {
 
           val interruptManager      = sctx.context.interruptManager
 
-          val enum = new MemoizedEnumerator[TypeTree, Expr](grammar.getGenerators)
+          val enum = new MemoizedEnumerator[TypeTree, Expr](grammar.getProductions _)
 
           val (targetType, isWrapped) = if (p.xs.size == 1) {
             (p.xs.head.getType, false)
diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
index 84c6487ae73c36462506523482b130bc25cbafc5..a9c606ab7fdb950ac453b8196bd3bf892f1e4394 100644
--- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
+++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
@@ -18,257 +18,193 @@ import purescala.ScalaPrinter
 
 import scala.collection.mutable.{HashMap => MutableMap}
 
-class ExpressionGrammar(ctx: LeonContext, prog: Program, inputs: Seq[Expr], currentFunction: FunDef, pathCondition: Expr) {
-  def this(sctx: SynthesisContext, p: Problem) = {
-    this(sctx.context, sctx.program, p.as.map(_.toVariable), sctx.functionContext, p.pc)
-  }
-
+abstract class ExpressionGrammar {
   type Gen = Generator[TypeTree, Expr]
 
   private[this] val cache = new MutableMap[TypeTree, Seq[Gen]]()
 
-  def getGenerators(t: TypeTree): Seq[Gen] = {
+  def getProductions(t: TypeTree): Seq[Gen] = {
     cache.getOrElse(t, {
-      val res = computeGenerators(t)
+      val res = computeProductions(t)
       cache += t -> res
       res
     })
   }
 
-  def computeGenerators(t: TypeTree): Seq[Gen] = {
-    computeBaseGenerators(t) ++
-    computeInputGenerators(t) ++
-    computeFcallGenerators(t) ++
-    computeSafeRecCalls(t)
-  }
+  def computeProductions(t: TypeTree): Seq[Gen]
 
-  def computeBaseGenerators(t: TypeTree): Seq[Gen] = t match {
-    case BooleanType =>
-      List(
-        Generator(Nil, { _ => BooleanLiteral(true) }),
-        Generator(Nil, { _ => BooleanLiteral(false) })
-      )
-    case Int32Type =>
-      List(
-        Generator(Nil, { _ => IntLiteral(0) }),
-        Generator(Nil, { _ => IntLiteral(1) }),
-        Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Plus(a, b) }),
-        Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Minus(a, b) }),
-        Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Times(a, b) })
-      )
-    case TupleType(stps) =>
-      List(Generator(stps, { sub => Tuple(sub) }))
-
-    case cct: CaseClassType =>
-      List(
-        Generator(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} )
-      )
-
-    case act: AbstractClassType =>
-      act.knownCCDescendents.map { cct =>
-        Generator[TypeTree, Expr](cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} )
-      }
-
-    case st @ SetType(base) =>
-      List(
-        Generator(List(base),   { case elems     => FiniteSet(elems.toSet).setType(st) }),
-        Generator(List(st, st), { case Seq(a, b) => SetUnion(a, b) }),
-        Generator(List(st, st), { case Seq(a, b) => SetIntersection(a, b) }),
-        Generator(List(st, st), { case Seq(a, b) => SetDifference(a, b) })
-      )
-
-    case _ =>
-      Nil
+  final def ||(that: ExpressionGrammar): ExpressionGrammar = {
+    ExpressionGrammar.Or(Seq(this, that))
   }
 
-  def computeInputGenerators(t: TypeTree): Seq[Gen] = {
-    inputs.collect {
-      case i if isSubtypeOf(i.getType, t) => Generator[TypeTree, Expr](Nil, { _ => i })
+  final def printProductions(printer: String => Unit) {
+    for ((t, gs) <- cache; g <- gs) {
+      val subs = g.subTrees.map { tpe => FreshIdentifier(tpe.toString).setType(tpe).toVariable }
+      val gen = g.builder(subs)
+
+      printer(f"$t%30s ::= "+gen)
     }
   }
+}
 
-  def computeFcallGenerators(t: TypeTree): Seq[Gen] = {
-
-    def getCandidates(fd: FunDef): Seq[TypedFunDef] = {
-      // Prevents recursive calls
-      val cfd = currentFunction
-
-      val isRecursiveCall = (prog.callGraph.transitiveCallers(cfd) + cfd) contains fd
-
-      val isNotSynthesizable = fd.body match {
-        case Some(b) =>
-          !containsChoose(b)
-
-        case None =>
-          false
-      }
+object ExpressionGrammar {
+  case class Or(gs: Seq[ExpressionGrammar]) extends ExpressionGrammar {
+    val subGrammars: Seq[ExpressionGrammar] = gs.flatMap {
+      case o: Or => o.subGrammars
+      case g => Seq(g)
+    }
 
+    def computeProductions(t: TypeTree): Seq[Gen] =
+      subGrammars.flatMap(_.getProductions(t))
+  }
+}
 
-      if (!isRecursiveCall && isNotSynthesizable) {
-        val free = fd.tparams.map(_.tp)
-        canBeSubtypeOf(fd.returnType, free, t) match {
-          case Some(tpsMap) =>
-            val tfd = fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp)))
-
-            if (tpsMap.size < free.size) {
-              /* Some type params remain free, we want to assign them:
-               *
-               * List[T] => Int, for instance, will be found when
-               * requesting Int, but we need to assign T to viable
-               * types. For that we use problem inputs as heuristic,
-               * and look for instantiations of T such that input <?:
-               * List[T].
-               */
-              inputs.map(_.getType).distinct.flatMap { (atpe: TypeTree) =>
-                var finalFree = free.toSet -- tpsMap.keySet
-                var finalMap = tpsMap
-
-                for (ptpe <- tfd.params.map(_.tpe).distinct) {
-                  canBeSubtypeOf(atpe, finalFree.toSeq, ptpe) match {
-                    case Some(ntpsMap) =>
-                      finalFree --= ntpsMap.keySet
-                      finalMap  ++= ntpsMap
-                    case _ =>
-                  }
-                }
-
-                if (finalFree.isEmpty) {
-                  List(fd.typed(free.map(tp => finalMap.getOrElse(tp, tp))))
-                } else {
-                  Nil
-                }
-              }
-            } else {
-              /* All type parameters that used to be free are assigned
-               */
-              List(tfd)
-            }
-          case None =>
-            Nil
+object ExpressionGrammars {
+
+  case object BaseGrammar extends ExpressionGrammar {
+    def computeProductions(t: TypeTree): Seq[Gen] = t match {
+      case BooleanType =>
+        List(
+          Generator(Nil, { _ => BooleanLiteral(true) }),
+          Generator(Nil, { _ => BooleanLiteral(false) })
+        )
+      case Int32Type =>
+        List(
+          Generator(Nil, { _ => IntLiteral(0) }),
+          Generator(Nil, { _ => IntLiteral(1) }),
+          Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Plus(a, b) }),
+          Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Minus(a, b) }),
+          Generator(List(Int32Type, Int32Type), { case Seq(a,b) => Times(a, b) })
+        )
+      case TupleType(stps) =>
+        List(Generator(stps, { sub => Tuple(sub) }))
+
+      case cct: CaseClassType =>
+        List(
+          Generator(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} )
+        )
+
+      case act: AbstractClassType =>
+        act.knownCCDescendents.map { cct =>
+          Generator[TypeTree, Expr](cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} )
         }
-      } else {
-        Nil
-      }
-    }
 
-    val funcs = functionsAvailable(prog).toSeq.flatMap(getCandidates)
+      case st @ SetType(base) =>
+        List(
+          Generator(List(base),   { case elems     => FiniteSet(elems.toSet).setType(st) }),
+          Generator(List(st, st), { case Seq(a, b) => SetUnion(a, b) }),
+          Generator(List(st, st), { case Seq(a, b) => SetIntersection(a, b) }),
+          Generator(List(st, st), { case Seq(a, b) => SetDifference(a, b) })
+        )
 
-    funcs.map{ tfd =>
-      Generator[TypeTree, Expr](tfd.params.map(_.tpe), { sub => FunctionInvocation(tfd, sub) })
+      case _ =>
+        Nil
     }
   }
 
-  def computeSafeRecCalls(t: TypeTree): Seq[Gen] = {
-    val calls = terminatingCalls(prog, t, pathCondition)
-
-    calls.map {
-      case (e, free) =>
-        val freeSeq = free.toSeq
-        Generator[TypeTree, Expr](freeSeq.map(_.getType), { sub =>
-          replaceFromIDs(freeSeq.zip(sub).toMap, e)
-        })
+  case class OneOf(inputs: Seq[Expr]) extends ExpressionGrammar {
+    def computeProductions(t: TypeTree): Seq[Gen] = {
+      inputs.collect {
+        case i if isSubtypeOf(i.getType, t) => Generator[TypeTree, Expr](Nil, { _ => i })
+      }
     }
   }
 
-  def computeSubexpressionGenerators(canPlacehold : Expr => Boolean)(e : Expr) : Seq[Gen] = {
-
-    /** A simple Generator API **/
-
-    def gen(tps : Seq[TypeTree], f : Seq[Expr] => Expr) : Gen = 
-      Generator[TypeTree, Expr](tps,f)
-
-    // A generator that accepts a single type, and always regenerates its input
-    // (simple placeholder of 1 position)
-    def wildcardGen(tp : TypeTree) = gen(Seq(tp), { case Seq(x) => x })
-
-    // A generator that always regenerates its input
-    def const(e: Expr) : Gen = gen(Seq(), _ => e)
- 
-    // Creates a new generator by applying f on the result of g.builder
-    def map(f : Expr => Expr)(g : Gen) : Gen = {
-      gen(g.subTrees, es => f(g.builder(es)) )
-    }
-
-    // Concatenate a sequence of generators into a generator.
-    // The arity of the resulting generator is the total arity of the constituting generators.
-    // builder is the function combining the results of the partial generators
-    def concat(gens : Seq[Gen], builder : Seq[Expr] => Expr ) : Gen = {
-      val types = gens flatMap { _.subTrees }
-      gen(
-        types,
-        exprs => {
-          assert(exprs.length == types.length) // Total arity is arity of subgenerators
-          var remaining = exprs
-          val fromSubGens = for (gen <- gens) yield {
-            val (current, rem) = remaining splitAt gen.arity
-            remaining = rem
-            gen.builder(current)
-          }
-          builder(fromSubGens)
-        }
-      )
-          
-    }
-
-    
-    def rec(e : Expr) : Seq[Gen] = {
-      
-      // Add an additional wildcard generator, if current expression passes the filter
-      def optWild(gens : Seq[Gen]) : Seq[Gen] = 
-        if (canPlacehold(e)) {
-           wildcardGen(e.getType) +: gens
-        }
-        else gens
-
-
-      e match {
-
-        case t : Terminal => 
-          // In case of Terminal, we either return the terminal itself, or the input expression
-          optWild(Seq(const(t)))
-          
-        case UnaryOperator(sub, builder) =>
-          val fromSub = for (subGen <- rec(sub)) yield map(builder)(subGen) 
-          optWild(fromSub)
-
-        case BinaryOperator(e1,e2,builder) =>
-          val fromSub = for {
-            subGen1 <- rec(e1)
-            subGen2 <- rec(e2)
-          } yield concat(Seq(subGen1, subGen2), { case Seq(e1,e2) => builder(e1,e2) })
-
-          optWild(fromSub)
-
-        case NAryOperator(subExpressions, builder) =>
+  case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[TypeTree]) extends ExpressionGrammar {
+   def computeProductions(t: TypeTree): Seq[Gen] = {
+
+     def getCandidates(fd: FunDef): Seq[TypedFunDef] = {
+       // Prevents recursive calls
+       val cfd = currentFunction
+
+       val isRecursiveCall = (prog.callGraph.transitiveCallers(cfd) + cfd) contains fd
+
+       val isNotSynthesizable = fd.body match {
+         case Some(b) =>
+           !containsChoose(b)
+
+         case None =>
+           false
+       }
+
+
+       if (!isRecursiveCall && isNotSynthesizable) {
+         val free = fd.tparams.map(_.tp)
+         canBeSubtypeOf(fd.returnType, free, t) match {
+           case Some(tpsMap) =>
+             val tfd = fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp)))
+
+             if (tpsMap.size < free.size) {
+               /* Some type params remain free, we want to assign them:
+                *
+                * List[T] => Int, for instance, will be found when
+                * requesting Int, but we need to assign T to viable
+                * types. For that we use list of input types as heuristic,
+                * and look for instantiations of T such that input <?:
+                * List[T].
+                */
+               types.distinct.flatMap { (atpe: TypeTree) =>
+                 var finalFree = free.toSet -- tpsMap.keySet
+                 var finalMap = tpsMap
+
+                 for (ptpe <- tfd.params.map(_.tpe).distinct) {
+                   canBeSubtypeOf(atpe, finalFree.toSeq, ptpe) match {
+                     case Some(ntpsMap) =>
+                       finalFree --= ntpsMap.keySet
+                       finalMap  ++= ntpsMap
+                     case _ =>
+                   }
+                 }
+
+                 if (finalFree.isEmpty) {
+                   List(fd.typed(free.map(tp => finalMap.getOrElse(tp, tp))))
+                 } else {
+                   Nil
+                 }
+               }
+             } else {
+               /* All type parameters that used to be free are assigned
+                */
+               List(tfd)
+             }
+           case None =>
+             Nil
+         }
+       } else {
+         Nil
+       }
+     }
+
+     val funcs = functionsAvailable(prog).toSeq.flatMap(getCandidates)
+
+     funcs.map{ tfd =>
+       Generator[TypeTree, Expr](tfd.params.map(_.tpe), { sub => FunctionInvocation(tfd, sub) })
+     }
+   }
+  }
 
-          def combinations[A](seqs : Seq[Seq[A]]) : Seq[Seq[A]] = {
-            if (seqs.isEmpty) Seq(Seq())
-            else for {
-              hd <- seqs.head
-              tl <- combinations(seqs.tail)
-            } yield hd +: tl
-          }
+  case class SafeRecCalls(prog: Program, pc: Expr) extends ExpressionGrammar {
+    def computeProductions(t: TypeTree): Seq[Gen] = {
+      val calls = terminatingCalls(prog, t, pc)
 
-          val combos = combinations(subExpressions map rec) 
-          val fromSub = combos map { concat(_, builder) }
-     
-          optWild(fromSub)
+      calls.map {
+        case (e, free) =>
+          val freeSeq = free.toSeq
+          Generator[TypeTree, Expr](freeSeq.map(_.getType), { sub =>
+            replaceFromIDs(freeSeq.zip(sub).toMap, e)
+          })
       }
     }
-    
-    rec(e)
-
   }
 
-  def computeCompleteSubexpressionGenerators = inputs flatMap computeSubexpressionGenerators{ _ => true}
-
-
-  def printGrammar(printer: String => Unit) {
-    for ((t, gs) <- cache; g <- gs) {
-      val subs = g.subTrees.map { tpe => FreshIdentifier(tpe.toString).setType(tpe).toVariable }
-      val gen = g.builder(subs)
+  def default(prog: Program, inputs: Seq[Expr], currentFunction: FunDef, pc: Expr): ExpressionGrammar = {
+    BaseGrammar ||
+    OneOf(inputs) ||
+    FunctionCalls(prog, currentFunction, inputs.map(_.getType)) ||
+    SafeRecCalls(prog, pc)
+  }
 
-      printer(f"$t%30s ::= "+gen)
-    }
+  def default(sctx: SynthesisContext, p: Problem): ExpressionGrammar = {
+    default(sctx.program, p.as.map(_.toVariable), sctx.functionContext, p.pc)
   }
 }