-
Etienne Kneuss authoredEtienne Kneuss authored
Cegis.scala 22.85 KiB
package leon
package synthesis
package rules
import solvers.TimeoutSolver
import purescala.Trees._
import purescala.Common._
import purescala.Definitions._
import purescala.TypeTrees._
import purescala.TreeOps._
import purescala.Extractors._
import purescala.ScalaPrinter
import evaluators._
import solvers.z3.FairZ3Solver
case object CEGIS extends Rule("CEGIS") {
def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
// CEGIS Flags to actiave or de-activate features
val useCEAsserts = false
val useUninterpretedProbe = false
val useUnsatCores = true
val useFunGenerators = sctx.options.cegisGenerateFunCalls
val useBPaths = sctx.options.cegisUseBPaths
val useCETests = sctx.options.cegisUseCETests
val useCEPruning = sctx.options.cegisUseCEPruning
val evaluator = new CodeGenEvaluator(sctx.context, sctx.program)
case class Generator(tpe: TypeTree, altBuilder: () => List[(Expr, Set[Identifier])]);
var generators = Map[TypeTree, Generator]()
def getGenerator(t: TypeTree): Generator = generators.get(t) match {
case Some(g) => g
case None =>
val alternatives: () => List[(Expr, Set[Identifier])] = t match {
case BooleanType =>
{ () => List((BooleanLiteral(true), Set()), (BooleanLiteral(false), Set())) }
case Int32Type =>
{ () => List((IntLiteral(0), Set()), (IntLiteral(1), Set())) }
case TupleType(tps) =>
{ () =>
val ids = tps.map(t => FreshIdentifier("t", true).setType(t))
List((Tuple(ids.map(Variable(_))), ids.toSet))
}
case CaseClassType(cd) =>
{ () =>
val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType))
List((CaseClass(cd, ids.map(Variable(_))), ids.toSet))
}
case AbstractClassType(cd) =>
{ () =>
val alts: Seq[(Expr, Set[Identifier])] = cd.knownDescendents.flatMap(i => i match {
case acd: AbstractClassDef =>
sctx.reporter.error("Unnexpected abstract class in descendants!")
None
case cd: CaseClassDef =>
val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType))
Some((CaseClass(cd, ids.map(Variable(_))), ids.toSet))
})
alts.toList
}
case _ =>
sctx.reporter.error("Can't construct generator. Unsupported type: "+t+"["+t.getClass+"]");
{ () => Nil }
}
val g = Generator(t, alternatives)
generators += t -> g
g
}
def inputAlternatives(t: TypeTree): List[(Expr, Set[Identifier])] = {
p.as.filter(a => isSubtypeOf(a.getType, t)).map(id => (Variable(id) : Expr, Set[Identifier]()))
}
def funcAlternatives(t: TypeTree): List[(Expr, Set[Identifier])] = {
if (useFunGenerators) {
def isCandidate(fd: FunDef): Boolean = {
// Prevents recursive calls
val isRecursiveCall = sctx.functionContext match {
case Some(cfd) =>
(sctx.program.transitiveCallers(cfd) + cfd) contains fd
case None =>
false
}
val isNotSynthesizable = fd.body match {
case Some(b) =>
collectChooses(b).isEmpty
case None =>
false
}
isSubtypeOf(fd.returnType, t) && !isRecursiveCall && isNotSynthesizable
}
sctx.program.definedFunctions.filter(isCandidate).map{ fd =>
val ids = fd.args.map(vd => FreshIdentifier("c", true).setType(vd.getType))
(FunctionInvocation(fd, ids.map(Variable(_))), ids.toSet)
}.toList
} else {
Nil
}
}
class NonDeterministicProgram(val p: Problem,
val initGuard: Identifier) {
//var program: Expr = BooleanLiteral(true)
// b -> (c, ex) means the clause b => c == ex
var mappings: Map[Identifier, (Identifier, Expr)] = Map()
// b -> Set(c1, c2) means c1 and c2 are uninterpreted behing b, requires b to be closed
private var guardedTerms: Map[Identifier, Set[Identifier]] = Map(initGuard -> p.xs.toSet)
def isBClosed(b: Identifier) = guardedTerms.contains(b)
// b -> Map(c1 -> Set(b2, b3), c2 -> Set(b4, b5)) means b protects c1 (with sub alternatives b2/b3), and c2 (with sub b4/b5)
private var bTree = Map[Identifier, Map[Identifier, Set[Identifier]]]( initGuard -> p.xs.map(_ -> Set[Identifier]()).toMap)
// Returns all possible assignments to Bs in order to enumerate all possible programs
def allPrograms(): Set[Set[Identifier]] = {
def allChildPaths(b: Identifier): Stream[Set[Identifier]] = {
if (isBClosed(b)) {
Stream.empty
} else {
bTree.get(b) match {
case Some(cToBs) =>
val streams = cToBs.values.map { children =>
children.toStream.flatMap(c => allChildPaths(c).map(l => l + b))
}
streams.reduceLeft{ (s1: Stream[Set[Identifier]], s2: Stream[Set[Identifier]]) => for(p1 <- s1; p2 <- s2) yield { p1 ++ p2 } }
case None =>
Stream.cons(Set(b), Stream.empty)
}
}
}
allChildPaths(initGuard).toSet
}
/*
* Compilation/Execution of programs
*/
// b1 => c == F(c2, c3) OR b2 => c == F(c4, c5) is represented here as c -> Set(c2, c3, c4, c5)
private var cChildren: Map[Identifier, Set[Identifier]] = Map().withDefaultValue(Set())
private var triedCompilation = false
private var progEvaluator: Option[(Seq[Expr], Seq[Expr]) => EvaluationResult] = None
def canTest() = {
if (!triedCompilation) {
progEvaluator = compile()
}
progEvaluator.isDefined
}
private var bssOrdered: Seq[Identifier] = Seq()
def testForProgram(bss: Set[Identifier])(ins: Seq[Expr]): Boolean = {
if (canTest()) {
val bssValues : Seq[Expr] = bssOrdered.map(i => BooleanLiteral(bss(i)))
val evalResult = progEvaluator.get.apply(bssValues, ins)
evalResult match {
case EvaluationSuccessful(res) =>
res == BooleanLiteral(true)
case EvaluationError(err) =>
sctx.reporter.error("Error testing CE: "+err)
true
case EvaluationFailure(err) =>
sctx.reporter.error("Error testing CE: "+err)
true
}
} else {
true
}
}
def compile(): Option[(Seq[Expr], Seq[Expr]) => EvaluationResult] = {
var unreachableCs: Set[Identifier] = guardedTerms.flatMap(_._2).toSet
val cToExprs = mappings.groupBy(_._2._1).map {
case (c, maps) =>
// We only keep cases within the current unrolling closedBs
val cases = maps.flatMap{ case (b, (_, ex)) => if (isBClosed(b)) None else Some(b -> ex) }
// We compute the IF expression corresponding to each c
val ifExpr = if (cases.isEmpty) {
// This can happen with ADTs with only cases with arguments
Error("No valid clause available").setType(c.getType)
} else {
cases.tail.foldLeft(cases.head._2) {
case (elze, (b, then)) => IfExpr(Variable(b), then, elze)
}
}
c -> ifExpr
} toMap
// Map each x generated by the program to fresh xs passed as argument
val newXs = p.xs.map(x => x -> FreshIdentifier(x.name, true).setType(x.getType))
val baseExpr = p.phi
bssOrdered = bss.toSeq.sortBy(_.id)
var res = baseExpr
def composeWith(c: Identifier) {
cToExprs.get(c) match {
case Some(value) =>
res = Let(c, cToExprs(c), res)
case None =>
res = Let(c, Error("No value available").setType(c.getType), res)
}
for (dep <- cChildren(c) if !unreachableCs(dep)) {
composeWith(dep)
}
}
for (c <- p.xs) {
composeWith(c)
}
val simplerRes = simplifyLets(res)
// println("COMPILATION RESULT: ")
// println(ScalaPrinter(simplerRes))
// println("BSS: "+bssOrdered)
// println("FREE: "+variablesOf(simplerRes))
def compileWithArray(): Option[(Seq[Expr], Seq[Expr]) => EvaluationResult] = {
val ba = FreshIdentifier("bssArray").setType(ArrayType(BooleanType))
val bav = Variable(ba)
val substMap : Map[Expr,Expr] = (bssOrdered.zipWithIndex.map {
case (b,i) => Variable(b) -> ArraySelect(bav, IntLiteral(i)).setType(BooleanType)
}).toMap
val forArray = replace(substMap, simplerRes)
// println("FORARRAY RESULT: ")
// println(ScalaPrinter(forArray))
// println("FREE: "+variablesOf(simplerRes))
// We trust arrays to be fast...
// val simple = evaluator.compile(simplerRes, bssOrdered ++ p.as)
val eval = evaluator.compile(forArray, ba +: p.as)
eval.map{e => { case (bss, ins) =>
e(FiniteArray(bss).setType(ArrayType(BooleanType)) +: ins)
}}
}
def compileWithArgs(): Option[(Seq[Expr], Seq[Expr]) => EvaluationResult] = {
val eval = evaluator.compile(simplerRes, bssOrdered ++ p.as)
eval.map{e => { case (bss, ins) =>
e(bss ++ ins)
}}
}
triedCompilation = true
val localVariables = bss.size + cToExprs.size + p.as.size
if (localVariables < 128) {
compileWithArgs().orElse(compileWithArray())
} else {
compileWithArray()
}
}
def determinize(bss: Set[Identifier]): Expr = {
val cClauses = mappings.filterKeys(bss).map(_._2).toMap
def getCValue(c: Identifier): Expr = {
val map = for (dep <- cChildren(c) if cClauses contains dep) yield {
dep -> getCValue(dep)
}
substAll(map.toMap, cClauses(c))
}
Tuple(p.xs.map(c => getCValue(c))).setType(TupleType(p.xs.map(_.getType)))
}
def unroll: (List[Expr], Set[Identifier]) = {
var newClauses = List[Expr]()
var newGuardedTerms = Map[Identifier, Set[Identifier]]()
var newMappings = Map[Identifier, (Identifier, Expr)]()
for ((parentGuard, recIds) <- guardedTerms; recId <- recIds) {
val gen = getGenerator(recId.getType)
val alts = gen.altBuilder() ::: inputAlternatives(recId.getType) ::: funcAlternatives(recId.getType)
val altsWithBranches = alts.map(alt => FreshIdentifier("B", true).setType(BooleanType) -> alt)
val bvs = altsWithBranches.map(alt => Variable(alt._1))
val failedPath = Not(Variable(parentGuard))
val distinct = bvs.combinations(2).collect {
case List(a, b) =>
Or(Not(a) :: Not(b) :: Nil)
}
val pre = And(Seq(Or(failedPath :: bvs), Implies(failedPath, And(bvs.map(Not(_))))) ++ distinct)
val cases = for((bid, (ex, rec)) <- altsWithBranches.toList) yield { // b1 => E(gen1, gen2) [b1 -> {gen1, gen2}]
if (!rec.isEmpty) {
newGuardedTerms += bid -> rec
cChildren += recId -> (cChildren(recId) ++ rec)
}
newMappings += bid -> (recId -> ex)
Implies(Variable(bid), Equals(Variable(recId), ex))
}
val newBIds = altsWithBranches.map(_._1).toSet
bTree += parentGuard -> (bTree.getOrElse(parentGuard, Map()) + (recId -> newBIds))
newClauses = newClauses ::: pre :: cases
}
//program = And(program :: newClauses)
mappings = mappings ++ newMappings
guardedTerms = newGuardedTerms
// Finally, we reset the state of the evalautor
triedCompilation = false
progEvaluator = None
(newClauses, newGuardedTerms.keySet)
}
def bss = mappings.keySet
def css : Set[Identifier] = mappings.values.map(_._1).toSet ++ guardedTerms.flatMap(_._2)
}
val TopLevelAnds(ands) = p.phi
val xsSet = p.xs.toSet
val (exprsA, others) = ands.partition(e => (variablesOf(e) & xsSet).isEmpty)
if (exprsA.isEmpty) {
val res = new RuleInstantiation(p, this, SolutionBuilder.none) {
def apply(sctx: SynthesisContext): RuleApplicationResult = {
var result: Option[RuleApplicationResult] = None
var ass = p.as.toSet
var xss = p.xs.toSet
val initGuard = FreshIdentifier("START", true).setType(BooleanType)
val ndProgram = new NonDeterministicProgram(p, initGuard)
var unrolings = 0
val maxUnrolings = 3
val mainSolver = new TimeoutSolver(sctx.solver, 2000L) // 2sec
var exampleInputs = Set[Seq[Expr]]()
// We populate the list of examples with a predefined one
if (p.pc == BooleanLiteral(true)) {
exampleInputs += p.as.map(a => simplestValue(a.getType))
} else {
val solver = mainSolver.getNewSolver
solver.assertCnstr(p.pc)
solver.check match {
case Some(true) =>
val model = solver.getModel
exampleInputs += p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))
case Some(false) =>
return RuleApplicationImpossible
case None =>
sctx.reporter.warning("Solver could not solve path-condition")
return RuleApplicationImpossible // This is not necessary though, but probably wanted
}
}
// Keep track of collected cores to filter programs to test
var collectedCores = Set[Set[Identifier]]()
// solver1 is used for the initial SAT queries
var solver1 = mainSolver.getNewSolver
solver1.assertCnstr(And(p.pc :: p.phi :: Variable(initGuard) :: Nil))
// solver2 is used for validating a candidate program, or finding new inputs
val solver2 = mainSolver.getNewSolver
solver2.assertCnstr(And(p.pc :: Not(p.phi) :: Variable(initGuard) :: Nil))
var allClauses = List[Expr]()
try {
do {
var needMoreUnrolling = false
// Compute all programs that have not been excluded yet
var allPrograms: Set[Set[Identifier]] = if (useCEPruning) {
ndProgram.allPrograms.filterNot(p => collectedCores.exists(c => c.subsetOf(p)))
} else {
Set()
}
//println("Programs: "+allPrograms.size)
//println("CEs: "+exampleInputs.size)
// We further filter the set of working programs to remove those that fail on known examples
if (useCEPruning && !exampleInputs.isEmpty && ndProgram.canTest()) {
//for (ce <- exampleInputs) {
// println("CE: "+ce)
//}
for (p <- allPrograms) {
if (!exampleInputs.forall(ndProgram.testForProgram(p))) {
// This program failed on at least one example
solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq)))
allPrograms -= p
}
}
if (allPrograms.isEmpty) {
needMoreUnrolling = true
}
//println("Passing tests: "+allPrograms.size)
}
//allPrograms.foreach { p =>
// println("PATH: "+p)
// println("CLAUSES: "+p.flatMap( b => ndProgram.mappings.get(b).map{ case (c, ex) => c+" = "+ex}).mkString(" && "))
//}
val (clauses, closedBs) = ndProgram.unroll
//println("UNROLLING: ")
//for (c <- clauses) {
// println(" - " + c)
//}
//println("CLOSED Bs "+closedBs)
val clause = And(clauses)
allClauses = clause :: allClauses
solver1.assertCnstr(clause)
solver2.assertCnstr(clause)
val tpe = TupleType(p.xs.map(_.getType))
val bss = ndProgram.bss
if (clauses.isEmpty) {
needMoreUnrolling = true
}
while (result.isEmpty && !needMoreUnrolling && !sctx.shouldStop.get) {
solver1.checkAssumptions(closedBs.map(id => Not(Variable(id)))) match {
case Some(true) =>
val satModel = solver1.getModel
val bssAssumptions: Set[Expr] = bss.map(b => satModel(b) match {
case BooleanLiteral(true) => Variable(b)
case BooleanLiteral(false) => Not(Variable(b))
})
//println("CEGIS OUT!")
//println("Found solution: "+bssAssumptions)
//bssAssumptions.collect { case Variable(b) => ndProgram.mappings(b) }.foreach {
// case (c, ex) =>
// println(". "+c+" = "+ex)
//}
val validateWithZ3 = if (useCETests && !exampleInputs.isEmpty && ndProgram.canTest()) {
val p = bssAssumptions.collect { case Variable(b) => b }
if (exampleInputs.forall(ndProgram.testForProgram(p))) {
// All valid inputs also work with this, we need to
// make sure by validating this candidate with z3
true
} else {
// One valid input failed with this candidate, we can skip
solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq)))
false
}
} else {
// No inputs or capability to test, we need to ask Z3
true
}
if (validateWithZ3) {
solver2.checkAssumptions(bssAssumptions) match {
case Some(true) =>
//println("#"*80)
val invalidModel = solver2.getModel
val fixedAss = And(ass.collect {
case a if invalidModel contains a => Equals(Variable(a), invalidModel(a))
}.toSeq)
val newCE = p.as.map(valuateWithModel(invalidModel))
exampleInputs += newCE
//println("Found counter example: "+fixedAss)
// Retest whether the newly found C-E invalidates all programs
if (useCEPruning && ndProgram.canTest) {
if (allPrograms.forall(p => !ndProgram.testForProgram(p)(newCE))) {
// println("I found a killer example!")
needMoreUnrolling = true
}
}
val unsatCore = if (useUnsatCores) {
solver1.push()
solver1.assertCnstr(fixedAss)
val core = solver1.checkAssumptions(bssAssumptions) match {
case Some(false) =>
// Core might be empty if unrolling level is
// insufficient, it becomes unsat no matter what
// the assumptions are.
solver1.getUnsatCore
case Some(true) =>
// Can't be!
bssAssumptions
case None =>
return RuleApplicationImpossible
}
solver1.pop()
collectedCores += core.collect{ case Variable(id) => id }
core
} else {
bssAssumptions
}
if (unsatCore.isEmpty) {
needMoreUnrolling = true
} else {
//if (useCEAsserts) {
// val freshCss = ndProgram.css.map(c => c -> Variable(FreshIdentifier(c.name, true).setType(c.getType))).toMap
// val ceIn = ass.collect {
// case id if invalidModel contains id => id -> invalidModel(id)
// }
// val ceMap = (freshCss ++ ceIn)
// val counterexample = substAll(ceMap, And(Seq(ndProgram.program, p.phi)))
// //val And(ands) = counterexample
// //println("CE:")
// //for (a <- ands) {
// // println(" - "+a)
// //}
// solver1.assertCnstr(counterexample)
//}
solver1.assertCnstr(Not(And(unsatCore.toSeq)))
}
case Some(false) =>
val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet)
result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr)))
case _ =>
return RuleApplicationImpossible
}
}
case Some(false) =>
//println("%%%% UNSAT")
if (useUninterpretedProbe) {
solver1.check match {
case Some(false) =>
// Unsat even without blockers (under which fcalls are then uninterpreted)
return RuleApplicationImpossible
case _ =>
}
}
needMoreUnrolling = true
case _ =>
//println("%%%% WOOPS")
return RuleApplicationImpossible
}
}
unrolings += 1
} while(unrolings < maxUnrolings && result.isEmpty && !sctx.shouldStop.get)
result.getOrElse(RuleApplicationImpossible)
} catch {
case e: Throwable =>
sctx.reporter.warning("CEGIS crashed: "+e.getMessage)
e.printStackTrace
RuleApplicationImpossible
}
}
}
List(res)
} else {
Nil
}
}
}