From f5fb158f1ed74cb89adc79d7bb19f743de0c587f Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Sun, 20 Jan 2013 19:38:08 +0100
Subject: [PATCH] Implement the concept of Normalizing rules

Normalizing rules are rules that:
1) always help synthesis
2) are commutative
3) should be applied as early as possible

Here we apply normalizing rules explicitly before all other rules, and
in a deterministic order. This should dramatically reduce the search
space in cases where such rules apply.

Note that rules that are said to be normalizing should never fail once
instantiated.
---
 .../scala/leon/synthesis/Heuristics.scala     |  2 +-
 .../scala/leon/synthesis/ParallelSearch.scala | 22 +++++++++++++------
 src/main/scala/leon/synthesis/Rules.scala     |  9 +++++++-
 .../scala/leon/synthesis/SimpleSearch.scala   | 20 ++++++++++++-----
 .../scala/leon/synthesis/Synthesizer.scala    |  2 +-
 .../scala/leon/synthesis/rules/Assert.scala   |  2 +-
 .../scala/leon/synthesis/rules/OnePoint.scala |  2 +-
 .../synthesis/rules/UnconstrainedOutput.scala |  2 +-
 .../leon/synthesis/rules/UnusedInput.scala    |  2 +-
 9 files changed, 44 insertions(+), 19 deletions(-)

diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala
index 95ac39a13..cb18ef781 100644
--- a/src/main/scala/leon/synthesis/Heuristics.scala
+++ b/src/main/scala/leon/synthesis/Heuristics.scala
@@ -7,7 +7,7 @@ import purescala.TypeTrees.TupleType
 import heuristics._
 
 object Heuristics {
-  def all = Set[Rule](
+  def all = List[Rule](
     IntInduction,
     InnerCaseSplit,
     //new OptimisticInjection(_),
diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala
index fb3f824c8..efe05182e 100644
--- a/src/main/scala/leon/synthesis/ParallelSearch.scala
+++ b/src/main/scala/leon/synthesis/ParallelSearch.scala
@@ -8,7 +8,7 @@ import solvers.TrivialSolver
 
 class ParallelSearch(synth: Synthesizer,
                      problem: Problem,
-                     rules: Set[Rule],
+                     rules: Seq[Rule],
                      costModel: CostModel,
                      nWorkers: Int) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel)), nWorkers) {
 
@@ -69,14 +69,22 @@ class ParallelSearch(synth: Synthesizer,
   }
 
   def expandOrTask(ref: ActorRef, sctx: SynthesisContext)(t: TaskTryRules) = {
-    val sub = rules.flatMap { r => 
-      r.instantiateOn(sctx, t.p).map(TaskRunRule(_))
-    }
+    val (normRules, otherRules) = rules.partition(_.isInstanceOf[NormalizingRule])
+
+    val normApplications = normRules.flatMap(_.instantiateOn(sctx, t.p))
 
-    if (!sub.isEmpty) {
-      Expanded(sub.toList)
+    if (!normApplications.isEmpty) {
+      Expanded(List(TaskRunRule(normApplications.head)))
     } else {
-      ExpandFailure()
+      val sub = otherRules.flatMap { r =>
+        r.instantiateOn(sctx, t.p).map(TaskRunRule(_))
+      }
+
+      if (!sub.isEmpty) {
+        Expanded(sub.toList)
+      } else {
+        ExpandFailure()
+      }
     }
   }
 }
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index 27eba7ac6..7926efd69 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -8,7 +8,7 @@ import purescala.TreeOps._
 import rules._
 
 object Rules {
-  def all = Set[Rule](
+  def all = List[Rule](
     Unification.DecompTrivialClash,
     Unification.OccursCheck, // probably useless
     Disunification.Decomp,
@@ -89,3 +89,10 @@ abstract class Rule(val name: String) {
 
   override def toString = "R: "+name
 }
+
+// Note: Rules that extend NormalizingRule should all be commutative, The will
+// be applied before others in a deterministic order and their application
+// should never fail!
+abstract class NormalizingRule(name: String) extends Rule(name) {
+  override def toString = "N: "+name
+}
diff --git a/src/main/scala/leon/synthesis/SimpleSearch.scala b/src/main/scala/leon/synthesis/SimpleSearch.scala
index 2b329ef70..32258cedd 100644
--- a/src/main/scala/leon/synthesis/SimpleSearch.scala
+++ b/src/main/scala/leon/synthesis/SimpleSearch.scala
@@ -32,7 +32,7 @@ case class SearchCostModel(cm: CostModel) extends AOCostModel[TaskRunRule, TaskT
 
 class SimpleSearch(synth: Synthesizer,
                    problem: Problem,
-                   rules: Set[Rule],
+                   rules: Seq[Rule],
                    costModel: CostModel) extends AndOrGraphSearch[TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel))) {
 
   import synth.reporter._
@@ -64,12 +64,22 @@ class SimpleSearch(synth: Synthesizer,
   }
 
   def expandOrTask(t: TaskTryRules): ExpandResult[TaskRunRule] = {
-    val sub = rules.flatMap ( r => r.instantiateOn(sctx, t.p).map(TaskRunRule(_)) )
+    val (normRules, otherRules) = rules.partition(_.isInstanceOf[NormalizingRule])
 
-    if (!sub.isEmpty) {
-      Expanded(sub.toList)
+    val normApplications = normRules.flatMap(_.instantiateOn(sctx, t.p))
+
+    if (!normApplications.isEmpty) {
+      Expanded(List(TaskRunRule(normApplications.head)))
     } else {
-      ExpandFailure()
+      val sub = otherRules.flatMap { r =>
+        r.instantiateOn(sctx, t.p).map(TaskRunRule(_))
+      }
+
+      if (!sub.isEmpty) {
+        Expanded(sub.toList)
+      } else {
+        ExpandFailure()
+      }
     }
   }
 
diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala
index 0a402e5db..624b6a178 100644
--- a/src/main/scala/leon/synthesis/Synthesizer.scala
+++ b/src/main/scala/leon/synthesis/Synthesizer.scala
@@ -22,7 +22,7 @@ class Synthesizer(val context : LeonContext,
                   val solver: Solver,
                   val program: Program,
                   val problem: Problem,
-                  val rules: Set[Rule],
+                  val rules: Seq[Rule],
                   val options: SynthesizerOptions) {
 
   protected[synthesis] val reporter = context.reporter
diff --git a/src/main/scala/leon/synthesis/rules/Assert.scala b/src/main/scala/leon/synthesis/rules/Assert.scala
index 898f1ea05..cab8ea72c 100644
--- a/src/main/scala/leon/synthesis/rules/Assert.scala
+++ b/src/main/scala/leon/synthesis/rules/Assert.scala
@@ -6,7 +6,7 @@ import purescala.Trees._
 import purescala.TreeOps._
 import purescala.Extractors._
 
-case object Assert extends Rule("Assert") {
+case object Assert extends NormalizingRule("Assert") {
   def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
     p.phi match {
       case TopLevelAnds(exprs) =>
diff --git a/src/main/scala/leon/synthesis/rules/OnePoint.scala b/src/main/scala/leon/synthesis/rules/OnePoint.scala
index 42c931f59..d245e7bb2 100644
--- a/src/main/scala/leon/synthesis/rules/OnePoint.scala
+++ b/src/main/scala/leon/synthesis/rules/OnePoint.scala
@@ -6,7 +6,7 @@ import purescala.Trees._
 import purescala.TreeOps._
 import purescala.Extractors._
 
-case object OnePoint extends Rule("One-point") {
+case object OnePoint extends NormalizingRule("One-point") {
   def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
     val TopLevelAnds(exprs) = p.phi
 
diff --git a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala b/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala
index c5ddd02ef..2ae957da3 100644
--- a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala
+++ b/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala
@@ -6,7 +6,7 @@ import purescala.Trees._
 import purescala.TreeOps._
 import purescala.Extractors._
 
-case object UnconstrainedOutput extends Rule("Unconstr.Output") {
+case object UnconstrainedOutput extends NormalizingRule("Unconstr.Output") {
   def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
     val unconstr = p.xs.toSet -- variablesOf(p.phi)
 
diff --git a/src/main/scala/leon/synthesis/rules/UnusedInput.scala b/src/main/scala/leon/synthesis/rules/UnusedInput.scala
index b44ad3411..b97ff9033 100644
--- a/src/main/scala/leon/synthesis/rules/UnusedInput.scala
+++ b/src/main/scala/leon/synthesis/rules/UnusedInput.scala
@@ -6,7 +6,7 @@ import purescala.Trees._
 import purescala.TreeOps._
 import purescala.Extractors._
 
-case object UnusedInput extends Rule("UnusedInput") {
+case object UnusedInput extends NormalizingRule("UnusedInput") {
   def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
     val unused = p.as.toSet -- variablesOf(p.phi) -- variablesOf(p.pc)
 
-- 
GitLab