Skip to content
Snippets Groups Projects
Commit 3914a9a0 authored by Manos Koukoutos's avatar Manos Koukoutos
Browse files

CEGIS: Skip CE search if all programs have been excluded. Also some fixes

parent 82407106
Branches
Tags
No related merge requests found
...@@ -102,11 +102,11 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ...@@ -102,11 +102,11 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
* b3 => c6 == H(c4, c5) * b3 => c6 == H(c4, c5)
* *
* c1 -> Seq( * c1 -> Seq(
* (b1, F(c2, c3), Set(c2, c3)) * (b1, F(_, _), Seq(c2, c3))
* (b2, G(c4, c5), Set(c4, c5)) * (b2, G(_, _), Seq(c4, c5))
* ) * )
* c6 -> Seq( * c6 -> Seq(
* (b3, H(c7, c8), Set(c7, c8)) * (b3, H(_, _), Seq(c7, c8))
* ) * )
*/ */
private var cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]] = Map() private var cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]] = Map()
...@@ -198,6 +198,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ...@@ -198,6 +198,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
} }
bsOrdered = bs.toSeq.sorted bsOrdered = bs.toSeq.sorted
excludedPrograms = ArrayBuffer()
setCExpr(computeCExpr()) setCExpr(computeCExpr())
ctx.timers.synthesis.cegis.updateCTree.stop() ctx.timers.synthesis.cegis.updateCTree.stop()
...@@ -250,44 +251,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ...@@ -250,44 +251,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
SeqUtils.cartesianProduct(seqs).map(_.flatten.toSet) SeqUtils.cartesianProduct(seqs).map(_.flatten.toSet)
} }
def redundant(e: Expr): Boolean = { allProgramsFor(Seq(rootC))
val (op1, op2) = e match {
case Minus(o1, o2) => (o1, o2)
case Modulo(o1, o2) => (o1, o2)
case Division(o1, o2) => (o1, o2)
case BVMinus(o1, o2) => (o1, o2)
case BVRemainder(o1, o2) => (o1, o2)
case BVDivision(o1, o2) => (o1, o2)
case And(Seq(Not(o1), Not(o2))) => (o1, o2)
case And(Seq(Not(o1), o2)) => (o1, o2)
case And(Seq(o1, Not(o2))) => (o1, o2)
case And(Seq(o1, o2)) => (o1, o2)
case Or(Seq(Not(o1), Not(o2))) => (o1, o2)
case Or(Seq(Not(o1), o2)) => (o1, o2)
case Or(Seq(o1, Not(o2))) => (o1, o2)
case Or(Seq(o1, o2)) => (o1, o2)
case SetUnion(o1, o2) => (o1, o2)
case SetIntersection(o1, o2) => (o1, o2)
case SetDifference(o1, o2) => (o1, o2)
case Equals(Not(o1), Not(o2)) => (o1, o2)
case Equals(Not(o1), o2) => (o1, o2)
case Equals(o1, Not(o2)) => (o1, o2)
case Equals(o1, o2) => (o1, o2)
case _ => return false
}
op1 == op2
}
allProgramsFor(Seq(rootC))/* filterNot { bs =>
val res = params.optimizations && exists(redundant)(getExpr(bs))
if (!res) excludeProgram(bs, false)
res
}*/
} }
private def debugCTree(cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]], private def debugCTree(cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]],
...@@ -377,8 +341,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ...@@ -377,8 +341,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
Lambda(p.xs.map(ValDef), p.phi) Lambda(p.xs.map(ValDef), p.phi)
) )
phiFd.body = Some(
phiFd.body = Some(
letTuple(p.xs, letTuple(p.xs,
FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)),
p.phi) p.phi)
...@@ -566,6 +529,8 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ...@@ -566,6 +529,8 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
var excludedPrograms = ArrayBuffer[Set[Identifier]]() var excludedPrograms = ArrayBuffer[Set[Identifier]]()
def allProgramsClosed = allProgramsCount() <= excludedPrograms.size
// Explicitly remove program computed by bValues from the search space // Explicitly remove program computed by bValues from the search space
// //
// If the bValues comes from models, we make sure the bValues we exclude // If the bValues comes from models, we make sure the bValues we exclude
...@@ -825,18 +790,12 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ...@@ -825,18 +790,12 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
if (hasInputExamples) { if (hasInputExamples) {
timers.filter.start() timers.filter.start()
for (bs <- prunedPrograms if !interruptManager.isInterrupted) { for (bs <- prunedPrograms if !interruptManager.isInterrupted) {
var valid = true
val examples = allInputExamples() val examples = allInputExamples()
while(valid && examples.hasNext) { examples.find(e => !ndProgram.testForProgram(bs)(e)).foreach { e =>
val e = examples.next() failedTestsStats(e) += 1
if (!ndProgram.testForProgram(bs)(e)) { sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}")
failedTestsStats(e) += 1 wrongPrograms += bs
sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}") prunedPrograms -= bs
wrongPrograms += bs
prunedPrograms -= bs
valid = false
}
} }
if (wrongPrograms.size+1 % 1000 == 0) { if (wrongPrograms.size+1 % 1000 == 0) {
...@@ -850,14 +809,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ...@@ -850,14 +809,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
val nTotal = ndProgram.allProgramsCount() val nTotal = ndProgram.allProgramsCount()
//println(s"Iotal: $nTotal, passing: $nPassing") //println(s"Iotal: $nTotal, passing: $nPassing")
/*locally {
val progs = ndProgram.allPrograms() map ndProgram.getExpr
val ground = progs count isGround
println("Programs")
progs take 100 foreach println
println(s"$ground ground out of $nTotal")
}*/
sctx.reporter.debug(s"#Programs passing tests: $nPassing out of $nTotal") sctx.reporter.debug(s"#Programs passing tests: $nPassing out of $nTotal")
sctx.reporter.ifDebug{ printer => sctx.reporter.ifDebug{ printer =>
for (p <- prunedPrograms.take(100)) { for (p <- prunedPrograms.take(100)) {
...@@ -898,14 +849,11 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ...@@ -898,14 +849,11 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
val newCexs = cexs.map(InExample) val newCexs = cexs.map(InExample)
baseExampleInputs ++= newCexs baseExampleInputs ++= newCexs
// Retest whether the newly found C-E invalidates some programs // Retest whether the newly found C-E invalidates some programs
for (p <- otherPrograms) { for (p <- otherPrograms if !interruptManager.isInterrupted) {
// Exclude any programs that fail the new cex's // Exclude any programs that fail at least one new cex
var valid = true newCexs.find { cex => !ndProgram.testForProgram(p)(cex) }.foreach { cex =>
newCexs.takeWhile(_ => valid).foreach { cex => failedTestsStats(cex) += 1
if (!ndProgram.testForProgram(p)(cex)) { ndProgram.excludeProgram(p, true)
ndProgram.excludeProgram(p, true)
valid = false
}
} }
} }
// If we excluded all programs, we can skip CE search // If we excluded all programs, we can skip CE search
...@@ -921,7 +869,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ...@@ -921,7 +869,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
} }
// CEGIS Loop at a given unfolding level // CEGIS Loop at a given unfolding level
while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted) { while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted && !ndProgram.allProgramsClosed) {
timers.loop.start() timers.loop.start()
ndProgram.solveForTentativeProgram() match { ndProgram.solveForTentativeProgram() match {
case Some(Some(bs)) => case Some(Some(bs)) =>
...@@ -968,12 +916,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ...@@ -968,12 +916,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) {
case None => case None =>
result = Some(RuleFailed()) result = Some(RuleFailed())
} }
timers.loop.stop() timers.loop.stop()
} }
unfolding += 1 unfolding += 1
} while(unfolding <= maxSize && result.isEmpty && !interruptManager.isInterrupted) } while(unfolding <= maxSize && result.isEmpty && !interruptManager.isInterrupted)
if (interruptManager.isInterrupted) interruptManager.recoverInterrupt()
result.getOrElse(RuleFailed()) result.getOrElse(RuleFailed())
} catch { } catch {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment