From f751527f2a080ff0355eab97692437360ea6a1dc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Wed, 14 Nov 2012 02:15:10 +0000
Subject: [PATCH] integrate equation synthesis as a rule

---
 .../synthesis/ArithmeticNormalization.scala   |  3 +-
 .../leon/synthesis/IntegerSynthesis.scala     | 14 +++--
 .../leon/synthesis/LinearEquations.scala      | 22 +++++---
 src/main/scala/leon/synthesis/Rules.scala     | 54 ++++++++++++++++++-
 .../ArithmeticNormalizationSuite.scala        |  5 ++
 .../test/synthesis/LinearEquationsSuite.scala |  9 ++++
 6 files changed, 95 insertions(+), 12 deletions(-)

diff --git a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala b/src/main/scala/leon/synthesis/ArithmeticNormalization.scala
index aacfa2051..ba22b8f0d 100644
--- a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala
+++ b/src/main/scala/leon/synthesis/ArithmeticNormalization.scala
@@ -71,10 +71,11 @@ object ArithmeticNormalization {
   def expand(expr: Expr): Seq[Expr] = expr match {
     case Plus(es1, es2) => expand(es1) ++ expand(es2)
     case Minus(e1, e2) => expand(e1) ++ expand(e2).map(Times(IntLiteral(-1), _): Expr)
+    case UMinus(e) => expand(e).map(Times(IntLiteral(-1), _): Expr)
     case Times(es1, es2) => multiply(expand(es1), expand(es2))
     case v@Variable(_) => Seq(v)
     case n@IntLiteral(_) => Seq(n)
-    case _ => sys.error("Unexpected")
+    case err => sys.error("Unexpected in expand: " + err)
   }
 
   //simple, local simplifications
diff --git a/src/main/scala/leon/synthesis/IntegerSynthesis.scala b/src/main/scala/leon/synthesis/IntegerSynthesis.scala
index 5ed2abfa3..a1fac69fc 100644
--- a/src/main/scala/leon/synthesis/IntegerSynthesis.scala
+++ b/src/main/scala/leon/synthesis/IntegerSynthesis.scala
@@ -14,13 +14,20 @@ object IntegerSynthesis {
     //  case (None, f) => (f, xs.map(id => (Variable(id))))
     //  case (Some(eq), f) => {
     //    val vars: Set[Identifier] = variablesOf(eq)
-    //    val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList
-    //    val ys: Set[Identifier] = xs.toSet.difference(vars).toList
     //    val eqas: Set[Identifier] = as.intersect(vars)
+
+    //    val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList
+    //    val ys: Set[Identifier] = xs.toSet.diff(vars)
+
     //    val normalizedEq: List[Expr] = ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList
     //    val (eqPre, eqWitness, eqFreshVars) = elimVariable(eqas, normalizedEq)
+
     //    val eqSubstMap: Map[Expr, Expr] = eqxs.zip(eqWitness).map{case (id, e) => (Variable(id), e)}.toMap
     //    val freshFormula = simplify(replace(eqSubstMap, f))
+    //    (eqPre, freshFormula)
+
+
+    //    /*
     //    val (recPre, recSubst) = apply(as, ys ++ eqFreshVars, freshFormula)
 
     //    val freshPre = simplify(replace(
@@ -29,7 +36,8 @@ object IntegerSynthesis {
     //      }.toMap,
     //      eqPre))
 
-    //    (And(freshPre, recPre), 
+    //    (And(freshPre, recPre), recSubst)
+    //    */
 
     //  }
     //}
diff --git a/src/main/scala/leon/synthesis/LinearEquations.scala b/src/main/scala/leon/synthesis/LinearEquations.scala
index 6774c99f9..2f7ed20de 100644
--- a/src/main/scala/leon/synthesis/LinearEquations.scala
+++ b/src/main/scala/leon/synthesis/LinearEquations.scala
@@ -1,6 +1,7 @@
 package leon.synthesis
 
 import leon.purescala.Trees._
+import leon.purescala.TypeTrees._
 import leon.purescala.Common._
 import leon.Evaluator 
 
@@ -9,18 +10,25 @@ object LinearEquations {
   //eliminate one variable from normalizedEquation t + a1*x1 + ... + an*xn = 0
   //return a mapping for each of the n variables in (pre, map, freshVars)
   def elimVariable(as: Set[Identifier], normalizedEquation: List[Expr]): (Expr, List[Expr], List[Identifier]) = {
+    println("elim in normalized: " + normalizedEquation)
     val t: Expr = normalizedEquation.head
     val coefsVars: List[Int] = normalizedEquation.tail.map{case IntLiteral(i) => i}
     val orderedParams: Array[Identifier] = as.toArray
-    val coefsParams: List[Int] = ArithmeticNormalization(t, orderedParams).map{case IntLiteral(i) => i}.toList
-    val d: Int = GCD.gcd((coefsParams.tail ++ coefsVars).toSeq)
-
-    if(d > 1) 
-      elimVariable(as, normalizedEquation.map(Division(_, IntLiteral(d)))) 
-    else {
+    val coefsParams0: List[Int] = ArithmeticNormalization(t, orderedParams).map{case IntLiteral(i) => i}.toList
+    val coefsParams: List[Int] = if(coefsParams0.head == 0) coefsParams0.tail else coefsParams0
+    val d: Int = GCD.gcd((coefsParams ++ coefsVars).toSeq)
+
+    if(coefsVars.size == 1) {
+      val coef = coefsVars.head
+      (Equals(Modulo(t, IntLiteral(coef)), IntLiteral(0)), List(UMinus(Division(t, IntLiteral(coef)))), List())
+    } else if(d > 1) {
+      val newCoefsParams: List[Expr] = coefsParams.map(i => IntLiteral(i/d) : Expr)
+      val newT = newCoefsParams.zip(IntLiteral(1)::orderedParams.map(Variable(_)).toList).foldLeft[Expr](IntLiteral(0))((acc, p) => Plus(acc, Times(p._1, p._2)))
+      elimVariable(as, newT :: normalizedEquation.tail.map{case IntLiteral(i) => IntLiteral(i/d) : Expr})
+    } else {
       val basis: Array[Array[Int]]  = linearSet(as, normalizedEquation.tail.map{case IntLiteral(i) => i}.toArray)
       val (pre, sol) = particularSolution(as, normalizedEquation)
-      val freshVars: Array[Identifier] = basis(0).map(_ => FreshIdentifier("v", true))
+      val freshVars: Array[Identifier] = basis(0).map(_ => FreshIdentifier("v", true).setType(Int32Type))
 
       val tbasis = basis.transpose
       assert(freshVars.size == tbasis.size)
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index b79ade73c..519a8b3cc 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -8,6 +8,8 @@ import purescala.Extractors._
 import purescala.TreeOps._
 import purescala.TypeTrees._
 import purescala.Definitions._
+import LinearEquations.elimVariable
+import ArithmeticNormalization.simplify
 
 object Rules {
   def all = Set[Synthesizer => Rule](
@@ -21,7 +23,8 @@ object Rules {
     new UnconstrainedOutput(_),
     new OptimisticGround(_),
     new CEGIS(_),
-    new Assert(_)
+    new Assert(_),
+    new IntegerEquation(_)
   )
 }
 
@@ -512,3 +515,52 @@ class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", syn
     }
   }
 }
+
+class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth, 300) {
+  def applyOn(task: Task): RuleResult = {
+
+    val p = task.problem
+
+    val TopLevelAnds(exprs) = p.phi
+    val xs = p.xs
+    val as = p.as
+    val formula = p.phi
+
+    val (eqs, others) = exprs.partition(_.isInstanceOf[Equals])
+
+    if (!eqs.isEmpty) {
+
+      val (eq@Equals(_,_), rest) = (eqs.head, eqs.tail)
+      val allOthers = rest ++ others
+      
+      val vars: Set[Identifier] = variablesOf(eq)
+      val eqas: Set[Identifier] = as.toSet.intersect(vars)
+
+      val eqxs: List[Identifier] = xs.toSet.intersect(vars).toList
+      val ys: Set[Identifier] = xs.toSet.diff(vars)
+
+      val normalizedEq: List[Expr] = ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList
+      val (eqPre, eqWitness, eqFreshVars) = elimVariable(eqas, normalizedEq)
+
+      val eqSubstMap: Map[Expr, Expr] = eqxs.zip(eqWitness).map{case (id, e) => (Variable(id), simplify(e))}.toMap
+      val freshFormula = simplify(replace(eqSubstMap, And(allOthers)))
+      (eqPre, freshFormula)
+
+      val newProblem = Problem(as, And(eqPre, p.c), freshFormula, eqFreshVars)
+
+      val onSuccess: List[Solution] => Solution = { 
+        case List(Solution(pre, defs, term)) =>
+          if (eqFreshVars.isEmpty) {
+            Solution(pre, defs, replace(eqSubstMap, Tuple(eqxs.map(Variable(_)))))
+          } else {
+            Solution(pre, defs, LetTuple(eqFreshVars, term, replace(eqSubstMap, Tuple(eqxs.map(Variable(_))))))
+          }
+        case _ => Solution.none
+      }
+
+      RuleStep(List(newProblem), onSuccess)
+    } else {
+      RuleInapplicable
+    }
+  }
+}
diff --git a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala b/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala
index a2dc9ec68..992e668f7 100644
--- a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala
+++ b/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala
@@ -60,6 +60,11 @@ class ArithmeticNormalizationSuite extends FunSuite {
     val e2 = Times(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1)))
     checkSameExpr(toSum(expand(e2)), e2, xs)
 
+    val e3 = Minus(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1)))
+    checkSameExpr(toSum(expand(e3)), e3, xs)
+
+    val e4 = UMinus(Plus(x, Times(i(2), y)))
+    checkSameExpr(toSum(expand(e4)), e4, xs)
   }
 
   test("apply") {
diff --git a/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala b/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala
index fe507bea3..f84c3d3b8 100644
--- a/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala
+++ b/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala
@@ -199,6 +199,15 @@ class LinearEquationsSuite extends FunSuite {
     val c2 = List(IntLiteral(1), IntLiteral(-1))
     val (pre2, wit2, f2) = elimVariable(Set(), t2::c2)
 
+
+    val t3 = Minus(Times(IntLiteral(2), a), IntLiteral(3))
+    val c3 = List(IntLiteral(2))
+    val (pre3, wit3, f3) = elimVariable(Set(aId), t3::c3)
+
+    val t4 = Times(IntLiteral(2), a)
+    val c4 = List(IntLiteral(2), IntLiteral(4))
+    val (pre4, wit4, f4) = elimVariable(Set(aId), t4::c4)
+
   }
 
 }
-- 
GitLab