From 4128518c4aba17913addbecf37680c845d78b03a Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Wed, 1 Jul 2015 15:24:15 +0200
Subject: [PATCH] Filter tests in EqSplit

---
 src/main/scala/leon/synthesis/TestBank.scala   | 18 ++++++++++--------
 .../leon/synthesis/rules/EqualitySplit.scala   | 11 +++++++++--
 2 files changed, 19 insertions(+), 10 deletions(-)

diff --git a/src/main/scala/leon/synthesis/TestBank.scala b/src/main/scala/leon/synthesis/TestBank.scala
index b4067cb2e..b5ee5790c 100644
--- a/src/main/scala/leon/synthesis/TestBank.scala
+++ b/src/main/scala/leon/synthesis/TestBank.scala
@@ -61,7 +61,7 @@ case class TestBank(valids: Seq[Example], invalids: Seq[Example]) {
   def mapIns(f: Seq[Expr] => List[Seq[Expr]]) = {
     map {
       case InExample(in) =>
-        f(in).map(InExample(_))
+        f(in).map(InExample)
 
       case InOutExample(in, out) =>
         f(in).map(InOutExample(_, out))
@@ -155,17 +155,19 @@ case class ProblemTestBank(p: Problem, tb: TestBank)(implicit hctx: SearchContex
     tb mapIns { in => List(toKeep.map(in)) }
   }
 
-  def filterIns(expr: Expr) = {
+  def filterIns(expr: Expr): TestBank = {
     val ev = new DefaultEvaluator(hctx.sctx.context, hctx.sctx.program)
 
+    filterIns(m => ev.eval(expr, m).result == Some(BooleanLiteral(true)))
+  }
+
+  def filterIns(pred: Map[Identifier, Expr] => Boolean): TestBank = {
     tb mapIns { in =>
       val m = (p.as zip in).toMap
-
-      ev.eval(expr, m) match {
-        case EvaluationResults.Successful(BooleanLiteral(true)) =>
-          List(in)
-        case _ =>
-          Nil
+      if(pred(m)) {
+        List(in)
+      } else {
+        Nil
       }
     }
   }
diff --git a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala
index 784ad96ac..44043b0e0 100644
--- a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala
+++ b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala
@@ -4,6 +4,7 @@ package leon
 package synthesis
 package rules
 
+import leon.purescala.Common.Identifier
 import purescala.Expressions._
 import purescala.Constructors._
 
@@ -40,8 +41,14 @@ case object EqualitySplit extends Rule("Eq. Split") {
     candidates.flatMap {
       case List(a1, a2) =>
 
-        val sub1 = p.copy(pc = and(Equals(Variable(a1), Variable(a2)), p.pc))
-        val sub2 = p.copy(pc = and(not(Equals(Variable(a1), Variable(a2))), p.pc))
+        val sub1 = p.copy(
+          pc = and(Equals(Variable(a1), Variable(a2)), p.pc),
+          tb = p.tbOps.filterIns( (m: Map[Identifier, Expr]) => m(a1) == m(a2))
+        )
+        val sub2 = p.copy(
+          pc = and(not(Equals(Variable(a1), Variable(a2))), p.pc),
+          tb = p.tbOps.filterIns( (m: Map[Identifier, Expr]) => m(a1) != m(a2))
+        )
 
         val onSuccess: List[Solution] => Option[Solution] = {
           case List(s1, s2) =>
-- 
GitLab