diff --git a/src/main/scala/leon/synthesis/LikelyEq.scala b/src/main/scala/leon/purescala/LikelyEq.scala similarity index 92% rename from src/main/scala/leon/synthesis/LikelyEq.scala rename to src/main/scala/leon/purescala/LikelyEq.scala index 35d3cdfc8ddf52e959a9467a9b71396bbb976c84..71177603bdf77e5b1f126b187968d14005d8be9b 100644 --- a/src/main/scala/leon/synthesis/LikelyEq.scala +++ b/src/main/scala/leon/purescala/LikelyEq.scala @@ -1,4 +1,4 @@ -package leon.synthesis +package leon.purescala import leon.Evaluator._ import leon.purescala.Trees._ @@ -6,7 +6,10 @@ import leon.purescala.TreeOps.replace import leon.purescala.Common._ /* - * determine if two expressions over arithmetic variables are likely to be the same + * Determine if two expressions over arithmetic variables are likely to be the same + * + * This is a probabilistic based approach, it does not rely on any external solver and can + * only prove the non equality of two expressions. */ object LikelyEq { diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 37939efa033bd85ca39e6b02ea14643aba328e7b..08f4c0f060aab8cbb8f7fc6515613ac8537e10ef 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1359,4 +1359,74 @@ object TreeOps { replace(vars.map(id => Variable(id) -> valuator(id)).toMap, expr) } + //simple, local simplifications on arithmetic + //you should not assume anything smarter than some constant folding and simple cancelation + //to avoid infinite cycle we only apply simplification that reduce the size of the tree + //The only guarentee from this function is to not augment the size of the expression and to be sound + //(note that an identity function would meet this specification) + def simplifyArithmetic(expr: Expr): Expr = { + def simplify0(expr: Expr): Expr = expr match { + case Plus(IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 + i2) + case Plus(IntLiteral(0), e) => e + case Plus(e, IntLiteral(0)) => e + case Plus(e1, UMinus(e2)) => Minus(e1, e2) + + case Minus(e, IntLiteral(0)) => e + case Minus(IntLiteral(0), e) => UMinus(e) + case Minus(IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 - i2) + case Minus(e1, UMinus(e2)) => Plus(e1, e2) + case Minus(e1, Minus(UMinus(e2), e3)) => Plus(e1, Plus(e2, e3)) + + case UMinus(IntLiteral(x)) => IntLiteral(-x) + case UMinus(UMinus(x)) => x + case UMinus(Plus(UMinus(e1), e2)) => Plus(e1, UMinus(e2)) + case UMinus(Minus(e1, e2)) => Minus(e2, e1) + + case Times(IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 * i2) + case Times(IntLiteral(1), e) => e + case Times(IntLiteral(-1), e) => UMinus(e) + case Times(e, IntLiteral(1)) => e + case Times(IntLiteral(0), _) => IntLiteral(0) + case Times(_, IntLiteral(0)) => IntLiteral(0) + case Times(IntLiteral(i1), Times(IntLiteral(i2), t)) => Times(IntLiteral(i1*i2), t) + case Times(IntLiteral(i1), Times(t, IntLiteral(i2))) => Times(IntLiteral(i1*i2), t) + case Times(IntLiteral(i), UMinus(e)) => Times(IntLiteral(-i), e) + case Times(UMinus(e), IntLiteral(i)) => Times(e, IntLiteral(-i)) + case Times(IntLiteral(i1), Division(e, IntLiteral(i2))) if i2 != 0 && i1 % i2 == 0 => Times(IntLiteral(i1/i2), e) + case Times(IntLiteral(i1), Plus(Division(e1, IntLiteral(i2)), e2)) if i2 != 0 && i1 % i2 == 0 => Times(IntLiteral(i1/i2), Plus(e1, e2)) + + case Division(IntLiteral(i1), IntLiteral(i2)) if i2 != 0 => IntLiteral(i1 / i2) + case Division(e, IntLiteral(1)) => e + + //here we put more expensive rules + case Minus(e1, e2) if e1 == e2 => IntLiteral(0) + case e => e + } + def fix[A](f: (A) => A)(a: A): A = { + val na = f(a) + if(a == na) a else fix(f)(na) + } + val res = fix(simplePostTransform(simplify0))(expr) + res + } + + //Simplify the expression, applying all the simplify for various theories + //Maybe it would be a good design decision to not have any simplify calling + //an underlying solver, to somehow keep it light and become a function we call often + def simplify(expr: Expr): Expr = simplifyArithmetic(expr) + + //If the formula consist of some top level AND, find a top level + //Equals and extract it, return the remaining formula as well + def extractEquals(expr: Expr): (Option[Equals], Expr) = expr match { + case And(es) => + // OK now I'm just messing with you. + val (r, nes) = es.foldLeft[(Option[Equals],Seq[Expr])]((None, Seq())) { + case ((None, nes), eq @ Equals(_,_)) => (Some(eq), nes) + case ((o, nes), e) => (o, e +: nes) + } + (r, And(nes.reverse)) + + case e => (None, e) + } + } diff --git a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala b/src/main/scala/leon/synthesis/ArithmeticNormalization.scala index 1ec71f0fe7d65be64b2ace4aa6208288034eb472..f81311ffdc16388e64d5e7eb2468646f234df728 100644 --- a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala +++ b/src/main/scala/leon/synthesis/ArithmeticNormalization.scala @@ -43,7 +43,7 @@ object ArithmeticNormalization { res(index+1) = coef }} - res(0) = simplify(expandedForm.foldLeft[Expr](IntLiteral(0))(Plus(_, _))) + res(0) = simplifyArithmetic(expandedForm.foldLeft[Expr](IntLiteral(0))(Plus(_, _))) res } @@ -80,66 +80,5 @@ object ArithmeticNormalization { case err => throw NonLinearExpressionException("unexpected in expand: " + err) } - //simple, local simplifications - //you should not assume anything smarter than some constant folding and simple cancelation - //to avoid infinite cycle we only apply simplification that reduce the size of the tree - def simplify(expr: Expr): Expr = { - def simplify0(expr: Expr): Expr = expr match { - case Plus(IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 + i2) - case Plus(IntLiteral(0), e) => e - case Plus(e, IntLiteral(0)) => e - case Plus(e1, UMinus(e2)) => Minus(e1, e2) - - case Minus(e, IntLiteral(0)) => e - case Minus(IntLiteral(0), e) => UMinus(e) - case Minus(IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 - i2) - case Minus(e1, UMinus(e2)) => Plus(e1, e2) - case Minus(e1, Minus(UMinus(e2), e3)) => Plus(e1, Plus(e2, e3)) - - case UMinus(IntLiteral(x)) => IntLiteral(-x) - case UMinus(UMinus(x)) => x - case UMinus(Plus(UMinus(e1), e2)) => Plus(e1, UMinus(e2)) - case UMinus(Minus(e1, e2)) => Minus(e2, e1) - - case Times(IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 * i2) - case Times(IntLiteral(1), e) => e - case Times(IntLiteral(-1), e) => UMinus(e) - case Times(e, IntLiteral(1)) => e - case Times(IntLiteral(0), _) => IntLiteral(0) - case Times(_, IntLiteral(0)) => IntLiteral(0) - case Times(IntLiteral(i1), Times(IntLiteral(i2), t)) => Times(IntLiteral(i1*i2), t) - case Times(IntLiteral(i1), Times(t, IntLiteral(i2))) => Times(IntLiteral(i1*i2), t) - case Times(IntLiteral(i), UMinus(e)) => Times(IntLiteral(-i), e) - case Times(UMinus(e), IntLiteral(i)) => Times(e, IntLiteral(-i)) - case Times(IntLiteral(i1), Division(e, IntLiteral(i2))) if i2 != 0 && i1 % i2 == 0 => Times(IntLiteral(i1/i2), e) - case Times(IntLiteral(i1), Plus(Division(e1, IntLiteral(i2)), e2)) if i2 != 0 && i1 % i2 == 0 => Times(IntLiteral(i1/i2), Plus(e1, e2)) - - case Division(IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 / i2) - case Division(e, IntLiteral(1)) => e - - //here we put more expensive rules - case Minus(e1, e2) if e1 == e2 => IntLiteral(0) - case e => e - } - def fix[A](f: (A) => A)(a: A): A = { - val na = f(a) - if(a == na) a else fix(f)(na) - } - val res = fix(simplePostTransform(simplify0))(expr) - res - } - // Assume the formula consist only of top level AND, find a top level - // Equals and extract it, return the remaining formula as well - def extractEquals(expr: Expr): (Option[Equals], Expr) = expr match { - case And(es) => - // OK now I'm just messing with you. - val (r, nes) = es.foldLeft[(Option[Equals],Seq[Expr])]((None, Seq())) { - case ((None, nes), eq @ Equals(_,_)) => (Some(eq), nes) - case ((o, nes), e) => (o, e +: nes) - } - (r, And(nes.reverse)) - - case e => (None, e) - } } diff --git a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala index 6b94025b6b6609a219b80dc59f54d43cbdf16094..07a60c02b34e4f15c3bd3d44d044f7f5ab7cef96 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala @@ -9,7 +9,6 @@ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ import LinearEquations.elimVariable -import ArithmeticNormalization.simplify class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth, 300) { def attemptToApplyOn(problem: Problem): RuleResult = { @@ -69,8 +68,8 @@ class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth case c => c } - val eqSubstMap: Map[Expr, Expr] = neqxs.zip(eqWitness).map{case (id, e) => (Variable(id), simplify(e))}.toMap - val freshFormula0 = simplify(replace(eqSubstMap, And(allOthers))) + val eqSubstMap: Map[Expr, Expr] = neqxs.zip(eqWitness).map{case (id, e) => (Variable(id), simplifyArithmetic(e))}.toMap + val freshFormula0 = simplifyArithmetic(replace(eqSubstMap, And(allOthers))) var freshInputVariables: List[Identifier] = Nil var equivalenceConstraints: Map[Expr, Expr] = Map() @@ -98,7 +97,7 @@ class IntegerEquation(synth: Synthesizer) extends Rule("Integer Equation", synth val id2res: Map[Expr, Expr] = freshsubxs.zip(subproblemxs).map{case (id1, id2) => (Variable(id1), Variable(id2))}.toMap ++ neqxs.map(id => (Variable(id), eqSubstMap(Variable(id)))).toMap - Solution(And(eqPre, freshPre), defs, simplify(simplifyLets(LetTuple(subproblemxs, freshTerm, replace(id2res, Tuple(problem.xs.map(Variable(_)))))))) + Solution(And(eqPre, freshPre), defs, simplifyArithmetic(simplifyLets(LetTuple(subproblemxs, freshTerm, replace(id2res, Tuple(problem.xs.map(Variable(_)))))))) } case _ => Solution.none } diff --git a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala index 5f70d9ac5259a497560d24c63211e8cbb064644f..a4ff5fc7a8ddf1675b5d5bc9471a84ca48fae0de 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala @@ -9,7 +9,6 @@ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ import LinearEquations.elimVariable -import ArithmeticNormalization.simplify class IntegerInequalities(synth: Synthesizer) extends Rule("Integer Inequalities", synth, 300) { def attemptToApplyOn(problem: Problem): RuleResult = { @@ -39,7 +38,7 @@ class IntegerInequalities(synth: Synthesizer) extends Rule("Integer Inequalities var lowerBounds: List[(Expr, Int)] = Nil // (t, c) means t <= c*x normalizedLhs.foreach{ case List(t, IntLiteral(i)) => - if(i > 0) upperBounds ::= (simplify(UMinus(t)), i) + if(i > 0) upperBounds ::= (simplifyArithmetic(UMinus(t)), i) else if(i < 0) lowerBounds ::= (simplify(t), -i) else /*if (i == 0)*/ exprNotUsed ::= LessEquals(t, IntLiteral(0)) case err => sys.error("unexpected from normal form: " + err) diff --git a/src/test/scala/leon/test/PureScalaPrograms.scala b/src/test/scala/leon/test/PureScalaPrograms.scala deleted file mode 100644 index 92fbb68ebad3e674fec6c58b735ce7410421ab9f..0000000000000000000000000000000000000000 --- a/src/test/scala/leon/test/PureScalaPrograms.scala +++ /dev/null @@ -1,84 +0,0 @@ -package leon -package test - -import leon.verification.{AnalysisPhase,VerificationReport} - -import org.scalatest.FunSuite - -import java.io.File - -class PureScalaPrograms extends FunSuite { - private var counter : Int = 0 - private def nextInt() : Int = { - counter += 1 - counter - } - private case class Output(report : VerificationReport, reporter : Reporter) - - private def mkPipeline : Pipeline[List[String],VerificationReport] = - leon.plugin.ExtractionPhase andThen leon.verification.AnalysisPhase - - private def mkTest(file : File)(block: Output=>Unit) = { - val fullName = file.getPath() - val start = fullName.indexOf("regression") - val displayName = if(start != -1) { - fullName.substring(start, fullName.length) - } else { - fullName - } - - test("PureScala program %3d: [%s]".format(nextInt(), displayName)) { - assert(file.exists && file.isFile && file.canRead, - "Benchmark [%s] is not a readable file".format(displayName)) - - val ctx = LeonContext( - settings = Settings( - synthesis = false, - xlang = false, - verify = true - ), - files = List(file), - reporter = new SilentReporter - ) - - val pipeline = mkPipeline - - val report = pipeline.run(ctx)("--timeout=2" :: file.getPath :: Nil) - - block(Output(report, ctx.reporter)) - } - } - - private def forEachFileIn(dirName : String)(block : Output=>Unit) { - import scala.collection.JavaConversions._ - - val dir = this.getClass.getClassLoader.getResource(dirName) - - if(dir == null || dir.getProtocol != "file") { - assert(false, "Tests have to be run from within `sbt`, for otherwise " + - "the test files will be harder to access (and we dislike that).") - } - - for(f <- (new File(dir.toURI())).listFiles() if f.getPath().endsWith(".scala")) { - mkTest(f)(block) - } - } - - forEachFileIn("regression/verification/purescala/valid") { output => - val Output(report, reporter) = output - assert(report.totalConditions === report.totalValid, - "All verification conditions should be valid.") - assert(reporter.errorCount === 0) - assert(reporter.warningCount === 0) - } - - forEachFileIn("regression/verification/purescala/invalid") { output => - val Output(report, reporter) = output - assert(report.totalInvalid > 0, - "There should be at least one invalid verification condition.") - assert(report.totalUnknown === 0, - "There should not be unknown verification conditions.") - assert(reporter.errorCount >= report.totalInvalid) - assert(reporter.warningCount === 0) - } -} diff --git a/src/test/scala/leon/test/synthesis/LikelyEqSuite.scala b/src/test/scala/leon/test/purescala/LikelyEqSuite.scala similarity index 95% rename from src/test/scala/leon/test/synthesis/LikelyEqSuite.scala rename to src/test/scala/leon/test/purescala/LikelyEqSuite.scala index 9f911ae2aaa51106cc96a178bac0d3a9fd35e599..98cf7ba423920aae8c23021f3a549a7a03ce859a 100644 --- a/src/test/scala/leon/test/synthesis/LikelyEqSuite.scala +++ b/src/test/scala/leon/test/purescala/LikelyEqSuite.scala @@ -1,4 +1,4 @@ -package leon.test.synthesis +package leon.test.purescala import org.scalatest.FunSuite @@ -6,7 +6,7 @@ import leon.Evaluator import leon.purescala.Trees._ import leon.purescala.Common._ -import leon.synthesis.LikelyEq +import leon.purescala.LikelyEq class LikelyEqSuite extends FunSuite { diff --git a/src/test/scala/leon/test/purescala/TreeOpsTests.scala b/src/test/scala/leon/test/purescala/TreeOpsTests.scala index a729f00d069cae5a8898a573934f2970f9694837..87ae844ec3827ea5838a2a44491a1553c2afff71 100644 --- a/src/test/scala/leon/test/purescala/TreeOpsTests.scala +++ b/src/test/scala/leon/test/purescala/TreeOpsTests.scala @@ -4,6 +4,7 @@ import leon.purescala.Common._ import leon.purescala.Definitions._ import leon.purescala.Trees._ import leon.purescala.TreeOps._ +import leon.purescala.LikelyEq import leon.SilentReporter import org.scalatest.FunSuite @@ -26,4 +27,69 @@ class TreeOpsTests extends FunSuite { assert(true) } + + + def i(x: Int) = IntLiteral(x) + + val xId = FreshIdentifier("x") + val x = Variable(xId) + val yId = FreshIdentifier("y") + val y = Variable(yId) + val xs = Set(xId, yId) + + val aId = FreshIdentifier("a") + val a = Variable(aId) + val bId = FreshIdentifier("b") + val b = Variable(bId) + val as = Set(aId, bId) + + def checkSameExpr(e1: Expr, e2: Expr, vs: Set[Identifier]) { + assert( //this outer assert should not be needed because of the nested one + LikelyEq(e1, e2, vs, BooleanLiteral(true), (e1, e2) => {assert(e1 === e2); true}) + ) + } + + test("simplifyArithmetic") { + val e1 = Plus(IntLiteral(3), IntLiteral(2)) + checkSameExpr(e1, simplify(e1), Set()) + val e2 = Plus(x, Plus(IntLiteral(3), IntLiteral(2))) + checkSameExpr(e2, simplify(e2), Set(xId)) + + val e3 = Minus(IntLiteral(3), IntLiteral(2)) + checkSameExpr(e3, simplify(e3), Set()) + val e4 = Plus(x, Minus(IntLiteral(3), IntLiteral(2))) + checkSameExpr(e4, simplify(e4), Set(xId)) + val e5 = Plus(x, Minus(x, IntLiteral(2))) + checkSameExpr(e5, simplify(e5), Set(xId)) + } + + + test("extractEquals") { + val eq = Equals(a, b) + val lt1 = LessThan(a, b) + val lt2 = LessThan(b, a) + val lt3 = LessThan(x, y) + + val f1 = And(Seq(eq, lt1, lt2, lt3)) + val (eq1, r1) = extractEquals(f1) + assert(eq1 != None) + assert(eq1.get === eq) + assert(extractEquals(r1)._1 === None) + + val f2 = And(Seq(lt1, lt2, eq, lt3)) + val (eq2, r2) = extractEquals(f2) + assert(eq2 != None) + assert(eq2.get === eq) + assert(extractEquals(r2)._1 === None) + + val f3 = And(Seq(lt1, eq, lt2, lt3, eq)) + val (eq3, r3) = extractEquals(f3) + assert(eq3 != None) + assert(eq3.get === eq) + val (eq4, r4) = extractEquals(r3) + assert(eq4 != None) + assert(eq4.get === eq) + assert(extractEquals(r4)._1 === None) + + } } diff --git a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala b/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala index 992e668f7c70ce82dd372f5fb0bb36c721fcfcd1..562ea4e0ce9dadfe06c68a70f9c80bcaa05a9c92 100644 --- a/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala +++ b/src/test/scala/leon/test/synthesis/ArithmeticNormalizationSuite.scala @@ -5,9 +5,9 @@ import org.scalatest.FunSuite import leon.Evaluator import leon.purescala.Trees._ import leon.purescala.Common._ +import leon.purescala.LikelyEq import leon.synthesis.ArithmeticNormalization._ -import leon.synthesis.LikelyEq class ArithmeticNormalizationSuite extends FunSuite { @@ -25,7 +25,6 @@ class ArithmeticNormalizationSuite extends FunSuite { val b = Variable(bId) val as = Set(aId, bId) - val allMaps: Seq[Map[Identifier, Expr]] = (-20 to 20).flatMap(xVal => (-20 to 20).map(yVal => Map(xId-> i(xVal), yId -> i(yVal)))) def checkSameExpr(e1: Expr, e2: Expr, vs: Set[Identifier]) { assert( //this outer assert should not be needed because of the nested one @@ -80,50 +79,5 @@ class ArithmeticNormalizationSuite extends FunSuite { checkSameExpr(coefToSum(apply(e3, xsOrder), Array(x, y)), e3, xs) } - - - test("simplify") { - val e1 = Plus(IntLiteral(3), IntLiteral(2)) - checkSameExpr(e1, simplify(e1), Set()) - val e2 = Plus(x, Plus(IntLiteral(3), IntLiteral(2))) - checkSameExpr(e2, simplify(e2), Set(xId)) - - val e3 = Minus(IntLiteral(3), IntLiteral(2)) - checkSameExpr(e3, simplify(e3), Set()) - val e4 = Plus(x, Minus(IntLiteral(3), IntLiteral(2))) - checkSameExpr(e4, simplify(e4), Set(xId)) - val e5 = Plus(x, Minus(x, IntLiteral(2))) - checkSameExpr(e5, simplify(e5), Set(xId)) - } - - - test("extractEquals") { - val eq = Equals(a, b) - val lt1 = LessThan(a, b) - val lt2 = LessThan(b, a) - val lt3 = LessThan(x, y) - - val f1 = And(Seq(eq, lt1, lt2, lt3)) - val (eq1, r1) = extractEquals(f1) - assert(eq1 != None) - assert(eq1.get === eq) - assert(extractEquals(r1)._1 === None) - - val f2 = And(Seq(lt1, lt2, eq, lt3)) - val (eq2, r2) = extractEquals(f2) - assert(eq2 != None) - assert(eq2.get === eq) - assert(extractEquals(r2)._1 === None) - - val f3 = And(Seq(lt1, eq, lt2, lt3, eq)) - val (eq3, r3) = extractEquals(f3) - assert(eq3 != None) - assert(eq3.get === eq) - val (eq4, r4) = extractEquals(r3) - assert(eq4 != None) - assert(eq4.get === eq) - assert(extractEquals(r4)._1 === None) - - } } diff --git a/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala b/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala index d356e0448e7e5cba3d1414937a8a8059bde3a559..04a01cb7e6fe86036027c0079dc4c478ec310570 100644 --- a/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala +++ b/src/test/scala/leon/test/synthesis/LinearEquationsSuite.scala @@ -6,10 +6,9 @@ import leon.Evaluator import leon.purescala.Trees._ import leon.purescala.TreeOps._ import leon.purescala.Common._ +import leon.purescala.LikelyEq import leon.synthesis.LinearEquations._ -import leon.synthesis.LikelyEq -import leon.synthesis.ArithmeticNormalization.simplify class LinearEquationsSuite extends FunSuite {