From 86345e4808fa93b1e7d90b836bdea4e7ea372b2e Mon Sep 17 00:00:00 2001 From: "Emmanouil (Manos) Koukoutos" <emmanouil.koukoutos@epfl.ch> Date: Fri, 28 Nov 2014 16:23:00 +0100 Subject: [PATCH] Passes: examples in specs --- library/lang/package.scala | 7 +++ .../scala/leon/codegen/CodeGeneration.scala | 3 ++ .../leon/evaluators/RecursiveEvaluator.scala | 3 ++ .../leon/evaluators/TracingEvaluator.scala | 6 ++- .../leon/frontends/scalac/ASTExtractors.scala | 22 +++++++++ .../frontends/scalac/CodeExtraction.scala | 16 ++++++- .../scala/leon/purescala/Constructors.scala | 20 +++++---- .../scala/leon/purescala/Extractors.scala | 12 +++++ .../scala/leon/purescala/PrettyPrinter.scala | 14 ++++-- .../leon/purescala/ScopeSimplifier.scala | 4 +- .../leon/purescala/TransformerWithPC.scala | 7 ++- src/main/scala/leon/purescala/TreeOps.scala | 45 +++++++------------ src/main/scala/leon/purescala/Trees.scala | 33 +++++++++----- src/main/scala/leon/refactor/Repairman.scala | 12 +++-- .../leon/solvers/z3/AbstractZ3Solver.scala | 3 ++ 15 files changed, 143 insertions(+), 64 deletions(-) diff --git a/library/lang/package.scala b/library/lang/package.scala index 128f13df4..94578101e 100644 --- a/library/lang/package.scala +++ b/library/lang/package.scala @@ -32,5 +32,12 @@ package object lang { implicit class Gives[A](v : A) { def gives[B](tests : A => B) : B = tests(v) } + + @ignore + implicit class Passes[A,B](io : (A,B)) { + val (in, out) = io + def passes(tests : A => B ) : Boolean = + try { tests(in) == out } catch { case _ : MatchError => true } + } } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 04237df5e..0ac7ae01a 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -699,6 +699,9 @@ trait CodeGeneration { case This(ct) => ch << ALoad(0) // FIXME what if doInstrument etc + case p : Passes => + mkExpr(matchToIfThenElse(p.asConstraint), ch) + case m : MatchExpr => mkExpr(matchToIfThenElse(m), ch) diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 50643a440..dcc8c493f 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -397,6 +397,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case g : Gives => e(convertHoles(g, ctx, true)) + + case p : Passes => + e(p.asConstraint) case choose: Choose => import purescala.TreeOps.simplestValue diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index 2dd541997..c95b5f716 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -33,7 +33,11 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex val res = e(b)(rctx.withNewVar(i, first), gctx) (res, first) - case MatchLike(scrut, cases, _) => + case p: Passes => + val r = e(p.asConstraint) + (r, r) + + case MatchExpr(scrut, cases) => val rscrut = e(scrut) val r = cases.toStream.map(c => matchesCase(rscrut, c)).find(_.nonEmpty) match { diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 7a13e372a..a7b5f1e87 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -178,6 +178,28 @@ trait ASTExtractors { case _ => None } } + + object ExPasses { + def unapply(tree : Apply) : Option[(Tree, Tree, List[CaseDef])] = tree match { + case Apply( + Select( + Apply( + TypeApply( + ExSelected("leon", "lang", "package", "Passes"), + _ :: _ :: Nil + ), + ExpressionExtractors.ExTuple(_, Seq(in,out)) :: Nil + ), + ExNamed("passes") + ), + (Function( + (_ @ ValDef(_, _, _, EmptyTree)) :: Nil, + ExpressionExtractors.ExPatternMatching(_,tests))) :: Nil + ) + => Some((in, out, tests)) + case _ => None + } + } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 0cbd20953..0b5bf20a7 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -20,7 +20,7 @@ import purescala.Definitions.{ import purescala.Trees.{Expr => LeonExpr, This => LeonThis, _} import purescala.TypeTrees.{TypeTree => LeonType, _} import purescala.Common._ -import purescala.Extractors.IsTyped +import purescala.Extractors.{IsTyped,UnwrapTuple} import purescala.Constructors._ import purescala.TreeOps._ import purescala.TypeTreeOps._ @@ -1002,7 +1002,19 @@ trait CodeExtraction extends ASTExtractors { rest = None Require(pre, b) - + + case ExPasses(in, out, cases) => + val ine = extractTree(in) + val oute = extractTree(out) + val rc = cases.map(extractMatchCase(_)) + + val UnwrapTuple(ines) = ine + (oute +: ines) foreach { + case Variable(_) => { } + case other => ctx.reporter.fatalError(other.getPos, "Only i/o variables are allowed in i/o examples") + } + passes(ine, oute, rc) + case ExGives(sel, cses) => val rs = extractTree(sel) val rc = cses.map(extractMatchCase(_)) diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index ffcd1b072..bb3b8156f 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -34,14 +34,14 @@ object Constructors { } } - def tupleWrap(es: Seq[Expr]): Expr = if (es.size > 1) { - Tuple(es) - } else { - es.head + def tupleWrap(es: Seq[Expr]): Expr = es match { + case Seq() => UnitLiteral() + case Seq(elem) => elem + case more => Tuple(more) } - private def filterCases(scrutinee: Expr, cases: Seq[MatchCase]): Seq[MatchCase] = { - scrutinee.getType match { + private def filterCases(scrutType : TypeTree, cases: Seq[MatchCase]): Seq[MatchCase] = { + scrutType match { case c: CaseClassType => cases.filter(_.pattern match { case CaseClassPattern(_, cct, _) if cct.classDef != c.classDef => false @@ -57,10 +57,14 @@ object Constructors { } def gives(scrutinee : Expr, cases : Seq[MatchCase]) : Gives = - Gives(scrutinee, filterCases(scrutinee, cases)) + Gives(scrutinee, filterCases(scrutinee.getType, cases)) + def passes(in : Expr, out : Expr, cases : Seq[MatchCase]) : Passes = { + Passes(in, out, filterCases(in.getType, cases)) + } + def matchExpr(scrutinee : Expr, cases : Seq[MatchCase]) : MatchExpr = - MatchExpr(scrutinee, filterCases(scrutinee, cases)) + MatchExpr(scrutinee, filterCases(scrutinee.getType, cases)) def and(exprs: Expr*): Expr = { val flat = exprs.flatMap(_ match { diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 5146387b9..439d0888c 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -231,6 +231,11 @@ object Extractors { (m.scrutinee, m.cases, m match { case _ : MatchExpr => matchExpr case _ : Gives => gives + case _ : Passes => + (s, cases) => { + val Tuple(Seq(in, out)) = s + passes(in,out,cases) + } }) } } @@ -250,4 +255,11 @@ object Extractors { } } + object UnwrapTuple { + def unapply(e : Expr) : Option[Seq[Expr]] = Option(e) map { + case Tuple(subs) => subs + case other => Seq(other) + } + } + } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index b203c27a1..b8ebf7c95 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -223,10 +223,16 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe optP { p"""|$s gives { | ${nary(tests, "\n")} - |} - |""" + |}""" } + case p@Passes(in, out, tests) => + optP { + p"""|${p.scrutinee} passes { + | ${nary(tests, "\n")} + |}""" + } + case c @ WithOracle(vars, pred) => p"""|withOracle { (${typed(vars)}) => | $pred @@ -641,7 +647,7 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case (_: Require, _) => true case (_: Assert, Some(_: Definition)) => true case (_, Some(_: Definition)) => false - case (_, Some(_: MatchExpr | _: MatchCase | _: Let | _: LetTuple | _: LetDef)) => false + case (_, Some(_: MatchExpr | _: MatchCase | _: Let | _: LetTuple | _: LetDef )) => false case (_, _) => true } @@ -669,7 +675,7 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case (BinaryMethodCall(_, _, _), Some(_: FunctionInvocation)) => true case (_, Some(_: FunctionInvocation)) => false case (ie: IfExpr, _) => true - case (me: MatchExpr, _ ) => true + case (me: MatchLike, _ ) => true case (e1: Expr, Some(e2: Expr)) if precedence(e1) > precedence(e2) => false case (_, _) => true } diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index 7d3024bde..010c70812 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -75,7 +75,7 @@ class ScopeSimplifier extends Transformer { val sb = rec(b, newScope) LetTuple(sis, se, sb) - case MatchLike(scrut, cases, builder) => + case MatchExpr(scrut, cases) => val rs = rec(scrut, scope) def trPattern(p: Pattern, scope: Scope): (Pattern, Scope) = { @@ -112,7 +112,7 @@ class ScopeSimplifier extends Transformer { (newPattern, curScope) } - builder(rs, cases.map { c => + MatchExpr(rs, cases.map { c => val (newP, newScope) = trPattern(c.pattern, scope) c match { diff --git a/src/main/scala/leon/purescala/TransformerWithPC.scala b/src/main/scala/leon/purescala/TransformerWithPC.scala index 493b4639c..858f02a5f 100644 --- a/src/main/scala/leon/purescala/TransformerWithPC.scala +++ b/src/main/scala/leon/purescala/TransformerWithPC.scala @@ -21,12 +21,15 @@ abstract class TransformerWithPC extends Transformer { val sb = rec(b, register(Equals(Variable(i), se), path)) Let(i, se, sb).copiedFrom(e) - case MatchLike(scrut, cases, builder) => + case p: Passes => + rec(p.asConstraint, path) + + case MatchExpr(scrut, cases) => val rs = rec(scrut, path) var soFar = path - builder(rs, cases.map { c => + MatchExpr(rs, cases.map { c => val patternExprPos = conditionForPattern(rs, c.pattern, includeBinders = true) val patternExprNeg = conditionForPattern(rs, c.pattern, includeBinders = false) diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index bea32ebe2..33cc9699d 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -852,22 +852,22 @@ object TreeOps { postMap(rewritePM)(expr) } - def matchCasePathConditions(m : MatchLike, pathCond: List[Expr]) : Seq[List[Expr]] = m match { - case MatchLike(scrut, cases, _) => - var pcSoFar = pathCond - for (c <- cases) yield { - - val g = c.optGuard getOrElse BooleanLiteral(true) - val cond = conditionForPattern(scrut, c.pattern, includeBinders = true) - val localCond = pcSoFar :+ cond :+ g - - // These contain no binders defined in this MatchCase - val condSafe = conditionForPattern(scrut, c.pattern) - val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern),g) - pcSoFar ::= not(and(condSafe, gSafe)) + def matchCasePathConditions(m: MatchExpr, pathCond: List[Expr]) : Seq[List[Expr]] = { + val MatchExpr(scrut, cases) = m + var pcSoFar = pathCond + for (c <- cases) yield { + + val g = c.optGuard getOrElse BooleanLiteral(true) + val cond = conditionForPattern(scrut, c.pattern, includeBinders = true) + val localCond = pcSoFar :+ cond :+ g + + // These contain no binders defined in this MatchCase + val condSafe = conditionForPattern(scrut, c.pattern) + val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern),g) + pcSoFar ::= not(and(condSafe, gSafe)) - localCond - } + localCond + } } @@ -1550,19 +1550,8 @@ object TreeOps { fdHomo(fd1, fd2) && isHomo(e1, e2)(map + (fd1.id -> fd2.id)) - case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) => - if (cs1.size == cs2.size) { - isHomo(s1, s2) && casesMatch(cs1,cs2) - } else { - false - } - - case (Gives(s1, cs1), Gives(s2, cs2)) => - if (cs1.size == cs2.size) { - isHomo(s1, s2) && casesMatch(cs1,cs2) - } else { - false - } + case Same(MatchLike(s1, cs1, _), MatchLike(s2, cs2, _)) => + cs1.size == cs2.size && isHomo(s1, s2) && casesMatch(cs1,cs2) case (FunctionInvocation(tfd1, args1), FunctionInvocation(tfd2, args2)) => // TODO: Check type params diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index bf90a8aeb..70b9c2116 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -46,7 +46,7 @@ object Trees { } case class Choose(vars: List[Identifier], pred: Expr) extends Expr with UnaryExtractable { - assert(!vars.isEmpty) + require(!vars.isEmpty) def getType = if (vars.size > 1) TupleType(vars.map(_.getType)) else vars.head.getType @@ -60,7 +60,7 @@ object Trees { } case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr { - assert(value.getType.isInstanceOf[TupleType], + require(value.getType.isInstanceOf[TupleType], "The definition value in LetTuple must be of some tuple type; yet we got [%s]. In expr: \n%s".format(value.getType, this)) def getType = body.getType @@ -169,11 +169,11 @@ object Trees { // Index is 1-based, first element of tuple is 1. case class TupleSelect(tuple: Expr, index: Int) extends Expr { - assert(index >= 1) + require(index >= 1) def getType = tuple.getType match { case TupleType(ts) => - assert(index <= ts.size) + require(index <= ts.size) ts(index - 1) case _ => @@ -188,17 +188,30 @@ object Trees { } case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends MatchLike { - assert(cases.nonEmpty) + require(cases.nonEmpty) } case class Gives(scrutinee: Expr, cases : Seq[MatchCase]) extends MatchLike { - assert(cases.nonEmpty) + require(cases.nonEmpty) def asIncompleteMatch = { val theHole = SimpleCase(WildcardPattern(None), Hole(this.getType, Seq())) MatchExpr(scrutinee, cases :+ theHole) } + } + + case class Passes(in: Expr, out : Expr, cases : Seq[MatchCase]) extends MatchLike { + require(cases.nonEmpty) + + override def getType = BooleanType + val scrutinee = Tuple(Seq(in, out)) + + def asConstraint = { + val defaultCase = SimpleCase(WildcardPattern(None), out) + Equals(out, MatchExpr(in, cases :+ defaultCase)) + } } + sealed abstract class MatchCase extends Tree { val pattern: Pattern val rhs: Expr @@ -246,7 +259,7 @@ object Trees { case class And(exprs: Seq[Expr]) extends Expr { def getType = BooleanType - assert(exprs.size >= 2) + require(exprs.size >= 2) } object And { @@ -256,7 +269,7 @@ object Trees { case class Or(exprs: Seq[Expr]) extends Expr { def getType = BooleanType - assert(exprs.size >= 2) + require(exprs.size >= 2) } object Or { @@ -456,7 +469,7 @@ object Trees { // Provide an oracle (synthesizable, all-seeing choose) case class WithOracle(oracles: List[Identifier], body: Expr) extends Expr with UnaryExtractable { - assert(!oracles.isEmpty) + require(!oracles.isEmpty) def getType = body.getType @@ -501,7 +514,7 @@ object Trees { val getType = MultisetType(baseType) } case class FiniteMultiset(elements: Seq[Expr]) extends Expr { - assert(elements.size > 0) + require(elements.nonEmpty) def getType = MultisetType(leastUpperBound(elements.map(_.getType)).getOrElse(Untyped)) } case class Multiplicity(element: Expr, multiset: Expr) extends Expr { diff --git a/src/main/scala/leon/refactor/Repairman.scala b/src/main/scala/leon/refactor/Repairman.scala index 1e8a6299a..c9bd7e9fc 100644 --- a/src/main/scala/leon/refactor/Repairman.scala +++ b/src/main/scala/leon/refactor/Repairman.scala @@ -41,17 +41,15 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { // Compute tests val out = fd.postcondition.map(_._1).getOrElse(FreshIdentifier("res", true).setType(fd.returnType)) - val tfd = program.library.passes.get.typed(Seq(argsWrapped.getType, out.getType)) - val inouts = testBank; - val testsExpr = FiniteMap(inouts.collect { + val testsCases = inouts.collect { case InOutExample(ins, outs) => - tupleWrap(ins) -> tupleWrap(outs) - }.toList).setType(MapType(argsWrapped.getType, out.getType)) + GuardedCase(WildcardPattern(None), Equals(argsWrapped, tupleWrap(ins)), tupleWrap(outs)) + }.toList - val passes = if (testsExpr.singletons.nonEmpty) { - FunctionInvocation(tfd, Seq(argsWrapped, out.toVariable, testsExpr)) + val passes = if (testsCases.nonEmpty) { + Passes(argsWrapped, out.toVariable, testsCases) } else { BooleanLiteral(true) } diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 87a1e8e06..ba1945ae4 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -495,6 +495,9 @@ trait AbstractZ3Solver } def rec(ex: Expr): Z3AST = ex match { + case p @ Passes(_, _, _) => + rec(p.asConstraint) + case me @ MatchExpr(s, cs) => rec(matchToIfThenElse(me)) -- GitLab