From e16ab22ca9337adc6ede78417d3b6942426e366b Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Thu, 1 Nov 2012 14:34:47 +0100
Subject: [PATCH] Optimistic Ground Rule with CE guided refinement

---
 src/main/scala/leon/purescala/TreeOps.scala |  5 ++
 src/main/scala/leon/synthesis/Rules.scala   | 57 +++++++++++++++++++++
 2 files changed, 62 insertions(+)

diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index 131c556e3..39b0a7700 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -1302,4 +1302,9 @@ object TreeOps {
 
     rec(expr, Nil)
   }
+
+  def valuateWithModel(expr: Expr, vars: Set[Identifier], model: Map[Identifier, Expr]) = {
+    replace(vars.map(id => Variable(id) -> model.getOrElse(id, simplestValue(id.getType))).toMap, expr)
+  }
+
 }
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index be07f654c..2bcdc96d3 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -14,6 +14,7 @@ object Rules {
     new ADTDual(synth),
     new OnePoint(synth),
     new Ground(synth),
+    new OptimisticGround(synth),
     new CaseSplit(synth),
     new UnusedInput(synth),
     new UnconstrainedOutput(synth),
@@ -101,6 +102,62 @@ class Ground(synth: Synthesizer) extends Rule("Ground", synth, 500) {
   }
 }
 
+class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", synth, 90) {
+  def applyOn(task: Task): RuleResult = {
+    val p = task.problem
+
+    if (!p.as.isEmpty && !p.xs.isEmpty) {
+      val xss = p.xs.toSet
+      val ass = p.as.toSet
+
+      val tpe = TupleType(p.xs.map(_.getType))
+
+      var i = 0;
+      var maxTries = 5;
+
+      var result: Option[RuleResult]   = None
+      var predicates: Seq[Expr]        = Seq()
+
+      while (result.isEmpty && i < maxTries) {
+        val phi = And(p.phi +: predicates)
+        synth.solveSAT(phi) match {
+          case (Some(true), satModel) =>
+            val satXsModel = satModel.filterKeys(xss) 
+
+            val newPhi = valuateWithModel(phi, xss, satModel)
+
+            synth.solveSAT(Not(newPhi)) match {
+              case (Some(true), invalidModel) =>
+                // Found as such as the xs break, refine predicates
+                predicates = valuateWithModel(phi, ass, invalidModel) +: predicates
+
+              case (Some(false), _) =>
+                result = Some(RuleSuccess(Solution(BooleanLiteral(true), newPhi)))
+
+              case _ =>
+                result = Some(RuleInapplicable)
+            }
+
+          case (Some(false), _) =>
+            if (predicates.isEmpty) {
+              result = Some(RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe))))
+            } else {
+              result = Some(RuleInapplicable)
+            }
+          case _ =>
+            result = Some(RuleInapplicable)
+        }
+
+        i += 1 
+      }
+
+      result.getOrElse(RuleInapplicable)
+    } else {
+      RuleInapplicable
+    }
+  }
+}
+
 class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth, 200) {
   def applyOn(task: Task): RuleResult = {
     val p = task.problem
-- 
GitLab