From 6fde235e65655f9904eee097d00d65050cf4aa60 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Fri, 9 Nov 2012 16:21:30 +0100
Subject: [PATCH] Add heuristic exploiting an optimistic assumption about
 injective nature of functions.

foo(x) == expr && foo(y) == expr  --> x == y
---
 .../scala/leon/synthesis/Heuristics.scala     | 52 +++++++++++++++++--
 src/main/scala/leon/synthesis/Rules.scala     |  2 +-
 .../scala/leon/synthesis/Synthesizer.scala    |  9 ----
 src/main/scala/leon/synthesis/Task.scala      | 12 +++++
 4 files changed, 62 insertions(+), 13 deletions(-)

diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala
index 6802fe40d..4d32f4738 100644
--- a/src/main/scala/leon/synthesis/Heuristics.scala
+++ b/src/main/scala/leon/synthesis/Heuristics.scala
@@ -11,11 +11,18 @@ import purescala.Definitions._
 object Heuristics {
   def all(synth: Synthesizer) = Set(
     new OptimisticGround(synth),
-    new IntInduction(synth)
+    new IntInduction(synth),
+    new OptimisticInjection(synth)
   )
 }
 
-class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", synth, 9) {
+trait Heuristic {
+  this: Rule =>
+
+  override def toString = "H: "+name
+}
+
+class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", synth, 9) with Heuristic {
   def applyOn(task: Task): RuleResult = {
     val p = task.problem
 
@@ -72,7 +79,7 @@ class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", syn
 }
 
 
-class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 8) {
+class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 8) with Heuristic {
   def applyOn(task: Task): RuleResult = {
     val p = task.problem
 
@@ -116,3 +123,42 @@ class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 8) {
     }
   }
 }
+
+class OptimisticInjection(synth: Synthesizer) extends Rule("Opt. Injection", synth, 5) with Heuristic {
+  def applyOn(task: Task): RuleResult = {
+    val p = task.problem
+
+    val TopLevelAnds(exprs) = p.phi
+
+    val eqfuncalls = exprs.collect{
+      case eq @ Equals(FunctionInvocation(fd, args), e) =>
+        ((fd, e), args, eq : Expr)
+      case eq @ Equals(e, FunctionInvocation(fd, args)) =>
+        ((fd, e), args, eq : Expr)
+    }
+
+    val candidates = eqfuncalls.groupBy(_._1).filter(_._2.size > 1)
+    if (!candidates.isEmpty) {
+
+      var newExprs = exprs
+      for (cands <- candidates.values) {
+        val cand = cands.take(2)
+        val toRemove = cand.map(_._3).toSet
+        val argss    = cand.map(_._2)
+        val args     = argss(0) zip argss(1)
+
+        newExprs ++= args.map{ case (l, r) => Equals(l, r) }
+
+
+
+        newExprs = newExprs.filterNot(toRemove)
+      }
+
+      val sub = p.copy(phi = And(newExprs))
+
+      RuleDecomposed(List(sub), forward)
+    } else {
+      RuleInapplicable
+    }
+  }
+}
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index ba7be35c7..ed3836e2e 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -39,7 +39,7 @@ abstract class Rule(val name: String, val synth: Synthesizer, val priority: Prio
     case _ => Solution.none
   }
 
-  override def toString = name
+  override def toString = "R: "+name
 }
 
 class OnePoint(synth: Synthesizer) extends Rule("One-point", synth, 30) {
diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala
index f8f9f5686..b493a50de 100644
--- a/src/main/scala/leon/synthesis/Synthesizer.scala
+++ b/src/main/scala/leon/synthesis/Synthesizer.scala
@@ -54,15 +54,6 @@ class Synthesizer(val r: Reporter,
       // Check if solving this task has the slightest chance of improving the
       // current solution
       if (task.minComplexity < bestSolutionSoFar().complexity) {
-        if (!subProblems.isEmpty) {
-          val prefix = "[%-20s] ".format(Option(task.rule).map(_.name).getOrElse("root"))
-          println(prefix+"Got: "+task.problem)
-          println(prefix+"Decomposed into:")
-          for(p <- subProblems) {
-            println(prefix+" - "+p)
-          }
-        }
-
         for (p <- subProblems; r <- rules) yield {
           workList += new Task(this, task, p, r)
         }
diff --git a/src/main/scala/leon/synthesis/Task.scala b/src/main/scala/leon/synthesis/Task.scala
index 00c8ecbc0..e8eba4143 100644
--- a/src/main/scala/leon/synthesis/Task.scala
+++ b/src/main/scala/leon/synthesis/Task.scala
@@ -58,6 +58,11 @@ class Task(synth: Synthesizer,
         // Solved
         this.solution = Some(solution)
         parent.partlySolvedBy(this, solution)
+
+        val prefix = "[%-20s] ".format(Option(rule).map(_.toString).getOrElse("root"))
+        println(prefix+"Got: "+problem)
+        println(prefix+"Solved with: "+solution)
+
         Nil
 
       case RuleDecomposed(subProblems, onSuccess) =>
@@ -67,6 +72,13 @@ class Task(synth: Synthesizer,
         val simplestSolution = onSuccess(subProblems.map(Solution.basic _))
         minComplexity = new FixedSolComplexity(parent.minComplexity.value + simplestSolution.complexity.value)
 
+        val prefix = "[%-20s] ".format(Option(rule).map(_.toString).getOrElse("root"))
+        println(prefix+"Got: "+problem)
+        println(prefix+"Decomposed into:")
+        for(p <- subProblems) {
+          println(prefix+" - "+p)
+        }
+
         subProblems
 
       case RuleInapplicable =>
-- 
GitLab