From 4139d9fafa06207e02b7087f282e257af0e6181c Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <etienne.kneuss@epfl.ch>
Date: Mon, 8 Dec 2014 16:23:46 +0100
Subject: [PATCH] Allow per-problem maxunfolding, (re?)introduce depth

---
 src/main/scala/leon/purescala/TreeOps.scala   |  4 ++++
 .../scala/leon/synthesis/rules/Cegis.scala    |  9 +++----
 .../leon/synthesis/rules/CegisLike.scala      | 24 +++++++++++++------
 .../scala/leon/synthesis/rules/Cegless.scala  | 15 ++++++------
 4 files changed, 34 insertions(+), 18 deletions(-)

diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index 46c7bab49..b3bb309d7 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -444,6 +444,10 @@ object TreeOps {
     })(expr)
   }
 
+  def depth(e: Expr): Int = {
+    foldRight[Int]({ (e, sub) => 1 + (0 +: sub).max })(e)
+  }
+
   def normalizeExpression(expr: Expr) : Expr = {
     def rec(e: Expr): Option[Expr] = e match {
       case TupleSelect(Let(id, v, b), ts) =>
diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala
index b033058d8..3f2c75548 100644
--- a/src/main/scala/leon/synthesis/rules/Cegis.scala
+++ b/src/main/scala/leon/synthesis/rules/Cegis.scala
@@ -13,9 +13,10 @@ import utils.Helpers._
 import utils._
 
 case object CEGIS extends CEGISLike[TypeTree]("CEGIS") {
-  def getGrammar(sctx: SynthesisContext, p: Problem) = {
-    ExpressionGrammars.default(sctx, p)
+  def getParams(sctx: SynthesisContext, p: Problem) = {
+    CegisParams(
+      grammar = ExpressionGrammars.default(sctx, p),
+      rootLabel = {(tpe: TypeTree) => tpe }
+    )
   }
-
-  def getRootLabel(tpe: TypeTree): TypeTree = tpe
 }
diff --git a/src/main/scala/leon/synthesis/rules/CegisLike.scala b/src/main/scala/leon/synthesis/rules/CegisLike.scala
index a458f51f1..cec945313 100644
--- a/src/main/scala/leon/synthesis/rules/CegisLike.scala
+++ b/src/main/scala/leon/synthesis/rules/CegisLike.scala
@@ -31,11 +31,13 @@ import utils._
 
 abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
 
-  def getGrammar(sctx: SynthesisContext, p: Problem): ExpressionGrammar[T]
+  case class CegisParams(
+    grammar: ExpressionGrammar[T],
+    rootLabel: TypeTree => T,
+    maxUnfoldings: Int = 3
+  );
 
-  def getRootLabel(tpe: TypeTree): T
-
-  val maxUnfoldings = 3
+  def getParams(sctx: SynthesisContext, p: Problem): CegisParams
 
   def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
 
@@ -56,10 +58,16 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
 
     val interruptManager      = sctx.context.interruptManager
 
+    val params = getParams(sctx, p)
+
+    if (params.maxUnfoldings == 0) {
+      return Nil
+    }
+
     class NonDeterministicProgram(val p: Problem,
                                   val initGuard: Identifier) {
 
-      val grammar = getGrammar(sctx, p)
+      val grammar = params.grammar
 
       // b -> (c, ex) means the clause b => c == ex
       var mappings: Map[Identifier, (Identifier, Expr)] = Map()
@@ -67,7 +75,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
       // b -> Set(c1, c2) means c1 and c2 are uninterpreted behind b, requires b to be closed
       private var guardedTerms: Map[Identifier, Set[Identifier]] = Map(initGuard -> p.xs.toSet)
 
-      private var labels: Map[Identifier, T] = Map() ++ p.xs.map(x => x -> getRootLabel(x.getType))
+      private var labels: Map[Identifier, T] = Map() ++ p.xs.map(x => x -> params.rootLabel(x.getType))
 
       def isBClosed(b: Identifier) = guardedTerms.contains(b)
 
@@ -417,7 +425,9 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) {
 
         val ndProgram = new NonDeterministicProgram(p, initGuard)
         var unfolding = 1
-        val maxUnfoldings = CEGISLike.this.maxUnfoldings
+        val maxUnfoldings = params.maxUnfoldings
+
+        sctx.reporter.debug(s"maxUnfoldings=$maxUnfoldings")
 
         val exSolverTo  = 2000L
         val cexSolverTo = 2000L
diff --git a/src/main/scala/leon/synthesis/rules/Cegless.scala b/src/main/scala/leon/synthesis/rules/Cegless.scala
index bca7d6787..c4789c8a0 100644
--- a/src/main/scala/leon/synthesis/rules/Cegless.scala
+++ b/src/main/scala/leon/synthesis/rules/Cegless.scala
@@ -5,6 +5,7 @@ package synthesis
 package rules
 
 import purescala.Trees._
+import purescala.TreeOps._
 import purescala.TypeTrees._
 import purescala.Common._
 import purescala.Definitions._
@@ -13,11 +14,9 @@ import utils._
 import utils.ExpressionGrammars._
 
 case object CEGLESS extends CEGISLike[Label[String]]("CEGLESS") {
-  override val maxUnfoldings = 1000;
+  def getParams(sctx: SynthesisContext, p: Problem) = {
 
-  def getGrammar(sctx: SynthesisContext, p: Problem) = {
-
-    val TopLevelAnds(clauses) = p.pc
+    val TopLevelAnds(clauses) = p.ws
 
     val guide = sctx.program.library.guide.get
 
@@ -29,10 +28,12 @@ case object CEGLESS extends CEGISLike[Label[String]]("CEGLESS") {
 
     val guidedGrammar = guides.map(SimilarTo(_, inputs.toSet, Set(sctx.functionContext))).foldLeft[ExpressionGrammar[Label[String]]](Empty())(_ || _)
 
-    guidedGrammar
+    CegisParams(
+      grammar = guidedGrammar,
+      rootLabel = { (tpe: TypeTree) => Label(tpe, "G0") },
+      maxUnfoldings = (0 +: guides.map(depth(_) + 1)).max
+    )
   }
-
-  def getRootLabel(tpe: TypeTree): Label[String] = Label(tpe, "G0")
 }
 
 
-- 
GitLab