diff --git a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala b/src/main/scala/leon/purescala/TreeNormalizations.scala similarity index 63% rename from src/main/scala/leon/synthesis/ArithmeticNormalization.scala rename to src/main/scala/leon/purescala/TreeNormalizations.scala index 6c148e2ccfd5066d6d99a9035ad392aecbb475ef..0a805a8aff024713966aad9e02ec5f508297b4a2 100644 --- a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala +++ b/src/main/scala/leon/purescala/TreeNormalizations.scala @@ -1,17 +1,40 @@ -package leon.synthesis +package leon +package purescala -import leon.purescala.Trees._ -import leon.purescala.TreeOps._ -import leon.purescala.Common._ +object TreeNormalizations { + import Common._ + import TypeTrees._ + import Definitions._ + import Trees._ + import TreeOps._ + import Extractors._ -object ArithmeticNormalization { + /* TODO: we should add CNF and DNF at least */ case class NonLinearExpressionException(msg: String) extends Exception //assume the function is an arithmetic expression, not a relation //return a normal form where the [t a1 ... an] where //expr = t + a1*x1 + ... + an*xn and xs = [x1 ... xn] - def apply(expr: Expr, xs: Array[Identifier]): Array[Expr] = { + //do not keep the evaluation order + def linearArithmeticForm(expr: Expr, xs: Array[Identifier]): Array[Expr] = { + + //assume the expr is a literal (mult of constants and variables) with degree one + def extractCoef(e: Expr): (Expr, Identifier) = { + var id: Option[Identifier] = None + var coef = 1 + + def rec(e: Expr): Unit = e match { + case IntLiteral(i) => coef = coef*i + case Variable(id2) => if(id.isEmpty) id = Some(id2) else throw NonLinearExpressionException("multiple variable") + case Times(e1, e2) => rec(e1); rec(e2) + } + + rec(e) + assert(!id.isEmpty) + (IntLiteral(coef), id.get) + } + def containsId(e: Expr, id: Identifier): Boolean = e match { case Times(e1, e2) => containsId(e1, id) || containsId(e2, id) @@ -43,29 +66,14 @@ object ArithmeticNormalization { res } - - //assume the expr is a literal (mult of constants and variables) with degree one - def extractCoef(e: Expr): (Expr, Identifier) = { - var id: Option[Identifier] = None - var coef = 1 - - def rec(e: Expr): Unit = e match { - case IntLiteral(i) => coef = coef*i - case Variable(id2) => if(id.isEmpty) id = Some(id2) else throw NonLinearExpressionException("multiple variable") - case Times(e1, e2) => rec(e1); rec(e2) - } - - rec(e) - assert(!id.isEmpty) - (IntLiteral(coef), id.get) - } - - //multiply two sums together and distribute in a bigger sum + //multiply two sums together and distribute in a larger sum + //do not keep the evaluation order def multiply(es1: Seq[Expr], es2: Seq[Expr]): Seq[Expr] = { es1.flatMap(e1 => es2.map(e2 => Times(e1, e2))) } - + //expand the expr in a sum of "atoms" + //do not keep the evaluation order def expand(expr: Expr): Seq[Expr] = expr match { case Plus(es1, es2) => expand(es1) ++ expand(es2) case Minus(e1, e2) => expand(e1) ++ expand(e2).map(Times(IntLiteral(-1), _): Expr) @@ -76,5 +84,4 @@ object ArithmeticNormalization { case err => throw NonLinearExpressionException("unexpected in expand: " + err) } - } diff --git a/src/main/scala/leon/synthesis/LinearEquations.scala b/src/main/scala/leon/synthesis/LinearEquations.scala index 8ff7d093d093ead26e7a75eaee6e9793a1405daa..9debb3a8d340dcc8b9ba594cd242b8c434d122fc 100644 --- a/src/main/scala/leon/synthesis/LinearEquations.scala +++ b/src/main/scala/leon/synthesis/LinearEquations.scala @@ -1,6 +1,7 @@ package leon.synthesis import leon.purescala.Trees._ +import leon.purescala.TreeNormalizations.linearArithmeticForm import leon.purescala.TypeTrees._ import leon.purescala.Common._ import leon.Evaluator @@ -16,7 +17,7 @@ object LinearEquations { val t: Expr = normalizedEquation.head val coefsVars: List[Int] = normalizedEquation.tail.map{case IntLiteral(i) => i} val orderedParams: Array[Identifier] = as.toArray - val coefsParams: List[Int] = ArithmeticNormalization(t, orderedParams).map{case IntLiteral(i) => i}.toList + val coefsParams: List[Int] = linearArithmeticForm(t, orderedParams).map{case IntLiteral(i) => i}.toList //val coefsParams: List[Int] = if(coefsParams0.head == 0) coefsParams0.tail else coefsParams0 val d: Int = gcd((coefsParams ++ coefsVars).toSeq) @@ -83,7 +84,7 @@ object LinearEquations { val lhs = equation.left val rhs = equation.right val orderedXs = xs.toArray - val normalized: Array[Expr] = ArithmeticNormalization(Minus(lhs, rhs), orderedXs) + val normalized: Array[Expr] = linearArithmeticForm(Minus(lhs, rhs), orderedXs) val (pre, sols) = particularSolution(as, normalized.toList) (pre, orderedXs.zip(sols).toMap) } diff --git a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala index 741b31bd458858948cde07aa2f96c662568d6673..e96cd54cf0125d1522bbca2b7d6e03bdced4f529 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala @@ -6,6 +6,7 @@ import purescala.Common._ import purescala.Trees._ import purescala.Extractors._ import purescala.TreeOps._ +import purescala.TreeNormalizations._ import purescala.TypeTrees._ import purescala.Definitions._ import LinearEquations.elimVariable @@ -31,9 +32,9 @@ class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth eqxs = problem.xs.toSet.intersect(vars).toList try { - optionNormalizedEq = Some(ArithmeticNormalization(Minus(eq.left, eq.right), eqxs.toArray).toList) + optionNormalizedEq = Some(linearArithmeticForm(Minus(eq.left, eq.right), eqxs.toArray).toList) } catch { - case ArithmeticNormalization.NonLinearExpressionException(_) => + case NonLinearExpressionException(_) => allOthers = allOthers :+ eq } } diff --git a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala index 5aa459614154cd4ec4e2daa4c5f6fbdb5ca47eec..b2c767771bd75ec3b634ff34107ecaf27f4081da 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala @@ -6,6 +6,7 @@ import purescala.Common._ import purescala.Trees._ import purescala.Extractors._ import purescala.TreeOps._ +import purescala.TreeNormalizations.linearArithmeticForm import purescala.TypeTrees._ import purescala.Definitions._ import LinearEquations.elimVariable @@ -35,7 +36,7 @@ class IntegerInequalities(synth: Synthesizer) extends Rule("Integer Inequalities val otherVars: List[Identifier] = problem.xs.filterNot(_ == processedVar) println("lhsSides: " + lhsSides) - val normalizedLhs: List[List[Expr]] = lhsSides.map(ArithmeticNormalization(_, Array(processedVar)).toList) + val normalizedLhs: List[List[Expr]] = lhsSides.map(linearArithmeticForm(_, Array(processedVar)).toList) println("normalized: " + normalizedLhs.mkString("\n")) var upperBounds: List[(Expr, Int)] = Nil // (t, c) means c*x <= t var lowerBounds: List[(Expr, Int)] = Nil // (t, c) means t <= c*x diff --git a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala b/src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala similarity index 81% rename from src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala rename to src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala index 562ea4e0ce9dadfe06c68a70f9c80bcaa05a9c92..18d6c76274af9768153c01222abd78052d0d494c 100644 --- a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala +++ b/src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala @@ -1,16 +1,16 @@ -package leon.test.synthesis +package leon.test.purescala -import org.scalatest.FunSuite - -import leon.Evaluator -import leon.purescala.Trees._ import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.purescala.Trees._ +import leon.purescala.TreeOps._ +import leon.purescala.TreeNormalizations._ import leon.purescala.LikelyEq +import leon.SilentReporter -import leon.synthesis.ArithmeticNormalization._ - -class ArithmeticNormalizationSuite extends FunSuite { +import org.scalatest.FunSuite +class TreeNormalizationsTests extends FunSuite { def i(x: Int) = IntLiteral(x) val xId = FreshIdentifier("x") @@ -66,18 +66,17 @@ class ArithmeticNormalizationSuite extends FunSuite { checkSameExpr(toSum(expand(e4)), e4, xs) } - test("apply") { + test("linearArithmeticForm") { val xsOrder = Array(xId, yId) val e1 = Plus(Times(Plus(x, i(2)), i(3)), Times(i(4), y)) - checkSameExpr(coefToSum(apply(e1, xsOrder), Array(x, y)), e1, xs) + checkSameExpr(coefToSum(linearArithmeticForm(e1, xsOrder), Array(x, y)), e1, xs) val e2 = Plus(Times(Plus(x, i(2)), i(3)), Plus(Plus(a, Times(i(5), b)), Times(i(4), y))) - checkSameExpr(coefToSum(apply(e2, xsOrder), Array(x, y)), e2, xs ++ as) + checkSameExpr(coefToSum(linearArithmeticForm(e2, xsOrder), Array(x, y)), e2, xs ++ as) val e3 = Minus(Plus(x, i(3)), Plus(y, i(2))) - checkSameExpr(coefToSum(apply(e3, xsOrder), Array(x, y)), e3, xs) + checkSameExpr(coefToSum(linearArithmeticForm(e3, xsOrder), Array(x, y)), e3, xs) } - }