diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 98223ecceb30e6dcff1d30ce2a15ec0a2f7d3750..aceebf495ed23694dbbf8d351e132803f2e7c367 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -90,7 +90,7 @@ object Constructors { case _ => true }) - case _: TupleType | Int32Type | BooleanType | UnitType | _: AbstractClassType => + case _: TupleType | Int32Type | IntegerType | BooleanType | UnitType | _: AbstractClassType => cases case t => diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 9c72b465c58908cd1bc3b20d910c0638a9fc0600..c8e6886ba3079d40955f5fa62dd9fa8327e5c753 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -883,7 +883,7 @@ object ExprOps { preMap(rewritePM)(expr) } - def matchCasePathConditions(m: MatchExpr, pathCond: List[Expr]) : Seq[List[Expr]] = { + def matchExprCaseConditions(m: MatchExpr, pathCond: List[Expr]) : Seq[List[Expr]] = { val MatchExpr(scrut, cases) = m var pcSoFar = pathCond for (c <- cases) yield { @@ -901,8 +901,24 @@ object ExprOps { } } + // Condition to pass this match case, expressed w.r.t scrut only + def matchCaseCondition(scrut: Expr, c: MatchCase): Expr = { + + val patternC = conditionForPattern(scrut, c.pattern, includeBinders = false) + + c.optGuard match { + case Some(g) => + // guard might refer to binders + val map = mapForPattern(scrut, c.pattern) + and(patternC, replaceFromIDs(map, g)) + + case None => + patternC + } + } + def passesPathConditions(p : Passes, pathCond: List[Expr]) : Seq[List[Expr]] = { - matchCasePathConditions(MatchExpr(p.in, p.cases), pathCond) + matchExprCaseConditions(MatchExpr(p.in, p.cases), pathCond) } /* diff --git a/src/main/scala/leon/purescala/SimplifierWithPaths.scala b/src/main/scala/leon/purescala/SimplifierWithPaths.scala index 8fe77deee8af24239c464554bdb41ca2178fc780..9e33ba1f64171b1e622ef9b83a90bb78132d73f4 100644 --- a/src/main/scala/leon/purescala/SimplifierWithPaths.scala +++ b/src/main/scala/leon/purescala/SimplifierWithPaths.scala @@ -84,7 +84,7 @@ class SimplifierWithPaths(sf: SolverFactory[Solver]) extends TransformerWithPC { var stillPossible = true - val conds = matchCasePathConditions(me, path) + val conds = matchExprCaseConditions(me, path) val newCases = cases.zip(conds).flatMap { case (cs, cond) => if (stillPossible && sat(and(cond: _*))) { diff --git a/src/main/scala/leon/repair/RepairCostModel.scala b/src/main/scala/leon/repair/RepairCostModel.scala index 0d6837f1c028b0bc220394b19458ccda062a1cd7..5d36fcf93683838a0c9627df1f1407435a8e0427 100644 --- a/src/main/scala/leon/repair/RepairCostModel.scala +++ b/src/main/scala/leon/repair/RepairCostModel.scala @@ -14,11 +14,12 @@ case class RepairCostModel(cm: CostModel) extends WrappedCostModel(cm, "Repair(" val h = cm.andNode(an, subs).minSize Cost(an.ri.rule match { - case GuidedDecomp => 1 - case GuidedCloser => 0 + case Split => 1 + case Verify => 0 + case Focus => -10 case CEGLESS => 0 case TEGLESS => 1 - case _ => h+1 + case _ => h+1 }) } } diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 3415365c180432dcabc10ecab10e7e34db53cb5f..20b95a30bb868e93db3311e14c64740ce016afac 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -34,378 +34,155 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout implicit val debugSection = DebugSectionRepair - private object VRes { - trait VerificationResult - case object Valid extends VerificationResult - case class NotValid(passing : Seq[Example], failing : Seq[Example]) extends VerificationResult - } - - def programSize(pgm: Program): Int = { - visibleFunDefsFromMain(pgm).foldLeft(0) { - case (s, f) => - 1 + f.params.size + formulaSize(f.fullBody) + s - } - } - - import VRes._ def repair(): Unit = { val to = new TimeoutFor(ctx.interruptManager) to.interruptAfter(repairTimeoutMs) { reporter.info(ASCIIHelpers.title("1. Discovering tests for "+fd.id)) - val t1 = new Timer().start - discoverTests match { - case Valid => - reporter.info(s"Function ${fd.id} is found valid, no repair needed!") - case NotValid(passingTests, failingTests) => - - reporter.info(f" - Passing: ${passingTests.size}%3d") - reporter.info(f" - Failing: ${failingTests.size}%3d") - reporter.ifDebug { printer => - printer(new ExamplesTable("Tests failing:", failingTests).toString) - printer(new ExamplesTable("Tests passing:", passingTests).toString) - } - - // We exclude redundant failing tests, and only select the minimal tests - val minimalFailingTests = { - type FI = (FunDef, Seq[Expr]) - // We don't want tests whose invocation will call other failing tests. - // This is because they will appear erroneous, - // even though the error comes from the called test - val testEval = new RepairTrackingEvaluator(ctx, program) - failingTests foreach { ts => - testEval.eval(functionInvocation(fd, ts.ins)) - } - val test2Tests : Map[FI, Set[FI]] = testEval.fullCallGraph - /*println("About to print") - for{ - (test, tests) <-test2Tests - if (test._1 == fd) - } { - println(test._2 mkString ", ") - println(new ExamplesTable("", tests.toSeq.filter{_._1 == fd}.map{ x => InExample(x._2)}).toString) - }*/ - def isFailing(fi : FI) = !testEval.fiStatus(fi) && (fi._1 == fd) - val failing = test2Tests filter { case (from, to) => - isFailing(from) && (to forall (!isFailing(_)) ) - } - failing.keySet map { case (_, args) => InExample(args) } - } - - reporter.info(f" - Minimal Failing Set Size: ${minimalFailingTests.size}%3d") - - val initTime = t1.stop - reporter.info("Finished in "+initTime+"ms") - - reporter.ifDebug { printer => - printer(new ExamplesTable("Minimal failing:", minimalFailingTests.toSeq).toString) - } - - reporter.info(ASCIIHelpers.title("2. Locating/Focusing synthesis problem")) - val t2 = new Timer().start - val synth = getSynthesizer(minimalFailingTests.toList) - val ci = synth.ci - val p = synth.problem - - var solutions = List[Solution]() - val focusTime = t2.stop - - reporter.info("Finished in "+focusTime+"ms") - reporter.info(ASCIIHelpers.title("3. Synthesizing")) - reporter.info(p) - - try { - synth.synthesize() match { - case (search, sols) => - for (sol <- sols) { - - // Validate solution if not trusted - if (!sol.isTrusted) { - reporter.info("Found untrusted solution! Verifying...") - val expr = sol.toSimplifiedExpr(ctx, program) - ci.ch.impl = Some(expr) - - getVerificationCounterExamples(ci.fd, program) match { - case NotValid(_, ces) if ces.nonEmpty => - reporter.error("I ended up finding this counter example:\n"+ces.mkString(" | ")) - - case NotValid(_, _) => - solutions ::= sol - reporter.warning("Solution is not trusted!") - - case Valid => - solutions ::= sol - reporter.info("Solution was not trusted but post-validation passed!") - } - } else { - reporter.info("Found trusted solution!") - solutions ::= sol - } - } - - if (synth.settings.generateDerivationTrees) { - val dot = new DotGenerator(search.g) - dot.writeFile("derivation"+DotGenerator.nextId()+".dot") - } - - if (solutions.isEmpty) { - reporter.error(ASCIIHelpers.title("Failed to repair!")) - } else { - - reporter.info(ASCIIHelpers.title("Repair successful:")) - for ((sol, i) <- solutions.reverse.zipWithIndex) { - reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":")) - val expr = sol.toSimplifiedExpr(ctx, program) - reporter.info(ScalaPrinter(expr)) - } - reporter.info(ASCIIHelpers.title("In context:")) - - - for ((sol, i) <- solutions.reverse.zipWithIndex) { - reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":")) - val expr = sol.toSimplifiedExpr(ctx, program) - val nfd = fd.duplicate - - nfd.body = fd.body.map { b => - replace(Map(ci.source -> expr), b) - } - - reporter.info(ScalaPrinter(nfd)) - } - - } - } - } finally { - synth.shutdown() - } - } - } - } - def getSynthesizer(failingTests: List[Example]): Synthesizer = { - - val body = fd.body.get + val tb = discoverTests() - val (newBody, replacedExpr) = focusRepair(program, fd, failingTests) - fd.body = Some(newBody) + if (tb.invalids.nonEmpty) { + reporter.info(f" - Passing: ${tb.valids.size}%3d") + reporter.info(f" - Failing: ${tb.invalids.size}%3d") - val psize = programSize(initProgram) - val size = formulaSize(body) - val focusSize = formulaSize(replacedExpr) + reporter.ifDebug { printer => + printer(tb.asString("Discovered Tests")) + } - reporter.info("Program size : "+psize) - reporter.info("Original body size: "+size) - reporter.info("Focused expr size : "+focusSize) + reporter.info(ASCIIHelpers.title("2. Minimizing tests")) + val tb2 = tb.minimizeInvalids(fd, ctx, program) - val guide = Guide(replacedExpr) + // We exclude redundant failing tests, and only select the minimal tests + reporter.info(f" - Minimal Failing Set Size: ${tb2.invalids.size}%3d") - // Return synthesizer for this choose - val soptions0 = SynthesisPhase.processOptions(ctx) - - val soptions = soptions0.copy( - functionsToIgnore = soptions0.functionsToIgnore + fd, - costModel = RepairCostModel(soptions0.costModel), - rules = (soptions0.rules ++ Seq( - GuidedDecomp, - GuidedCloser, - CEGLESS - //TEGLESS - )) diff Seq(ADTInduction, TEGIS, IntegerInequalities, IntegerEquation) - ) + reporter.ifDebug { printer => + printer(tb2.asString("Minimal Failing Tests")) + } - // extract chooses from fd - val Seq(ci) = ChooseInfo.extractFromFunction(program, fd) + val synth = getSynthesizer(tb2) - val nci = ci.copy(pc = and(ci.pc, guide)) + try { + reporter.info(ASCIIHelpers.title("3. Synthesizing repair")) + val (search, solutions) = synth.validate(synth.synthesize()) match { + case (search, sols) => + (search, sols.collect { case (s, true) => s }) + } - new Synthesizer(ctx, program, nci, soptions) - } + if (synth.settings.generateDerivationTrees) { + val dot = new DotGenerator(search.g) + dot.writeFile("derivation"+DotGenerator.nextId()+".dot") + } - private def focusRepair(program: Program, fd: FunDef, failingTests: List[Example]): (Expr, Expr) = { + if (solutions.isEmpty) { + reporter.error(ASCIIHelpers.title("Failed to repair!")) + } else { - val args = fd.params.map(_.id) + reporter.info(ASCIIHelpers.title("Repair successful:")) + for ((sol, i) <- solutions.zipWithIndex) { + reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":")) + val expr = sol.toSimplifiedExpr(ctx, synth.program) + reporter.info(ScalaPrinter(expr)) + } + reporter.info(ASCIIHelpers.title("In context:")) - val spec = fd.postcondition.getOrElse(Lambda(Seq(ValDef(FreshIdentifier("res", fd.returnType, true))), BooleanLiteral(true))) - val body = fd.body.get + for ((sol, i) <- solutions.zipWithIndex) { + reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":")) + val expr = sol.toSimplifiedExpr(ctx, synth.program) + val nfd = fd.duplicate - val evaluator = new DefaultEvaluator(ctx, program) + nfd.body = Some(expr) - // Check how an expression behaves on tests - // - returns Some(true) if for all tests e evaluates to true - // - returns Some(false) if for all tests e evaluates to false - // - returns None otherwise - def forAllTests(e: Expr, env: Map[Identifier, Expr], evaluator: Evaluator): Option[Boolean] = { - var soFar : Option[Boolean] = None - failingTests.foreach { ex => - val ins = ex.ins - evaluator.eval(e, env ++ (args zip ins)) match { - case EvaluationResults.Successful(BooleanLiteral(b)) => - soFar match { - case None => - soFar = Some(b) - case Some(`b`) => - case _ => - return None + reporter.info(ScalaPrinter(nfd)) } - case e => - return None + } + } finally { + synth.shutdown() } + } else { + reporter.info(s"Could not find a wrong execution.") } - - soFar } + } - def focus(expr: Expr, env: Map[Identifier, Expr])(implicit spec: Expr): (Expr, Expr) = { - val choose = Choose(spec) - - def testCondition(cond: Expr, inExpr: Expr => Expr) = forAllTests( - application(spec, Seq(inExpr(not(cond)))), - env, - new RepairNDEvaluator(ctx,program,fd,cond) - ) - - def condAsSpec(cond: Expr, inExpr: Expr => Expr) = { - val newOut = FreshIdentifier("cond", BooleanType, true) - val newSpec = Lambda(Seq(ValDef(newOut)), application(spec, Seq(inExpr(Variable(newOut))))) - val (b, r) = focus(cond, env)(newSpec) - (inExpr(b), r) + def getSynthesizer(tb: TestBank): Synthesizer = { + val (np, fdMap) = replaceFunDefs(program)({ fd => + if (fd == this.fd) { + Some(fd.duplicate) + } else { + None } - - expr match { - case me @ MatchExpr(scrut, cases) => - var res: Option[(Expr, Expr)] = None - - // in case scrut is an non-variable expr, we simplify to a variable + inject in env - for (c <- cases if res.isEmpty) { - val cond = and(conditionForPattern(scrut, c.pattern, includeBinders = false), - c.optGuard.getOrElse(BooleanLiteral(true))) - val map = mapForPattern(scrut, c.pattern) - - - forAllTests(cond, env ++ map, evaluator) match { - case Some(true) => - val (b, r) = focus(c.rhs, env ++ map) - res = Some((MatchExpr(scrut, cases.map { c2 => - if (c2 eq c) { - c2.copy(rhs = b) - } else { - c2 - } - }), r)) - - case Some(false) => - // continue until next case - case None => - res = Some((choose, expr)) - } - } - - res.getOrElse((choose, expr)) - - case Let(id, value, body) => - val (b, r) = focus(body, env + (id -> value)) - (Let(id, value, b), r) - - case ite @ IfExpr(c, thn, els) => - testCondition(c, IfExpr(_, thn, els)) match { - case Some(true) => - condAsSpec(c, IfExpr(_, thn, els)) - case _ => - // Try to focus on branches - forAllTests(c, env, evaluator) match { - case Some(true) => - val (b, r) = focus(thn, env) - (IfExpr(c, b, els), r) - case Some(false) => - val (b, r) = focus(els, env) - (IfExpr(c, thn, b), r) - case None => - // We cannot focus any further - (choose, expr) - } - } - - case a@And(Seq(ex, exs@_*)) => - testCondition(ex, e => andJoin(e +: exs)) match { - case Some(true) => - // The first is wrong - condAsSpec(ex, e => andJoin(e +: exs)) - case _ => - forAllTests(ex, env, evaluator) match { - case Some(true) => - // First is always true, focus on rest - focus(andJoin(exs), env) - case Some(false) => - // Seems all test break when we evaluate to false, try true??? - (choose, BooleanLiteral(true)) - case None => - // We cannot focus any further - (choose, expr) - } - } + }) - case o@Or(Seq(ex, exs@_*)) => - testCondition(ex, e => orJoin(e +: exs)) match { - case Some(true) => - condAsSpec(ex, e => orJoin(e +: exs)) - case _ => - forAllTests(ex, env, evaluator) match { - case Some(false) => - // First is always false, focus on rest - focus(orJoin(exs), env) - case Some(true) => - // Seems all test break when we evaluate to true, try false??? - (choose, BooleanLiteral(false)) - case None => - // We cannot focus any further - (choose, expr) - } - } - - // Let, LetTuple, Methods, tuples? - - case _ => - (choose, expr) - } - } - - focus(body, Map())(spec) + val nfd = fdMap(fd) + + val origBody = nfd.body.get + + val spec = nfd.postcondition.getOrElse( + Lambda(Seq(ValDef(FreshIdentifier("res", nfd.returnType))), BooleanLiteral(true)) + ) + + val choose = Choose(spec) + choose.impl = Some(origBody) + nfd.body = Some(choose) + + val term = Terminating(nfd.typed, nfd.params.map(_.id.toVariable)) + val guide = Guide(origBody) + val pre = nfd.precondition.getOrElse(BooleanLiteral(true)) + + val ci = ChooseInfo( + nfd, + andJoin(Seq(pre, guide, term)), + origBody, + choose, + tb + ) + + // Return synthesizer for this choose + val so0 = SynthesisPhase.processOptions(ctx) + + val soptions = so0.copy( + functionsToIgnore = so0.functionsToIgnore + fd, + costModel = RepairCostModel(so0.costModel), + rules = (so0.rules ++ Seq( + Split, + Verify, + Focus, + CEGLESS + //TEGLESS + )) diff Seq(ADTInduction, TEGIS, IntegerInequalities, IntegerEquation) + ) + + new Synthesizer(ctx, np, ci, soptions) } - private def getVerificationCounterExamples(fd: FunDef, prog: Program): VerificationResult = { + def getVerificationCExs(fd: FunDef): Seq[Example] = { val timeoutMs = verifTimeoutMs.getOrElse(3000L) - val solverf = SolverFactory.getFromSettings(ctx, prog).withTimeout(timeoutMs) - val vctx = VerificationContext(ctx, prog, solverf, reporter) + val solverf = SolverFactory.getFromSettings(ctx, program).withTimeout(timeoutMs) + val vctx = VerificationContext(ctx, program, solverf, reporter) val vcs = AnalysisPhase.generateVCs(vctx, Some(Seq(fd.id.name))) - try { + try { val report = AnalysisPhase.checkVCs( - vctx, - vcs, + vctx, + vcs, checkInParallel = true, stopAfter = Some({ (vc, vr) => vr.isInvalid }) ) val vrs = report.vrs - if(vrs.forall{ _._2.isValid }) { - Valid - } else { - NotValid(Nil, - vrs.collect { - case (_, VCResult(VCStatus.Invalid(ex), _, _)) => - InExample(fd.params.map{vd => ex(vd.id)}) - } - ) + vrs.collect { case (_, VCResult(VCStatus.Invalid(ex), _, _)) => + InExample(fd.params.map{vd => ex(vd.id)}) } } finally { solverf.shutdown() } } - - private def discoverTests: VerificationResult = { + + def discoverTests(): TestBank = { import bonsai.enumerators._ import utils.ExpressionGrammars.ValueGrammar @@ -439,7 +216,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout new InExample(ins) } } - + val generatedTests = inputs .take(maxEnumerated) .filter(filtering) @@ -447,67 +224,29 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout .map(inputsToExample) .toList - val (generatedPassing, generatedFailing) = generatedTests.partition { + val (genPassing, genFailing) = generatedTests.partition { case _: InOutExample => true case _ => false } - // Extract passing/failing from the passes in POST - val ef = new ExamplesFinder(ctx, program) - val (userPassing, userFailing) = ef.extractTests(fd) - val passing = generatedPassing ++ userPassing - - // If we have no ce yet, try to verify, if it fails, we have at least one CE - (generatedFailing ++ userFailing) match { - case Seq() => getVerificationCounterExamples(fd, program) match { - case Valid => Valid - case NotValid(_, ces) => NotValid(passing, ces) - } - case nonEmpty => NotValid(passing, nonEmpty) - } - - } + val genTb = TestBank(genPassing, genFailing).stripOuts - // ununsed for now, but implementation could be useful later - private def disambiguate(p: Problem, sol1: Solution, sol2: Solution): Option[(InOutExample, InOutExample)] = { - val s1 = sol1.toSimplifiedExpr(ctx, program) - val s2 = sol2.toSimplifiedExpr(ctx, program) + // Extract passing/failing from the passes in POST + val userTb = new ExamplesFinder(ctx, program).extractTests(fd) - val e = new DefaultEvaluator(ctx, program) + val allTb = genTb union userTb - if (s1 == s2) { - None + if (allTb.invalids.isEmpty) { + TestBank(allTb.valids, getVerificationCExs(fd)) } else { - val diff = and(p.pc, not(Equals(s1, s2))) - val solverf = SolverFactory.default(ctx, program).withTimeout(1.second) - val solver = solverf.getNewSolver() - - try { - solver.assertCnstr(diff) - solver.check match { - case Some(true) => - val m = solver.getModel - val inputs = p.as.map(id => m.getOrElse(id, simplestValue(id.getType))) - val inputsMap = (p.as zip inputs).toMap - - (e.eval(s1, inputsMap), e.eval(s2, inputsMap)) match { - case (EvaluationResults.Successful(tr1), EvaluationResults.Successful(tr2)) => - val r1 = unwrapTuple(tr1, p.xs.size) - val r2 = unwrapTuple(tr2, p.xs.size) - Some((InOutExample(inputs, r1), InOutExample(inputs, r2))) - case _ => - None - } - case Some(false) => - None - case _ => - // considered as equivalent - None - } - } finally { - solver.free() - solverf.shutdown() - } + allTb + } + } + + def programSize(pgm: Program): Int = { + visibleFunDefsFromMain(pgm).foldLeft(0) { + case (s, f) => + 1 + f.params.size + formulaSize(f.fullBody) + s } } } diff --git a/src/main/scala/leon/repair/rules/Focus.scala b/src/main/scala/leon/repair/rules/Focus.scala new file mode 100644 index 0000000000000000000000000000000000000000..d1ddf7037c72278103657f5aa1cb2d30afcad769 --- /dev/null +++ b/src/main/scala/leon/repair/rules/Focus.scala @@ -0,0 +1,223 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package repair +package rules + +import synthesis._ + +import leon.utils.Simplifiers + +import purescala.ScalaPrinter + +import evaluators._ + +import purescala.Expressions._ +import purescala.Definitions._ +import purescala.Common._ +import purescala.Types._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Constructors._ +import purescala.Extractors._ + +import Witnesses._ + +import solvers._ + +case object Focus extends Rule("Focus") { + + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + if (hctx.searchDepth > 0) { + return Nil + } + + val fd = hctx.ci.fd + val ctx = hctx.sctx.context + val program = hctx.sctx.program + + val evaluator = new DefaultEvaluator(ctx, program) + + // Check how an expression behaves on tests + // - returns Some(true) if for all tests e evaluates to true + // - returns Some(false) if for all tests e evaluates to false + // - returns None otherwise + def forAllTests(tests: Seq[Example])(e: Expr, env: Map[Identifier, Expr], evaluator: Evaluator): Option[Boolean] = { + var soFar: Option[Boolean] = None + + tests.foreach { ex => + evaluator.eval(e, (p.as zip ex.ins).toMap ++ env) match { + case EvaluationResults.Successful(BooleanLiteral(b)) => + soFar match { + case None => + soFar = Some(b) + case Some(`b`) => + /* noop */ + case _ => + return None + } + + case e => + //println("Evaluator said "+e) + return None + } + } + + soFar + } + + val fdSpec = { + val id = FreshIdentifier("res", fd.returnType) + Let(id, fd.body.get, + fd.postcondition.map(l => application(l, Seq(id.toVariable))).getOrElse(BooleanLiteral(true)) + ) + } + + def focus(p: Problem): Traversable[RuleInstantiation] = { + val faTests = forAllTests(p.tb.invalids) _ + + val TopLevelAnds(clauses) = p.ws + + val guides = clauses.collect { + case Guide(expr) => expr + } + + val wss = clauses.filter { + case _: Guide => false + case _ => true + } + + def ws(g: Expr) = andJoin(Guide(g) +: wss) + + def testCondition(cond: Expr) = { + val ndSpec = postMap { + case c if c eq cond => Some(not(cond)) // Use reference equality + case _ => None + }(fdSpec) + + faTests(ndSpec, Map(), new RepairNDEvaluator(ctx, program, fd, cond)) + } + + guides.flatMap { + case IfExpr(c, thn, els) => + testCondition(c) match { + case Some(true) => + val cx = FreshIdentifier("cond", BooleanType) + // Focus on condition + val np = Problem(p.as, ws(c), p.pc, letTuple(p.xs, IfExpr(cx.toVariable, thn, els), p.phi), List(cx), p.tb.stripOuts) + + Some(decomp(List(np), termWrap(IfExpr(_, thn, els)), s"Focus on if-cond '$c'")(p)) + + case _ => + // Try to focus on branches + faTests(c, Map(), evaluator) match { + case Some(true) => + val np = Problem(p.as, ws(thn), and(p.pc, c), p.phi, p.xs, p.tbOps.filterIns(c)) + + Some(decomp(List(np), termWrap(IfExpr(c, _, els), c), s"Focus on if-then")(p)) + case Some(false) => + val np = Problem(p.as, ws(els), and(p.pc, not(c)), p.phi, p.xs, p.tbOps.filterIns(not(c))) + + Some(decomp(List(np), termWrap(IfExpr(c, thn, _), not(c)), s"Focus on if-else")(p)) + case None => + // We cannot focus any further + None + } + } + + case MatchExpr(scrut, cases) => + var res: Option[Traversable[RuleInstantiation]] = None + + var pcSoFar: Seq[Expr] = Nil + + for (c <- cases if res.isEmpty) { + val map = mapForPattern(scrut, c.pattern) + + val thisCond = matchCaseCondition(scrut, c) + val cond = andJoin(pcSoFar :+ thisCond) + pcSoFar = pcSoFar :+ not(thisCond) + + // thisCond here is safe, because we focus we now that all tests have been false so far + faTests(thisCond, map, evaluator) match { + case Some(true) => + + val vars = map.toSeq.map(_._1) + + // Filter tests by the path-condition + val tb2 = p.tbOps.filterIns(cond) + + // Augment test with the additional variables and their valuations + val tbF: (Seq[Expr] => List[Seq[Expr]]) = { (e: Seq[Expr]) => + val emap = (p.as zip e).toMap + + evaluator.eval(tupleWrap(vars.map(map)), emap).result.map { r => + e ++ unwrapTuple(r, vars.size) + }.toList + } + + val tb3 = if (vars.nonEmpty) { + tb2.mapIns(tbF) + } else { + tb2 + } + + val newPc = andJoin(cond +: vars.map { id => equality(id.toVariable, map(id)) }) + + val np = Problem(p.as ++ vars, ws(c.rhs), and(p.pc, newPc), p.phi, p.xs, tb3) + + res = Some( + Some( + decomp(List(np), termWrap(x => MatchExpr(scrut, cases.map { + case `c` => c.copy(rhs = x) + case c2 => c2 + }), cond), s"Focus on match-case '${c.pattern}'")(p) + ) + ) + + case Some(false) => + // continue until next case + case None => + res = Some(Nil) + } + } + + res.getOrElse(Nil) + + + case Let(id, value, body) => + val tbF: (Seq[Expr] => List[Seq[Expr]]) = { (e: Seq[Expr]) => + val map = (p.as zip e).toMap + + evaluator.eval(value, map).result.map { r => + e :+ r + }.toList + } + + val np = Problem(p.as :+ id, ws(body), and(p.pc, equality(id.toVariable, value)), p.phi, p.xs, p.tb.mapIns(tbF)) + + Some(decomp(List(np), termWrap(Let(id, value, _)), s"Focus on let-body")(p)) + + case _ => None + } + } + + def focusRec(is: Traversable[RuleInstantiation]): Traversable[RuleInstantiation] = { + val res = is.flatMap { ri => + ri.apply(hctx) match { + case RuleExpanded(subs) => + subs.flatMap(focus) + case _ => + Nil + } + } + + if (res.isEmpty) { + is + } else { + res + } + } + + focusRec(focus(p)) + } +} diff --git a/src/main/scala/leon/repair/rules/GuidedDecomp.scala b/src/main/scala/leon/repair/rules/GuidedDecomp.scala deleted file mode 100644 index 266372d76816f093083b847de5d99049e938549f..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/repair/rules/GuidedDecomp.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package repair -package rules - -import synthesis._ - -import leon.utils.Simplifiers - -import purescala.Expressions._ -import purescala.Definitions._ -import purescala.Common._ -import purescala.Types._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Constructors._ - -import Witnesses._ - -import solvers._ - -case object GuidedDecomp extends Rule("Guided Decomp") { - def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - if (hctx.searchDepth > 0) { - return Nil - } - - val TopLevelAnds(clauses) = p.ws - - val guides = clauses.collect { - case Guide(expr) => expr - } - - val simplify = Simplifiers.bestEffort(hctx.context, hctx.program)_ - - val alts = guides.collect { - case g @ IfExpr(c, thn, els) => - val sub1 = p.copy(ws = replace(Map(g -> thn), p.ws), pc = and(c, replace(Map(g -> thn), p.pc))) - val sub2 = p.copy(ws = replace(Map(g -> els), p.ws), pc = and(Not(c), replace(Map(g -> els), p.pc))) - - val onSuccess: List[Solution] => Option[Solution] = { - case List(s1, s2) => - Some(Solution(or(s1.pre, s2.pre), s1.defs++s2.defs, IfExpr(c, s1.term, s2.term))) - case _ => - None - } - - Some(decomp(List(sub1, sub2), onSuccess, s"Guided If-Split on '$c'")) - - case m @ MatchExpr(scrut0, _) => - - val scrut = scrut0 match { - case v : Variable => v - case _ => Variable(FreshIdentifier("scrut", scrut0.getType, true)) - } - var scrutCond: Expr = if (scrut == scrut0) BooleanLiteral(true) else Equals(scrut0, scrut) - - val fullMatch = if (isMatchExhaustive(m)) { - m - } else { - m.copy(cases = m.cases :+ MatchCase(WildcardPattern(None), None, Error(m.getType, "unreachable in original program"))) - } - - val cs = fullMatch.cases - - val subs = for ((c, cond) <- cs zip matchCasePathConditions(fullMatch, List(p.pc))) yield { - - val localScrut = c.pattern.binder.map( Variable ) getOrElse scrut - val scrutConstraint = if (localScrut == scrut) BooleanLiteral(true) else Equals(localScrut, scrut) - val substs = patternSubstitutions(localScrut, c.pattern) - - val pc = simplify(and( - scrutCond, - replace(Map(scrut0 -> scrut), replaceSeq(substs,scrutConstraint)), - replace(Map(scrut0 -> scrut), replace(Map(m -> c.rhs), andJoin(cond))) - )) - val ws = replace(Map(m -> c.rhs), p.ws) - val phi = replaceSeq(substs, p.phi) - val free = variablesOf(and(pc, phi)) -- p.xs - val asPrefix = p.as.filter(free) - - Problem(asPrefix ++ (free -- asPrefix), ws, pc, phi, p.xs) - } - - val onSuccess: List[Solution] => Option[Solution] = { subs => - val cases = for ((c, s) <- cs zip subs) yield { - c.copy(rhs = s.term) - } - - Some(Solution( - orJoin(subs.map(_.pre)), - subs.map(_.defs).foldLeft(Set[FunDef]())(_ ++ _), - if (scrut0 != scrut) Let(scrut.id, scrut0, matchExpr(scrut, cases)) - else matchExpr(scrut, cases), - subs.forall(_.isTrusted) - )) - } - - Some(decomp(subs.toList, onSuccess, s"Guided Match-Split on '$scrut0'")) - - case e => - None - } - - alts.flatten - } -} diff --git a/src/main/scala/leon/repair/rules/Split.scala b/src/main/scala/leon/repair/rules/Split.scala new file mode 100644 index 0000000000000000000000000000000000000000..b05f3570e62f8b2f4b67a1407db64f0a56a08232 --- /dev/null +++ b/src/main/scala/leon/repair/rules/Split.scala @@ -0,0 +1,124 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package repair +package rules + +import synthesis._ + +import leon.utils.Simplifiers + +import purescala.Expressions._ +import purescala.Definitions._ +import purescala.Common._ +import purescala.Types._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.Constructors._ + +import evaluators.DefaultEvaluator + +import Witnesses._ + +import solvers._ + +case object Split extends Rule("Split") { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + if (hctx.searchDepth > 0) { + return Nil + } + + val ctx = hctx.sctx.context + val program = hctx.sctx.program + + val evaluator = new DefaultEvaluator(ctx, program) + + val TopLevelAnds(clauses) = p.ws + + val guides = clauses.collect { + case Guide(expr) => expr + } + + val wss = clauses.filter { + case _: Guide => false + case _ => true + } + + def ws(g: Expr) = andJoin(Guide(g) +: wss) + + val simplify = Simplifiers.bestEffort(hctx.context, hctx.program)_ + + val alts = guides.collect { + case g @ IfExpr(c, thn, els) => + val sub1 = p.copy(ws = replace(Map(g -> thn), p.ws), pc = and(c, replace(Map(g -> thn), p.pc)), tb = p.tbOps.filterIns(c)) + val sub2 = p.copy(ws = replace(Map(g -> els), p.ws), pc = and(Not(c), replace(Map(g -> els), p.pc)), tb = p.tbOps.filterIns(Not(c))) + + val onSuccess: List[Solution] => Option[Solution] = { + case List(s1, s2) => + Some(Solution(or(s1.pre, s2.pre), s1.defs++s2.defs, IfExpr(c, s1.term, s2.term))) + case _ => + None + } + + Some(decomp(List(sub1, sub2), onSuccess, s"Guided If-Split on '$c'")) + + case m @ MatchExpr(scrut, cases) => + + var pcSoFar: Seq[Expr] = Nil + + val infos = for (c <- cases) yield { + val map = mapForPattern(scrut, c.pattern) + + val thisCond = matchCaseCondition(scrut, c) // this case alone, without past cases + val cond = andJoin(pcSoFar :+ thisCond) // with previous cases + pcSoFar = pcSoFar :+ not(thisCond) + + val vars = map.toSeq.map(_._1) + + // Filter tests by the path-condition + val tb2 = p.tbOps.filterIns(cond) + + // Augment test with the additional variables and their valuations + val tbF: (Seq[Expr] => List[Seq[Expr]]) = { (e: Seq[Expr]) => + val emap = (p.as zip e).toMap + + evaluator.eval(tupleWrap(vars.map(map)), emap).result.map { r => + e ++ unwrapTuple(r, vars.size) + }.toList + } + + val tb3 = if (vars.nonEmpty) { + tb2.mapIns(tbF) + } else { + tb2 + } + + val newPc = andJoin(cond +: vars.map { id => equality(id.toVariable, map(id)) }) + + (cond, Problem(p.as ++ vars, ws(c.rhs), and(p.pc, newPc), p.phi, p.xs, tb3)) + } + + val onSuccess = { (sols: List[Solution]) => + val term = MatchExpr(scrut, (cases zip sols).map { + case (c, sol) => c.copy(rhs = sol.term) + }) + + val pres = (infos zip sols).collect { + case ((cond, _), sol) if sol.pre != BooleanLiteral(true) => + and(cond, sol.pre) + } + + Some(Solution(andJoin(pres), sols.map(_.defs).reduceLeft(_ ++ _), term, sols.forall(_.isTrusted))) + } + + Some( + decomp(infos.map(_._2).toList, onSuccess, s"Split match on '$scrut'") + ) + + case e => + None + } + + alts.flatten + } +} diff --git a/src/main/scala/leon/repair/rules/GuidedCloser.scala b/src/main/scala/leon/repair/rules/Verify.scala similarity index 92% rename from src/main/scala/leon/repair/rules/GuidedCloser.scala rename to src/main/scala/leon/repair/rules/Verify.scala index aa80a394b45422a8359b57ccb4739e8ab435f448..b3590729e00d4ffa019108bc56a16ca326153c9c 100644 --- a/src/main/scala/leon/repair/rules/GuidedCloser.scala +++ b/src/main/scala/leon/repair/rules/Verify.scala @@ -16,10 +16,10 @@ import Witnesses._ import graph._ -case object GuidedCloser extends NormalizingRule("Guided Closer") { +case object Verify extends NormalizingRule("Verify") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { hctx.parentNode match { - case Some(an: AndNode) if an.ri.rule == GuidedDecomp => + case Some(an: AndNode) if an.ri.rule == Split => // We proceed as usual case _ => return Nil diff --git a/src/main/scala/leon/synthesis/ChooseInfo.scala b/src/main/scala/leon/synthesis/ChooseInfo.scala index 9e4b00d8c62c3cf0f932241d8b8c9fcedace6d99..68f8745af7ea61a28b141f226b4a63788d446d6f 100644 --- a/src/main/scala/leon/synthesis/ChooseInfo.scala +++ b/src/main/scala/leon/synthesis/ChooseInfo.scala @@ -12,9 +12,10 @@ import Witnesses._ case class ChooseInfo(fd: FunDef, pc: Expr, source: Expr, - ch: Choose) { + ch: Choose, + tests: TestBank) { - val problem = Problem.fromChoose(ch, pc) + val problem = Problem.fromChooseInfo(this) } object ChooseInfo { @@ -35,7 +36,7 @@ object ChooseInfo { val term = Terminating(fd.typed, fd.params.map(_.id.toVariable)) for ((ch, path) <- new ChooseCollectorWithPaths().traverse(actualBody)) yield { - ChooseInfo(fd, and(path, term), ch, ch) + ChooseInfo(fd, and(path, term), ch, ch, TestBank.empty) } } } diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index fb34d7d60c491fedb29d7d49cec359e0b9e3164b..4dc281f17f8cc1d023a779dd5f82d44b163d2938 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -20,7 +20,7 @@ class ExamplesFinder(ctx: LeonContext, program: Program) { val reporter = ctx.reporter - def extractTests(fd: FunDef): (Seq[Example], Seq[Example]) = fd.postcondition match { + def extractTests(fd: FunDef): TestBank = fd.postcondition match { case Some(Lambda(Seq(ValDef(id, _)), post)) => // @mk FIXME: make this more general val tests = extractTestsOf(post) @@ -52,10 +52,10 @@ class ExamplesFinder(ctx: LeonContext, program: Program) { } } - examples.partition(isValidTest) - + val (v, iv) = examples.partition(isValidTest) + TestBank(v, iv) case None => - (Nil, Nil) + TestBank(Nil, Nil) } def generateTests(p: Problem): Seq[Example] = { diff --git a/src/main/scala/leon/synthesis/InOutExample.scala b/src/main/scala/leon/synthesis/InOutExample.scala index 4218a6ec7686350bb2a303d4db30c7d1685a9e2d..67adb0787aceb8d1e98aef3ba71d50d76f77d558 100644 --- a/src/main/scala/leon/synthesis/InOutExample.scala +++ b/src/main/scala/leon/synthesis/InOutExample.scala @@ -9,26 +9,3 @@ import leon.utils.ASCIIHelpers._ class Example(val ins: Seq[Expr]) case class InOutExample(is: Seq[Expr], outs: Seq[Expr]) extends Example(is) case class InExample(is: Seq[Expr]) extends Example(is) - -class ExamplesTable(title: String, ts: Seq[Example]) { - override def toString = { - var tt = new Table(title) - - for (t <- ts) { - val os = t match { - case InOutExample(_, outs) => - outs.map(Cell(_)) - case _ => - Seq(Cell("?")) - } - - tt += Row( - t.ins.map(Cell(_)) ++ Seq(Cell("->")) ++ os - ) - } - - tt.render - } - -} - diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index 7a617eef41927a31aaa1b1870d5d03b72113c628..cfaa57e5042ded74ea951cfcff88d02f2b376451 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -13,37 +13,48 @@ import Witnesses._ // Defines a synthesis triple of the form: // ⟦ as ⟨ ws && pc | phi ⟩ xs ⟧ -case class Problem(as: List[Identifier], ws: Expr, pc: Expr, phi: Expr, xs: List[Identifier]) { +case class Problem(as: List[Identifier], ws: Expr, pc: Expr, phi: Expr, xs: List[Identifier], tb: TestBank = TestBank.empty) { def inType = tupleTypeWrap(as.map(_.getType)) def outType = tupleTypeWrap(xs.map(_.getType)) override def toString = { val pcws = and(ws, pc) - "⟦ "+as.mkString(";")+", "+(if (pcws != BooleanLiteral(true)) pcws+" ≺ " else "")+" ⟨ "+phi+" ⟩ "+xs.mkString(";")+" ⟧ " + + val tbInfo = "/"+tb.valids.size+","+tb.invalids.size+"/" + + "⟦ "+as.mkString(";")+", "+(if (pcws != BooleanLiteral(true)) pcws+" ≺ " else "")+" ⟨ "+phi+" ⟩ "+xs.mkString(";")+" ⟧ "+tbInfo } + def tbOps(implicit sctx: SearchContext) = ProblemTestBank(this, tb) } object Problem { - def fromChoose(ch: Choose, pc: Expr = BooleanLiteral(true)): Problem = { + def fromChoose(ch: Choose, pc: Expr = BooleanLiteral(true), tb: TestBank = TestBank.empty): Problem = { val xs = { val tps = ch.pred.getType.asInstanceOf[FunctionType].from tps map (FreshIdentifier("x", _, true)) }.toList val phi = application(simplifyLets(ch.pred), xs map { _.toVariable}) - val as = (variablesOf(And(pc, phi)) -- xs).toList + val as = (variablesOf(And(pc, phi)) -- xs).toList.sortBy(_.name) - // FIXME do we need this at all? val TopLevelAnds(clauses) = pc - // @mk FIXME: Is this needed? val (pcs, wss) = clauses.partition { case w : Witness => false case _ => true } - Problem(as, andJoin(wss), andJoin(pcs), phi, xs) + Problem(as, andJoin(wss), andJoin(pcs), phi, xs, tb) + } + + def fromChooseInfo(ci: ChooseInfo): Problem = { + // Same as fromChoose, but we order the input variables by the arguments of + // the functions, so that tests are compatible + val p = fromChoose(ci.ch, ci.pc, ci.tests) + val argsIndex = ci.fd.params.map(_.id).zipWithIndex.toMap.withDefaultValue(100) + + p.copy( as = p.as.sortBy(a => argsIndex(a))) } } diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 5a8063d5b43aedf7e36f3e614bdcdf852a726eea..20596de878bebb9e58cffe538bd8c20322e37203 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -7,6 +7,7 @@ import purescala.Common._ import purescala.Expressions._ import purescala.Types._ import purescala.ExprOps._ +import purescala.Constructors.and import rules._ abstract class Rule(val name: String) extends RuleDSL { @@ -25,6 +26,10 @@ abstract class NormalizingRule(name: String) extends Rule(name) { override val priority = RulePriorityNormalizing } +abstract class PreprocessingRule(name: String) extends Rule(name) { + override val priority = RulePriorityPreprocessing +} + object Rules { def all = List[Rule]( Unification.DecompTrivialClash, @@ -127,9 +132,10 @@ sealed abstract class RulePriority(val v: Int) extends Ordered[RulePriority] { def compare(that: RulePriority) = this.v - that.v } -case object RulePriorityNormalizing extends RulePriority(0) -case object RulePriorityHoles extends RulePriority(1) -case object RulePriorityDefault extends RulePriority(2) +case object RulePriorityPreprocessing extends RulePriority(5) +case object RulePriorityNormalizing extends RulePriority(10) +case object RulePriorityHoles extends RulePriority(15) +case object RulePriorityDefault extends RulePriority(20) /** * Common utilities used by rules @@ -168,4 +174,20 @@ trait RuleDSL { } } + + // pc corresponds to the pc to reach the point where the solution is used. It + // will be used if the sub-solution has a non-true pre. + def termWrap(f: Expr => Expr, pc: Expr = BooleanLiteral(true)): List[Solution] => Option[Solution] = { + (sols: List[Solution]) => sols match { + case List(s) => + val pre = if (s.pre == BooleanLiteral(true)) { + BooleanLiteral(true) + } else { + and(pc, s.pre) + } + + Some(Solution(pre, s.defs, f(s.term), s.isTrusted)) + case _ => None + } + } } diff --git a/src/main/scala/leon/synthesis/TestBank.scala b/src/main/scala/leon/synthesis/TestBank.scala new file mode 100644 index 0000000000000000000000000000000000000000..94dc664b127eb967260a2e9608f102532cad37b1 --- /dev/null +++ b/src/main/scala/leon/synthesis/TestBank.scala @@ -0,0 +1,174 @@ +package leon +package synthesis + +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.Constructors._ +import evaluators._ +import purescala.Common._ +import repair._ +import leon.utils.ASCIIHelpers._ + +case class TestBank(valids: Seq[Example], invalids: Seq[Example]) { + def examples: Seq[Example] = valids ++ invalids + + // Minimize tests of a function so that tests that are invalid because of a + // recursive call are eliminated + def minimizeInvalids(fd: FunDef, ctx: LeonContext, program: Program): TestBank = { + val evaluator = new RepairTrackingEvaluator(ctx, program) + + invalids foreach { ts => + evaluator.eval(functionInvocation(fd, ts.ins)) + } + + val outInfo = (invalids.collect { + case InOutExample(ins, outs) => ins -> outs + }).toMap + + val callGraph = evaluator.fullCallGraph + + def isFailing(fi: (FunDef, Seq[Expr])) = !evaluator.fiStatus(fi) && (fi._1 == fd) + + val failing = callGraph filter { case (from, to) => + isFailing(from) && (to forall (!isFailing(_)) ) + } + + val newInvalids = failing.keySet map { + case (_, args) => + outInfo.get(args) match { + case Some(outs) => + InOutExample(args, outs) + + case None => + InExample(args) + } + } + + TestBank(valids, newInvalids.toSeq) + } + + def union(that: TestBank) = { + TestBank( + (this.valids ++ that.valids).distinct, + (this.invalids ++ that.invalids).distinct + ) + } + + def map(f: Example => List[Example]) = { + TestBank(valids.flatMap(f), invalids.flatMap(f)) + } + + def mapIns(f: Seq[Expr] => List[Seq[Expr]]) = { + map { + case InExample(in) => + f(in).map(InExample(_)) + + case InOutExample(in, out) => + f(in).map(InOutExample(_, out)) + } + } + + def mapOuts(f: Seq[Expr] => List[Seq[Expr]]) = { + map { + case InOutExample(in, out) => + f(out).map(InOutExample(in, _)) + + case e => + List(e) + } + } + + def stripOuts = { + map { + case InOutExample(in, out) => + List(InExample(in)) + case e => + List(e) + } + } + + def asString(title: String): String = { + var tt = new Table(title) + + if (examples.nonEmpty) { + + val ow = examples.map { + case InOutExample(_, out) => out.size + case _ => 0 + }.max + + val iw = examples.map(_.ins.size).max + + def testsRows(section: String, ts: Seq[Example]) { + if (tt.rows.nonEmpty) { + tt += Row(Seq( + Cell(" ", iw + ow + 1) + )) + } + + tt += Row(Seq( + Cell(Console.BOLD+section+Console.RESET+":", iw + ow + 1) + )) + tt += Separator + + for (t <- ts) { + val os = t match { + case InOutExample(_, outs) => + outs.map(Cell(_)) + case _ => + Seq(Cell("?", ow)) + } + + tt += Row( + t.ins.map(Cell(_)) ++ Seq(Cell("->")) ++ os + ) + } + } + if (valids.nonEmpty) { + testsRows("Valid tests", valids) + } + if (invalids.nonEmpty) { + testsRows("Invalid tests", invalids) + } + + tt.render + } else { + "No tests." + } + } + +} + +object TestBank { + def empty = TestBank(Nil, Nil) +} + +case class ProblemTestBank(p: Problem, tb: TestBank)(implicit hctx: SearchContext) { + + def removeOuts(toRemove: Set[Identifier]) = { + val toKeep = p.xs.zipWithIndex.filterNot(x => toRemove(x._1)).map(_._2) + + tb mapOuts { out => List(toKeep.map(out)) } + } + + def removeIns(toRemove: Set[Identifier]) = { + val toKeep = p.as.zipWithIndex.filterNot(a => toRemove(a._1)).map(_._2) + tb mapIns { in => List(toKeep.map(in)) } + } + + def filterIns(expr: Expr) = { + val ev = new DefaultEvaluator(hctx.sctx.context, hctx.sctx.program) + + tb mapIns { in => + val m = (p.as zip in).toMap + + ev.eval(expr, m) match { + case EvaluationResults.Successful(BooleanLiteral(true)) => + List(in) + case _ => + Nil + } + } + } + +} diff --git a/src/main/scala/leon/synthesis/rules/Assert.scala b/src/main/scala/leon/synthesis/rules/Assert.scala index baeaa7052c242d62b5dcd28545fe2728566bf54a..bd070e6d06378ad153e10fa5026bf38ad3c6317d 100644 --- a/src/main/scala/leon/synthesis/rules/Assert.scala +++ b/src/main/scala/leon/synthesis/rules/Assert.scala @@ -19,7 +19,7 @@ case object Assert extends NormalizingRule("Assert") { if (others.isEmpty) { Some(solve(Solution(andJoin(exprsA), Set(), tupleWrap(p.xs.map(id => simplestValue(id.getType)))))) } else { - val sub = p.copy(pc = andJoin(p.pc +: exprsA), phi = andJoin(others)) + val sub = p.copy(pc = andJoin(p.pc +: exprsA), phi = andJoin(others), tb = p.tbOps.filterIns(andJoin(exprsA))) Some(decomp(List(sub), { case (s @ Solution(pre, defs, term)) :: Nil => Some(Solution(andJoin(exprsA :+ pre), defs, term, s.isTrusted)) diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index c97d14d1f0a983e46b8359b9299db076b254881d..c0f44d23d2c02b3d2a26f95af835ffadf4063f1b 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -791,8 +791,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { // We populate the list of examples with a predefined one sctx.reporter.debug("Acquiring initial list of examples") - val ef = new ExamplesFinder(sctx.context, sctx.program) - baseExampleInputs ++= ef.extractTests(p).map(_.ins).toSet + baseExampleInputs ++= p.tb.examples.map(_.ins).toSet val pc = p.pc diff --git a/src/main/scala/leon/synthesis/rules/OnePoint.scala b/src/main/scala/leon/synthesis/rules/OnePoint.scala index eb4a1a4b22c72b5e9ec6f6ef6af2fe9cb2cac23c..230102b74a9117d2c3d7a6156c6fca4ec24cab0e 100644 --- a/src/main/scala/leon/synthesis/rules/OnePoint.scala +++ b/src/main/scala/leon/synthesis/rules/OnePoint.scala @@ -31,7 +31,7 @@ case object OnePoint extends NormalizingRule("One-point") { val others = exprs.filter(_ != eq) val oxs = p.xs.filter(_ != x) - val newProblem = Problem(p.as, p.ws, p.pc, subst(x -> e, andJoin(others)), oxs) + val newProblem = Problem(p.as, p.ws, p.pc, subst(x -> e, andJoin(others)), oxs, p.tbOps.removeOuts(Set(x))) val onSuccess: List[Solution] => Option[Solution] = { case List(s @ Solution(pre, defs, term)) => diff --git a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala b/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala index 8e8f6dc1191095d5ea22730febb84253e884207d..f261f3637a059f3f4efbb0fb7b134a61ff9ffe93 100644 --- a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala +++ b/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala @@ -13,7 +13,7 @@ case object UnconstrainedOutput extends NormalizingRule("Unconstr.Output") { val unconstr = p.xs.toSet -- variablesOf(p.phi) if (unconstr.nonEmpty) { - val sub = p.copy(xs = p.xs.filterNot(unconstr)) + val sub = p.copy(xs = p.xs.filterNot(unconstr), tb = p.tbOps.removeOuts(unconstr)) val onSuccess: List[Solution] => Option[Solution] = { case List(s) => diff --git a/src/main/scala/leon/synthesis/rules/UnusedInput.scala b/src/main/scala/leon/synthesis/rules/UnusedInput.scala index 70082316f4718e55ea80be08ede0460b2c49a887..534dbea787e7aec1fed4c38897859cc1c23ac63f 100644 --- a/src/main/scala/leon/synthesis/rules/UnusedInput.scala +++ b/src/main/scala/leon/synthesis/rules/UnusedInput.scala @@ -11,7 +11,7 @@ case object UnusedInput extends NormalizingRule("UnusedInput") { val unused = p.as.toSet -- variablesOf(p.phi) -- variablesOf(p.pc) -- variablesOf(p.ws) if (unused.nonEmpty) { - val sub = p.copy(as = p.as.filterNot(unused)) + val sub = p.copy(as = p.as.filterNot(unused), tb = p.tbOps.removeIns(unused)) List(decomp(List(sub), forward, s"Unused inputs ${p.as.filter(unused).mkString(", ")}")) } else { diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index ebb4bb8faf22ddd1bf25a5b0ab26cbaea41d2add..8ad0dc7f7e63e71b014358457116cd9cabbd890f 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -149,13 +149,13 @@ object ExpressionGrammars { List( Generator(Nil, { _ => IntLiteral(0) }), Generator(Nil, { _ => IntLiteral(1) }), - Generator(Nil, { _ => IntLiteral(42) }) + Generator(Nil, { _ => IntLiteral(5) }) ) case IntegerType => List( Generator(Nil, { _ => InfiniteIntegerLiteral(0) }), Generator(Nil, { _ => InfiniteIntegerLiteral(1) }), - Generator(Nil, { _ => InfiniteIntegerLiteral(42) }) + Generator(Nil, { _ => InfiniteIntegerLiteral(5) }) ) case tp@TypeParameter(_) => diff --git a/src/main/scala/leon/utils/ASCIIHelpers.scala b/src/main/scala/leon/utils/ASCIIHelpers.scala index 524d1fc04206a093af84bcf7434c353814c0bd76..4115078586416f30ddf653781e6191de9c485152 100644 --- a/src/main/scala/leon/utils/ASCIIHelpers.scala +++ b/src/main/scala/leon/utils/ASCIIHelpers.scala @@ -88,10 +88,14 @@ object ASCIIHelpers { } val size = (i to i+c.spanning-1).map(colSizes).sum + (c.spanning-1) * 2 - if (c.align == Left) { - sb append ("%-"+size+"s").format(c.vString) + if (size >= 0) { + if (c.align == Left) { + sb append ("%-"+size+"s").format(c.vString) + } else { + sb append ("%"+size+"s").format(c.vString) + } } else { - sb append ("%"+size+"s").format(c.vString) + sb append c.vString } i += c.spanning