diff --git a/library/lang/package.scala b/library/lang/package.scala index 12497cfc1d1f98d4f008426dec00fa91fc5c51a7..128f13df4c06a6bd0e303eba09881bac05ac2bc8 100644 --- a/library/lang/package.scala +++ b/library/lang/package.scala @@ -29,8 +29,8 @@ package object lang { def error[T](reason: java.lang.String): T = sys.error(reason) @ignore - implicit class Passes[A](v : A) { - def passes[B](tests : A => B) : B = tests(v) + implicit class Gives[A](v : A) { + def gives[B](tests : A => B) : B = tests(v) } } diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 8abed0de05d65641950383323658ab37aaae195e..50643a44088908b5d4587a895829dfd1616caa6c 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -14,6 +14,7 @@ import solvers.TimeoutSolver import xlang.Trees._ import solvers.SolverFactory +import synthesis.ConvertHoles.convertHoles abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int) extends Evaluator(ctx, prog) { val name = "evaluator" @@ -107,8 +108,15 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case Assert(cond, oerr, body) => e(IfExpr(Not(cond), Error(expr.getType, oerr.getOrElse("Assertion failed @"+expr.getPos)), body)) - case Ensuring(body, id, post) => - e(Let(id, body, Assert(post, Some("Ensuring failed"), Variable(id)))) + case en@Ensuring(body, id, post) => + if ( exists{ + case Hole(_,_) => true + case Gives(_,_) => true + case _ => false + }(en)) + e(convertHoles(en, ctx, true)) + else + e(Let(id, body, Assert(post, Some("Ensuring failed"), Variable(id)))) case Error(tpe, desc) => throw RuntimeError("Error reached in evaluation: " + desc) @@ -387,6 +395,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case rh: RepairHole => simplestValue(rh.getType) // It will be wrong, we don't care + case g : Gives => + e(convertHoles(g, ctx, true)) + case choose: Choose => import purescala.TreeOps.simplestValue diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index 4647f3a036fd6ad448b1ba223fe67fe56a8e7862..2dd541997e916419679068e307a83845ace660d1 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -5,6 +5,7 @@ package evaluators import purescala.Common._ import purescala.Trees._ +import purescala.Extractors._ import purescala.Definitions._ import purescala.TreeOps._ import purescala.TypeTrees._ @@ -32,7 +33,7 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex val res = e(b)(rctx.withNewVar(i, first), gctx) (res, first) - case MatchExpr(scrut, cases) => + case MatchLike(scrut, cases, _) => val rscrut = e(scrut) val r = cases.toStream.map(c => matchesCase(rscrut, c)).find(_.nonEmpty) match { diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 7edde4b5c312feee59fb74f09a3e579d859a5cc7..7a13e372afad7b9975cef15503c8a223b37d529c 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -157,19 +157,19 @@ trait ASTExtractors { } } - object ExPasses { + object ExGives { def unapply(tree : Apply) : Option[(Tree, List[CaseDef])] = tree match { case Apply( TypeApply( Select( Apply( TypeApply( - ExSelected("leon", "lang", "package", "Passes"), + ExSelected("leon", "lang", "package", "Gives"), _ :: Nil ), body :: Nil ), - ExNamed("passes") + ExNamed("gives") ), _ :: Nil ), diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 763986c39c386e51d557014dfefe1f0903e451f2..0cbd2095336c3ae817c0e50ea57ffd3f7acda464 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1003,12 +1003,10 @@ trait CodeExtraction extends ASTExtractors { Require(pre, b) - - case passes @ ExPasses(sel, cses) => + case ExGives(sel, cses) => val rs = extractTree(sel) val rc = cses.map(extractMatchCase(_)) - Passes(rs, rc) - + gives(rs, rc) case ExArrayLiteral(tpe, args) => FiniteArray(args.map(extractTree)).setType(ArrayType(extractType(tpe)(dctx, current.pos))) diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 2b5c1d0ced3ad7ea5a08ebf5a8402310d0ad2caa..ffcd1b072dd1f9eda62c5945c571379d88b288e0 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -40,24 +40,28 @@ object Constructors { es.head } - def matchExpr(scrutinee: Expr, cases: Seq[MatchCase]): MatchExpr = { + private def filterCases(scrutinee: Expr, cases: Seq[MatchCase]): Seq[MatchCase] = { scrutinee.getType match { case c: CaseClassType => - new MatchExpr(scrutinee, - cases.filter(_.pattern match { - case CaseClassPattern(_, cct, _) if cct.classDef != c.classDef => false - case _ => true - }) - ) + cases.filter(_.pattern match { + case CaseClassPattern(_, cct, _) if cct.classDef != c.classDef => false + case _ => true + }) case _: TupleType | Int32Type | BooleanType | UnitType | _: AbstractClassType => - new MatchExpr(scrutinee, cases) + cases case t => scala.sys.error("Constructing match expression on non-supported type: "+t) } } + def gives(scrutinee : Expr, cases : Seq[MatchCase]) : Gives = + Gives(scrutinee, filterCases(scrutinee, cases)) + + def matchExpr(scrutinee : Expr, cases : Seq[MatchCase]) : MatchExpr = + MatchExpr(scrutinee, filterCases(scrutinee, cases)) + def and(exprs: Expr*): Expr = { val flat = exprs.flatMap(_ match { case And(es) => es diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 6cf4e1bbe3cf193e21276981e89bf64046c17b3c..5146387b9109d9749df6d1f18a1cbc01c9efd607 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -126,31 +126,21 @@ object Extractors { })) case Tuple(args) => Some((args, Tuple)) case IfExpr(cond, thenn, elze) => Some((Seq(cond, thenn, elze), (as: Seq[Expr]) => IfExpr(as(0), as(1), as(2)))) - case MatchExpr(scrut, cases) => - Some((scrut +: cases.flatMap{ case SimpleCase(_, e) => Seq(e) - case GuardedCase(_, e1, e2) => Seq(e1, e2) } - , { es: Seq[Expr] => - var i = 1; - val newcases = for (caze <- cases) yield caze match { - case SimpleCase(b, _) => i+=1; SimpleCase(b, es(i-1)) - case GuardedCase(b, _, _) => i+=2; GuardedCase(b, es(i-2), es(i-1)) - } - - matchExpr(es(0), newcases) - })) - case Passes(scrut, tests) => - Some((scrut +: tests.flatMap{ case SimpleCase(_, e) => Seq(e) - case GuardedCase(_, e1, e2) => Seq(e1, e2) } - , { es: Seq[Expr] => - var i = 1; - val newtests = for (test <- tests) yield test match { - case SimpleCase(b, _) => i+=1; SimpleCase(b, es(i-1)) - case GuardedCase(b, _, _) => i+=2; GuardedCase(b, es(i-2), es(i-1)) - } - - Passes(es(0), newtests) - })) + case MatchLike(scrut, cases, builder) => Some(( + scrut +: cases.flatMap { + case SimpleCase(_, e) => Seq(e) + case GuardedCase(_, e1, e2) => Seq(e1, e2) + }, + (es: Seq[Expr]) => { + var i = 1 + val newcases = for (caze <- cases) yield caze match { + case SimpleCase(b, _) => i+=1; SimpleCase(b, es(i-1)) + case GuardedCase(b, _, _) => i+=2; GuardedCase(b, es(i-2), es(i-1)) + } + builder(es(0), newcases) + } + )) case LetDef(fd, body) => fd.body match { case Some(b) => @@ -235,18 +225,29 @@ object Extractors { def unapply[T <: Typed](e: T): Option[(T, TypeTree)] = Some((e, e.getType)) } + object MatchLike { + def unapply(m : MatchLike) : Option[(Expr, Seq[MatchCase], (Expr, Seq[MatchCase]) => MatchLike)] = { + Option(m) map { m => + (m.scrutinee, m.cases, m match { + case _ : MatchExpr => matchExpr + case _ : Gives => gives + }) + } + } + } + object Pattern { def unapply(p : Pattern) : Option[( Option[Identifier], Seq[Pattern], (Option[Identifier], Seq[Pattern]) => Pattern - )] = Some(p match { + )] = Option(p) map { case InstanceOfPattern(b, ct) => (b, Seq(), (b, _) => InstanceOfPattern(b,ct)) case WildcardPattern(b) => (b, Seq(), (b, _) => WildcardPattern(b)) case CaseClassPattern(b, ct, subs) => (b, subs, (b, sp) => CaseClassPattern(b, ct, sp)) case TuplePattern(b,subs) => (b, subs, (b, sp) => TuplePattern(b, sp)) case LiteralPattern(b, l) => (b, Seq(), (b, _) => LiteralPattern(b, l)) - }) + } } } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index fe01a181b1467baa492ee9149521b5f31de6c13e..b203c27a1f2ef3604aca7888000bf958d01d6c05 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -219,9 +219,9 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe | (${typed(id)}) => $post |}""" - case Passes(s, tests) => + case Gives(s, tests) => optP { - p"""|$s passes { + p"""|$s gives { | ${nary(tests, "\n")} |} |""" diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index 9ffe994c347588dda502f0379dd420d0e0c9bd26..7d3024bde0a1c50c72745b5149dda24d7db6ee76 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -75,7 +75,7 @@ class ScopeSimplifier extends Transformer { val sb = rec(b, newScope) LetTuple(sis, se, sb) - case MatchExpr(scrut, cases) => + case MatchLike(scrut, cases, builder) => val rs = rec(scrut, scope) def trPattern(p: Pattern, scope: Scope): (Pattern, Scope) = { @@ -112,7 +112,7 @@ class ScopeSimplifier extends Transformer { (newPattern, curScope) } - matchExpr(rs, cases.map { c => + builder(rs, cases.map { c => val (newP, newScope) = trPattern(c.pattern, scope) c match { diff --git a/src/main/scala/leon/purescala/TransformerWithPC.scala b/src/main/scala/leon/purescala/TransformerWithPC.scala index 372274569ce349491b5865d1316f938d7cc0f3a8..493b4639c7f71dc0121b105987a7066f605c68e4 100644 --- a/src/main/scala/leon/purescala/TransformerWithPC.scala +++ b/src/main/scala/leon/purescala/TransformerWithPC.scala @@ -21,12 +21,12 @@ abstract class TransformerWithPC extends Transformer { val sb = rec(b, register(Equals(Variable(i), se), path)) Let(i, se, sb).copiedFrom(e) - case MatchExpr(scrut, cases) => + case MatchLike(scrut, cases, builder) => val rs = rec(scrut, path) var soFar = path - matchExpr(rs, cases.map { c => + builder(rs, cases.map { c => val patternExprPos = conditionForPattern(rs, c.pattern, includeBinders = true) val patternExprNeg = conditionForPattern(rs, c.pattern, includeBinders = false) diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index a6dd02d2d0e1f1bcdadee8de98e87a1ada5a4ebe..bea32ebe21801b65152a767f5c43fd1cd069a906 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -354,7 +354,7 @@ object TreeOps { case LetDef(fd,_) => subvs -- fd.params.map(_.id) -- fd.postcondition.map(_._1) case Let(i,_,_) => subvs - i case Choose(is,_) => subvs -- is - case MatchExpr(_, cses) => subvs -- (cses.map(_.pattern.binders).foldLeft(Set[Identifier]())((a, b) => a ++ b)) + case MatchLike(_, cses, _) => subvs -- (cses.map(_.pattern.binders).foldLeft(Set[Identifier]())((a, b) => a ++ b)) case Lambda(args, body) => subvs -- args.map(_.id) case Forall(args, body) => subvs -- args.map(_.id) case _ => subvs @@ -427,8 +427,8 @@ object TreeOps { postMap({ - case m @ MatchExpr(s, cses) => - Some(matchExpr(s, cses.map(freshenCase(_))).copiedFrom(m)) + case m @ MatchLike(s, cses, builder) => + Some(builder(s, cses.map(freshenCase(_))).copiedFrom(m)) case l @ Let(i,e,b) => val newID = FreshIdentifier(i.name, true).copiedFrom(i) @@ -608,7 +608,7 @@ object TreeOps { case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s) case l @ Let(i,e,b) => rec(b, s + (i -> rec(e, s))) case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)) - case m @ MatchExpr(scrut,cses) => matchExpr(rec(scrut, s), cses.map(inCase(_, s))).setPos(m) + case m @ MatchLike(scrut,cses,builder) => builder(rec(scrut, s), cses.map(inCase(_, s))).setPos(m) case n @ NAryOperator(args, recons) => { var change = false val rargs = args.map(a => { @@ -852,8 +852,8 @@ object TreeOps { postMap(rewritePM)(expr) } - def matchCasePathConditions(m : MatchExpr, pathCond: List[Expr]) : Seq[List[Expr]] = m match { - case MatchExpr(scrut, cases) => + def matchCasePathConditions(m : MatchLike, pathCond: List[Expr]) : Seq[List[Expr]] = m match { + case MatchLike(scrut, cases, _) => var pcSoFar = pathCond for (c <- cases) yield { @@ -1288,6 +1288,7 @@ object TreeOps { case Choose(_, _) => return false case Hole(_, _) => return false case RepairHole(_, _) => return false + case Gives(_,_) => return false case _ => }(e) true @@ -1471,6 +1472,61 @@ object TreeOps { } def isHomo(t1: Expr, t2: Expr)(implicit map: Map[Identifier,Identifier]): Boolean = { + + def casesMatch(cs1 : Seq[MatchCase], cs2 : Seq[MatchCase]) : Boolean = { + def patternHomo(p1: Pattern, p2: Pattern): (Boolean, Map[Identifier, Identifier]) = (p1, p2) match { + case (InstanceOfPattern(ob1, cd1), InstanceOfPattern(ob2, cd2)) => + (ob1.size == ob2.size && cd1 == cd2, Map((ob1 zip ob2).toSeq : _*)) + + case (WildcardPattern(ob1), WildcardPattern(ob2)) => + (ob1.size == ob2.size, Map((ob1 zip ob2).toSeq : _*)) + + case (CaseClassPattern(ob1, ccd1, subs1), CaseClassPattern(ob2, ccd2, subs2)) => + val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) + + if (ob1.size == ob2.size && ccd1 == ccd2 && subs1.size == subs2.size) { + (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) { + case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2) + } + } else { + (false, Map()) + } + + case (TuplePattern(ob1, subs1), TuplePattern(ob2, subs2)) => + val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) + + if (ob1.size == ob2.size && subs1.size == subs2.size) { + (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) { + case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2) + } + } else { + (false, Map()) + } + + case (LiteralPattern(ob1, lit1), LiteralPattern(ob2,lit2)) => + (ob1.size == ob2.size && lit1 == lit2, (ob1 zip ob2).toMap) + + case _ => + (false, Map()) + } + + (cs1 zip cs2).forall { + case (SimpleCase(p1, e1), SimpleCase(p2, e2)) => + val (h, nm) = patternHomo(p1, p2) + + h && isHomo(e1, e2)(map ++ nm) + + case (GuardedCase(p1, g1, e1), GuardedCase(p2, g2, e2)) => + val (h, nm) = patternHomo(p1, p2) + + h && isHomo(g1, g2)(map ++ nm) && isHomo(e1, e2)(map ++ nm) + + case _ => + false + } + + } + val res = (t1, t2) match { case (Variable(i1), Variable(i2)) => idHomo(i1, i2) @@ -1496,60 +1552,14 @@ object TreeOps { case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) => if (cs1.size == cs2.size) { - val scrutMatch = isHomo(s1, s2) - - def patternHomo(p1: Pattern, p2: Pattern): (Boolean, Map[Identifier, Identifier]) = (p1, p2) match { - case (InstanceOfPattern(ob1, cd1), InstanceOfPattern(ob2, cd2)) => - (ob1.size == ob2.size && cd1 == cd2, Map((ob1 zip ob2).toSeq : _*)) - - case (WildcardPattern(ob1), WildcardPattern(ob2)) => - (ob1.size == ob2.size, Map((ob1 zip ob2).toSeq : _*)) - - case (CaseClassPattern(ob1, ccd1, subs1), CaseClassPattern(ob2, ccd2, subs2)) => - val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) - - if (ob1.size == ob2.size && ccd1 == ccd2 && subs1.size == subs2.size) { - (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) { - case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2) - } - } else { - (false, Map()) - } - - case (TuplePattern(ob1, subs1), TuplePattern(ob2, subs2)) => - val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) - - if (ob1.size == ob2.size && subs1.size == subs2.size) { - (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) { - case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2) - } - } else { - (false, Map()) - } - - case (LiteralPattern(ob1, lit1), LiteralPattern(ob2,lit2)) => - (ob1.size == ob2.size && lit1 == lit2, (ob1 zip ob2).toMap) - - case _ => - (false, Map()) - } - - val casesMatch = (cs1 zip cs2).forall { - case (SimpleCase(p1, e1), SimpleCase(p2, e2)) => - val (h, nm) = patternHomo(p1, p2) - - h && isHomo(e1, e2)(map ++ nm) - - case (GuardedCase(p1, g1, e1), GuardedCase(p2, g2, e2)) => - val (h, nm) = patternHomo(p1, p2) - - h && isHomo(g1, g2)(map ++ nm) && isHomo(e1, e2)(map ++ nm) - - case _ => - false - } - - scrutMatch && casesMatch + isHomo(s1, s2) && casesMatch(cs1,cs2) + } else { + false + } + + case (Gives(s1, cs1), Gives(s2, cs2)) => + if (cs1.size == cs2.size) { + isHomo(s1, s2) && casesMatch(cs1,cs2) } else { false } @@ -1849,6 +1859,8 @@ object TreeOps { case _ => None } + def breakDownSpecs(e : Expr) = (preconditionOf(e), withoutSpec(e), postconditionOf(e)) + def preTraversalWithParent(f: (Expr, Option[Tree]) => Unit, initParent: Option[Tree] = None)(e: Expr): Unit = { val rec = preTraversalWithParent(f, Some(e)) _ diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 94a79ca5773610ce8023f59b7f7555a56ddfd426..bf90a8aeb5b25e80e4fc22b9051ab08993e075e8 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -45,10 +45,6 @@ object Trees { def getType = body.getType } - case class Passes(scrut: Expr, tests : List[MatchCase]) extends Expr { - def getType = leastUpperBound(tests.map(_.rhs.getType)).getOrElse(Untyped) - } - case class Choose(vars: List[Identifier], pred: Expr) extends Expr with UnaryExtractable { assert(!vars.isEmpty) @@ -185,12 +181,24 @@ object Trees { } } - case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr { - assert(cases.nonEmpty) - + abstract sealed class MatchLike extends Expr { + val scrutinee : Expr + val cases : Seq[MatchCase] def getType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse(Untyped) } + case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends MatchLike { + assert(cases.nonEmpty) + } + + case class Gives(scrutinee: Expr, cases : Seq[MatchCase]) extends MatchLike { + assert(cases.nonEmpty) + def asIncompleteMatch = { + val theHole = SimpleCase(WildcardPattern(None), Hole(this.getType, Seq())) + MatchExpr(scrutinee, cases :+ theHole) + } + } + sealed abstract class MatchCase extends Tree { val pattern: Pattern val rhs: Expr diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 1ee056af437a7a0ab03c6e9abf07ac0001a90a89..87a1e8e0629524a6229c8b63c420a0907d8e5963 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -498,7 +498,7 @@ trait AbstractZ3Solver case me @ MatchExpr(s, cs) => rec(matchToIfThenElse(me)) - case Passes(scrut, tests) => + case Gives(scrut, tests) => rec(matchToIfThenElse(matchExpr(scrut, tests))) case tu @ Tuple(args) => diff --git a/src/main/scala/leon/synthesis/ConvertHoles.scala b/src/main/scala/leon/synthesis/ConvertHoles.scala index 55e3e685d804c250909b4c35ec0370e349f75903..375b6dfe17a929ac4d6378f58d09360f3345526e 100644 --- a/src/main/scala/leon/synthesis/ConvertHoles.scala +++ b/src/main/scala/leon/synthesis/ConvertHoles.scala @@ -36,15 +36,35 @@ object ConvertHoles extends LeonPhase[Program, Program] { * } * */ - def run(ctx: LeonContext)(pgm: Program): Program = { - pgm.definedFunctions.foreach(fd => { - if (fd.hasBody) { + def convertHoles(e : Expr, ctx : LeonContext, treatGives : Boolean = false) : Expr = { + val (pre, body, post) = breakDownSpecs(e) + + // Ensure that holes are not found in pre and/or post conditions + pre.foreach { + preTraversal{ + case h : Hole => + ctx.reporter.error("Holes are not supported in preconditions. @"+ h.getPos) + case _ => + } + } + + post.foreach { case (id, post) => + preTraversal{ + case h : Hole => + ctx.reporter.error("Holes are not supported in postconditions. @"+ h.getPos) + case _ => + }(post) + } + body match { + case Some(body) => var holes = List[Identifier]() - val newBody = preMap { - case h @ Hole(tpe, es) => + val withoutHoles = preMap { + case p : Gives if treatGives => + Some(p.asIncompleteMatch) + case h : Hole => val (expr, ids) = toExpr(h) holes ++= ids @@ -52,43 +72,33 @@ object ConvertHoles extends LeonPhase[Program, Program] { Some(expr) case _ => None - }(fd.body.get) + }(body) - if (holes.nonEmpty) { + val asChoose = if (holes.nonEmpty) { val cids = holes.map(_.freshen) - val pred = fd.postcondition match { + val pred = post match { case Some((id, post)) => - replaceFromIDs((holes zip cids.map(_.toVariable)).toMap, Let(id, newBody, post)) + replaceFromIDs((holes zip cids.map(_.toVariable)).toMap, Let(id, withoutHoles, post)) case None => BooleanLiteral(true) } - val withChoose = letTuple(holes, tupleChoose(Choose(cids, pred)), newBody) + letTuple(holes, tupleChoose(Choose(cids, pred)), withoutHoles) - fd.body = Some(withChoose) } + else withoutHoles - } - - // Ensure that holes are not found in pre and/or post conditions - fd.precondition.foreach { - preTraversal{ - case _: Hole => - ctx.reporter.error("Holes are not supported in preconditions. (function "+fd.id.asString(ctx)+")") - case _ => - } - } + withPostcondition(withPrecondition(asChoose, pre), post) + + case None => e + } - fd.postcondition.foreach { case (id, post) => - preTraversal{ - case _: Hole => - ctx.reporter.error("Holes are not supported in postconditions. (function "+fd.id.asString(ctx)+")") - case _ => - }(post) - } - }) + } + + def run(ctx: LeonContext)(pgm: Program): Program = { + pgm.definedFunctions.foreach(fd => fd.fullBody = convertHoles(fd.fullBody,ctx) ) pgm } diff --git a/src/test/resources/regression/verification/purescala/invalid/Asserts1.scala b/src/test/resources/regression/verification/purescala/invalid/Asserts1.scala index 2643eff3e10a772e4bf475475b8eaa88bef2d168..4860603fcfc6f84c332c8cbfc553b2e28233e18b 100644 --- a/src/test/resources/regression/verification/purescala/invalid/Asserts1.scala +++ b/src/test/resources/regression/verification/purescala/invalid/Asserts1.scala @@ -6,27 +6,19 @@ object Operators { def foo(a: Int): Int = { require(a > 0) - - { - val b = a - assert(b > 0, "Hey now") - b + bar(1) - } ensuring { _ < 2 } - - } ensuring { - _ > a + val b = a + assert(b > 0, "Hey now") + b + bar(1) + } ensuring { res => + res > a && res < 2 } def bar(a: Int): Int = { require(a > 0) - - { - val b = a - assert(b > 0, "Hey now") - b + 2 - } ensuring { _ > 2 } - - } ensuring { - _ > a + val b = a + assert(b > 0, "Hey now") + b + 2 + } ensuring { res => + res > a && res > 2 } }