From 97467ee2fd689b439487de286fc524342893389c Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Wed, 14 Nov 2012 20:50:10 +0100 Subject: [PATCH] Fix EQ Split --- src/main/scala/leon/purescala/Trees.scala | 7 ++++ src/main/scala/leon/synthesis/Problem.scala | 2 +- src/main/scala/leon/synthesis/Rules.scala | 41 +++++++++++++++++---- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index edf082730..7746046cd 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -274,6 +274,13 @@ object Trees { class Equals(val left: Expr, val right: Expr) extends Expr with FixedType { val fixedType = BooleanType + + override def equals(that: Any): Boolean = (that != null) && (that match { + case t: Equals => t.left == left && t.right == right + case _ => false + }) + + override def hashCode: Int = left.hashCode+right.hashCode } case class Variable(id: Identifier) extends Expr with Terminal { diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index 52443ab53..2f85a916f 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -7,7 +7,7 @@ import leon.purescala.Common._ // Defines a synthesis triple of the form: // ⟦ as ⟨ C | phi ⟩ xs ⟧ case class Problem(as: List[Identifier], c: Expr, phi: Expr, xs: List[Identifier]) { - override def toString = "⟦ "+as.mkString(";")+", "+c+" ==> ⟨ "+phi+" ⟩ "+xs.mkString(";")+" ⟧ " + override def toString = "⟦ "+as.mkString(";")+", "+c+" ᚒ ⟨ "+phi+" ⟩ "+xs.mkString(";")+" ⟧ " val complexity: ProblemComplexity = ProblemComplexity(this) } diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 216da2afb..78a8546a2 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -521,17 +521,42 @@ class EqualitySplit(synth: Synthesizer) extends Rule("Eq. Split.", synth, 90) { def applyOn(task: Task): RuleResult = { val p = task.problem - val asgroups = p.as.groupBy(_.getType).filter(_._2.size == 2).mapValues(_.toList) + val TopLevelAnds(presSeq) = p.c + val pres = presSeq.toSet + + def combinations(a1: Identifier, a2: Identifier): Set[Expr] = { + val v1 = Variable(a1) + val v2 = Variable(a2) + Set( + Equals(v1, v2), + Equals(v2, v1), + Not(Equals(v1, v2)), + Not(Equals(v2, v1)) + ) + } - val extraConds = for (List(a1, a2) <- asgroups.values) yield { - Or(Equals(Variable(a1), Variable(a2)), Not(Equals(Variable(a1), Variable(a2)))) + val candidate = p.as.groupBy(_.getType).map(_._2.toList).find{ + case List(a1, a2) => (pres & combinations(a1, a2)).isEmpty + case _ => false } - if (!extraConds.isEmpty) { - val sub = p.copy(phi = And(And(extraConds.toSeq), p.phi)) - RuleStep(List(sub), forward) - } else { - RuleInapplicable + + candidate match { + case Some(List(a1, a2)) => + + val sub1 = p.copy(c = And(Equals(Variable(a1), Variable(a2)), p.c)) + val sub2 = p.copy(c = And(Not(Equals(Variable(a1), Variable(a2))), p.c)) + + val onSuccess: List[Solution] => Solution = { + case List(s1, s2) => + Solution(Or(s1.pre, s2.pre), s1.defs++s2.defs, IfExpr(Equals(Variable(a1), Variable(a2)), s1.term, s2.term)) + case _ => + Solution.none + } + + RuleStep(List(sub1, sub2), onSuccess) + case _ => + RuleInapplicable } } } -- GitLab