Skip to content
Snippets Groups Projects
Commit 97d4c17a authored by Régis Blanc's avatar Régis Blanc
Browse files

some more refactoring

parent 1e135d1e
No related branches found
No related tags found
No related merge requests found
......@@ -52,17 +52,17 @@ object TreeNormalizations {
Times(IntLiteral(totalCoef), Variable(id))
}
var expandedForm: Seq[Expr] = expand(expr)
var exprs: Seq[Expr] = expandedForm(expr)
val res: Array[Expr] = new Array(xs.size + 1)
xs.zipWithIndex.foreach{case (id, index) => {
val (terms, rests) = expandedForm.partition(containsId(_, id))
expandedForm = rests
val (terms, rests) = exprs.partition(containsId(_, id))
exprs = rests
val Times(coef, Variable(_)) = group(terms, id)
res(index+1) = coef
}}
res(0) = simplifyArithmetic(expandedForm.foldLeft[Expr](IntLiteral(0))(Plus(_, _)))
res(0) = simplifyArithmetic(exprs.foldLeft[Expr](IntLiteral(0))(Plus(_, _)))
res
}
......@@ -72,16 +72,16 @@ object TreeNormalizations {
es1.flatMap(e1 => es2.map(e2 => Times(e1, e2)))
}
//expand the expr in a sum of "atoms"
//expand the expr in a sum of "atoms", each atom being a product of literal and variable
//do not keep the evaluation order
def expand(expr: Expr): Seq[Expr] = expr match {
case Plus(es1, es2) => expand(es1) ++ expand(es2)
case Minus(e1, e2) => expand(e1) ++ expand(e2).map(Times(IntLiteral(-1), _): Expr)
case UMinus(e) => expand(e).map(Times(IntLiteral(-1), _): Expr)
case Times(es1, es2) => multiply(expand(es1), expand(es2))
case v@Variable(_) => Seq(v)
def expandedForm(expr: Expr): Seq[Expr] = expr match {
case Plus(es1, es2) => expandedForm(es1) ++ expandedForm(es2)
case Minus(e1, e2) => expandedForm(e1) ++ expandedForm(e2).map(Times(IntLiteral(-1), _): Expr)
case UMinus(e) => expandedForm(e).map(Times(IntLiteral(-1), _): Expr)
case Times(es1, es2) => multiply(expandedForm(es1), expandedForm(es2))
case v@Variable(_) if v.getType == Int32Type => Seq(v)
case n@IntLiteral(_) => Seq(n)
case err => throw NonLinearExpressionException("unexpected in expand: " + err)
case err => throw NonLinearExpressionException("unexpected in expandedForm: " + err)
}
}
......@@ -12,7 +12,7 @@ import purescala.Definitions._
import LinearEquations.elimVariable
class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth, 300) {
def attemptToApplyOn(problem: Problem): RuleResult = if(problem.xs.isEmpty) RuleInapplicable else {
def attemptToApplyOn(problem: Problem): RuleResult = if(!problem.xs.exists(_.getType == Int32Type)) RuleInapplicable else {
val TopLevelAnds(exprs) = problem.phi
......
......@@ -2,6 +2,7 @@ package leon.test.purescala
import leon.purescala.Common._
import leon.purescala.Definitions._
import leon.purescala.TypeTrees._
import leon.purescala.Trees._
import leon.purescala.TreeOps._
import leon.purescala.TreeNormalizations._
......@@ -13,15 +14,15 @@ import org.scalatest.FunSuite
class TreeNormalizationsTests extends FunSuite {
def i(x: Int) = IntLiteral(x)
val xId = FreshIdentifier("x")
val xId = FreshIdentifier("x").setType(Int32Type)
val x = Variable(xId)
val yId = FreshIdentifier("y")
val yId = FreshIdentifier("y").setType(Int32Type)
val y = Variable(yId)
val xs = Set(xId, yId)
val aId = FreshIdentifier("a")
val aId = FreshIdentifier("a").setType(Int32Type)
val a = Variable(aId)
val bId = FreshIdentifier("b")
val bId = FreshIdentifier("b").setType(Int32Type)
val b = Variable(bId)
val as = Set(aId, bId)
......@@ -52,18 +53,18 @@ class TreeNormalizationsTests extends FunSuite {
checkSameExpr(Times(toSum(lhs2), toSum(rhs2)), toSum(multiply(lhs2, rhs2)), xs)
}
test("expand") {
test("expandedForm") {
val e1 = Times(Plus(x, i(2)), Plus(y, i(1)))
checkSameExpr(toSum(expand(e1)), e1, xs)
checkSameExpr(toSum(expandedForm(e1)), e1, xs)
val e2 = Times(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1)))
checkSameExpr(toSum(expand(e2)), e2, xs)
checkSameExpr(toSum(expandedForm(e2)), e2, xs)
val e3 = Minus(Plus(x, Times(i(2), y)), Plus(Plus(x, y), i(1)))
checkSameExpr(toSum(expand(e3)), e3, xs)
checkSameExpr(toSum(expandedForm(e3)), e3, xs)
val e4 = UMinus(Plus(x, Times(i(2), y)))
checkSameExpr(toSum(expand(e4)), e4, xs)
checkSameExpr(toSum(expandedForm(e4)), e4, xs)
}
test("linearArithmeticForm") {
......
......@@ -3,6 +3,7 @@ package leon.test.purescala
import leon.purescala.Common._
import leon.purescala.Definitions._
import leon.purescala.Trees._
import leon.purescala.TypeTrees._
import leon.purescala.TreeOps._
import leon.purescala.LikelyEq
import leon.SilentReporter
......@@ -31,15 +32,15 @@ class TreeOpsTests extends FunSuite {
def i(x: Int) = IntLiteral(x)
val xId = FreshIdentifier("x")
val xId = FreshIdentifier("x").setType(Int32Type)
val x = Variable(xId)
val yId = FreshIdentifier("y")
val yId = FreshIdentifier("y").setType(Int32Type)
val y = Variable(yId)
val xs = Set(xId, yId)
val aId = FreshIdentifier("a")
val aId = FreshIdentifier("a").setType(Int32Type)
val a = Variable(aId)
val bId = FreshIdentifier("b")
val bId = FreshIdentifier("b").setType(Int32Type)
val b = Variable(bId)
val as = Set(aId, bId)
......
......@@ -4,6 +4,7 @@ import org.scalatest.FunSuite
import leon.Evaluator
import leon.purescala.Trees._
import leon.purescala.TypeTrees._
import leon.purescala.TreeOps._
import leon.purescala.Common._
import leon.purescala.LikelyEq
......@@ -14,16 +15,16 @@ class LinearEquationsSuite extends FunSuite {
def i(x: Int) = IntLiteral(x)
val xId = FreshIdentifier("x")
val xId = FreshIdentifier("x").setType(Int32Type)
val x = Variable(xId)
val yId = FreshIdentifier("y")
val yId = FreshIdentifier("y").setType(Int32Type)
val y = Variable(yId)
val zId = FreshIdentifier("z")
val zId = FreshIdentifier("z").setType(Int32Type)
val z = Variable(zId)
val aId = FreshIdentifier("a")
val aId = FreshIdentifier("a").setType(Int32Type)
val a = Variable(aId)
val bId = FreshIdentifier("b")
val bId = FreshIdentifier("b").setType(Int32Type)
val b = Variable(bId)
def toSum(es: Seq[Expr]) = es.reduceLeft(Plus(_, _))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment