diff --git a/library/lang/package.scala b/library/lang/package.scala index 1b26c5006dea499e6d94632be47a20c14c723160..12497cfc1d1f98d4f008426dec00fa91fc5c51a7 100644 --- a/library/lang/package.scala +++ b/library/lang/package.scala @@ -28,12 +28,9 @@ package object lang { @ignore def error[T](reason: java.lang.String): T = sys.error(reason) - @library - def passes[A, B](in: A, out: B)(tests: Map[A,B]): Boolean = { - if (tests contains in) { - tests(in) == out - } else { - true - } + @ignore + implicit class Passes[A](v : A) { + def passes[B](tests : A => B) : B = tests(v) } + } diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 1d89142d3de669bfeeea5bd0a2a3499b4b138464..7edde4b5c312feee59fb74f09a3e579d859a5cc7 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -157,6 +157,29 @@ trait ASTExtractors { } } + object ExPasses { + def unapply(tree : Apply) : Option[(Tree, List[CaseDef])] = tree match { + case Apply( + TypeApply( + Select( + Apply( + TypeApply( + ExSelected("leon", "lang", "package", "Passes"), + _ :: Nil + ), + body :: Nil + ), + ExNamed("passes") + ), + _ :: Nil + ), + (Function((_ @ ValDef(_, _, _, EmptyTree)) :: Nil, ExpressionExtractors.ExPatternMatching(_,tests))) :: Nil) + => Some((body, tests)) + case _ => None + } + } + + object ExStringLiteral { def unapply(tree: Tree): Option[String] = tree match { diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 74ce0498cd52674dd9e0840a06b19ffe5598daee..cbaa9f400f7296721c536fdaf4b34ae73b467f36 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -183,7 +183,7 @@ trait CodeExtraction extends ASTExtractors { } def isIgnored(s: Symbol) = { - (annotationsOf(s) contains "ignore") || (s.fullName.toString.endsWith(".main")) + (annotationsOf(s) contains "ignore") || (s.isImplicit) || (s.fullName.toString.endsWith(".main")) } def isExtern(s: Symbol) = { @@ -1001,6 +1001,14 @@ trait CodeExtraction extends ASTExtractors { rest = None Require(pre, b) + + + case passes @ ExPasses(sel, cses) => + val rs = extractTree(sel) + val rc = cses.map(extractMatchCase(_)) + val rt: LeonType = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_).get) + Passes(rs, rc).setType(rt) + case ExArrayLiteral(tpe, args) => FiniteArray(args.map(extractTree)).setType(ArrayType(extractType(tpe)(dctx, current.pos))) diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 419170251a5b2cc2a20583e5e6795b806ac7cc36..501f9955174a8e6841de2f9331248ea2059831e4 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -139,6 +139,19 @@ object Extractors { MatchExpr(es(0), newcases) })) + case Passes(scrut, tests) => + Some((scrut +: tests.flatMap{ case SimpleCase(_, e) => Seq(e) + case GuardedCase(_, e1, e2) => Seq(e1, e2) } + , { es: Seq[Expr] => + var i = 1; + val newtests = for (test <- tests) yield test match { + case SimpleCase(b, _) => i+=1; SimpleCase(b, es(i-1)) + case GuardedCase(b, _, _) => i+=2; GuardedCase(b, es(i-2), es(i-1)) + } + + Passes(es(0), newtests) + })) + case LetDef(fd, body) => fd.body match { case Some(b) => diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 4fbf90ba6160273ffa15faddd4b8575c41f9933a..f09e63815207be5ab5a6a7df2e90952c7e8f097d 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -219,6 +219,14 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe | (${typed(id)}) => $post |}""" + case Passes(s, tests) => + optP { + p"""|$s passes { + | ${nary(tests, "\n")} + |} + |""" + } + case c @ WithOracle(vars, pred) => p"""|withOracle { (${typed(vars)}) => | $pred @@ -389,6 +397,7 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe |}""" } + case MatchExpr(s, csc) => optP { p"""|$s match { diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index c88b7c38f7067db3e0ec7cfddce41bc71c948e95..ec3879c1fc912be52eaa5b40b57bd82cf9b61d60 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -42,6 +42,12 @@ object Trees { val fixedType = body.getType } + case class Passes(scrut: Expr, tests : List[MatchCase]) extends Expr with FixedType { + val fixedType = leastUpperBound(tests.map(_.rhs.getType)).getOrElse{ + Untyped + } + } + case class Choose(vars: List[Identifier], pred: Expr) extends Expr with FixedType with UnaryExtractable { assert(!vars.isEmpty) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index d4df8948d96e98d44866e90ef8b2af2a6554fb42..623b124457497dd53f1f2e5ab99c2d035e1ce408 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -496,6 +496,9 @@ trait AbstractZ3Solver def rec(ex: Expr): Z3AST = ex match { case me @ MatchExpr(s, cs) => rec(matchToIfThenElse(me)) + + case Passes(scrut, tests) => + rec(matchToIfThenElse(MatchExpr(scrut, tests))) case tu @ Tuple(args) => typeToSort(tu.getType) // Make sure we generate sort & meta info