From 865736519d61a9b0ce84c8cc23ab7af89972d58f Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Tue, 18 Dec 2012 14:24:37 +0100
Subject: [PATCH] Implement the inner case-split heuristic, extend case-split
 to Ors with more than two elements

1) Implement inner-case-split heuristic, that distribute And(..,Or(),..)
in a case-split. It also pushes Not() inside the formula, so
Not(And(a,b)) becomes Or(Not(a), Not(b)) which is then handled by
inner-case-split.

2) Extend regular case-split to work with n-way ors. Or(a, .., m,n) gets
decomposed into a N-alternatives case-split.

Given solutions (Sa, .., Sm, Sn), it recomposes into:
    If(Sa.pre, Sa.term, If(.., If(Sm.pre, Sm.term, Sn.term)))
---
 src/main/scala/leon/purescala/TreeOps.scala   | 15 +-----
 .../scala/leon/synthesis/Heuristics.scala     |  1 +
 .../synthesis/heuristics/InnerCaseSplit.scala | 51 +++++++++++++++++++
 .../leon/synthesis/rules/CaseSplit.scala      | 31 +++++++----
 testcases/synthesis/InnerSplit.scala          |  5 ++
 5 files changed, 80 insertions(+), 23 deletions(-)
 create mode 100644 src/main/scala/leon/synthesis/heuristics/InnerCaseSplit.scala
 create mode 100644 testcases/synthesis/InnerSplit.scala

diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index 3409bde31..ff30df5c9 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -918,19 +918,6 @@ object TreeOps {
     genericTransform[Unit]((e,c) => (e, None), newPost, noCombiner)(())(expr)._1
   }
 
-  def toDNF(e: Expr): Expr = {
-    def pre(e: Expr) = e match {
-      case And(Seq(l, Or(Seq(r1, r2)))) =>
-        Or(And(l, r1), And(l, r2))
-      case And(Seq(Or(Seq(l1, l2)), r)) =>
-        Or(And(l1, r), And(l2, r))
-      case _ =>
-        e
-    }
-
-    simplePreTransform(pre)(e)
-  }
-
   def toCNF(e: Expr): Expr = {
     def pre(e: Expr) = e match {
       case Or(Seq(l, And(Seq(r1, r2)))) =>
@@ -980,7 +967,7 @@ object TreeOps {
   def decomposeIfs(e: Expr): Expr = {
     def pre(e: Expr): Expr = e match {
       case IfExpr(cond, then, elze) =>
-        val TopLevelOrs(orcases) = toDNF(cond)
+        val TopLevelOrs(orcases) = cond
 
         if (orcases.exists{ case TopLevelAnds(ands) => ands.exists(_.isInstanceOf[CaseClassInstanceOf]) } ) {
           if (!orcases.tail.isEmpty) {
diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala
index 706ed0f6c..cc313e629 100644
--- a/src/main/scala/leon/synthesis/Heuristics.scala
+++ b/src/main/scala/leon/synthesis/Heuristics.scala
@@ -8,6 +8,7 @@ import heuristics._
 object Heuristics {
   def all = Set[Rule](
     IntInduction,
+    InnerCaseSplit,
     //new OptimisticInjection(_),
     //new SelectiveInlining(_),
     ADTInduction
diff --git a/src/main/scala/leon/synthesis/heuristics/InnerCaseSplit.scala b/src/main/scala/leon/synthesis/heuristics/InnerCaseSplit.scala
new file mode 100644
index 000000000..5b50ac758
--- /dev/null
+++ b/src/main/scala/leon/synthesis/heuristics/InnerCaseSplit.scala
@@ -0,0 +1,51 @@
+package leon
+package synthesis
+package heuristics
+
+import purescala.Trees._
+import purescala.TreeOps._
+import purescala.Extractors._
+
+case object InnerCaseSplit extends Rule("Inner-Case-Split", 200) with Heuristic {
+  def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
+    p.phi match {
+      case Or(_) =>
+        // Inapplicable in this case, normal case-split has precedence here.
+        Nil
+      case _ =>
+        var phi = p.phi
+        phi match {
+          case Not(And(es)) =>
+            phi = Or(es.map(Not(_)))
+            
+          case Not(Or(es)) =>
+            phi = And(es.map(Not(_)))
+
+          case _ =>
+        }
+
+        phi match {
+          case Or(os) =>
+            List(rules.CaseSplit.split(os, p))
+
+          case And(as) =>
+            val optapp = for ((a, i) <- as.zipWithIndex) yield {
+              a match {
+                case Or(os) =>
+                  Some(rules.CaseSplit.split(os.map(o => And(as.updated(i, o))), p))
+
+                case _ =>
+                  None
+              }
+            }
+
+            optapp.flatten
+
+          case e =>
+            Nil
+        }
+    }
+  }
+
+}
+
diff --git a/src/main/scala/leon/synthesis/rules/CaseSplit.scala b/src/main/scala/leon/synthesis/rules/CaseSplit.scala
index a2627256b..2375f3375 100644
--- a/src/main/scala/leon/synthesis/rules/CaseSplit.scala
+++ b/src/main/scala/leon/synthesis/rules/CaseSplit.scala
@@ -9,19 +9,32 @@ import purescala.Extractors._
 case object CaseSplit extends Rule("Case-Split", 200) {
   def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
     p.phi match {
-      case Or(o1 :: o2 :: _) =>
-        val sub1 = Problem(p.as, p.pc, o1, p.xs)
-        val sub2 = Problem(p.as, p.pc, o2, p.xs)
+      case Or(os) =>
+        List(split(os, p))
+      case _ =>
+        Nil
+    }
+  }
+
+  def split(alts: Seq[Expr], p: Problem): RuleInstantiation = {
+    val subs = alts.map(a => Problem(p.as, p.pc, a, p.xs)).toList
 
-        val onSuccess: List[Solution] => Solution = { 
-          case List(Solution(p1, d1, t1), Solution(p2, d2, t2)) => Solution(Or(p1, p2), d1++d2, IfExpr(p1, t1, t2))
-          case _ => Solution.none
-        }
+    val onSuccess: List[Solution] => Solution = {
+      case sols if sols.size == subs.size =>
+        val pre = Or(sols.map(_.pre))
+        val defs = sols.map(_.defs).reduceLeft(_ ++ _)
+
+        val (prefix, last) = (sols.dropRight(1), sols.last)
+
+        val term = prefix.foldRight(last.term) { (s, t) => IfExpr(s.pre, s.term, t) }
+
+        Solution(pre, defs, term)
 
-        List(RuleInstantiation.immediateDecomp(List(sub1, sub2), onSuccess))
       case _ =>
-        Nil
+        Solution.none
     }
+
+    RuleInstantiation.immediateDecomp(subs, onSuccess)
   }
 }
 
diff --git a/testcases/synthesis/InnerSplit.scala b/testcases/synthesis/InnerSplit.scala
new file mode 100644
index 000000000..b584c45e4
--- /dev/null
+++ b/testcases/synthesis/InnerSplit.scala
@@ -0,0 +1,5 @@
+import leon.Utils._
+
+object Test {
+  def test(x: Int, y: Int) = choose((z: Int) => z >= x && z >= y && (z == x || z == y))
+}
-- 
GitLab