diff --git a/library/lang/synthesis/package.scala b/library/lang/synthesis/package.scala index 1b37fc4afc10613c34f29345d30396d1ea1629ce..df19bcf8e280aed3e81c117c91623235aa7d769b 100644 --- a/library/lang/synthesis/package.scala +++ b/library/lang/synthesis/package.scala @@ -35,6 +35,9 @@ package object synthesis { def withOracle[A, R](body: Oracle[A] => R): R = noImpl @library - def terminating[T](t: T) = true + def terminating[T](t: T): Boolean = true + + @library + def guide[T](e: T): Boolean = true } diff --git a/src/main/scala/leon/refactor/Repairman.scala b/src/main/scala/leon/refactor/Repairman.scala index 287c01db46daedc4e7859c6a975b8a4891aeacc7..78666580284d319bc8915d3d37cfa0796433b9e2 100644 --- a/src/main/scala/leon/refactor/Repairman.scala +++ b/src/main/scala/leon/refactor/Repairman.scala @@ -8,6 +8,7 @@ import purescala.Definitions._ import purescala.Trees._ import purescala.TreeOps._ import purescala.TypeTrees._ +import purescala.DefOps._ import purescala.Constructors._ import purescala.ScalaPrinter import evaluators._ @@ -32,6 +33,7 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { val args = fd.params.map(_.id) val argsWrapped = tupleWrap(args.map(_.toVariable)) + // Compute tests val out = fd.postcondition.map(_._1).getOrElse(FreshIdentifier("res", true).setType(fd.returnType)) val tfd = program.library.passes.get.typed(Seq(argsWrapped.getType, out.getType)) @@ -45,10 +47,23 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { val passes = FunctionInvocation(tfd, Seq(argsWrapped, out.toVariable, testsExpr)) - val spec = And(fd.postcondition.map(_._2).getOrElse(BooleanLiteral(true)), passes) + // Compute guide implementation + val gexpr = fd.body.get + val gfd = program.library.guide.get.typed(Seq(gexpr.getType)) + val guide = FunctionInvocation(gfd, Seq(gexpr)) + + val spec = And( + fd.postcondition.map(_._2).getOrElse(BooleanLiteral(true)), + passes + ) + + val pc = And( + pre, + guide + ) // Synthesis from the ground up - val p = Problem(fd.params.map(_.id).toList, pre, spec, List(out)) + val p = Problem(fd.params.map(_.id).toList, pc, spec, List(out)) val soptions = SynthesisPhase.processOptions(ctx); @@ -64,40 +79,12 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { if (!sol.isTrusted) { - val timeoutMs = 3000l - val solverf = SolverFactory(() => (new FairZ3Solver(ctx, npr) with TimeoutSolver).setTimeout(timeoutMs)) - val vctx = VerificationContext(ctx, npr, solverf, reporter) - val nfd = fds.head - val vcs = AnalysisPhase.generateVerificationConditions(vctx, Some(List(nfd.id.name))) - - AnalysisPhase.checkVerificationConditions(vctx, vcs) - - var unknown = false; - var ces = List[Seq[Expr]]() - - for (vc <- vcs.getOrElse(nfd, List())) { - if (vc.value == Some(false)) { - vc.counterExample match { - case Some(m) => - ces = nfd.params.map(vd => m(vd.id)) :: ces; - - case _ => - } - } else if (vc.value == None) { - unknown = true; - } - } - - - if (ces.isEmpty) { - if (!unknown) { + getVerificationCounterExamples(fds.head, npr) match { + case Some(ces) => + testBank ++= ces + reporter.info("Failed :(, but I learned: "+ces.mkString(" | ")) + case None => reporter.info("ZZUCCESS!") - } else { - reporter.info("ZZUCCESS (maybe)!") - } - } else { - reporter.info("Failed :(, but I learned: "+ces.map(_.mkString(",")).mkString(" | ")) - testBank ++= ces.map(InExample(_)) } } else { reporter.info("ZZUCCESS!") @@ -136,6 +123,36 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { } + def getVerificationCounterExamples(fd: FunDef, prog: Program): Option[Seq[InExample]] = { + val timeoutMs = 3000l + val solverf = SolverFactory(() => (new FairZ3Solver(ctx, prog) with TimeoutSolver).setTimeout(timeoutMs)) + val vctx = VerificationContext(ctx, prog, solverf, reporter) + val vcs = AnalysisPhase.generateVerificationConditions(vctx, Some(List(fd.id.name))) + + AnalysisPhase.checkVerificationConditions(vctx, vcs) + + var invalid = false; + var ces = List[Seq[Expr]]() + + for (vc <- vcs.getOrElse(fd, List())) { + if (vc.value == Some(false)) { + invalid = true; + + vc.counterExample match { + case Some(m) => + ces = fd.params.map(vd => m(vd.id)) :: ces; + + case _ => + } + } + } + if (invalid) { + Some(ces.map(InExample(_))) + } else { + None + } + } + def disambiguate(p: Problem, sol1: Solution, sol2: Solution): Option[(InOutExample, InOutExample)] = { val s1 = sol1.toSimplifiedExpr(ctx, program) val s2 = sol2.toSimplifiedExpr(ctx, program) @@ -199,6 +216,13 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { new InExample(i) } } + + // Try to verify, if it fails, we have at least one CE + getVerificationCounterExamples(fd, program) match { + case Some(ces) => + testBank ++= ces + case _ => + } } discoverTests() diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 431eeba3b32c6a9a74ae8a1a56ad25df75ffde2b..fae464a790c904b5e14d50f98b12a466bebbf22d 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -26,6 +26,8 @@ object Rules { InequalitySplit, CEGIS, TEGIS, + GuidedDecomp, + GuidedCloser, rules.Assert, DetupleOutput, DetupleInput, diff --git a/src/main/scala/leon/synthesis/rules/GuidedCloser.scala b/src/main/scala/leon/synthesis/rules/GuidedCloser.scala new file mode 100644 index 0000000000000000000000000000000000000000..097e04a27d6469e0b443eaf109def4464f6a0117 --- /dev/null +++ b/src/main/scala/leon/synthesis/rules/GuidedCloser.scala @@ -0,0 +1,54 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package synthesis +package rules + +import purescala.Trees._ +import purescala.Definitions._ +import purescala.Common._ +import purescala.TypeTrees._ +import purescala.TreeOps._ +import purescala.Extractors._ +import purescala.Constructors._ + +import solvers._ + +case object GuidedCloser extends NormalizingRule("Guided Closer") { + def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + val TopLevelAnds(clauses) = p.pc + + val guide = sctx.program.library.guide.get + + val guides = clauses.collect { + case FunctionInvocation(TypedFunDef(`guide`, _), Seq(expr)) => expr + } + + val alts = guides.flatMap { e => + // Tentative solution using e + val wrappedE = if (p.xs.size == 1) Tuple(Seq(e)) else e + + val vc = And(p.pc, LetTuple(p.xs, wrappedE, Not(p.phi))) + + val solver = sctx.newSolver.setTimeout(1000L) + + solver.assertCnstr(vc) + val osol = solver.check match { + case Some(false) => + Some(Solution(BooleanLiteral(true), Set(), wrappedE, true)) + + case None => + Some(Solution(BooleanLiteral(true), Set(), wrappedE, false)) + + case _ => + None + } + + osol.map { s => + RuleInstantiation.immediateSuccess(p, this, s) + } + } + + alts + } +} diff --git a/src/main/scala/leon/synthesis/rules/GuidedDecomp.scala b/src/main/scala/leon/synthesis/rules/GuidedDecomp.scala new file mode 100644 index 0000000000000000000000000000000000000000..0fb044ca5da0ac91f34dc8cbaec11da1687bdfc3 --- /dev/null +++ b/src/main/scala/leon/synthesis/rules/GuidedDecomp.scala @@ -0,0 +1,47 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package synthesis +package rules + +import purescala.Trees._ +import purescala.Definitions._ +import purescala.Common._ +import purescala.TypeTrees._ +import purescala.TreeOps._ +import purescala.Extractors._ +import purescala.Constructors._ + +import solvers._ + +case object GuidedDecomp extends Rule("Guided Decomp") { + def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + val TopLevelAnds(clauses) = p.pc + + val guide = sctx.program.library.guide.get + + val guides = clauses.collect { + case FunctionInvocation(TypedFunDef(`guide`, _), Seq(expr)) => expr + } + + val alts = guides.collect { + case g @ IfExpr(c, thn, els) => + val sub1 = p.copy(pc = And(c, replace(Map(g -> thn), p.pc))) + val sub2 = p.copy(pc = And(Not(c), replace(Map(g -> els), p.pc))) + + val onSuccess: List[Solution] => Option[Solution] = { + case List(s1, s2) => + Some(Solution(Or(s1.pre, s2.pre), s1.defs++s2.defs, IfExpr(c, s1.term, s2.term))) + case _ => + None + } + + Some(RuleInstantiation.immediateDecomp(p, this, List(sub1, sub2), onSuccess, "Guided If-Split on '"+c+"'")) + + case e => + None + } + + alts.flatten + } +} diff --git a/src/main/scala/leon/utils/Library.scala b/src/main/scala/leon/utils/Library.scala index 4899368aa217eb0e7f8c34372d9a8ef471cc6838..5b5621b142de012ed1d5b58918086f88f0537461 100644 --- a/src/main/scala/leon/utils/Library.scala +++ b/src/main/scala/leon/utils/Library.scala @@ -21,6 +21,10 @@ case class Library(pgm: Program) { case fd: FunDef => fd } + lazy val guide = lookup("leon.lang.synthesis.guide") collect { + case (fd: FunDef) => fd + } + def lookup(name: String): Option[Definition] = { searchByFullName(name, pgm) }