From 97d4c17a7d8e4aaa33566c41fcabb911bcb8623e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com> Date: Wed, 21 Nov 2012 04:01:18 +0000 Subject: [PATCH] some more refactoring --- .../leon/purescala/TreeNormalizations.scala | 24 +++++++++---------- .../synthesis/rules/IntegerEquation.scala | 2 +- .../purescala/TreeNormalizationsTests.scala | 19 ++++++++------- .../leon/test/purescala/TreeOpsTests.scala | 9 +++---- .../test/synthesis/LinearEquationsSuite.scala | 11 +++++---- 5 files changed, 34 insertions(+), 31 deletions(-) diff --git a/src/main/scala/leon/purescala/TreeNormalizations.scala b/src/main/scala/leon/purescala/TreeNormalizations.scala index 0a805a8af..fb34f35ed 100644 --- a/src/main/scala/leon/purescala/TreeNormalizations.scala +++ b/src/main/scala/leon/purescala/TreeNormalizations.scala @@ -52,17 +52,17 @@ object TreeNormalizations { Times(IntLiteral(totalCoef), Variable(id)) } - var expandedForm: Seq[Expr] = expand(expr) + var exprs: Seq[Expr] = expandedForm(expr) val res: Array[Expr] = new Array(xs.size + 1) xs.zipWithIndex.foreach{case (id, index) => { - val (terms, rests) = expandedForm.partition(containsId(_, id)) - expandedForm = rests + val (terms, rests) = exprs.partition(containsId(_, id)) + exprs = rests val Times(coef, Variable(_)) = group(terms, id) res(index+1) = coef }} - res(0) = simplifyArithmetic(expandedForm.foldLeft[Expr](IntLiteral(0))(Plus(_, _))) + res(0) = simplifyArithmetic(exprs.foldLeft[Expr](IntLiteral(0))(Plus(_, _))) res } @@ -72,16 +72,16 @@ object TreeNormalizations { es1.flatMap(e1 => es2.map(e2 => Times(e1, e2))) } - //expand the expr in a sum of "atoms" + //expand the expr in a sum of "atoms", each atom being a product of literal and variable //do not keep the evaluation order - def expand(expr: Expr): Seq[Expr] = expr match { - case Plus(es1, es2) => expand(es1) ++ expand(es2) - case Minus(e1, e2) => expand(e1) ++ expand(e2).map(Times(IntLiteral(-1), _): Expr) - case UMinus(e) => expand(e).map(Times(IntLiteral(-1), _): Expr) - case Times(es1, es2) => multiply(expand(es1), expand(es2)) - case v@Variable(_) => Seq(v) + def expandedForm(expr: Expr): Seq[Expr] = expr match { + case Plus(es1, es2) => expandedForm(es1) ++ expandedForm(es2) + case Minus(e1, e2) => expandedForm(e1) ++ expandedForm(e2).map(Times(IntLiteral(-1), _): Expr) + case UMinus(e) => expandedForm(e).map(Times(IntLiteral(-1), _): Expr) + case Times(es1, es2) => multiply(expandedForm(es1), expandedForm(es2)) + case v@Variable(_) if v.getType == Int32Type => Seq(v) case n@IntLiteral(_) => Seq(n) - case err => throw NonLinearExpressionException("unexpected in expand: " + err) + case err => throw NonLinearExpressionException("unexpected in expandedForm: " + err) } } diff --git a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala index e96cd54cf..ef830dc0b 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala @@ -12,7 +12,7 @@ import purescala.Definitions._ import LinearEquations.elimVariable class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth, 300) { - def attemptToApplyOn(problem: Problem): RuleResult = if(problem.xs.isEmpty) RuleInapplicable else { + def attemptToApplyOn(problem: Problem): RuleResult = if(!problem.xs.exists(_.getType == Int32Type)) RuleInapplicable else { val TopLevelAnds(exprs) = problem.phi diff --git a/src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala b/src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala index 18d6c7627..679f2a588 100644 --- a/src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala +++ b/src/test/scala/leon/test/purescala/TreeNormalizationsTests.scala @@ -2,6 +2,7 @@ package leon.test.purescala import leon.purescala.Common._ import leon.purescala.Definitions._ +import leon.purescala.TypeTrees._ import leon.purescala.Trees._ import leon.purescala.TreeOps._ import leon.purescala.TreeNormalizations._ @@ -13,15 +14,15 @@ import org.scalatest.FunSuite class TreeNormalizationsTests extends FunSuite { def i(x: Int) = IntLiteral(x) - val xId = FreshIdentifier("x") + val xId = FreshIdentifier("x").setType(Int32Type) val x = Variable(xId) - val yId = FreshIdentifier("y") + val yId = FreshIdentifier("y").setType(Int32Type) val y = Variable(yId) val xs = Set(xId, yId) - val aId = FreshIdentifier("a") + val aId = FreshIdentifier("a").setType(Int32Type) val a = Variable(aId) - val bId = FreshIdentifier("b") + val bId = FreshIdentifier("b").setType(Int32Type) val b = Variable(bId) val as = Set(aId, bId) @@ -52,18 +53,18 @@ class TreeNormalizationsTests extends FunSuite { checkSameExpr(Times(toSum(lhs2), toSum(rhs2)), toSum(multiply(lhs2, rhs2)), xs) } - test("expand") { + test("expandedForm") { val e1 = Times(Plus(x, i(2)), Plus(y, i(1))) - checkSameExpr(toSum(expand(e1)), e1, xs) + checkSameExpr(toSum(expandedForm(e1)), e1, xs) val e2 = Times(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1))) - checkSameExpr(toSum(expand(e2)), e2, xs) + checkSameExpr(toSum(expandedForm(e2)), e2, xs) val e3 = Minus(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1))) - checkSameExpr(toSum(expand(e3)), e3, xs) + checkSameExpr(toSum(expandedForm(e3)), e3, xs) val e4 = UMinus(Plus(x, Times(i(2), y))) - checkSameExpr(toSum(expand(e4)), e4, xs) + checkSameExpr(toSum(expandedForm(e4)), e4, xs) } test("linearArithmeticForm") { diff --git a/src/test/scala/leon/test/purescala/TreeOpsTests.scala b/src/test/scala/leon/test/purescala/TreeOpsTests.scala index 763550e44..ae3be5513 100644 --- a/src/test/scala/leon/test/purescala/TreeOpsTests.scala +++ b/src/test/scala/leon/test/purescala/TreeOpsTests.scala @@ -3,6 +3,7 @@ package leon.test.purescala import leon.purescala.Common._ import leon.purescala.Definitions._ import leon.purescala.Trees._ +import leon.purescala.TypeTrees._ import leon.purescala.TreeOps._ import leon.purescala.LikelyEq import leon.SilentReporter @@ -31,15 +32,15 @@ class TreeOpsTests extends FunSuite { def i(x: Int) = IntLiteral(x) - val xId = FreshIdentifier("x") + val xId = FreshIdentifier("x").setType(Int32Type) val x = Variable(xId) - val yId = FreshIdentifier("y") + val yId = FreshIdentifier("y").setType(Int32Type) val y = Variable(yId) val xs = Set(xId, yId) - val aId = FreshIdentifier("a") + val aId = FreshIdentifier("a").setType(Int32Type) val a = Variable(aId) - val bId = FreshIdentifier("b") + val bId = FreshIdentifier("b").setType(Int32Type) val b = Variable(bId) val as = Set(aId, bId) diff --git a/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala b/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala index a3fc2606e..961ae949f 100644 --- a/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala +++ b/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala @@ -4,6 +4,7 @@ import org.scalatest.FunSuite import leon.Evaluator import leon.purescala.Trees._ +import leon.purescala.TypeTrees._ import leon.purescala.TreeOps._ import leon.purescala.Common._ import leon.purescala.LikelyEq @@ -14,16 +15,16 @@ class LinearEquationsSuite extends FunSuite { def i(x: Int) = IntLiteral(x) - val xId = FreshIdentifier("x") + val xId = FreshIdentifier("x").setType(Int32Type) val x = Variable(xId) - val yId = FreshIdentifier("y") + val yId = FreshIdentifier("y").setType(Int32Type) val y = Variable(yId) - val zId = FreshIdentifier("z") + val zId = FreshIdentifier("z").setType(Int32Type) val z = Variable(zId) - val aId = FreshIdentifier("a") + val aId = FreshIdentifier("a").setType(Int32Type) val a = Variable(aId) - val bId = FreshIdentifier("b") + val bId = FreshIdentifier("b").setType(Int32Type) val b = Variable(bId) def toSum(es: Seq[Expr]) = es.reduceLeft(Plus(_, _)) -- GitLab