diff --git a/src/main/scala/leon/synthesis/TestBank.scala b/src/main/scala/leon/synthesis/TestBank.scala index b4067cb2e36cc547152887a001e9c20502676586..b5ee5790c976400f0646cdc38efba981797e2ed7 100644 --- a/src/main/scala/leon/synthesis/TestBank.scala +++ b/src/main/scala/leon/synthesis/TestBank.scala @@ -61,7 +61,7 @@ case class TestBank(valids: Seq[Example], invalids: Seq[Example]) { def mapIns(f: Seq[Expr] => List[Seq[Expr]]) = { map { case InExample(in) => - f(in).map(InExample(_)) + f(in).map(InExample) case InOutExample(in, out) => f(in).map(InOutExample(_, out)) @@ -155,17 +155,19 @@ case class ProblemTestBank(p: Problem, tb: TestBank)(implicit hctx: SearchContex tb mapIns { in => List(toKeep.map(in)) } } - def filterIns(expr: Expr) = { + def filterIns(expr: Expr): TestBank = { val ev = new DefaultEvaluator(hctx.sctx.context, hctx.sctx.program) + filterIns(m => ev.eval(expr, m).result == Some(BooleanLiteral(true))) + } + + def filterIns(pred: Map[Identifier, Expr] => Boolean): TestBank = { tb mapIns { in => val m = (p.as zip in).toMap - - ev.eval(expr, m) match { - case EvaluationResults.Successful(BooleanLiteral(true)) => - List(in) - case _ => - Nil + if(pred(m)) { + List(in) + } else { + Nil } } } diff --git a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala index 784ad96acaca2c59d65fe8c4090028b7e402dbda..44043b0e095b86964af85c5d43a57821e4f872ab 100644 --- a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala @@ -4,6 +4,7 @@ package leon package synthesis package rules +import leon.purescala.Common.Identifier import purescala.Expressions._ import purescala.Constructors._ @@ -40,8 +41,14 @@ case object EqualitySplit extends Rule("Eq. Split") { candidates.flatMap { case List(a1, a2) => - val sub1 = p.copy(pc = and(Equals(Variable(a1), Variable(a2)), p.pc)) - val sub2 = p.copy(pc = and(not(Equals(Variable(a1), Variable(a2))), p.pc)) + val sub1 = p.copy( + pc = and(Equals(Variable(a1), Variable(a2)), p.pc), + tb = p.tbOps.filterIns( (m: Map[Identifier, Expr]) => m(a1) == m(a2)) + ) + val sub2 = p.copy( + pc = and(not(Equals(Variable(a1), Variable(a2))), p.pc), + tb = p.tbOps.filterIns( (m: Map[Identifier, Expr]) => m(a1) != m(a2)) + ) val onSuccess: List[Solution] => Option[Solution] = { case List(s1, s2) =>