Skip to content
Snippets Groups Projects
Commit 5b39b094 authored by Manos Koukoutos's avatar Manos Koukoutos
Browse files

Small CEGIS fixes

parent 3ac4c266
No related branches found
No related tags found
No related merge requests found
...@@ -37,6 +37,8 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -37,6 +37,8 @@ abstract class CEGISLike(name: String) extends Rule(name) {
def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = {
import hctx.reporter._
val exSolverTo = 2000L val exSolverTo = 2000L
val cexSolverTo = 3000L val cexSolverTo = 3000L
...@@ -184,7 +186,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -184,7 +186,7 @@ abstract class CEGISLike(name: String) extends Rule(name) {
c c
} }
hctx.reporter.ifDebug { printer => ifDebug { printer =>
printer("Grammar so far:") printer("Grammar so far:")
grammar.printProductions(printer) grammar.printProductions(printer)
printer("") printer("")
...@@ -226,7 +228,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -226,7 +228,7 @@ abstract class CEGISLike(name: String) extends Rule(name) {
val c = allProgramsCount() val c = allProgramsCount()
if (c > nProgramsLimit) { if (c > nProgramsLimit) {
hctx.reporter.debug(s"Exceeded program limit: $c > $nProgramsLimit") debug(s"Exceeded program limit: $c > $nProgramsLimit")
return Seq() return Seq()
} }
...@@ -279,7 +281,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -279,7 +281,7 @@ abstract class CEGISLike(name: String) extends Rule(name) {
val outerSolution = { val outerSolution = {
new PartialSolution(hctx.search.strat, true) new PartialSolution(hctx.search.strat, true)
.solutionAround(hctx.currentNode)(FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable))) .solutionAround(hctx.currentNode)(FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)))
.getOrElse(hctx.reporter.fatalError("Unable to get outer solution")) .getOrElse(fatalError("Unable to get outer solution"))
} }
val program0 = addFunDefs(hctx.program, Seq(cTreeFd, phiFd) ++ outerSolution.defs, hctx.functionContext) val program0 = addFunDefs(hctx.program, Seq(cTreeFd, phiFd) ++ outerSolution.defs, hctx.functionContext)
...@@ -476,11 +478,11 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -476,11 +478,11 @@ abstract class CEGISLike(name: String) extends Rule(name) {
println(err) println(err)
println() println()
}*/ }*/
hctx.reporter.debug("RE testing CE: "+err) debug("RE testing CE: "+err)
Some(false) Some(false)
case EvaluationResults.EvaluatorError(err) => case EvaluationResults.EvaluatorError(err) =>
hctx.reporter.debug("Error testing CE: "+err) debug("Error testing CE: "+err)
None None
} }
...@@ -525,7 +527,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -525,7 +527,7 @@ abstract class CEGISLike(name: String) extends Rule(name) {
val eval = new DefaultEvaluator(hctx, innerProgram) val eval = new DefaultEvaluator(hctx, innerProgram)
if (cexs exists (cex => eval.eval(cnstr, p.as.zip(cex).toMap).result == Some(BooleanLiteral(true)))) { if (cexs exists (cex => eval.eval(cnstr, p.as.zip(cex).toMap).result == Some(BooleanLiteral(true)))) {
hctx.reporter.debug(s"Rejected by CEX: $outerSol") debug(s"Rejected by CEX: $outerSol")
excludeProgram(bs, true) excludeProgram(bs, true)
cTreeFd.fullBody = origImpl cTreeFd.fullBody = origImpl
} else { } else {
...@@ -534,11 +536,11 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -534,11 +536,11 @@ abstract class CEGISLike(name: String) extends Rule(name) {
val solverf = SolverFactory.getFromSettings(hctx, innerProgram).withTimeout(cexSolverTo) val solverf = SolverFactory.getFromSettings(hctx, innerProgram).withTimeout(cexSolverTo)
val solver = solverf.getNewSolver() val solver = solverf.getNewSolver()
try { try {
hctx.reporter.debug("Sending candidate to solver...") debug("Sending candidate to solver...")
solver.assertCnstr(cnstr) solver.assertCnstr(cnstr)
solver.check match { solver.check match {
case Some(true) => case Some(true) =>
hctx.reporter.debug(s"Proven invalid: $outerSol") debug(s"Proven invalid: $outerSol")
excludeProgram(bs, true) excludeProgram(bs, true)
val model = solver.getModel val model = solver.getModel
//println("Found counter example: ") //println("Found counter example: ")
...@@ -553,13 +555,13 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -553,13 +555,13 @@ abstract class CEGISLike(name: String) extends Rule(name) {
case Some(false) => case Some(false) =>
// UNSAT, valid program // UNSAT, valid program
hctx.reporter.debug("Found valid program!") debug("Found valid program!")
return Right(Solution(BooleanLiteral(true), Set(), outerSol, true)) return Right(Solution(BooleanLiteral(true), Set(), outerSol, true))
case None => case None =>
debug("Found a non-verifiable solution...")
if (useOptTimeout) { if (useOptTimeout) {
// Optimistic valid solution // Optimistic valid solution
hctx.reporter.debug("Found a non-verifiable solution...")
best = Some(Solution(BooleanLiteral(true), Set(), outerSol, false)) best = Some(Solution(BooleanLiteral(true), Set(), outerSol, false))
} }
} }
...@@ -573,12 +575,16 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -573,12 +575,16 @@ abstract class CEGISLike(name: String) extends Rule(name) {
best.map{ sol => best.map{ sol =>
// Interpret timeout in CE search as "the candidate is valid" // Interpret timeout in CE search as "the candidate is valid"
hctx.reporter.info("CEGIS could not prove the validity of the resulting expression") info("CEGIS could not prove the validity of the resulting expression")
Right(sol) Right(sol)
}.getOrElse(Left(cexs)) }.getOrElse(Left(cexs))
} }
def allProgramsClosed = prunedPrograms.isEmpty def allProgramsClosed = prunedPrograms.isEmpty
def closeAllPrograms() = {
excludedPrograms ++= prunedPrograms
prunedPrograms = Set()
}
// Explicitly remove program computed by bValues from the search space // Explicitly remove program computed by bValues from the search space
// //
...@@ -673,8 +679,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -673,8 +679,7 @@ abstract class CEGISLike(name: String) extends Rule(name) {
* might timeout instead of returning Some(false). We might still * might timeout instead of returning Some(false). We might still
* benefit from unfolding further * benefit from unfolding further
*/ */
hctx.reporter.debug("Timeout while getting tentative program!") None
Some(None)
} }
} finally { } finally {
timers.tentative.stop() timers.tentative.stop()
...@@ -735,7 +740,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -735,7 +740,7 @@ abstract class CEGISLike(name: String) extends Rule(name) {
implicit val ic = hctx implicit val ic = hctx
hctx.reporter.debug("Acquiring initial list of examples") debug("Acquiring initial list of examples")
// To the list of known examples, we add an additional one produced by the solver // To the list of known examples, we add an additional one produced by the solver
val solverExample = if (p.pc == BooleanLiteral(true)) { val solverExample = if (p.pc == BooleanLiteral(true)) {
...@@ -753,12 +758,12 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -753,12 +758,12 @@ abstract class CEGISLike(name: String) extends Rule(name) {
List(InExample(p.as.map(a => model.getOrElse(a, simplestValue(a.getType))))) List(InExample(p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))))
case Some(false) => case Some(false) =>
hctx.reporter.debug("Path-condition seems UNSAT") debug("Path-condition seems UNSAT")
return RuleFailed() return RuleFailed()
case None => case None =>
if (!interruptManager.isInterrupted) { if (!interruptManager.isInterrupted) {
hctx.reporter.warning("Solver could not solve path-condition") warning("Solver could not solve path-condition")
} }
Nil Nil
//return RuleFailed() // This is not necessary though, but probably wanted //return RuleFailed() // This is not necessary though, but probably wanted
...@@ -770,7 +775,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -770,7 +775,7 @@ abstract class CEGISLike(name: String) extends Rule(name) {
val baseExampleInputs = p.eb.examples ++ solverExample val baseExampleInputs = p.eb.examples ++ solverExample
hctx.reporter.ifDebug { debug => ifDebug { debug =>
baseExampleInputs.foreach { in => baseExampleInputs.foreach { in =>
debug(" - "+in.asString) debug(" - "+in.asString)
} }
...@@ -816,7 +821,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -816,7 +821,7 @@ abstract class CEGISLike(name: String) extends Rule(name) {
ndProgram.unfold() ndProgram.unfold()
val nInitial = ndProgram.prunedPrograms.size val nInitial = ndProgram.prunedPrograms.size
hctx.reporter.debug("#Programs: "+nInitial) debug(s"#Programs: $nInitial")
def nPassing = ndProgram.prunedPrograms.size def nPassing = ndProgram.prunedPrograms.size
...@@ -843,8 +848,8 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -843,8 +848,8 @@ abstract class CEGISLike(name: String) extends Rule(name) {
// } // }
//} //}
hctx.reporter.debug("#Tests: "+baseExampleInputs.size) debug(s"#Tests: >= ${gi.bufferedCount}")
hctx.reporter.ifDebug{ printer => ifDebug{ printer =>
for (e <- baseExampleInputs.take(10)) { for (e <- baseExampleInputs.take(10)) {
printer(" - "+e.asString) printer(" - "+e.asString)
} }
...@@ -867,11 +872,11 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -867,11 +872,11 @@ abstract class CEGISLike(name: String) extends Rule(name) {
// Program fails the test // Program fails the test
stop = true stop = true
failedTestsStats(e) += 1 failedTestsStats(e) += 1
hctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}") debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}")
ndProgram.excludeProgram(bs, true) ndProgram.excludeProgram(bs, true)
case None => case None =>
// Eval. error -> bad example // Eval. error -> bad example
hctx.reporter.debug(s" Test $e failed, removing...") debug(s" Test $e crashed the evaluator, removing...")
badExamples ::= e badExamples ::= e
} }
} }
...@@ -880,8 +885,8 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -880,8 +885,8 @@ abstract class CEGISLike(name: String) extends Rule(name) {
timers.filter.stop() timers.filter.stop()
} }
hctx.reporter.debug(s"#Programs passing tests: $nPassing out of $nInitial") debug(s"#Programs passing tests: $nPassing out of $nInitial")
hctx.reporter.ifDebug{ printer => ifDebug{ printer =>
for (p <- ndProgram.prunedPrograms.take(100)) { for (p <- ndProgram.prunedPrograms.take(100)) {
printer(" - "+ndProgram.getExpr(p).asString) printer(" - "+ndProgram.getExpr(p).asString)
} }
...@@ -891,39 +896,39 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -891,39 +896,39 @@ abstract class CEGISLike(name: String) extends Rule(name) {
} }
// CEGIS Loop at a given unfolding level // CEGIS Loop at a given unfolding level
while (result.isEmpty && !interruptManager.isInterrupted && !ndProgram.allProgramsClosed) { while (result.isEmpty && !interruptManager.isInterrupted && !ndProgram.allProgramsClosed) {
hctx.reporter.debug("Programs left: " + ndProgram.prunedPrograms.size) debug("Programs left: " + ndProgram.prunedPrograms.size)
// Phase 0: If the number of remaining programs is small, validate them individually // Phase 0: If the number of remaining programs is small, validate them individually
if (programsReduced()) { if (programsReduced()) {
timers.validate.start() timers.validate.start()
val programsToValidate = ndProgram.prunedPrograms val programsToValidate = ndProgram.prunedPrograms
hctx.reporter.debug(s"Will send ${programsToValidate.size} program(s) to validate individually") debug(s"Will send ${programsToValidate.size} program(s) to validate individually")
ndProgram.validatePrograms(programsToValidate) match { ndProgram.validatePrograms(programsToValidate) match {
case Right(sol) => case Right(sol) =>
// Found solution! Exit CEGIS // Found solution! Exit CEGIS
result = Some(RuleClosed(sol)) result = Some(RuleClosed(sol))
case Left(cexs) => case Left(cexs) =>
hctx.reporter.debug(s"Found cexs! $cexs") debug(s"Found cexs! $cexs")
// Found some counterexamples // Found some counterexamples
// (bear in mind that these will in fact exclude programs within validatePrograms()) // (bear in mind that these will in fact exclude programs within validatePrograms())
val newCexs = cexs.map(InExample) val newCexs = cexs.map(InExample)
newCexs foreach (failedTestsStats(_) += 1) newCexs foreach (failedTestsStats(_) += 1)
gi ++= newCexs gi ++= newCexs
} }
hctx.reporter.debug(s"#Programs after validating individually: ${ndProgram.prunedPrograms.size}") debug(s"#Programs after validating individually: ${ndProgram.prunedPrograms.size}")
timers.validate.stop() timers.validate.stop()
} }
if (result.isEmpty && !ndProgram.allProgramsClosed) { if (result.isEmpty && !ndProgram.allProgramsClosed) {
// Phase 1: Find a candidate program that works for at least 1 input // Phase 1: Find a candidate program that works for at least 1 input
hctx.reporter.debug("Looking for program that works on at least 1 input...") debug("Looking for program that works on at least 1 input...")
ndProgram.solveForTentativeProgram() match { ndProgram.solveForTentativeProgram() match {
case Some(Some(bs)) => case Some(Some(bs)) =>
hctx.reporter.debug(s"Found tentative model ${ndProgram.getExpr(bs)}, need to validate!") debug(s"Found tentative model ${ndProgram.getExpr(bs)}, need to validate!")
// Phase 2: Validate candidate model // Phase 2: Validate candidate model
ndProgram.solveForCounterExample(bs) match { ndProgram.solveForCounterExample(bs) match {
case Some(Some(inputsCE)) => case Some(Some(inputsCE)) =>
hctx.reporter.debug("Found counter-example:" + inputsCE) debug("Found counter-example:" + inputsCE)
val ce = InExample(inputsCE) val ce = InExample(inputsCE)
// Found counterexample! Exclude this program // Found counterexample! Exclude this program
gi += ce gi += ce
...@@ -935,11 +940,11 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -935,11 +940,11 @@ abstract class CEGISLike(name: String) extends Rule(name) {
ndProgram.testForProgram(p)(ce) match { ndProgram.testForProgram(p)(ce) match {
case Some(true) => case Some(true) =>
case Some(false) => case Some(false) =>
hctx.reporter.debug(f" Program: ${ndProgram.getExpr(p).asString}%-80s failed on: ${ce.asString}") debug(f" Program: ${ndProgram.getExpr(p).asString}%-80s failed on: ${ce.asString}")
failedTestsStats(ce) += 1 failedTestsStats(ce) += 1
ndProgram.excludeProgram(p, true) ndProgram.excludeProgram(p, true)
case None => case None =>
hctx.reporter.debug(s" Test $ce failed, removing...") debug(s" Test $ce failed, removing...")
gi -= ce gi -= ce
} }
} }
...@@ -951,26 +956,29 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -951,26 +956,29 @@ abstract class CEGISLike(name: String) extends Rule(name) {
case None => case None =>
// We are not sure // We are not sure
hctx.reporter.debug("Unknown") debug("Unknown")
if (useOptTimeout) { if (useOptTimeout) {
// Interpret timeout in CE search as "the candidate is valid" // Interpret timeout in CE search as "the candidate is valid"
hctx.reporter.info("CEGIS could not prove the validity of the resulting expression") info("CEGIS could not prove the validity of the resulting expression")
val expr = ndProgram.getExpr(bs) val expr = ndProgram.getExpr(bs)
result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false)))
} else { } else {
// Ok, we failed to validate, exclude this program // Ok, we failed to validate, exclude this program
ndProgram.excludeProgram(bs, false) ndProgram.excludeProgram(bs, false)
// TODO: Make CEGIS fail early when it fails on 1 program? // TODO: Make CEGIS fail early when it times out when verifying 1 program?
// result = Some(RuleFailed()) // result = Some(RuleFailed())
} }
} }
case Some(None) => case Some(None) =>
hctx.reporter.debug("There exists no candidate program!") debug("There exists no candidate program!")
ndProgram.prunedPrograms foreach (ndProgram.excludeProgram(_, true)) ndProgram.closeAllPrograms()
case None => case None =>
result = Some(RuleFailed()) debug("Timeout while getting tentative program!")
ndProgram.closeAllPrograms()
// TODO: Make CEGIS fail early when it times out when looking for tentative program?
//result = Some(RuleFailed())
} }
} }
} }
...@@ -982,7 +990,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { ...@@ -982,7 +990,7 @@ abstract class CEGISLike(name: String) extends Rule(name) {
} catch { } catch {
case e: Throwable => case e: Throwable =>
hctx.reporter.warning("CEGIS crashed: "+e.getMessage) warning("CEGIS crashed: "+e.getMessage)
e.printStackTrace() e.printStackTrace()
RuleFailed() RuleFailed()
} }
......
...@@ -19,15 +19,17 @@ class GrowableIterable[T](init: Seq[T], growth: Iterator[T]) extends Iterable[T] ...@@ -19,15 +19,17 @@ class GrowableIterable[T](init: Seq[T], growth: Iterator[T]) extends Iterable[T]
} }
} }
def += (more: T) = buffer += more def += (more: T) = buffer += more
def ++=(more: Seq[T]) = buffer ++= more def ++=(more: Iterable[T]) = buffer ++= more
def -= (less: T) = buffer -= less def -= (less: T) = buffer -= less
def --=(less: Seq[T]) = buffer --= less def --=(less: Iterable[T]) = buffer --= less
def iterator: Iterator[T] = { def iterator: Iterator[T] = {
buffer.iterator ++ cachingIterator buffer.iterator ++ cachingIterator
} }
def bufferedCount = buffer.size
def sortBufferBy[B](f: T => B)(implicit ord: math.Ordering[B]) = { def sortBufferBy[B](f: T => B)(implicit ord: math.Ordering[B]) = {
buffer = buffer.sortBy(f) buffer = buffer.sortBy(f)
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment