diff --git a/src/main/scala/leon/evaluators/CollectingEvaluator.scala b/src/main/scala/leon/evaluators/CollectingEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..0804d1ab2cd71e5c0b389c5e282847fb6f8275aa --- /dev/null +++ b/src/main/scala/leon/evaluators/CollectingEvaluator.scala @@ -0,0 +1,39 @@ +package leon.evaluators + +import scala.collection.immutable.Map +import leon.purescala.Common._ +import leon.purescala.Trees._ +import leon.purescala.Definitions._ +import leon.LeonContext + +abstract class CollectingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog, 50000) { + type RC = DefaultRecContext + type GC = CollectingGlobalContext + type ES = Seq[Expr] + + def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) + def initGC = new CollectingGlobalContext() + + class CollectingGlobalContext extends GlobalContext { + var collected : Set[Seq[Expr]] = Set() + def collect(es : ES) = collected += es + } + case class DefaultRecContext(mappings: Map[Identifier, Expr]) extends RecContext { + def withVars(news: Map[Identifier, Expr]) = copy(news) + } + + // A function that returns a Seq[Expr] + // This expressions will be evaluated in the current context and then collected in the global environment + def collecting(e : Expr) : Option[ES] + + override def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = { + for { + es <- collecting(expr) + evaled = es map e + } gctx.collect(evaled) + super.e(expr) + } + + def collected : Set[ES] = lastGC map { _.collected } getOrElse Set() + +} diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index a6ebf87a3206de2c003f5aa859ebf021ad38cb2c..309090f384d5c3c8e6afb7e6f87ef7c3b7f945a4 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -21,6 +21,7 @@ import synthesis._ import synthesis.rules._ import synthesis.heuristics._ import graph.DotGenerator +import leon.utils.ASCIIHelpers.title class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeoutMs: Option[Long]) { val reporter = ctx.reporter @@ -138,13 +139,20 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout } private def focusRepair(program: Program, fd: FunDef, passingTests: List[Example], failingTests: List[Example]): (Expr, Expr) = { + reporter.ifDebug { printer => - printer("Tests failing are: ") + printer(title("Tests failing are: ")) failingTests.foreach { ex => printer(ex.ins.mkString(", ")) } } - + reporter.ifDebug { printer => + printer(title("Tests passing are: ")) + passingTests.foreach { ex => + printer(ex.ins.mkString(", ")) + } + } + val pre = fd.precondition.getOrElse(BooleanLiteral(true)) val args = fd.params.map(_.id) @@ -160,19 +168,71 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout val evaluator = new DefaultEvaluator(ctx, program) + val minimalFailingTests = { + // We don't want tests whose invocation will call other failing tests. + // This is because they will appear erroneous, + // even though the error comes from the called test + val testEval : CollectingEvaluator = new CollectingEvaluator(ctx, program){ + def collecting(e : Expr) : Option[Seq[Expr]] = e match { + case fi@FunctionInvocation(TypedFunDef(`fd`, _), args) => + Some(args) + case _ => None + } + } + + val passingTs = for (test <- passingTests) yield InExample(test.ins) + val failingTs = for (test <- failingTests) yield InExample(test.ins) + + val test2Tests : Map[InExample, Set[InExample]] = (failingTs ++ passingTs).map{ ts => + testEval.eval(body, args.zip(ts.ins).toMap) + (ts, testEval.collected map (InExample(_))) + }.toMap + + val recursiveTests : Set[InExample] = test2Tests.values.toSet.flatten -- (failingTs ++ passingTs) + + val testsTransitive : Map[InExample,Set[InExample]] = + leon.utils.GraphOps.transitiveClosure[InExample]( + test2Tests ++ recursiveTests.map ((_,Set[InExample]())) + ) + + val knownWithResults : Map[InExample, Boolean] = (failingTs.map((_, false)).toMap) ++ (passingTs.map((_,true))) + + val recWithResults : Map[InExample, Boolean] = recursiveTests.map { ex => + (ex, evaluator.eval(spec, (args zip ex.ins).toMap + (out -> body)) match { + case EvaluationResults.Successful(BooleanLiteral(true)) => true + case _ => false + }) + }.toMap + + val allWithResults = knownWithResults ++ recWithResults + + + testsTransitive.collect { + case (rest, called) if !allWithResults(rest) && (called forall allWithResults) => + rest + }.toSet + } + + reporter.ifDebug { printer => + printer(title("MinimalTests are: ")) + minimalFailingTests.foreach { ex => + printer(ex.ins.mkString(", ")) + } + } + // Check how an expression behaves on tests // - returns Some(true) if for all tests e evaluates to true // - returns Some(false) if for all tests e evaluates to false // - returns None otherwise def forAllTests(e: Expr, env: Map[Identifier, Expr]): Option[Boolean] = { - val results = failingTests.map { ex => + val results = minimalFailingTests.map { ex => val ins = ex.ins evaluator.eval(e, env ++ (args zip ins)) match { case EvaluationResults.Successful(BooleanLiteral(true)) => Some(true) case EvaluationResults.Successful(BooleanLiteral(false)) => Some(false) case _ => None } - }.distinct + } if (results.size == 1) { results.head @@ -183,35 +243,43 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout def focus(expr: Expr, env: Map[Identifier, Expr]): (Expr, Expr) = expr match { case me @ MatchExpr(scrut, cases) => - var res: Option[(Expr, Expr)] = None - + // in case scrut is an non-variable expr, we simplify to a variable + inject in env - for (c <- cases if res.isEmpty) { + val perCase = for (c <- cases) yield { val cond = and(conditionForPattern(scrut, c.pattern, includeBinders = false), c.optGuard.getOrElse(BooleanLiteral(true))) val map = mapForPattern(scrut, c.pattern) forAllTests(cond, env ++ map) match { - case Some(true) => - val (b, r) = focus(c.rhs, env ++ map) - res = Some((MatchExpr(scrut, cases.map { c2 => - if (c2 eq c) { - c2.copy(rhs = b) - } else { - c2 - } - }), r)) - case Some(false) => - // continue until next case - case None => - res = Some((choose, expr)) + // We know this case is correct + None + + case Some(true) | None => + // Either incorrect case or unknown, treat it as incorrect + val (b, r) = focus(c.rhs, env ++ map) + Some((c.copy(rhs = b), r)) } } - - res.getOrElse((choose, expr)) - + + perCase count { _.isDefined } match { + // No wrong cases + case 0 => ctx.reporter.internalError("No erroneous case found!") + // 1 wrong case + case 1 => + val e = perCase.collect{ case Some((b,r)) => r }.head + val newCases = cases zip perCase map { + case (cs, None) => + cs + case (_ , Some((b,r))) => + b + } + (MatchExpr(scrut, newCases), e) + // More wrong cases, return a choose on the top-level + case _ => (choose, me) + } + case Let(id, value, body) => val (b, r) = focus(body, env + (id -> value)) (Let(id, value, b), r) diff --git a/src/main/scala/leon/synthesis/rules/CegisLike.scala b/src/main/scala/leon/synthesis/rules/CegisLike.scala index a76666c3dd9a950517c6447118faa72471876f4a..2df78babbe7081a3c397431d972755e8f180633f 100644 --- a/src/main/scala/leon/synthesis/rules/CegisLike.scala +++ b/src/main/scala/leon/synthesis/rules/CegisLike.scala @@ -49,7 +49,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val useCETests = sctx.settings.cegisUseCETests val useCEPruning = sctx.settings.cegisUseCEPruning - // Limits the number of programs CEGIS will specifically test for instead of reasonning symbolically + // Limits the number of programs CEGIS will specifically test for instead of reasoning symbolically val testUpTo = 5 val useBssFiltering = sctx.settings.cegisUseBssFiltering val filterThreshold = 1.0/2 diff --git a/src/main/scala/leon/utils/GraphOps.scala b/src/main/scala/leon/utils/GraphOps.scala index 5dba04ece0f906450cf0108f92b7e9c032526d84..505829e57ab8d07357ca59aa76e8450cd86b5f59 100644 --- a/src/main/scala/leon/utils/GraphOps.scala +++ b/src/main/scala/leon/utils/GraphOps.scala @@ -22,6 +22,23 @@ object GraphOps { } tSort(toPreds, Seq()) } + + def transitiveClosure[A](graph: Map[A,Set[A]]) : Map[A,Set[A]] = { + def step(graph : Map[A, Set[A]]) : Map[A,Set[A]] = graph map { + case (k, vs) => (k, vs ++ (vs flatMap { v => + graph.get(v).getOrElse(Set()) + })) + } + leon.purescala.TreeOps.fixpoint(step, -1)(graph) + } + + def sources[A](graph : Map[A,Set[A]]) = { + val notSources = graph.values.toSet.flatten + graph.keySet -- notSources + } + + def sinks[A](graph : Map[A,Set[A]]) = + graph.collect{ case (v, out) if out.isEmpty => v }.toSet /** * Returns the set of reachable nodes from a given node,