From 40eee30e0c787fef3808d3dc3c7e8a2aaa041cf5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Wed, 21 Nov 2012 03:33:21 +0000
Subject: [PATCH] refactor ArithmeticNormalization into TreeNormalizations

---
 .../TreeNormalizations.scala}                 | 59 +++++++++++--------
 .../leon/synthesis/LinearEquations.scala      |  5 +-
 .../synthesis/rules/IntegerEquation.scala     |  5 +-
 .../synthesis/rules/IntegerInequalities.scala |  3 +-
 .../TreeNormalizationsTests.scala}            | 25 ++++----
 5 files changed, 53 insertions(+), 44 deletions(-)
 rename src/main/scala/leon/{synthesis/ArithmeticNormalization.scala => purescala/TreeNormalizations.scala} (63%)
 rename src/test/scala/leon/test/{synthesis/ArithmeticNormalizationSuite.scala => purescala/TreeNormalizationsTests.scala} (81%)

diff --git a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala b/src/main/scala/leon/purescala/TreeNormalizations.scala
similarity index 63%
rename from src/main/scala/leon/synthesis/ArithmeticNormalization.scala
rename to src/main/scala/leon/purescala/TreeNormalizations.scala
index 6c148e2cc..0a805a8af 100644
--- a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala
+++ b/src/main/scala/leon/purescala/TreeNormalizations.scala
@@ -1,17 +1,40 @@
-package leon.synthesis
+package leon
+package purescala
 
-import leon.purescala.Trees._
-import leon.purescala.TreeOps._
-import leon.purescala.Common._
+object TreeNormalizations {
+  import Common._
+  import TypeTrees._
+  import Definitions._
+  import Trees._
+  import TreeOps._
+  import Extractors._
 
-object ArithmeticNormalization {
+  /* TODO: we should add CNF and DNF at least */
 
   case class NonLinearExpressionException(msg: String) extends Exception
 
   //assume the function is an arithmetic expression, not a relation
   //return a normal form where the [t a1 ... an] where
   //expr = t + a1*x1 + ... + an*xn and xs = [x1 ... xn]
-  def apply(expr: Expr, xs: Array[Identifier]): Array[Expr] = {
+  //do not keep the evaluation order
+  def linearArithmeticForm(expr: Expr, xs: Array[Identifier]): Array[Expr] = {
+
+    //assume the expr is a literal (mult of constants and variables) with degree one
+    def extractCoef(e: Expr): (Expr, Identifier) = {
+      var id: Option[Identifier] = None
+      var coef = 1
+
+      def rec(e: Expr): Unit = e match {
+        case IntLiteral(i) => coef = coef*i
+        case Variable(id2) => if(id.isEmpty) id = Some(id2) else throw NonLinearExpressionException("multiple variable")
+        case Times(e1, e2) => rec(e1); rec(e2)
+      }
+
+      rec(e)
+      assert(!id.isEmpty)
+      (IntLiteral(coef), id.get)
+    }
+
 
     def containsId(e: Expr, id: Identifier): Boolean = e match {
       case Times(e1, e2) => containsId(e1, id) || containsId(e2, id)
@@ -43,29 +66,14 @@ object ArithmeticNormalization {
     res
   }
 
-
-  //assume the expr is a literal (mult of constants and variables) with degree one
-  def extractCoef(e: Expr): (Expr, Identifier) = {
-    var id: Option[Identifier] = None
-    var coef = 1
-
-    def rec(e: Expr): Unit = e match {
-      case IntLiteral(i) => coef = coef*i
-      case Variable(id2) => if(id.isEmpty) id = Some(id2) else throw NonLinearExpressionException("multiple variable")
-      case Times(e1, e2) => rec(e1); rec(e2)
-    }
-
-    rec(e)
-    assert(!id.isEmpty)
-    (IntLiteral(coef), id.get)
-  }
-
-  //multiply two sums together and distribute in a bigger sum
+  //multiply two sums together and distribute in a larger sum
+  //do not keep the evaluation order
   def multiply(es1: Seq[Expr], es2: Seq[Expr]): Seq[Expr] = {
     es1.flatMap(e1 => es2.map(e2 => Times(e1, e2)))
   }
 
-
+  //expand the expr in a sum of "atoms"
+  //do not keep the evaluation order
   def expand(expr: Expr): Seq[Expr] = expr match {
     case Plus(es1, es2) => expand(es1) ++ expand(es2)
     case Minus(e1, e2) => expand(e1) ++ expand(e2).map(Times(IntLiteral(-1), _): Expr)
@@ -76,5 +84,4 @@ object ArithmeticNormalization {
     case err => throw NonLinearExpressionException("unexpected in expand: " + err)
   }
 
-
 }
diff --git a/src/main/scala/leon/synthesis/LinearEquations.scala b/src/main/scala/leon/synthesis/LinearEquations.scala
index 8ff7d093d..9debb3a8d 100644
--- a/src/main/scala/leon/synthesis/LinearEquations.scala
+++ b/src/main/scala/leon/synthesis/LinearEquations.scala
@@ -1,6 +1,7 @@
 package leon.synthesis
 
 import leon.purescala.Trees._
+import leon.purescala.TreeNormalizations.linearArithmeticForm
 import leon.purescala.TypeTrees._
 import leon.purescala.Common._
 import leon.Evaluator 
@@ -16,7 +17,7 @@ object LinearEquations {
     val t: Expr = normalizedEquation.head
     val coefsVars: List[Int] = normalizedEquation.tail.map{case IntLiteral(i) => i}
     val orderedParams: Array[Identifier] = as.toArray
-    val coefsParams: List[Int] = ArithmeticNormalization(t, orderedParams).map{case IntLiteral(i) => i}.toList
+    val coefsParams: List[Int] = linearArithmeticForm(t, orderedParams).map{case IntLiteral(i) => i}.toList
     //val coefsParams: List[Int] = if(coefsParams0.head == 0) coefsParams0.tail else coefsParams0
     val d: Int = gcd((coefsParams ++ coefsVars).toSeq)
 
@@ -83,7 +84,7 @@ object LinearEquations {
     val lhs = equation.left
     val rhs = equation.right
     val orderedXs = xs.toArray
-    val normalized: Array[Expr] = ArithmeticNormalization(Minus(lhs, rhs), orderedXs)
+    val normalized: Array[Expr] = linearArithmeticForm(Minus(lhs, rhs), orderedXs)
     val (pre, sols) = particularSolution(as, normalized.toList)
     (pre, orderedXs.zip(sols).toMap)
   }
diff --git a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala
index 741b31bd4..e96cd54cf 100644
--- a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala
+++ b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala
@@ -6,6 +6,7 @@ import purescala.Common._
 import purescala.Trees._
 import purescala.Extractors._
 import purescala.TreeOps._
+import purescala.TreeNormalizations._
 import purescala.TypeTrees._
 import purescala.Definitions._
 import LinearEquations.elimVariable
@@ -31,9 +32,9 @@ class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth
       eqxs = problem.xs.toSet.intersect(vars).toList
 
       try {
-        optionNormalizedEq = Some(ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList)
+        optionNormalizedEq = Some(linearArithmeticForm(Minus(eq.left, eq.right), eqxs.toArray).toList)
       } catch {
-        case ArithmeticNormalization.NonLinearExpressionException(_) =>
+        case NonLinearExpressionException(_) =>
           allOthers = allOthers :+ eq
       }
     }
diff --git a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala
index 5aa459614..b2c767771 100644
--- a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala
+++ b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala
@@ -6,6 +6,7 @@ import purescala.Common._
 import purescala.Trees._
 import purescala.Extractors._
 import purescala.TreeOps._
+import purescala.TreeNormalizations.linearArithmeticForm
 import purescala.TypeTrees._
 import purescala.Definitions._
 import LinearEquations.elimVariable
@@ -35,7 +36,7 @@ class IntegerInequalities(synth: Synthesizer) extends Rule("Integer Inequalities
       val otherVars: List[Identifier] = problem.xs.filterNot(_ == processedVar)
 
       println("lhsSides: " + lhsSides)
-      val normalizedLhs: List[List[Expr]] = lhsSides.map(ArithmeticNormalization(_, Array(processedVar)).toList)
+      val normalizedLhs: List[List[Expr]] = lhsSides.map(linearArithmeticForm(_, Array(processedVar)).toList)
       println("normalized: " + normalizedLhs.mkString("\n"))
       var upperBounds: List[(Expr, Int)] = Nil // (t, c) means c*x <= t
       var lowerBounds: List[(Expr, Int)] = Nil // (t, c) means t <= c*x
diff --git a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala b/src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala
similarity index 81%
rename from src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala
rename to src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala
index 562ea4e0c..18d6c7627 100644
--- a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala
+++ b/src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala
@@ -1,16 +1,16 @@
-package leon.test.synthesis
+package leon.test.purescala
 
-import org.scalatest.FunSuite
-
-import leon.Evaluator
-import leon.purescala.Trees._
 import leon.purescala.Common._
+import leon.purescala.Definitions._
+import leon.purescala.Trees._
+import leon.purescala.TreeOps._
+import leon.purescala.TreeNormalizations._
 import leon.purescala.LikelyEq
+import leon.SilentReporter
 
-import leon.synthesis.ArithmeticNormalization._
-
-class ArithmeticNormalizationSuite extends FunSuite {
+import org.scalatest.FunSuite
 
+class TreeNormalizationsTests extends FunSuite {
   def i(x: Int) = IntLiteral(x)
 
   val xId = FreshIdentifier("x")
@@ -66,18 +66,17 @@ class ArithmeticNormalizationSuite extends FunSuite {
     checkSameExpr(toSum(expand(e4)), e4, xs)
   }
 
-  test("apply") {
+  test("linearArithmeticForm") {
     val xsOrder = Array(xId, yId)
 
     val e1 = Plus(Times(Plus(x, i(2)), i(3)), Times(i(4), y))
-    checkSameExpr(coefToSum(apply(e1, xsOrder), Array(x, y)), e1, xs)
+    checkSameExpr(coefToSum(linearArithmeticForm(e1, xsOrder), Array(x, y)), e1, xs)
 
     val e2 = Plus(Times(Plus(x, i(2)), i(3)), Plus(Plus(a, Times(i(5), b)), Times(i(4), y)))
-    checkSameExpr(coefToSum(apply(e2, xsOrder), Array(x, y)), e2, xs ++ as)
+    checkSameExpr(coefToSum(linearArithmeticForm(e2, xsOrder), Array(x, y)), e2, xs ++ as)
 
     val e3 = Minus(Plus(x, i(3)), Plus(y, i(2)))
-    checkSameExpr(coefToSum(apply(e3, xsOrder), Array(x, y)), e3, xs)
+    checkSameExpr(coefToSum(linearArithmeticForm(e3, xsOrder), Array(x, y)), e3, xs)
 
   }
-  
 }
-- 
GitLab