From 40eee30e0c787fef3808d3dc3c7e8a2aaa041cf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com> Date: Wed, 21 Nov 2012 03:33:21 +0000 Subject: [PATCH] refactor ArithmeticNormalization into TreeNormalizations --- .../TreeNormalizations.scala} | 59 +++++++++++-------- .../leon/synthesis/LinearEquations.scala | 5 +- .../synthesis/rules/IntegerEquation.scala | 5 +- .../synthesis/rules/IntegerInequalities.scala | 3 +- .../TreeNormalizationsTests.scala} | 25 ++++---- 5 files changed, 53 insertions(+), 44 deletions(-) rename src/main/scala/leon/{synthesis/ArithmeticNormalization.scala => purescala/TreeNormalizations.scala} (63%) rename src/test/scala/leon/test/{synthesis/ArithmeticNormalizationSuite.scala => purescala/TreeNormalizationsTests.scala} (81%) diff --git a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala b/src/main/scala/leon/purescala/TreeNormalizations.scala similarity index 63% rename from src/main/scala/leon/synthesis/ArithmeticNormalization.scala rename to src/main/scala/leon/purescala/TreeNormalizations.scala index 6c148e2cc..0a805a8af 100644 --- a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala +++ b/src/main/scala/leon/purescala/TreeNormalizations.scala @@ -1,17 +1,40 @@ -package leon.synthesis +package leon +package purescala -import leon.purescala.Trees._ -import leon.purescala.TreeOps._ -import leon.purescala.Common._ +object TreeNormalizations { + import Common._ + import TypeTrees._ + import Definitions._ + import Trees._ + import TreeOps._ + import Extractors._ -object ArithmeticNormalization { + /* TODO: we should add CNF and DNF at least */ case class NonLinearExpressionException(msg: String) extends Exception //assume the function is an arithmetic expression, not a relation //return a normal form where the [t a1 ... an] where //expr = t + a1*x1 + ... + an*xn and xs = [x1 ... xn] - def apply(expr: Expr, xs: Array[Identifier]): Array[Expr] = { + //do not keep the evaluation order + def linearArithmeticForm(expr: Expr, xs: Array[Identifier]): Array[Expr] = { + + //assume the expr is a literal (mult of constants and variables) with degree one + def extractCoef(e: Expr): (Expr, Identifier) = { + var id: Option[Identifier] = None + var coef = 1 + + def rec(e: Expr): Unit = e match { + case IntLiteral(i) => coef = coef*i + case Variable(id2) => if(id.isEmpty) id = Some(id2) else throw NonLinearExpressionException("multiple variable") + case Times(e1, e2) => rec(e1); rec(e2) + } + + rec(e) + assert(!id.isEmpty) + (IntLiteral(coef), id.get) + } + def containsId(e: Expr, id: Identifier): Boolean = e match { case Times(e1, e2) => containsId(e1, id) || containsId(e2, id) @@ -43,29 +66,14 @@ object ArithmeticNormalization { res } - - //assume the expr is a literal (mult of constants and variables) with degree one - def extractCoef(e: Expr): (Expr, Identifier) = { - var id: Option[Identifier] = None - var coef = 1 - - def rec(e: Expr): Unit = e match { - case IntLiteral(i) => coef = coef*i - case Variable(id2) => if(id.isEmpty) id = Some(id2) else throw NonLinearExpressionException("multiple variable") - case Times(e1, e2) => rec(e1); rec(e2) - } - - rec(e) - assert(!id.isEmpty) - (IntLiteral(coef), id.get) - } - - //multiply two sums together and distribute in a bigger sum + //multiply two sums together and distribute in a larger sum + //do not keep the evaluation order def multiply(es1: Seq[Expr], es2: Seq[Expr]): Seq[Expr] = { es1.flatMap(e1 => es2.map(e2 => Times(e1, e2))) } - + //expand the expr in a sum of "atoms" + //do not keep the evaluation order def expand(expr: Expr): Seq[Expr] = expr match { case Plus(es1, es2) => expand(es1) ++ expand(es2) case Minus(e1, e2) => expand(e1) ++ expand(e2).map(Times(IntLiteral(-1), _): Expr) @@ -76,5 +84,4 @@ object ArithmeticNormalization { case err => throw NonLinearExpressionException("unexpected in expand: " + err) } - } diff --git a/src/main/scala/leon/synthesis/LinearEquations.scala b/src/main/scala/leon/synthesis/LinearEquations.scala index 8ff7d093d..9debb3a8d 100644 --- a/src/main/scala/leon/synthesis/LinearEquations.scala +++ b/src/main/scala/leon/synthesis/LinearEquations.scala @@ -1,6 +1,7 @@ package leon.synthesis import leon.purescala.Trees._ +import leon.purescala.TreeNormalizations.linearArithmeticForm import leon.purescala.TypeTrees._ import leon.purescala.Common._ import leon.Evaluator @@ -16,7 +17,7 @@ object LinearEquations { val t: Expr = normalizedEquation.head val coefsVars: List[Int] = normalizedEquation.tail.map{case IntLiteral(i) => i} val orderedParams: Array[Identifier] = as.toArray - val coefsParams: List[Int] = ArithmeticNormalization(t, orderedParams).map{case IntLiteral(i) => i}.toList + val coefsParams: List[Int] = linearArithmeticForm(t, orderedParams).map{case IntLiteral(i) => i}.toList //val coefsParams: List[Int] = if(coefsParams0.head == 0) coefsParams0.tail else coefsParams0 val d: Int = gcd((coefsParams ++ coefsVars).toSeq) @@ -83,7 +84,7 @@ object LinearEquations { val lhs = equation.left val rhs = equation.right val orderedXs = xs.toArray - val normalized: Array[Expr] = ArithmeticNormalization(Minus(lhs, rhs), orderedXs) + val normalized: Array[Expr] = linearArithmeticForm(Minus(lhs, rhs), orderedXs) val (pre, sols) = particularSolution(as, normalized.toList) (pre, orderedXs.zip(sols).toMap) } diff --git a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala index 741b31bd4..e96cd54cf 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala @@ -6,6 +6,7 @@ import purescala.Common._ import purescala.Trees._ import purescala.Extractors._ import purescala.TreeOps._ +import purescala.TreeNormalizations._ import purescala.TypeTrees._ import purescala.Definitions._ import LinearEquations.elimVariable @@ -31,9 +32,9 @@ class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth eqxs = problem.xs.toSet.intersect(vars).toList try { - optionNormalizedEq = Some(ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList) + optionNormalizedEq = Some(linearArithmeticForm(Minus(eq.left, eq.right), eqxs.toArray).toList) } catch { - case ArithmeticNormalization.NonLinearExpressionException(_) => + case NonLinearExpressionException(_) => allOthers = allOthers :+ eq } } diff --git a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala index 5aa459614..b2c767771 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala @@ -6,6 +6,7 @@ import purescala.Common._ import purescala.Trees._ import purescala.Extractors._ import purescala.TreeOps._ +import purescala.TreeNormalizations.linearArithmeticForm import purescala.TypeTrees._ import purescala.Definitions._ import LinearEquations.elimVariable @@ -35,7 +36,7 @@ class IntegerInequalities(synth: Synthesizer) extends Rule("Integer Inequalities val otherVars: List[Identifier] = problem.xs.filterNot(_ == processedVar) println("lhsSides: " + lhsSides) - val normalizedLhs: List[List[Expr]] = lhsSides.map(ArithmeticNormalization(_, Array(processedVar)).toList) + val normalizedLhs: List[List[Expr]] = lhsSides.map(linearArithmeticForm(_, Array(processedVar)).toList) println("normalized: " + normalizedLhs.mkString("\n")) var upperBounds: List[(Expr, Int)] = Nil // (t, c) means c*x <= t var lowerBounds: List[(Expr, Int)] = Nil // (t, c) means t <= c*x diff --git a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala b/src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala similarity index 81% rename from src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala rename to src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala index 562ea4e0c..18d6c7627 100644 --- a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala +++ b/src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala @@ -1,16 +1,16 @@ -package leon.test.synthesis +package leon.test.purescala -import org.scalatest.FunSuite - -import leon.Evaluator -import leon.purescala.Trees._ import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.purescala.Trees._ +import leon.purescala.TreeOps._ +import leon.purescala.TreeNormalizations._ import leon.purescala.LikelyEq +import leon.SilentReporter -import leon.synthesis.ArithmeticNormalization._ - -class ArithmeticNormalizationSuite extends FunSuite { +import org.scalatest.FunSuite +class TreeNormalizationsTests extends FunSuite { def i(x: Int) = IntLiteral(x) val xId = FreshIdentifier("x") @@ -66,18 +66,17 @@ class ArithmeticNormalizationSuite extends FunSuite { checkSameExpr(toSum(expand(e4)), e4, xs) } - test("apply") { + test("linearArithmeticForm") { val xsOrder = Array(xId, yId) val e1 = Plus(Times(Plus(x, i(2)), i(3)), Times(i(4), y)) - checkSameExpr(coefToSum(apply(e1, xsOrder), Array(x, y)), e1, xs) + checkSameExpr(coefToSum(linearArithmeticForm(e1, xsOrder), Array(x, y)), e1, xs) val e2 = Plus(Times(Plus(x, i(2)), i(3)), Plus(Plus(a, Times(i(5), b)), Times(i(4), y))) - checkSameExpr(coefToSum(apply(e2, xsOrder), Array(x, y)), e2, xs ++ as) + checkSameExpr(coefToSum(linearArithmeticForm(e2, xsOrder), Array(x, y)), e2, xs ++ as) val e3 = Minus(Plus(x, i(3)), Plus(y, i(2))) - checkSameExpr(coefToSum(apply(e3, xsOrder), Array(x, y)), e3, xs) + checkSameExpr(coefToSum(linearArithmeticForm(e3, xsOrder), Array(x, y)), e3, xs) } - } -- GitLab