From c9358c44c08a2480a1ae6a95a55a7f7afe95a371 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <etienne.kneuss@epfl.ch>
Date: Fri, 7 Nov 2014 15:23:28 +0100
Subject: [PATCH] Cost Distributions become Histograms

---
 .../scala/leon/refactor/RepairCostModel.scala |   3 +-
 src/main/scala/leon/refactor/Repairman.scala  |  19 +-
 src/main/scala/leon/synthesis/CostModel.scala |   2 +
 .../scala/leon/synthesis/Distribution.scala   | 179 ------------------
 src/main/scala/leon/synthesis/Histogram.scala | 136 +++++++++++++
 .../leon/synthesis/PartialSolution.scala      |   2 +-
 src/main/scala/leon/synthesis/Problem.scala   |   8 +-
 .../leon/synthesis/graph/DotGenerator.scala   |  14 +-
 .../scala/leon/synthesis/graph/Graph.scala    |  42 ++--
 .../scala/leon/synthesis/graph/Search.scala   |  11 +-
 .../scala/leon/synthesis/rules/Cegis.scala    |  18 +-
 .../leon/synthesis/rules/GuidedCloser.scala   |   8 +-
 .../synthesis/utils/ExpressionGrammar.scala   |   6 +-
 .../scala/leon/utils/InterruptManager.scala   |   2 +-
 14 files changed, 208 insertions(+), 242 deletions(-)
 delete mode 100644 src/main/scala/leon/synthesis/Distribution.scala
 create mode 100644 src/main/scala/leon/synthesis/Histogram.scala

diff --git a/src/main/scala/leon/refactor/RepairCostModel.scala b/src/main/scala/leon/refactor/RepairCostModel.scala
index 40dc2ad43..3cdf9af8d 100644
--- a/src/main/scala/leon/refactor/RepairCostModel.scala
+++ b/src/main/scala/leon/refactor/RepairCostModel.scala
@@ -11,8 +11,9 @@ case class RepairCostModel(cm: CostModel) extends CostModel(cm.name) {
   override def ruleAppCost(app: RuleInstantiation): Cost = {
     app.rule match {
       case rules.GuidedDecomp => 0
+      case rules.GuidedCloser => 0
       case rules.CEGLESS => 0
-      case _ => cm.ruleAppCost(app)
+      case _ => 10+cm.ruleAppCost(app)
     }
   }
   def solutionCost(s: Solution) = cm.solutionCost(s)
diff --git a/src/main/scala/leon/refactor/Repairman.scala b/src/main/scala/leon/refactor/Repairman.scala
index 1c1e17844..e77deaf5c 100644
--- a/src/main/scala/leon/refactor/Repairman.scala
+++ b/src/main/scala/leon/refactor/Repairman.scala
@@ -21,6 +21,7 @@ import verification._
 import synthesis._
 import synthesis.rules._
 import synthesis.heuristics._
+import graph.DotGenerator
 
 class Repairman(ctx: LeonContext, program: Program, fd: FunDef) {
   val reporter = ctx.reporter
@@ -79,7 +80,7 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) {
     // Synthesis from the ground up
     val p = Problem(fd.params.map(_.id).toList, pc, spec, List(out))
     val ch = Choose(List(out), spec)
-    //fd.body = Some(ch)
+    fd.body = Some(ch)
 
     val soptions0 = SynthesisPhase.processOptions(ctx);
 
@@ -100,28 +101,30 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) {
           val expr = sol.toSimplifiedExpr(ctx, program)
 
           val (npr, fds) = synthesizer.solutionToProgram(sol)
-          solutions ::= (sol, expr, fds)
 
           if (!sol.isTrusted) {
-
             getVerificationCounterExamples(fds.head, npr) match {
               case Some(ces) =>
                 testBank ++= ces
                 reporter.info("Failed :(, but I learned: "+ces.mkString("  |  "))
               case None =>
-                reporter.info("ZZUCCESS!")
+                solutions ::= (sol, expr, fds)
+                reporter.info(ASCIIHelpers.title("ZUCCESS!"))
             }
           } else {
-            reporter.info("ZZUCCESS!")
+            solutions ::= (sol, expr, fds)
+            reporter.info(ASCIIHelpers.title("ZUCCESS!"))
           }
         }
 
+        if (soptions.generateDerivationTrees) {
+          val dot = new DotGenerator(search.g)
+          dot.writeFile("derivation"+DotGenerator.nextId()+".dot")
+        }
 
         if (solutions.isEmpty) {
-            reporter.info("Trey aagggain")
-            repair()
+            reporter.info(ASCIIHelpers.title("FAILURZ!"))
         } else {
-
           reporter.info(ASCIIHelpers.title("Solutions"))
           for (((sol, expr, fds), i) <- solutions.zipWithIndex) {
             reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":"))
diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala
index b4cac6662..cba5671e9 100644
--- a/src/main/scala/leon/synthesis/CostModel.scala
+++ b/src/main/scala/leon/synthesis/CostModel.scala
@@ -3,6 +3,8 @@
 package leon
 package synthesis
 
+import graph._
+
 import purescala.Trees._
 import purescala.TreeOps._
 
diff --git a/src/main/scala/leon/synthesis/Distribution.scala b/src/main/scala/leon/synthesis/Distribution.scala
deleted file mode 100644
index 7c6a07040..000000000
--- a/src/main/scala/leon/synthesis/Distribution.scala
+++ /dev/null
@@ -1,179 +0,0 @@
-/* Copyright 2009-2014 EPFL, Lausanne */
-
-package leon.synthesis
-
-class Distribution(val span: Int, val values: Array[Long], val total: Long) extends Ordered[Distribution] {
-  def and(that: Distribution): Distribution = { val res = (this, that) match {
-    case (d1, d2) if d1.total == 0 =>
-      d1
-
-    case (d1, d2) if d2.total == 0 =>
-      d2
-
-    case (d1: PointDistribution, d2: PointDistribution) =>
-      if (d1.at + d2.at >= span) {
-        Distribution.empty(span)
-      } else {
-        new PointDistribution(span, d1.at+d2.at)
-      }
-
-    case (d: PointDistribution, o) =>
-      val a   = Array.fill(span)(0l)
-
-      val base = d.at
-      var innerTotal = 0l;
-      var i    = d.at;
-      while(i < span) {
-        val v = o.values(i-base)
-        a(i) = v
-        innerTotal += v
-        i += 1
-      }
-
-      if (innerTotal == 0) {
-        Distribution.empty(span)
-      } else {
-        new Distribution(span, a, total)
-      }
-
-    case (o, d: PointDistribution) =>
-      val a   = Array.fill(span)(0l)
-
-      val base = d.at
-      var innerTotal = 0l;
-      var i    = d.at;
-      while(i < span) {
-        val v = o.values(i-base)
-        a(i) = v
-        innerTotal += v
-        i += 1
-      }
-
-      if (innerTotal == 0) {
-        Distribution.empty(span)
-      } else {
-        new Distribution(span, a, total)
-      }
-
-    case (left, right) =>
-      if (left == right) {
-        left
-      } else {
-        val a   = Array.fill(span)(0l)
-        var innerTotal = 0l;
-        var i = 0;
-        while (i < span) {
-          var j = 0;
-          while (j < span) {
-            if (i+j < span) {
-              val lv = left.values(i)
-              val rv = right.values(j)
-
-              a(i+j) += lv*rv
-              innerTotal += lv*rv
-            }
-            j += 1
-          }
-          i += 1
-        }
-
-        if (innerTotal == 0) {
-          Distribution.empty(span)
-        } else {
-          new Distribution(span, a, left.total * right.total)
-        }
-      }
-  }
-    //println("And of "+this+" and "+that+" = "+res)
-    res
-  }
-
-  def or(that: Distribution): Distribution = (this, that) match {
-    case (d1, d2) if d1.total == 0 =>
-      d2
-
-    case (d1, d2) if d2.total == 0 =>
-      d1
-
-    case (d1: PointDistribution, d2: PointDistribution) =>
-      if (d1.at < d2.at) {
-        d1
-      } else {
-        d2
-      }
-
-    case (d1, d2) =>
-      if (d1.weightedSum < d2.weightedSum) {
-      //if (d1.firstNonZero < d2.firstNonZero) {
-        d1
-      } else {
-        d2
-      }
-  }
-
-  lazy val firstNonZero: Int = {
-    if (total == 0) {
-      span
-    } else {
-      var i = 0;
-      var continue = true;
-      while (continue && i < span) {
-        if (values(i) != 0l) {
-          continue = false
-        }
-        i += 1
-      }
-      i
-    }
-  }
-
-  lazy val weightedSum: Double = {
-    var res = 0d;
-    var i = 0;
-    while (i < span) {
-      res += (1d*i*values(i))/total
-      i += 1
-    }
-    res
-  }
-
-  override def toString: String = {
-    "Tot:"+total+"(at "+firstNonZero+")"
-  }
-
-  def compare(that: Distribution) = {
-    this.firstNonZero - that.firstNonZero
-  }
-}
-
-object Distribution {
-  def point(span: Int, at: Int) = {
-    if (span <= at) {
-      empty(span)
-    } else {
-      new PointDistribution(span, at)
-    }
-  }
-
-  def empty(span: Int)               = new Distribution(span, Array[Long](), 0l)
-  def uniform(span: Int, v: Long, total: Int) = {
-    new Distribution(span, Array.fill(span)(v), total)
-  }
-
-  def uniformFrom(span: Int, from: Int, ratio: Double) = {
-    var i = from
-    val a = Array.fill(span)(0l)
-    while(i < span) {
-      a(i) = 1
-      i += 1
-    }
-    
-    new Distribution(span, a, ((span-from).toDouble*(1/ratio)).toInt)
-  }
-}
-
-class PointDistribution(span: Int, val at: Int) extends Distribution(span, new Array[Long](span).updated(at, 1l), 1l) {
-  override lazy val firstNonZero: Int = {
-    at
-  }
-}
diff --git a/src/main/scala/leon/synthesis/Histogram.scala b/src/main/scala/leon/synthesis/Histogram.scala
new file mode 100644
index 000000000..10176e97b
--- /dev/null
+++ b/src/main/scala/leon/synthesis/Histogram.scala
@@ -0,0 +1,136 @@
+/* Copyright 2009-2014 EPFL, Lausanne */
+
+package leon.synthesis
+
+/**
+ * Histogram from 0 to `bound`, each value between 0 and 1
+ * hist(c) = v means we have a `v` likelihood of finding a solution of cost `c`
+ */
+class Histogram(val bound: Int, val values: Array[Double]) extends Ordered[Histogram] {
+  /**
+   */
+  def and(that: Histogram): Histogram = {
+    val a = Array.fill(bound)(0d)
+    var i = 0;
+    while(i < bound) {
+      var j = 0;
+      while(j <= i) {
+
+        val v1 = (this.values(j) * that.values(i-j))
+        val v2 = a(i)
+
+        a(i) = v1+v2 - (v1*v2)
+
+        j += 1
+      }
+      i += 1
+    }
+
+    new Histogram(bound, a)
+  }
+
+  /**
+   * hist1(c) || hist2(c) == hist1(c)+hist2(c) - hist1(c)*hist2(c)
+   */
+  def or(that: Histogram): Histogram = {
+    val a = Array.fill(bound)(0d)
+    var i = 0;
+    while(i < bound) {
+      val v1 = this.values(i)
+      val v2 = that.values(i)
+
+      a(i) = v1+v2 - (v1*v2)
+      i += 1
+    }
+
+    new Histogram(bound, a)
+  }
+
+  lazy val maxInfo = {
+    var max    = 0d;
+    var argMax = -1;
+    var i      = 0;
+    while(i < bound) {
+      if ((argMax < 0) || values(i) > max) {
+        argMax = i;
+        max = values(i)
+      }
+      i += 1;
+    }
+    (max, argMax)
+  }
+
+  def isImpossible  = maxInfo._1 == 0
+
+  /**
+   * Should return v<0 if `this` < `that`, that is, `this` represents better
+   * solutions than `that`.
+   */
+  def compare(that: Histogram) = {
+    val (m1, am1) = this.maxInfo
+    val (m2, am2) = that.maxInfo
+
+    if (m1 == m2) {
+      am1 - am2
+    } else {
+      if (m2 < m1) {
+        -1
+      } else if (m2 == m1) {
+        0
+      } else {
+        +1
+      }
+    }
+  }
+
+  override def toString: String = {
+    var printed = 0
+    val info = for (i <- 0 until bound if values(i) != 0 && printed < 5) yield {
+      f"$i%2d -> ${values(i)}%,3f"
+    }
+    val (m,am) = maxInfo
+
+    "H("+m+"@"+am+": "+info.mkString(", ")+")"
+  }
+
+}
+
+object Histogram {
+  def clampV(v: Double): Double = {
+    if (v < 0) {
+      0d
+    } else if (v > 1) {
+      1d
+    } else {
+      v
+    }
+  }
+
+  def point(bound: Int, at: Int, v: Int) = {
+    if (bound <= at) {
+      empty(bound)
+    } else {
+      new Histogram(bound, Array.fill(bound)(0d).updated(at, clampV(v)))
+    }
+  }
+
+  def empty(bound: Int) = {
+    new Histogram(bound, Array.fill(bound)(0d))
+  }
+
+  def uniform(bound: Int, v: Double) = {
+    uniformFrom(bound, 0, v)
+  }
+
+  def uniformFrom(bound: Int, from: Int, v: Double) = {
+    val vSafe = clampV(v)
+    var i = from
+    val a = Array.fill(bound)(0d)
+    while(i < bound) {
+      a(i) = vSafe
+      i += 1
+    }
+
+    new Histogram(bound, a)
+  }
+}
diff --git a/src/main/scala/leon/synthesis/PartialSolution.scala b/src/main/scala/leon/synthesis/PartialSolution.scala
index 039342d3c..2e33fd105 100644
--- a/src/main/scala/leon/synthesis/PartialSolution.scala
+++ b/src/main/scala/leon/synthesis/PartialSolution.scala
@@ -37,7 +37,7 @@ class PartialSolution(g: Graph, includeUntrusted: Boolean) {
           if (descs.isEmpty) {
             completeProblem(on.p)
           } else {
-            getSolutionFor(descs.minBy(_.costDist))
+            getSolutionFor(descs.minBy(_.histogram))
           }
         } else {
           completeProblem(on.p)
diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala
index cc21e87a9..a9d83d80e 100644
--- a/src/main/scala/leon/synthesis/Problem.scala
+++ b/src/main/scala/leon/synthesis/Problem.scala
@@ -20,12 +20,14 @@ case class Problem(as: List[Identifier], pc: Expr, phi: Expr, xs: List[Identifie
 
     val ev = new DefaultEvaluator(sctx.context, sctx.program)
 
+    val safePc = removeWitnesses(sctx.program)(pc)
+
     def isValidExample(ex: Example): Boolean = {
       val (mapping, cond) = ex match {
         case io: InOutExample =>
-          (Map((as zip io.ins) ++ (xs zip io.outs): _*), And(pc, phi))
+          (Map((as zip io.ins) ++ (xs zip io.outs): _*), And(safePc, phi))
         case i =>
-          ((as zip i.ins).toMap, pc)
+          ((as zip i.ins).toMap, safePc)
       }
 
       ev.eval(cond, mapping) match {
@@ -85,7 +87,7 @@ case class Problem(as: List[Identifier], pc: Expr, phi: Expr, xs: List[Identifie
       case FunctionInvocation(tfd, List(in, out, FiniteMap(inouts))) if tfd.id.name == "passes" =>
         val infos = extractIds(Tuple(Seq(in, out)))
         val exs   = inouts.map{ case (i, o) =>
-          val test = Tuple(Seq(i, o)) 
+          val test = Tuple(Seq(i, o))
           val ids = variablesOf(test)
           evaluator.eval(test, ids.map { (i: Identifier) => i -> i.toVariable }.toMap) match {
             case EvaluationResults.Successful(res) => res
diff --git a/src/main/scala/leon/synthesis/graph/DotGenerator.scala b/src/main/scala/leon/synthesis/graph/DotGenerator.scala
index 339deea13..3c74dd772 100644
--- a/src/main/scala/leon/synthesis/graph/DotGenerator.scala
+++ b/src/main/scala/leon/synthesis/graph/DotGenerator.scala
@@ -2,7 +2,7 @@ package leon.synthesis.graph
 
 import java.io.{File, FileWriter, BufferedWriter}
 
-import leon.synthesis.Distribution
+import leon.synthesis.Histogram
 
 class DotGenerator(g: Graph) {
   import g.{Node, AndNode, OrNode, RootNode}
@@ -68,11 +68,11 @@ class DotGenerator(g: Graph) {
     res.toString
   }
 
-  def distrib(d: Distribution): String = {
-    if (d.firstNonZero == g.maxCost) {
-      ">max"
+  def hist(h: Histogram): String = {
+    if (h.isImpossible) {
+      "-/-"
     } else {
-      d.firstNonZero.toString
+      h.maxInfo._1+"@"+h.maxInfo._2
     }
   }
 
@@ -110,9 +110,9 @@ class DotGenerator(g: Graph) {
     //cost
     n match {
       case an: AndNode =>
-        res append "<TR><TD BORDER=\"0\">"+escapeHTML(distrib(n.costDist)+" ("+distrib(an.selfCost))+")</TD></TR>"
+        res append "<TR><TD BORDER=\"0\">"+escapeHTML(hist(n.histogram)+" ("+hist(an.selfHistogram))+")</TD></TR>"
       case on: OrNode =>
-        res append "<TR><TD BORDER=\"0\">"+escapeHTML(distrib(n.costDist))+"</TD></TR>"
+        res append "<TR><TD BORDER=\"0\">"+escapeHTML(hist(n.histogram))+"</TD></TR>"
     }
 
     res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(nodeDesc(n)))+"</TD></TR>";
diff --git a/src/main/scala/leon/synthesis/graph/Graph.scala b/src/main/scala/leon/synthesis/graph/Graph.scala
index bea159f42..4ba9f3751 100644
--- a/src/main/scala/leon/synthesis/graph/Graph.scala
+++ b/src/main/scala/leon/synthesis/graph/Graph.scala
@@ -22,13 +22,13 @@ sealed class Graph(problem: Problem, costModel: CostModel) {
     val p: Problem
 
     // costs
-    var costDist: Distribution
-    def onNewDist(desc: Node)
+    var histogram: Histogram
+    def updateHistogram(desc: Node)
 
     var isSolved: Boolean   = false
 
     def isClosed: Boolean = {
-      costDist.total == 0
+      histogram.maxInfo._1 == 0
     }
 
     def onSolved(desc: Node)
@@ -52,8 +52,8 @@ sealed class Graph(problem: Problem, costModel: CostModel) {
 
   class AndNode(parent: Option[Node], val ri: RuleInstantiation) extends Node(parent) {
     val p = ri.problem
-    var selfCost = Distribution.point(maxCost, costModel.ruleAppCost(ri))
-    var costDist: Distribution = Distribution.uniformFrom(maxCost, costModel.ruleAppCost(ri), 0.5)
+    var selfHistogram = Histogram.point(maxCost, costModel.ruleAppCost(ri), 100)
+    var histogram     = Histogram.uniformFrom(maxCost, costModel.ruleAppCost(ri), 50)
 
     override def toString = "\u2227 "+ri;
 
@@ -72,8 +72,8 @@ sealed class Graph(problem: Problem, costModel: CostModel) {
           solutions = Some(sols)
           selectedSolution = 0;
 
-          costDist = sols.foldLeft(Distribution.empty(maxCost)) {
-            (d, sol) => d or Distribution.point(maxCost, costModel.solutionCost(sol))
+          histogram = sols.foldLeft(Histogram.empty(maxCost)) {
+            (d, sol) => d or Histogram.point(maxCost, costModel.solutionCost(sol), 100)
           }
 
           isSolved = sols.nonEmpty
@@ -86,7 +86,7 @@ sealed class Graph(problem: Problem, costModel: CostModel) {
           }
 
           parents.foreach{ p =>
-            p.onNewDist(this)
+            p.updateHistogram(this)
             if (isSolved) {
               p.onSolved(this)
             }
@@ -112,18 +112,18 @@ sealed class Graph(problem: Problem, costModel: CostModel) {
       }
     }
 
-    def onNewDist(desc: Node) = {
+    def updateHistogram(desc: Node) = {
       recomputeCost()
     }
 
     private def recomputeCost() = {
-      val newCostDist = descendents.foldLeft(selfCost){
-        case (c, d)  => c and d.costDist
+      val newHistogram = descendents.foldLeft(selfHistogram){
+        case (c, d)  => c and d.histogram
       }
 
-      if (newCostDist != costDist) {
-        costDist = newCostDist
-        parents.foreach(_.onNewDist(this))
+      if (newHistogram != histogram) {
+        histogram = newHistogram
+        parents.foreach(_.updateHistogram(this))
       }
     }
 
@@ -143,7 +143,7 @@ sealed class Graph(problem: Problem, costModel: CostModel) {
   }
 
   class OrNode(parent: Option[Node], val p: Problem) extends Node(parent) {
-    var costDist: Distribution = Distribution.uniformFrom(maxCost, costModel.problemCost(p), 0.5)
+    var histogram  = Histogram.uniformFrom(maxCost, costModel.problemCost(p), 50)
 
     override def toString = "\u2228 "+p;
 
@@ -171,18 +171,18 @@ sealed class Graph(problem: Problem, costModel: CostModel) {
     }
 
     private def recomputeCost(): Unit = {
-      val newCostDist = descendents.foldLeft(Distribution.empty(maxCost)){
-        case (c, d)  => c or d.costDist
+      val newHistogram = descendents.foldLeft(Histogram.empty(maxCost)){
+        case (c, d)  => c or d.histogram
       }
 
-      if (costDist != newCostDist) {
-        costDist = newCostDist
-        parents.foreach(_.onNewDist(this))
+      if (histogram != newHistogram) {
+        histogram = newHistogram
+        parents.foreach(_.updateHistogram(this))
 
       }
     }
 
-    def onNewDist(desc: Node): Unit = {
+    def updateHistogram(desc: Node): Unit = {
       recomputeCost()
     }
   }
diff --git a/src/main/scala/leon/synthesis/graph/Search.scala b/src/main/scala/leon/synthesis/graph/Search.scala
index 030d14e44..fd639221d 100644
--- a/src/main/scala/leon/synthesis/graph/Search.scala
+++ b/src/main/scala/leon/synthesis/graph/Search.scala
@@ -90,7 +90,7 @@ class SimpleSearch(ctx: LeonContext, p: Problem, costModel: CostModel, bound: Op
 
         case on: g.OrNode =>
           if (on.descendents.nonEmpty) {
-            findIn(on.descendents.minBy(_.costDist))
+            findIn(on.descendents.minBy(_.histogram))
           }
       }
     }
@@ -157,17 +157,18 @@ class ManualSearch(ctx: LeonContext, problem: Problem, costModel: CostModel) ext
     def failed(str: String) = "\u001b[31m" + str + "\u001b[0m"
     def solved(str: String) = "\u001b[32m" + str + "\u001b[0m"
 
-    def displayDist(d: Distribution): String = {
-      f"${d.firstNonZero}%3d"
+    def displayHistogram(h: Histogram): String = {
+      val (max, maxarg) = h.maxInfo
+      f"$max%,2f@$maxarg%2d"
     }
 
     def displayNode(n: Node): String = n match {
       case an: AndNode =>
         val app = an.ri
-        s"(${displayDist(n.costDist)}) $app"
+        s"(${displayHistogram(n.histogram)}) $app"
       case on: OrNode =>
         val p = on.p
-        s"(${displayDist(n.costDist)}) $p"
+        s"(${displayHistogram(n.histogram)}) $p"
     }
 
     def traversePathFrom(n: Node, prefix: List[Int]) {
diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala
index b5997165b..0b740adf7 100644
--- a/src/main/scala/leon/synthesis/rules/Cegis.scala
+++ b/src/main/scala/leon/synthesis/rules/Cegis.scala
@@ -49,9 +49,9 @@ case object CEGLESS extends CEGISLike("CEGLESS") {
 
     val inputs = p.as.map(_.toVariable)
 
-    val guidedGrammar = guides.map(SimilarTo(_, inputs.toSet)).foldLeft[ExpressionGrammar](Empty)(_ || _)
+    val guidedGrammar = guides.map(SimilarTo(_, inputs.toSet, Set(sctx.functionContext))).foldLeft[ExpressionGrammar](Empty)(_ || _)
 
-    guidedGrammar || OneOf(inputs)
+    guidedGrammar || OneOf(inputs) || SafeRecCalls(sctx.program, p.pc)
   }
 }
 
@@ -170,12 +170,11 @@ abstract class CEGISLike(name: String) extends Rule(name) {
               res == BooleanLiteral(true)
 
             case EvaluationResults.RuntimeError(err) =>
-              //sctx.reporter.error("Error testing CE: "+err)
               false
 
             case EvaluationResults.EvaluatorError(err) =>
               sctx.reporter.error("Error testing CE: "+err)
-              true
+              false
           }
         } else {
           true
@@ -600,15 +599,16 @@ abstract class CEGISLike(name: String) extends Rule(name) {
             // We further filter the set of working programs to remove those that fail on known examples
             if (useCEPruning && hasInputExamples() && ndProgram.canTest()) {
 
-              for (p <- prunedPrograms if !interruptManager.isInterrupted()) {
+              for (bs <- prunedPrograms if !interruptManager.isInterrupted()) {
                 var valid = true
                 val examples = allInputExamples()
                 while(valid && examples.hasNext) {
                   val e = examples.next()
-                  if (!ndProgram.testForProgram(p)(e)) {
+                  if (!ndProgram.testForProgram(bs)(e)) {
                     failedTestsStats(e) += 1
-                    wrongPrograms += p
-                    prunedPrograms -= p
+                    wrongPrograms += bs
+                    prunedPrograms -= bs
+
                     valid = false;
                   }
                 }
@@ -622,7 +622,7 @@ abstract class CEGISLike(name: String) extends Rule(name) {
             val nPassing = prunedPrograms.size
             sctx.reporter.debug("#Programs passing tests: "+nPassing)
 
-            if (nPassing == 0) {
+            if (nPassing == 0 || interruptManager.isInterrupted()) {
               skipCESearch = true;
             } else if (nPassing <= testUpTo) {
               // Immediate Test
diff --git a/src/main/scala/leon/synthesis/rules/GuidedCloser.scala b/src/main/scala/leon/synthesis/rules/GuidedCloser.scala
index 91b1494c7..f46e139a0 100644
--- a/src/main/scala/leon/synthesis/rules/GuidedCloser.scala
+++ b/src/main/scala/leon/synthesis/rules/GuidedCloser.scala
@@ -33,9 +33,7 @@ case object GuidedCloser extends NormalizingRule("Guided Closer") {
 
       val vc = simp(And(p.pc, LetTuple(p.xs, wrappedE, Not(p.phi))))
 
-      //println(vc)
-
-      val solver = sctx.newSolver.setTimeout(1000L)
+      val solver = sctx.newSolver.setTimeout(2000L)
 
       solver.assertCnstr(vc)
       val osol = solver.check match {
@@ -47,8 +45,8 @@ case object GuidedCloser extends NormalizingRule("Guided Closer") {
             printer(vc)
             printer("== Unknown ==")
           }
-          None
-          //Some(Solution(BooleanLiteral(true), Set(), wrappedE, false))
+          //None
+          Some(Solution(BooleanLiteral(true), Set(), wrappedE, false))
 
         case _ =>
           sctx.reporter.ifDebug { printer =>
diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
index f20fa982b..5ed5990c9 100644
--- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
+++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
@@ -112,7 +112,7 @@ object ExpressionGrammars {
     }
   }
 
-  case class SimilarTo(e: Expr, exclude: Set[Expr] = Set()) extends ExpressionGrammar {
+  case class SimilarTo(e: Expr, excludeExpr: Set[Expr] = Set(), excludeFCalls: Set[FunDef] = Set()) extends ExpressionGrammar {
     lazy val allSimilar = computeSimilar(e).groupBy(_._1).mapValues(_.map(_._2))
 
     def computeProductions(t: TypeTree): Seq[Gen] = {
@@ -121,7 +121,7 @@ object ExpressionGrammars {
 
     def computeSimilar(e : Expr) : Seq[(TypeTree, Gen)] = {
 
-      var seenSoFar = exclude;
+      var seenSoFar = excludeExpr;
 
       def gen(retType : TypeTree, tps : Seq[TypeTree], f : Seq[Expr] => Expr) : (TypeTree, Gen) =
         (bestRealType(retType), Generator[TypeTree, Expr](tps.map(bestRealType), f))
@@ -142,6 +142,8 @@ object ExpressionGrammars {
           val subs: Seq[(TypeTree, Gen)] = e match {
             case _: Terminal | _: Let | _: LetTuple | _: LetDef | _: MatchExpr =>
               Seq()
+            case FunctionInvocation(TypedFunDef(fd, _), _) if excludeFCalls contains fd =>
+              Seq()
             case UnaryOperator(sub, builder) => Seq(
               gen(tp, List(sub.getType), { case Seq(ex) => builder(ex) } )
             ) ++ rec(sub)
diff --git a/src/main/scala/leon/utils/InterruptManager.scala b/src/main/scala/leon/utils/InterruptManager.scala
index 1aa19f9a4..abe0d927b 100644
--- a/src/main/scala/leon/utils/InterruptManager.scala
+++ b/src/main/scala/leon/utils/InterruptManager.scala
@@ -57,7 +57,7 @@ class InterruptManager(reporter: Reporter) {
       def handle(sig: Signal) {
         Signal.handle(sigINT, oldHandler)
         println
-        reporter.info("Aborting Leon...")
+        reporter.warning("Aborting Leon...")
 
         interrupt()
 
-- 
GitLab