diff --git a/library/annotation/package.scala b/library/annotation/package.scala index f6e084cdac1cdc9dfc6bf20999aeb0c2aaf529dc..2db992498655300eba0ef908800976df1b46eec0 100644 --- a/library/annotation/package.scala +++ b/library/annotation/package.scala @@ -11,6 +11,8 @@ package object annotation { class verified extends StaticAnnotation @ignore class repair extends StaticAnnotation + @ignore + class witness extends StaticAnnotation @ignore class induct extends StaticAnnotation diff --git a/library/lang/synthesis/package.scala b/library/lang/synthesis/package.scala index 9fd244605123a6fcf1dfd4d4bd4b51e8a9fdb383..62e08280c27587e195230f33d852972384b06cb8 100644 --- a/library/lang/synthesis/package.scala +++ b/library/lang/synthesis/package.scala @@ -39,9 +39,11 @@ package object synthesis { def withOracle[A, R](body: Oracle[A] => R): R = noImpl @library + @witness def terminating[T](t: T): Boolean = true @library + @witness def guide[T](e: T): Boolean = true } diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 8bab198da5de4fe70ae318687826881452eca7d1..f216eb33d6fd413c49af61fbe4d1a279a3c5f7bb 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -31,29 +31,6 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout implicit val debugSection = DebugSectionRepair def getSynthesizer(tests: List[Example]): Synthesizer = { - // Gather in/out tests - val pre = fd.precondition.getOrElse(BooleanLiteral(true)) - val args = fd.params.map(_.id) - val argsWrapped = tupleWrap(args.map(_.toVariable)) - - // Compute tests - val out = fd.postcondition.map(_._1).getOrElse(FreshIdentifier("res", true).setType(fd.returnType)) - - val testsCases = tests.collect { - case InOutExample(ins, outs) => - val (patt, optGuard) = expressionToPattern(tupleWrap(ins)) - MatchCase(patt, optGuard match { - case BooleanLiteral(true) => None - case guard => Some(guard) - }, tupleWrap(outs)) - }.toList - - val passes = if (testsCases.nonEmpty) { - Passes(argsWrapped, out.toVariable, testsCases) - } else { - BooleanLiteral(true) - } - // Create a fresh function val nid = FreshIdentifier(fd.id.name+"_repair").copiedFrom(fd.id) val nfd = new FunDef(nid, fd.tparams, fd.returnType, fd.params, fd.defType) @@ -66,15 +43,16 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout case _ => None })(b)} - val spec = and( - nfd.postcondition.map(_._2).getOrElse(BooleanLiteral(true)), - passes - ) - program = program.addDefinition(nfd, fd) - val p = focusRepair(tests, nfd, spec, out) + val body = nfd.body.get; + + val (newBody, replacedExpr) = focusRepair(program, nfd, tests) + nfd.body = Some(newBody) + val guide = guideOf(replacedExpr) + + // Return synthesizer for this choose val soptions0 = SynthesisPhase.processOptions(ctx); val soptions = soptions0.copy( @@ -86,6 +64,13 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout )) diff Seq(ADTInduction) ); + // extract chooses from nfd + val Seq(ci) = ChooseInfo.extractFromFunction(ctx, program, nfd, soptions) + + val nci = ci.copy(pc = and(ci.pc, guideOf(replacedExpr))) + + val p = nci.problem + new Synthesizer(ctx, nfd, program, p, soptions) } @@ -94,33 +79,131 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout FunctionInvocation(gfd, Seq(expr)) } - private def focusRepair(tests: List[Example], fd: FunDef, post: Expr, out: Identifier): Problem = { + private def focusRepair(program: Program, fd: FunDef, tests: List[Example]): (Expr, Expr) = { + // Compute tests + val failingTests = tests.collect { + case InExample(ins) => ins + } + reporter.ifDebug { printer => printer("Tests failing are: ") - tests.collect { - case InExample(ins) => - printer(ins.mkString(", ")) + failingTests.foreach { ins => + printer(ins.mkString(", ")) } } - // Compute initial call for terminating argument - val termfd = program.library.terminating.get - val withinCall = FunctionInvocation(fd.typedWithDef, fd.params.map(_.id.toVariable)) - val terminating = FunctionInvocation(termfd.typed(Seq(fd.returnType)), Seq(withinCall)) + val testsCases = tests.collect { + case InOutExample(ins, outs) => + val (patt, optGuard) = expressionToPattern(tupleWrap(ins)) + MatchCase(patt, optGuard match { + case BooleanLiteral(true) => None + case guard => Some(guard) + }, tupleWrap(outs)) + }.toList val pre = fd.precondition.getOrElse(BooleanLiteral(true)) + val args = fd.params.map(_.id) + val argsWrapped = tupleWrap(args.map(_.toVariable)) - val body = fd.body.get + val out = fd.postcondition.map(_._1).getOrElse(FreshIdentifier("res", true).setType(fd.returnType)) - val ws = and( - guideOf(body), - terminating + val passes = if (testsCases.nonEmpty) { + Passes(argsWrapped, out.toVariable, testsCases) + } else { + BooleanLiteral(true) + } + + val spec = and( + fd.postcondition.map(_._2).getOrElse(BooleanLiteral(true)), + passes ) - // Synthesis from the ground up - val p = Problem(fd.params.map(_.id).toList, ws, pre, post, List(out)) + val body = fd.body.get + + val choose = Choose(List(out), spec) + + val evaluator = new DefaultEvaluator(ctx, program) + + // Check how an expression behaves on tests + // - returns Some(true) if for all tests e evaluates to true + // - returns Some(false) if for all tests e evaluates to false + // - returns None otherwise + def forAllTests(e: Expr, env: Map[Identifier, Expr]): Option[Boolean] = { + val results = failingTests.map { ins => + evaluator.eval(e, env ++ (args zip ins)) match { + case EvaluationResults.Successful(BooleanLiteral(true)) => Some(true) + case EvaluationResults.Successful(BooleanLiteral(false)) => Some(false) + case _ => None + } + }.distinct + + if (results.size == 1) { + results.head + } else { + None + } + } + + def focus(expr: Expr, env: Map[Identifier, Expr]): (Expr, Expr) = expr match { + case me @ MatchExpr(scrut, cases) => + var res: Option[(Expr, Expr)] = None + + // in case scrut is an non-variable expr, we simplify to a variable + inject in env + val (sid, nenv) = scrut match { + case Variable(id) => + (id, env) + case expr => + val id = FreshIdentifier("scrut", true).copiedFrom(scrut) + (id, env + (id -> scrut)) + } + + for (c <- cases if res.isEmpty) { + val cond = and(conditionForPattern(sid.toVariable, c.pattern, includeBinders = false), + c.optGuard.getOrElse(BooleanLiteral(true))) + val map = mapForPattern(sid.toVariable, c.pattern) + + + forAllTests(cond, nenv ++ map) match { + case Some(true) => + val (b, r) = focus(c.rhs, nenv ++ map) + res = Some((MatchExpr(scrut, cases.map { c2 => + if (c2 eq c) { + c2.copy(rhs = b) + } else { + c2 + } + }), r)) + + case Some(false) => + // continue until next case + case None => + res = Some((choose, expr)) + } + } + + res.getOrElse((choose, expr)) + + case Let(id, value, body) => + val (b, r) = focus(body, env + (id -> value)) + (Let(id, value, b), r) + + case IfExpr(c, thn, els) => + forAllTests(c, env) match { + case Some(true) => + val (b, r) = focus(thn, env) + (IfExpr(c, b, els), r) + case Some(false) => + val (b, r) = focus(els, env) + (IfExpr(c, thn, b), r) + case None => + (choose, expr) + } + + case _ => + (choose, expr) + } - p + focus(body, Map()) } def getVerificationCounterExamples(fd: FunDef, prog: Program): Option[Seq[InExample]] = { @@ -188,7 +271,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout reporter.info(ASCIIHelpers.title("1. Discovering tests for "+fd.id)) val tests = discoverTests - reporter.info(ASCIIHelpers.title("2. Creating synthesis Problem")) + reporter.info(ASCIIHelpers.title("2. Locating/Focusing synthesis problem")) val synth = getSynthesizer(tests) val p = synth.problem diff --git a/src/main/scala/leon/synthesis/ChooseInfo.scala b/src/main/scala/leon/synthesis/ChooseInfo.scala index 69a82a4cfe0dd751aa029eeee2a745e8f69c7f1e..2f31d5129d02c0174ec21887b980999cf7622141 100644 --- a/src/main/scala/leon/synthesis/ChooseInfo.scala +++ b/src/main/scala/leon/synthesis/ChooseInfo.scala @@ -24,21 +24,25 @@ case class ChooseInfo(ctx: LeonContext, object ChooseInfo { def extractFromProgram(ctx: LeonContext, prog: Program, options: SynthesisOptions): List[ChooseInfo] = { - val fterm = prog.library.terminating.getOrElse(ctx.reporter.fatalError("No library ?!?")) - - var results = List[ChooseInfo]() // Look for choose() - for (f <- prog.definedFunctions if f.body.isDefined) { - val actualBody = and(f.precondition.getOrElse(BooleanLiteral(true)), f.body.get) - val withinCall = FunctionInvocation(f.typedWithDef, f.params.map(_.id.toVariable)) - val term = FunctionInvocation(fterm.typed(Seq(f.returnType)), Seq(withinCall)) - - for ((ch, path) <- new ChooseCollectorWithPaths().traverse(actualBody)) { - results = ChooseInfo(ctx, prog, f, and(path, term), ch, ch, options) :: results - } + val results = for (f <- prog.definedFunctions if f.body.isDefined; + ci <- extractFromFunction(ctx, prog, f, options)) yield { + ci } results.sortBy(_.source.getPos) } + + def extractFromFunction(ctx: LeonContext, prog: Program, fd: FunDef, options: SynthesisOptions): Seq[ChooseInfo] = { + val fterm = prog.library.terminating.getOrElse(ctx.reporter.fatalError("No library ?!?")) + + val actualBody = and(fd.precondition.getOrElse(BooleanLiteral(true)), fd.body.get) + val withinCall = FunctionInvocation(fd.typedWithDef, fd.params.map(_.id.toVariable)) + val term = FunctionInvocation(fterm.typed(Seq(fd.returnType)), Seq(withinCall)) + + for ((ch, path) <- new ChooseCollectorWithPaths().traverse(actualBody)) yield { + ChooseInfo(ctx, prog, fd, and(path, term), ch, ch, options) + } + } } diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index ebd01a52ae4e0268ed0079858267a789cad6e0b1..3650a2ab4658b7e90d4a1c4f5070f44bde20f1ee 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -4,10 +4,12 @@ package leon package synthesis import leon.purescala.Trees._ +import leon.purescala.Definitions._ import leon.purescala.TreeOps._ import leon.purescala.TypeTrees.TypeTree import leon.purescala.Common._ import leon.purescala.Constructors._ +import leon.purescala.Extractors._ // Defines a synthesis triple of the form: // ⟦ as ⟨ ws && pc | phi ⟩ xs ⟧ @@ -217,6 +219,13 @@ object Problem { val phi = simplifyLets(ch.pred) val as = (variablesOf(And(pc, phi))--xs).toList - Problem(as, BooleanLiteral(true), pc, phi, xs) + val TopLevelAnds(clauses) = pc + + val (pcs, wss) = clauses.partition { + case FunctionInvocation(TypedFunDef(fd, _), _) if fd.annotations("witness") => false + case _ => true + } + + Problem(as, andJoin(wss), andJoin(pcs), phi, xs) } }