diff --git a/library/lang/package.scala b/library/lang/package.scala index 128f13df4c06a6bd0e303eba09881bac05ac2bc8..94578101eca94a11c4e9a2eb26ada6e3dfbf2cbf 100644 --- a/library/lang/package.scala +++ b/library/lang/package.scala @@ -32,5 +32,12 @@ package object lang { implicit class Gives[A](v : A) { def gives[B](tests : A => B) : B = tests(v) } + + @ignore + implicit class Passes[A,B](io : (A,B)) { + val (in, out) = io + def passes(tests : A => B ) : Boolean = + try { tests(in) == out } catch { case _ : MatchError => true } + } } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 04237df5eaca5b075078f54ab44f37e5f056e36e..0ac7ae01aad3ef7913a80d324f2b90285d18a6ba 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -699,6 +699,9 @@ trait CodeGeneration { case This(ct) => ch << ALoad(0) // FIXME what if doInstrument etc + case p : Passes => + mkExpr(matchToIfThenElse(p.asConstraint), ch) + case m : MatchExpr => mkExpr(matchToIfThenElse(m), ch) diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 50643a44088908b5d4587a895829dfd1616caa6c..dcc8c493f067cb9a3018abb75784dec1b04519c4 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -397,6 +397,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case g : Gives => e(convertHoles(g, ctx, true)) + + case p : Passes => + e(p.asConstraint) case choose: Choose => import purescala.TreeOps.simplestValue diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index 2dd541997e916419679068e307a83845ace660d1..c95b5f716e52246fe638f41f5c9c8de5195d24d1 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -33,7 +33,11 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex val res = e(b)(rctx.withNewVar(i, first), gctx) (res, first) - case MatchLike(scrut, cases, _) => + case p: Passes => + val r = e(p.asConstraint) + (r, r) + + case MatchExpr(scrut, cases) => val rscrut = e(scrut) val r = cases.toStream.map(c => matchesCase(rscrut, c)).find(_.nonEmpty) match { diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 7a13e372afad7b9975cef15503c8a223b37d529c..a7b5f1e87e37c71d8041c27a70dbda3e197506c2 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -178,6 +178,28 @@ trait ASTExtractors { case _ => None } } + + object ExPasses { + def unapply(tree : Apply) : Option[(Tree, Tree, List[CaseDef])] = tree match { + case Apply( + Select( + Apply( + TypeApply( + ExSelected("leon", "lang", "package", "Passes"), + _ :: _ :: Nil + ), + ExpressionExtractors.ExTuple(_, Seq(in,out)) :: Nil + ), + ExNamed("passes") + ), + (Function( + (_ @ ValDef(_, _, _, EmptyTree)) :: Nil, + ExpressionExtractors.ExPatternMatching(_,tests))) :: Nil + ) + => Some((in, out, tests)) + case _ => None + } + } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 0cbd2095336c3ae817c0e50ea57ffd3f7acda464..0b5bf20a7316bf9597a98fd32eb3d1f4bf653243 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -20,7 +20,7 @@ import purescala.Definitions.{ import purescala.Trees.{Expr => LeonExpr, This => LeonThis, _} import purescala.TypeTrees.{TypeTree => LeonType, _} import purescala.Common._ -import purescala.Extractors.IsTyped +import purescala.Extractors.{IsTyped,UnwrapTuple} import purescala.Constructors._ import purescala.TreeOps._ import purescala.TypeTreeOps._ @@ -1002,7 +1002,19 @@ trait CodeExtraction extends ASTExtractors { rest = None Require(pre, b) - + + case ExPasses(in, out, cases) => + val ine = extractTree(in) + val oute = extractTree(out) + val rc = cases.map(extractMatchCase(_)) + + val UnwrapTuple(ines) = ine + (oute +: ines) foreach { + case Variable(_) => { } + case other => ctx.reporter.fatalError(other.getPos, "Only i/o variables are allowed in i/o examples") + } + passes(ine, oute, rc) + case ExGives(sel, cses) => val rs = extractTree(sel) val rc = cses.map(extractMatchCase(_)) diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index ffcd1b072dd1f9eda62c5945c571379d88b288e0..bb3b8156f2105940dce7a4914dee25f235594cea 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -34,14 +34,14 @@ object Constructors { } } - def tupleWrap(es: Seq[Expr]): Expr = if (es.size > 1) { - Tuple(es) - } else { - es.head + def tupleWrap(es: Seq[Expr]): Expr = es match { + case Seq() => UnitLiteral() + case Seq(elem) => elem + case more => Tuple(more) } - private def filterCases(scrutinee: Expr, cases: Seq[MatchCase]): Seq[MatchCase] = { - scrutinee.getType match { + private def filterCases(scrutType : TypeTree, cases: Seq[MatchCase]): Seq[MatchCase] = { + scrutType match { case c: CaseClassType => cases.filter(_.pattern match { case CaseClassPattern(_, cct, _) if cct.classDef != c.classDef => false @@ -57,10 +57,14 @@ object Constructors { } def gives(scrutinee : Expr, cases : Seq[MatchCase]) : Gives = - Gives(scrutinee, filterCases(scrutinee, cases)) + Gives(scrutinee, filterCases(scrutinee.getType, cases)) + def passes(in : Expr, out : Expr, cases : Seq[MatchCase]) : Passes = { + Passes(in, out, filterCases(in.getType, cases)) + } + def matchExpr(scrutinee : Expr, cases : Seq[MatchCase]) : MatchExpr = - MatchExpr(scrutinee, filterCases(scrutinee, cases)) + MatchExpr(scrutinee, filterCases(scrutinee.getType, cases)) def and(exprs: Expr*): Expr = { val flat = exprs.flatMap(_ match { diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 5146387b9109d9749df6d1f18a1cbc01c9efd607..439d0888c1eeed5a7264818255a84e8e8a6fa73e 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -231,6 +231,11 @@ object Extractors { (m.scrutinee, m.cases, m match { case _ : MatchExpr => matchExpr case _ : Gives => gives + case _ : Passes => + (s, cases) => { + val Tuple(Seq(in, out)) = s + passes(in,out,cases) + } }) } } @@ -250,4 +255,11 @@ object Extractors { } } + object UnwrapTuple { + def unapply(e : Expr) : Option[Seq[Expr]] = Option(e) map { + case Tuple(subs) => subs + case other => Seq(other) + } + } + } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index b203c27a1f2ef3604aca7888000bf958d01d6c05..b8ebf7c95acf921f0663df88364d01eee839219d 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -223,10 +223,16 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe optP { p"""|$s gives { | ${nary(tests, "\n")} - |} - |""" + |}""" } + case p@Passes(in, out, tests) => + optP { + p"""|${p.scrutinee} passes { + | ${nary(tests, "\n")} + |}""" + } + case c @ WithOracle(vars, pred) => p"""|withOracle { (${typed(vars)}) => | $pred @@ -641,7 +647,7 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case (_: Require, _) => true case (_: Assert, Some(_: Definition)) => true case (_, Some(_: Definition)) => false - case (_, Some(_: MatchExpr | _: MatchCase | _: Let | _: LetTuple | _: LetDef)) => false + case (_, Some(_: MatchExpr | _: MatchCase | _: Let | _: LetTuple | _: LetDef )) => false case (_, _) => true } @@ -669,7 +675,7 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case (BinaryMethodCall(_, _, _), Some(_: FunctionInvocation)) => true case (_, Some(_: FunctionInvocation)) => false case (ie: IfExpr, _) => true - case (me: MatchExpr, _ ) => true + case (me: MatchLike, _ ) => true case (e1: Expr, Some(e2: Expr)) if precedence(e1) > precedence(e2) => false case (_, _) => true } diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index 7d3024bde0a1c50c72745b5149dda24d7db6ee76..010c708125ddb9a3360f0c03a15f0c2fe605f256 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -75,7 +75,7 @@ class ScopeSimplifier extends Transformer { val sb = rec(b, newScope) LetTuple(sis, se, sb) - case MatchLike(scrut, cases, builder) => + case MatchExpr(scrut, cases) => val rs = rec(scrut, scope) def trPattern(p: Pattern, scope: Scope): (Pattern, Scope) = { @@ -112,7 +112,7 @@ class ScopeSimplifier extends Transformer { (newPattern, curScope) } - builder(rs, cases.map { c => + MatchExpr(rs, cases.map { c => val (newP, newScope) = trPattern(c.pattern, scope) c match { diff --git a/src/main/scala/leon/purescala/TransformerWithPC.scala b/src/main/scala/leon/purescala/TransformerWithPC.scala index 493b4639c7f71dc0121b105987a7066f605c68e4..858f02a5f5eda3f2c4c34ec39bd2e80336d04520 100644 --- a/src/main/scala/leon/purescala/TransformerWithPC.scala +++ b/src/main/scala/leon/purescala/TransformerWithPC.scala @@ -21,12 +21,15 @@ abstract class TransformerWithPC extends Transformer { val sb = rec(b, register(Equals(Variable(i), se), path)) Let(i, se, sb).copiedFrom(e) - case MatchLike(scrut, cases, builder) => + case p: Passes => + rec(p.asConstraint, path) + + case MatchExpr(scrut, cases) => val rs = rec(scrut, path) var soFar = path - builder(rs, cases.map { c => + MatchExpr(rs, cases.map { c => val patternExprPos = conditionForPattern(rs, c.pattern, includeBinders = true) val patternExprNeg = conditionForPattern(rs, c.pattern, includeBinders = false) diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index bea32ebe21801b65152a767f5c43fd1cd069a906..33cc9699d590c9e9d1cbe5b62b4e4abc8cd6fdec 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -852,22 +852,22 @@ object TreeOps { postMap(rewritePM)(expr) } - def matchCasePathConditions(m : MatchLike, pathCond: List[Expr]) : Seq[List[Expr]] = m match { - case MatchLike(scrut, cases, _) => - var pcSoFar = pathCond - for (c <- cases) yield { - - val g = c.optGuard getOrElse BooleanLiteral(true) - val cond = conditionForPattern(scrut, c.pattern, includeBinders = true) - val localCond = pcSoFar :+ cond :+ g - - // These contain no binders defined in this MatchCase - val condSafe = conditionForPattern(scrut, c.pattern) - val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern),g) - pcSoFar ::= not(and(condSafe, gSafe)) + def matchCasePathConditions(m: MatchExpr, pathCond: List[Expr]) : Seq[List[Expr]] = { + val MatchExpr(scrut, cases) = m + var pcSoFar = pathCond + for (c <- cases) yield { + + val g = c.optGuard getOrElse BooleanLiteral(true) + val cond = conditionForPattern(scrut, c.pattern, includeBinders = true) + val localCond = pcSoFar :+ cond :+ g + + // These contain no binders defined in this MatchCase + val condSafe = conditionForPattern(scrut, c.pattern) + val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern),g) + pcSoFar ::= not(and(condSafe, gSafe)) - localCond - } + localCond + } } @@ -1550,19 +1550,8 @@ object TreeOps { fdHomo(fd1, fd2) && isHomo(e1, e2)(map + (fd1.id -> fd2.id)) - case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) => - if (cs1.size == cs2.size) { - isHomo(s1, s2) && casesMatch(cs1,cs2) - } else { - false - } - - case (Gives(s1, cs1), Gives(s2, cs2)) => - if (cs1.size == cs2.size) { - isHomo(s1, s2) && casesMatch(cs1,cs2) - } else { - false - } + case Same(MatchLike(s1, cs1, _), MatchLike(s2, cs2, _)) => + cs1.size == cs2.size && isHomo(s1, s2) && casesMatch(cs1,cs2) case (FunctionInvocation(tfd1, args1), FunctionInvocation(tfd2, args2)) => // TODO: Check type params diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index bf90a8aeb5b25e80e4fc22b9051ab08993e075e8..70b9c2116c68aa0b1717a6cc3cc5f873389982f3 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -46,7 +46,7 @@ object Trees { } case class Choose(vars: List[Identifier], pred: Expr) extends Expr with UnaryExtractable { - assert(!vars.isEmpty) + require(!vars.isEmpty) def getType = if (vars.size > 1) TupleType(vars.map(_.getType)) else vars.head.getType @@ -60,7 +60,7 @@ object Trees { } case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr { - assert(value.getType.isInstanceOf[TupleType], + require(value.getType.isInstanceOf[TupleType], "The definition value in LetTuple must be of some tuple type; yet we got [%s]. In expr: \n%s".format(value.getType, this)) def getType = body.getType @@ -169,11 +169,11 @@ object Trees { // Index is 1-based, first element of tuple is 1. case class TupleSelect(tuple: Expr, index: Int) extends Expr { - assert(index >= 1) + require(index >= 1) def getType = tuple.getType match { case TupleType(ts) => - assert(index <= ts.size) + require(index <= ts.size) ts(index - 1) case _ => @@ -188,17 +188,30 @@ object Trees { } case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends MatchLike { - assert(cases.nonEmpty) + require(cases.nonEmpty) } case class Gives(scrutinee: Expr, cases : Seq[MatchCase]) extends MatchLike { - assert(cases.nonEmpty) + require(cases.nonEmpty) def asIncompleteMatch = { val theHole = SimpleCase(WildcardPattern(None), Hole(this.getType, Seq())) MatchExpr(scrutinee, cases :+ theHole) } + } + + case class Passes(in: Expr, out : Expr, cases : Seq[MatchCase]) extends MatchLike { + require(cases.nonEmpty) + + override def getType = BooleanType + val scrutinee = Tuple(Seq(in, out)) + + def asConstraint = { + val defaultCase = SimpleCase(WildcardPattern(None), out) + Equals(out, MatchExpr(in, cases :+ defaultCase)) + } } + sealed abstract class MatchCase extends Tree { val pattern: Pattern val rhs: Expr @@ -246,7 +259,7 @@ object Trees { case class And(exprs: Seq[Expr]) extends Expr { def getType = BooleanType - assert(exprs.size >= 2) + require(exprs.size >= 2) } object And { @@ -256,7 +269,7 @@ object Trees { case class Or(exprs: Seq[Expr]) extends Expr { def getType = BooleanType - assert(exprs.size >= 2) + require(exprs.size >= 2) } object Or { @@ -456,7 +469,7 @@ object Trees { // Provide an oracle (synthesizable, all-seeing choose) case class WithOracle(oracles: List[Identifier], body: Expr) extends Expr with UnaryExtractable { - assert(!oracles.isEmpty) + require(!oracles.isEmpty) def getType = body.getType @@ -501,7 +514,7 @@ object Trees { val getType = MultisetType(baseType) } case class FiniteMultiset(elements: Seq[Expr]) extends Expr { - assert(elements.size > 0) + require(elements.nonEmpty) def getType = MultisetType(leastUpperBound(elements.map(_.getType)).getOrElse(Untyped)) } case class Multiplicity(element: Expr, multiset: Expr) extends Expr { diff --git a/src/main/scala/leon/refactor/Repairman.scala b/src/main/scala/leon/refactor/Repairman.scala index 1e8a6299a784f51df4baca1412bcf8d6bf51a7db..c9bd7e9fcc131f4d13bf76b2c0d89c1cbacf3363 100644 --- a/src/main/scala/leon/refactor/Repairman.scala +++ b/src/main/scala/leon/refactor/Repairman.scala @@ -41,17 +41,15 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { // Compute tests val out = fd.postcondition.map(_._1).getOrElse(FreshIdentifier("res", true).setType(fd.returnType)) - val tfd = program.library.passes.get.typed(Seq(argsWrapped.getType, out.getType)) - val inouts = testBank; - val testsExpr = FiniteMap(inouts.collect { + val testsCases = inouts.collect { case InOutExample(ins, outs) => - tupleWrap(ins) -> tupleWrap(outs) - }.toList).setType(MapType(argsWrapped.getType, out.getType)) + GuardedCase(WildcardPattern(None), Equals(argsWrapped, tupleWrap(ins)), tupleWrap(outs)) + }.toList - val passes = if (testsExpr.singletons.nonEmpty) { - FunctionInvocation(tfd, Seq(argsWrapped, out.toVariable, testsExpr)) + val passes = if (testsCases.nonEmpty) { + Passes(argsWrapped, out.toVariable, testsCases) } else { BooleanLiteral(true) } diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 87a1e8e0629524a6229c8b63c420a0907d8e5963..ba1945ae43cae18e8f9f331c39bf95270679b017 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -495,6 +495,9 @@ trait AbstractZ3Solver } def rec(ex: Expr): Z3AST = ex match { + case p @ Passes(_, _, _) => + rec(p.asConstraint) + case me @ MatchExpr(s, cs) => rec(matchToIfThenElse(me))