diff --git a/src/main/scala/leon/purescala/SimplifierWithPaths.scala b/src/main/scala/leon/purescala/SimplifierWithPaths.scala index 406e4c46b546c7406c7a09531e65928297cdab63..481cee54b5717a4ec836486300bbd6817c88614f 100644 --- a/src/main/scala/leon/purescala/SimplifierWithPaths.scala +++ b/src/main/scala/leon/purescala/SimplifierWithPaths.scala @@ -36,6 +36,24 @@ class SimplifierWithPaths(sf: SolverFactory[Solver]) extends TransformerWithPC { case _ : Exception => false } + def valid(e : Expr) : Boolean = try { + solver.solveVALID(e) match { + case Some(true) => true + case _ => false + } + } catch { + case _ : Exception => false + } + + def sat(e : Expr) : Boolean = try { + solver.solveSAT(e) match { + case (Some(false),_) => false + case _ => true + } + } catch { + case _ : Exception => true + } + protected override def rec(e: Expr, path: C) = e match { case IfExpr(cond, thenn, elze) => super.rec(e, path) match { @@ -60,40 +78,45 @@ class SimplifierWithPaths(sf: SolverFactory[Solver]) extends TransformerWithPC { BooleanLiteral(false).copiedFrom(e) } - case MatchExpr(scrut, cases) => + case me@MatchExpr(scrut, cases) => val rs = rec(scrut, path) var stillPossible = true + var pcSoFar = path - if (cases.exists(_.hasGuard)) { - // unsupported for now - e - } else { - val newCases = cases.flatMap { c => - val patternExpr = conditionForPattern(rs, c.pattern, includeBinders = true) - - if (stillPossible && !contradictedBy(patternExpr, path)) { - - if (impliedBy(patternExpr, path)) { - stillPossible = false - } - - c match { - case SimpleCase(p, rhs) => - Some(SimpleCase(p, rec(rhs, patternExpr +: path)).copiedFrom(c)) - case GuardedCase(_, _, _) => - sys.error("woot.") - } - } else { - None + val conds = matchCasePathConditions(me, path) + + val newCases = cases.zip(conds).flatMap { case (cs, cond) => + if (stillPossible && sat(And(cond))) { + + if (valid(And(cond))) { + stillPossible = false } - } - if (newCases.nonEmpty) { - MatchExpr(rs, newCases).copiedFrom(e) + + Some((cs match { + case SimpleCase(p, rhs) => + SimpleCase(p, rec(rhs, cond)) + case GuardedCase(p, g, rhs) => + // FIXME: This is quite a dirty hack. We just know matchCasePathConditions + // returns the current guard as the last element. + // We don't include it in the path condition when we recurse into itself. + val condWithoutGuard = try { cond.init } catch { case _ : UnsupportedOperationException => List() } + val newGuard = rec(g, condWithoutGuard) + if (valid(newGuard)) + SimpleCase(p, rec(rhs,cond)) + else + GuardedCase(p, newGuard, rec(rhs, cond)) + }).copiedFrom(cs)) } else { - Error("Unreachable code").copiedFrom(e) + None } } + newCases match { + case List() => Error("Unreachable code").copiedFrom(e) + case List(theCase) => + replaceFromIDs(mapForPattern(scrut, theCase.pattern), theCase.rhs) + case _ => MatchExpr(rs, newCases).copiedFrom(e) + } case Or(es) => var soFar = path diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 2e9b0a4204dbf609dc1a80a9f700370a9716bfee..85c68609566568096344eee1c8bcc2bbd75f10d6 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -851,6 +851,25 @@ object TreeOps { postMap(rewritePM)(expr) } + def matchCasePathConditions(m : MatchExpr, pathCond: List[Expr]) : Seq[List[Expr]] = m match { + case MatchExpr(scrut, cases) => + var pcSoFar = pathCond + for (c <- cases) yield { + + val g = c.theGuard getOrElse BooleanLiteral(true) + val cond = conditionForPattern(scrut, c.pattern, includeBinders = true) + val localCond = pcSoFar :+ cond :+ g + + // These contain no binders defined in this MatchCase + val condSafe = conditionForPattern(scrut, c.pattern) + val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern),g) + pcSoFar ::= Not(And(condSafe,gSafe)) + + localCond + } + } + + /** * Rewrites all map accesses with additional error conditions. */ diff --git a/src/test/resources/regression/transformations/SimplifyPaths.scala b/src/test/resources/regression/transformations/SimplifyPaths.scala new file mode 100644 index 0000000000000000000000000000000000000000..245e06f9435a5f7fe630d2fcba4f119d148c3da1 --- /dev/null +++ b/src/test/resources/regression/transformations/SimplifyPaths.scala @@ -0,0 +1,84 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +import leon.annotation._ +import leon.lang._ +import leon.collection._ + +object Transform { + + + def input01(a: Int): Int = { + if(true) 1 else 2 + } + + def output01(a: Int): Int = { + 1 + } + + def input02(a: Boolean, b : Boolean, c : Boolean): Boolean = { + (a && !b) == !(!a || b) match { + case x@true => x && c + case false => !c + } + } + + def output02(a: Boolean, b : Boolean, c : Boolean): Boolean = { + c + } + + def input03(a: List[Int]): Int = { + a match { + case Nil() => 0 + //case n if n.isEmpty => 1 + case Cons(x, y) => 1 + case Cons(x, Cons(y,z)) => 2 + } + } + + def output03(a: List[Int]): Int = { + a match { + case Nil() => 0 + case Cons(x,y) => 1 + } + } + + def input04(a: Int): Int = { + a match { + case 0 => 0 + case x if x >= 42 => x+42 + case y if y > 42 => y-42 + case z if z > 0 => z +100 + case w if w < 0 => w -100 + case other => 1000 + } + } + def output04(a: Int): Int = { + a match { + case 0 => 0 + case x if x >= 42 => x + 42 + case z if z > 0 => z+100 + case w => w-100 + } + } + + def input05(a : Int) : Int = { + a match { + case x@_ => x + } + } + + def output05(a : Int) : Int = { + a + } + + def input06(a : (Int, Int)) : Int = { + a match { + case (x,y) => x + } + } + + def output06(a : (Int,Int)) : Int = { + a._1 + } + +} diff --git a/src/test/scala/leon/test/purescala/TransformationTests.scala b/src/test/scala/leon/test/purescala/TransformationTests.scala index 4575865d74661439fbaf0aac04b7202096d5f89c..141259f6125d25c8a0aa9e3899e1ea5f973b476b 100644 --- a/src/test/scala/leon/test/purescala/TransformationTests.scala +++ b/src/test/scala/leon/test/purescala/TransformationTests.scala @@ -14,18 +14,31 @@ import leon.purescala.Trees._ import leon.purescala.TreeOps._ import leon.purescala.TypeTrees._ +import leon.solvers.z3.UninterpretedZ3Solver +import leon.solvers._ + class TransformationTests extends LeonTestSuite { val pipeline = ExtractionPhase andThen PreprocessingPhase + + val simpPaths = (p: Program, e : Expr) => { + val uninterpretedZ3 = SolverFactory(() => new UninterpretedZ3Solver(testContext, p)) + simplifyPaths(uninterpretedZ3)(e) + } filesInResourceDir("regression/transformations").foreach { file => // Configure which file corresponds to which transformation: - val (title: String, transformer: (Expr => Expr)) = file.getName match { + val (title: String, transformer: ((Program,Expr) => Expr)) = file.getName match { case "SimplifyLets.scala" => ( "Simplifying Lets", - simplifyLets _ + (_:Program, e : Expr) => simplifyLets(e) + ) + case "SimplifyPaths.scala" => + ( + "Simplifying paths", + simpPaths ) case "Match.scala" => ( @@ -43,7 +56,6 @@ class TransformationTests extends LeonTestSuite { val prog = pipeline.run(ctx)(file.getPath :: Nil) - // Proceed with the actual tests val inputs = prog.definedFunctions.collect{ case fd if fd.id.name.startsWith("input") => @@ -60,7 +72,7 @@ class TransformationTests extends LeonTestSuite { val in = fdin.body.get outputs.get(n) match { case Some(fdexp) => - val out = transformer(in) + val out = transformer(prog, in) val exp = fdexp.body.get val map = (fdin.params.map(_.id) zip fdexp.params.map(_.id)).toMap