diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 28819f634f018fc8e83f29de12645232dc8efea1..9dcd782a3323138e5bfe4c68822fba614e2065e7 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -155,7 +155,7 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou }(DebugSectionReport) if (synth.settings.generateDerivationTrees) { - val dot = new DotGenerator(search.g) + val dot = new DotGenerator(search) dot.writeFile("derivation"+ dotGenIds.nextGlobal + ".dot") } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 4744d8d606264b2c1f503a9511281b374e4ab51a..ac4d30614d8269ce78a92a95e232855eb40d9fbd 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -82,7 +82,7 @@ object SynthesisPhase extends TransformationPhase { try { if (options.generateDerivationTrees) { - val dot = new DotGenerator(search.g) + val dot = new DotGenerator(search) dot.writeFile("derivation"+dotGenIds.nextGlobal+".dot") } diff --git a/src/main/scala/leon/synthesis/graph/DotGenerator.scala b/src/main/scala/leon/synthesis/graph/DotGenerator.scala index 7da38716116f51d89e751a8aa12d709be776e17c..78ef7b371487a6711d3508b9712f7806e9c551e0 100644 --- a/src/main/scala/leon/synthesis/graph/DotGenerator.scala +++ b/src/main/scala/leon/synthesis/graph/DotGenerator.scala @@ -6,7 +6,11 @@ import leon.utils.UniqueCounter import java.io.{File, FileWriter, BufferedWriter} -class DotGenerator(g: Graph) { +class DotGenerator(search: Search) { + + implicit val ctx = search.ctx + + val g = search.g private val idCounter = new UniqueCounter[Unit] idCounter.nextGlobal // Start with 1 @@ -80,12 +84,14 @@ class DotGenerator(g: Graph) { } def nodeDesc(n: Node): String = n match { - case an: AndNode => an.ri.toString - case on: OrNode => on.p.toString + case an: AndNode => an.ri.asString + case on: OrNode => on.p.asString } def drawNode(res: StringBuffer, name: String, n: Node) { + val index = n.parent.map(_.descendants.indexOf(n) + " ").getOrElse("") + def escapeHTML(str: String) = str.replaceAll("&", "&").replaceAll("<", "<").replaceAll(">", ">") val color = if (n.isSolved) { @@ -109,10 +115,10 @@ class DotGenerator(g: Graph) { res append "<TR><TD BORDER=\"0\">"+escapeHTML(n.cost.asString)+"</TD></TR>" } - res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(nodeDesc(n)))+"</TD></TR>" + res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(index + nodeDesc(n)))+"</TD></TR>" if (n.isSolved) { - res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(limit(n.generateSolutions().head.toString))+"</TD></TR>" + res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(limit(n.generateSolutions().head.asString))+"</TD></TR>" } res append "</TABLE>>, shape = \"none\" ];\n" @@ -126,4 +132,4 @@ class DotGenerator(g: Graph) { } } -object dotGenIds extends UniqueCounter[Unit] \ No newline at end of file +object dotGenIds extends UniqueCounter[Unit] diff --git a/src/main/scala/leon/synthesis/graph/Search.scala b/src/main/scala/leon/synthesis/graph/Search.scala index 98554a5ae492972e0b7b3915979d9af829d81555..c630e315d9777110b5dcde7adc42cf6172161af3 100644 --- a/src/main/scala/leon/synthesis/graph/Search.scala +++ b/src/main/scala/leon/synthesis/graph/Search.scala @@ -10,7 +10,7 @@ import scala.collection.mutable.ArrayBuffer import leon.utils.Interruptible import java.util.concurrent.atomic.AtomicBoolean -abstract class Search(ctx: LeonContext, ci: SourceInfo, p: Problem, costModel: CostModel) extends Interruptible { +abstract class Search(val ctx: LeonContext, ci: SourceInfo, p: Problem, costModel: CostModel) extends Interruptible { val g = new Graph(costModel, p) def findNodeToExpandFrom(n: Node): Option[Node] diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index 5c3a1120f429770cb37b550251a50246d8afddb4..4a0a315d73bec6ff090e4bda2665b3195b401f43 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -374,7 +374,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { solFd.fullBody = Ensuring( FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), - Lambda(p.xs.map(ValDef(_)), p.phi) + Lambda(p.xs.map(ValDef), p.phi) ) @@ -798,8 +798,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { try { do { - var skipCESearch = false - // Unfold formula ndProgram.unfold() @@ -879,42 +877,46 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } } - if (nPassing == 0 || interruptManager.isInterrupted) { - // No test passed, we can skip solver and unfold again, if possible - skipCESearch = true - } else { - var doFilter = true - + // We can skip CE search if - we have excluded all programs or - we do so with validatePrograms + var skipCESearch = nPassing == 0 || interruptManager.isInterrupted || { // If the number of pruned programs is very small, or by far smaller than the number of total programs, // we hypothesize it will be easier to just validate them individually. // Otherwise, we validate a small number of programs just in case we are lucky FIXME is this last clause useful? - val programsToValidate = if (nTotal / nPassing > passingRatio || nPassing < 10) { - prunedPrograms + val (programsToValidate, otherPrograms) = if (nTotal / nPassing > passingRatio || nPassing < 10) { + (prunedPrograms, Nil) } else { - prunedPrograms.take(validateUpTo) + prunedPrograms.splitAt(validateUpTo) } - if (programsToValidate.nonEmpty) { - ndProgram.validatePrograms(programsToValidate) match { - case Left(sols) if sols.nonEmpty => - doFilter = false - result = Some(RuleClosed(sols)) - case Right(cexs) => - baseExampleInputs ++= cexs.map(InExample) - - if (nPassing <= validateUpTo) { - // All programs failed verification, we filter everything out and unfold - doFilter = false - skipCESearch = true + ndProgram.validatePrograms(programsToValidate) match { + case Left(sols) if sols.nonEmpty => + // Found solution! Exit CEGIS + result = Some(RuleClosed(sols)) + true + case Right(cexs) => + // Found some counterexamples + val newCexs = cexs.map(InExample) + baseExampleInputs ++= newCexs + // Retest whether the newly found C-E invalidates some programs + for (p <- otherPrograms) { + // Exclude any programs that fail the new cex's + var valid = true + newCexs.takeWhile(_ => valid).foreach { cex => + if (!ndProgram.testForProgram(p)(cex)) { + ndProgram.excludeProgram(p, true) + valid = false + } } - } + } + // If we excluded all programs, we can skip CE search + programsToValidate.size >= nPassing } + } - if (doFilter) { - sctx.reporter.debug("Excluding "+wrongPrograms.size+" programs") - wrongPrograms.foreach { - ndProgram.excludeProgram(_, true) - } + if (!skipCESearch) { + sctx.reporter.debug("Excluding "+wrongPrograms.size+" programs") + wrongPrograms.foreach { + ndProgram.excludeProgram(_, true) } } @@ -923,57 +925,41 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { timers.loop.start() ndProgram.solveForTentativeProgram() match { case Some(Some(bs)) => - // Should we validate this program with Z3? - - val validateWithZ3 = if (hasInputExamples) { - - if (allInputExamples().forall(ndProgram.testForProgram(bs))) { - // All valid inputs also work with this, we need to - // make sure by validating this candidate with z3 - true - } else { - // One valid input failed with this candidate, we can skip + // No inputs to test or all valid inputs also work with this. + // We need to make sure by validating this candidate with z3 + sctx.reporter.debug("Found tentative model, need to validate!") + ndProgram.solveForCounterExample(bs) match { + case Some(Some(inputsCE)) => + sctx.reporter.debug("Found counter-example:" + inputsCE) + val ce = InExample(inputsCE) + // Found counter example! Exclude this program + baseExampleInputs += ce ndProgram.excludeProgram(bs, false) - false - } - } else { - // No inputs or capability to test, we need to ask Z3 - true - } - sctx.reporter.debug("Found tentative model (Validate="+validateWithZ3+")!") - - if (validateWithZ3) { - ndProgram.solveForCounterExample(bs) match { - case Some(Some(inputsCE)) => - sctx.reporter.debug("Found counter-example:"+inputsCE) - val ce = InExample(inputsCE) - // Found counter example! - baseExampleInputs += ce - - // Retest whether the newly found C-E invalidates all programs - if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(ce))) { - skipCESearch = true - } else { - ndProgram.excludeProgram(bs, false) - } - - case Some(None) => - // Found no counter example! Program is a valid solution + + // Retest whether the newly found C-E invalidates some programs + prunedPrograms.foreach { p => + if (!ndProgram.testForProgram(p)(ce)) ndProgram.excludeProgram(p, true) + } + + case Some(None) => + // Found no counter example! Program is a valid solution + val expr = ndProgram.getExpr(bs) + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr))) + + case None => + // We are not sure + sctx.reporter.debug("Unknown") + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") val expr = ndProgram.getExpr(bs) - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr))) - - case None => - // We are not sure - sctx.reporter.debug("Unknown") - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - val expr = ndProgram.getExpr(bs) - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) - } else { - result = Some(RuleFailed()) - } - } + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) + } else { + // Ok, we failed to validate, exclude this program + ndProgram.excludeProgram(bs, false) + // TODO: Make CEGIS fail early when it fails on 1 program? + // result = Some(RuleFailed()) + } } case Some(None) => diff --git a/testcases/synthesis/etienne-thesis/run.sh b/testcases/synthesis/etienne-thesis/run.sh index 81dd480de10e1a05ee6dd9e36cb991fbc92f74a0..924b99cc57386f1dba92bfb97017b41a801cd8ea 100755 --- a/testcases/synthesis/etienne-thesis/run.sh +++ b/testcases/synthesis/etienne-thesis/run.sh @@ -1,7 +1,7 @@ #!/bin/bash function run { - cmd="./leon --debug=report --timeout=30 --synthesis $1" + cmd="./leon --debug=report --timeout=30 --synthesis --cegis:maxsize=5 $1" echo "Running " $cmd echo "------------------------------------------------------------------------------------------------------------------" $cmd;