From 97467ee2fd689b439487de286fc524342893389c Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Wed, 14 Nov 2012 20:50:10 +0100
Subject: [PATCH] Fix EQ Split

---
 src/main/scala/leon/purescala/Trees.scala   |  7 ++++
 src/main/scala/leon/synthesis/Problem.scala |  2 +-
 src/main/scala/leon/synthesis/Rules.scala   | 41 +++++++++++++++++----
 3 files changed, 41 insertions(+), 9 deletions(-)

diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index edf082730..7746046cd 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -274,6 +274,13 @@ object Trees {
 
   class Equals(val left: Expr, val right: Expr) extends Expr with FixedType {
     val fixedType = BooleanType
+
+    override def equals(that: Any): Boolean = (that != null) && (that match {
+      case t: Equals => t.left == left && t.right == right
+      case _ => false
+    })
+
+    override def hashCode: Int = left.hashCode+right.hashCode
   }
   
   case class Variable(id: Identifier) extends Expr with Terminal {
diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala
index 52443ab53..2f85a916f 100644
--- a/src/main/scala/leon/synthesis/Problem.scala
+++ b/src/main/scala/leon/synthesis/Problem.scala
@@ -7,7 +7,7 @@ import leon.purescala.Common._
 // Defines a synthesis triple of the form:
 // ⟦ as ⟨ C | phi ⟩ xs ⟧
 case class Problem(as: List[Identifier], c: Expr, phi: Expr, xs: List[Identifier]) {
-  override def toString = "⟦ "+as.mkString(";")+", "+c+" ==> ⟨ "+phi+" ⟩ "+xs.mkString(";")+" ⟧ "
+  override def toString = "⟦ "+as.mkString(";")+", "+c+" ᚒ  ⟨ "+phi+" ⟩ "+xs.mkString(";")+" ⟧ "
 
   val complexity: ProblemComplexity = ProblemComplexity(this)
 }
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index 216da2afb..78a8546a2 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -521,17 +521,42 @@ class EqualitySplit(synth: Synthesizer) extends Rule("Eq. Split.", synth, 90) {
   def applyOn(task: Task): RuleResult = {
     val p = task.problem
 
-    val asgroups = p.as.groupBy(_.getType).filter(_._2.size == 2).mapValues(_.toList)
+    val TopLevelAnds(presSeq) = p.c
+    val pres = presSeq.toSet
+
+    def combinations(a1: Identifier, a2: Identifier): Set[Expr] = {
+      val v1 = Variable(a1)
+      val v2 = Variable(a2)
+      Set(
+        Equals(v1, v2),
+        Equals(v2, v1),
+        Not(Equals(v1, v2)),
+        Not(Equals(v2, v1))
+      )
+    }
 
-    val extraConds = for (List(a1, a2) <- asgroups.values) yield {
-      Or(Equals(Variable(a1), Variable(a2)), Not(Equals(Variable(a1), Variable(a2))))
+    val candidate = p.as.groupBy(_.getType).map(_._2.toList).find{
+      case List(a1, a2) => (pres & combinations(a1, a2)).isEmpty
+      case _ => false
     }
 
-    if (!extraConds.isEmpty) {
-      val sub = p.copy(phi = And(And(extraConds.toSeq), p.phi))
-      RuleStep(List(sub), forward)
-    } else {
-      RuleInapplicable
+
+    candidate match {
+      case Some(List(a1, a2)) =>
+
+        val sub1 = p.copy(c = And(Equals(Variable(a1), Variable(a2)), p.c))
+        val sub2 = p.copy(c = And(Not(Equals(Variable(a1), Variable(a2))), p.c))
+
+        val onSuccess: List[Solution] => Solution = { 
+          case List(s1, s2) =>
+            Solution(Or(s1.pre, s2.pre), s1.defs++s2.defs, IfExpr(Equals(Variable(a1), Variable(a2)), s1.term, s2.term))
+          case _ =>
+            Solution.none
+        }
+
+        RuleStep(List(sub1, sub2), onSuccess)
+      case _ =>
+        RuleInapplicable
     }
   }
 }
-- 
GitLab