Skip to content
Snippets Groups Projects
Commit aab7b7f3 authored by Etienne Kneuss's avatar Etienne Kneuss
Browse files

EnumeratingSolver / PortfolioSolver

Use a datagen-based solver to find simple counter-examples. Note that
this solver returns Unknown most of the time, so it is best to combine
it with a full-fledged solver.

PortfolioSolver allows us to combine solvers and have them run in
parallel. The first result (!= Unknown) is used. Solvers can be selected
for verification using the --solvers option.
parent 89810c50
No related branches found
No related tags found
No related merge requests found
......@@ -5,7 +5,20 @@ package datagen
import purescala.Trees._
import purescala.Common._
import utils._
trait DataGenerator {
import java.util.concurrent.atomic.AtomicBoolean
trait DataGenerator extends Interruptible {
def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]];
protected val interrupted: AtomicBoolean = new AtomicBoolean(false)
def interrupt(): Unit = {
interrupted.set(true)
}
def recoverInterrupt(): Unit = {
interrupted.set(false)
}
}
......@@ -118,6 +118,7 @@ class NaiveDataGen(ctx: LeonContext, p: Program, evaluator: Evaluator, _bounds :
naryProduct(ins.map(id => generate(id.getType, bounds)))
.take(maxEnumerated)
.takeWhile(s => !interrupted.get)
.filter{s => evalFun(s) == sat }
.take(maxValid)
.iterator
......
......@@ -87,11 +87,12 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
unit.jvmClassToLeonClass(cc.getClass.getName) match {
case Some(ccd: CaseClassDef) =>
val cct = CaseClassType(ccd, ct.tps)
val c = ct match {
case act : AbstractClassType =>
getConstructorFor(CaseClassType(ccd, ct.tps), act)
getConstructorFor(cct, act)
case cct : CaseClassType =>
getConstructors(CaseClassType(ccd, ct.tps))(0)
getConstructors(cct)(0)
}
val fields = cc.productElements()
......@@ -99,7 +100,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
val elems = for (i <- 0 until fields.length) yield {
if (((r >> i) & 1) == 1) {
// has been read
valueToPattern(fields(i), ct.fieldsTypes(i))
valueToPattern(fields(i), cct.fieldsTypes(i))
} else {
(AnyPattern[Expr, TypeTree](), false)
}
......@@ -158,6 +159,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
(EvaluationResults.Successful(result), if (!pattern._2) Some(pattern._1) else None)
} catch {
case e : ClassCastException =>
(EvaluationResults.RuntimeError(e.getMessage), None)
case e : ArithmeticException =>
(EvaluationResults.RuntimeError(e.getMessage), None)
......@@ -228,7 +232,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
def computeNext(): Option[Seq[Expr]] = {
while(total < maxEnumerated && found < maxValid && it.hasNext) {
while(total < maxEnumerated && found < maxValid && it.hasNext && !interrupted.get) {
val model = it.next.asInstanceOf[Tuple]
if (model eq null) {
......@@ -250,10 +254,10 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
}
if (!failed) {
println("Got model:")
for ((i, v) <- (ins zip model.exprs)) {
println(" - "+i+" -> "+v)
}
//println("Got model:")
//for ((i, v) <- (ins zip model.exprs)) {
// println(" - "+i+" -> "+v)
//}
found += 1
......@@ -264,9 +268,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
return Some(model.exprs);
}
if (total % 1000 == 0) {
println("... "+total+" ...")
}
//if (total % 1000 == 0) {
// println("... "+total+" ...")
//}
}
}
None
......
/* Copyright 2009-2013 EPFL, Lausanne */
package leon
package solvers
import utils._
import purescala.Common._
import purescala.Definitions._
import purescala.Trees._
import purescala.Extractors._
import purescala.TreeOps._
import purescala.TypeTrees._
import datagen._
class EnumerationSolver(val context: LeonContext, val program: Program) extends Solver with Interruptible {
def name = "Enum"
val maxTried = 10000;
var datagen: DataGenerator = _
var freeVars = List[Identifier]()
var constraints = List[Expr]()
def assertCnstr(expression: Expr): Unit = {
constraints ::= expression
val newFreeVars = (variablesOf(expression) -- freeVars).toList
freeVars = freeVars ::: newFreeVars
}
private var modelMap = Map[Identifier, Expr]()
def check: Option[Boolean] = {
try {
val muteContext = context.copy(reporter = new DefaultReporter(context.settings))
datagen = new VanuatooDataGen(muteContext, program)
modelMap = Map()
val it = datagen.generateFor(freeVars, And(constraints.reverse), 1, maxTried)
if (it.hasNext) {
val model = it.next
modelMap = (freeVars zip model).toMap
Some(true)
} else {
None
}
} catch {
case e: codegen.CompilationException =>
None
}
}
def getModel: Map[Identifier, Expr] = {
modelMap
}
def free() = {
constraints = Nil
}
def interrupt(): Unit = {
Option(datagen).foreach(_.interrupt)
}
def recoverInterrupt(): Unit = {
Option(datagen).foreach(_.recoverInterrupt)
}
}
/* Copyright 2009-2013 EPFL, Lausanne */
package leon
package solvers
package combinators
import purescala.Common._
import purescala.Definitions._
import purescala.Trees._
import purescala.TreeOps._
import purescala.TypeTrees._
import utils.Interruptible
import scala.concurrent._
import scala.concurrent.duration._
import scala.collection.mutable.{Map=>MutableMap}
import ExecutionContext.Implicits.global
class PortfolioSolver(val context: LeonContext, solvers: Seq[SolverFactory[Solver with Interruptible]])
extends Solver with Interruptible {
val name = "Pfolio"
var constraints = List[Expr]()
def assertCnstr(expression: Expr): Unit = {
constraints ::= expression
}
private var modelMap = Map[Identifier, Expr]()
private var solversInsts = Seq[Solver with Interruptible]()
def check: Option[Boolean] = {
modelMap = Map()
// create fresh solvers
solversInsts = solvers.map(_.getNewSolver)
// assert
solversInsts.foreach { s =>
s.assertCnstr(And(constraints.reverse))
}
// solving
val fs = solversInsts.map { s =>
Future {
(s, s.check, s.getModel)
}
}
val result = Future.find(fs)(_._2.isDefined)
val res = Await.result(result, 10.days) match {
case Some((s, r, m)) =>
modelMap = m
solversInsts.foreach(_.interrupt)
r
case None =>
None
}
solversInsts.foreach(_.free)
res
}
def getModel: Map[Identifier, Expr] = {
modelMap
}
def free() = {
constraints = Nil
}
def interrupt(): Unit = {
solversInsts.foreach(_.interrupt())
}
def recoverInterrupt(): Unit = {
solversInsts.foreach(_.recoverInterrupt())
}
}
......@@ -83,7 +83,7 @@ class Synthesizer(val context : LeonContext,
val solverf = SolverFactory(() => (new FairZ3Solver(context, npr) with TimeoutSolver).setTimeout(timeoutMs))
val vctx = VerificationContext(context, npr, Seq(solverf), context.reporter)
val vctx = VerificationContext(context, npr, solverf, context.reporter)
val vcs = generateVerificationConditions(vctx, fds.map(_.id.name))
val vcreport = checkVerificationConditions(vctx, vcs)
......
......@@ -23,6 +23,7 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] {
override val definedOptions : Set[LeonOptionDef] = Set(
LeonValueOptionDef("functions", "--functions=f1:f2", "Limit verification to f1,f2,..."),
LeonValueOptionDef("solvers", "--solvers=s1,s2", "Use solvers s1 and s2 (fairz3,enum)", default = Some("fairz3")),
LeonValueOptionDef("timeout", "--timeout=T", "Timeout after T seconds when trying to prove a verification condition.")
)
......@@ -83,7 +84,7 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] {
def checkVerificationConditions(vctx: VerificationContext, vcs: Map[FunDef, List[VerificationCondition]]) : VerificationReport = {
import vctx.reporter
import vctx.solvers
import vctx.solverFactory
import vctx.program
val interruptManager = vctx.context.interruptManager
......@@ -100,60 +101,54 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] {
reporter.debug("Verification condition (" + vcInfo.kind + ") for ==== " + funDef.id + " ====")
reporter.debug(simplifyLets(vc).asString(vctx.context))
// try all solvers until one returns a meaningful answer
solvers.find(sf => {
val s = sf.getNewSolver
try {
reporter.debug("Trying with solver: " + s.name)
val t1 = System.nanoTime
s.assertCnstr(Not(vc))
val satResult = s.check
val counterexample: Map[Identifier, Expr] = if (satResult == Some(true)) s.getModel else Map()
val solverResult = satResult.map(!_)
val t2 = System.nanoTime
val dt = ((t2 - t1) / 1000000) / 1000.0
solverResult match {
case _ if interruptManager.isInterrupted() =>
reporter.info("=== CANCELLED ===")
vcInfo.time = Some(dt)
false
case None =>
vcInfo.time = Some(dt)
false
case Some(true) =>
reporter.info("==== VALID ====")
vcInfo.hasValue = true
vcInfo.value = Some(true)
vcInfo.solvedWith = Some(s)
vcInfo.time = Some(dt)
true
case Some(false) =>
reporter.error("Found counter-example : ")
reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n"))
reporter.error("==== INVALID ====")
vcInfo.hasValue = true
vcInfo.value = Some(false)
vcInfo.solvedWith = Some(s)
vcInfo.counterExample = Some(counterexample)
vcInfo.time = Some(dt)
true
}
} finally {
s.free()
}}) match {
case None => {
val s = solverFactory.getNewSolver
try {
reporter.debug("Trying with solver: " + s.name)
val t1 = System.nanoTime
s.assertCnstr(Not(vc))
val satResult = s.check
val counterexample: Map[Identifier, Expr] = if (satResult == Some(true)) s.getModel else Map()
val solverResult = satResult.map(!_)
val t2 = System.nanoTime
val dt = ((t2 - t1) / 1000000) / 1000.0
solverResult match {
case _ if interruptManager.isInterrupted() =>
reporter.info("=== CANCELLED ===")
vcInfo.time = Some(dt)
false
case None =>
vcInfo.hasValue = true
reporter.warning("==== UNKNOWN ====")
}
case _ =>
vcInfo.time = Some(dt)
false
case Some(true) =>
reporter.info("==== VALID ====")
vcInfo.hasValue = true
vcInfo.value = Some(true)
vcInfo.solvedWith = Some(s)
vcInfo.time = Some(dt)
true
case Some(false) =>
reporter.error("Found counter-example : ")
reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n"))
reporter.error("==== INVALID ====")
vcInfo.hasValue = true
vcInfo.value = Some(false)
vcInfo.solvedWith = Some(s)
vcInfo.counterExample = Some(counterexample)
vcInfo.time = Some(dt)
true
}
} finally {
s.free()
}
}
val report = new VerificationReport(vcs)
......@@ -163,33 +158,52 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] {
def run(ctx: LeonContext)(program: Program) : VerificationReport = {
var functionsToAnalyse = Set[String]()
var timeout: Option[Int] = None
var selectedSolvers = Set[String]("fairz3")
val allSolvers = Map(
"fairz3" -> SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver),
"enum" -> SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver)
)
val reporter = ctx.reporter
for(opt <- ctx.options) opt match {
case LeonValueOption("functions", ListValue(fs)) =>
functionsToAnalyse = Set() ++ fs
case LeonValueOption("solvers", ListValue(ss)) =>
val unknownSolvers = ss.toSet -- allSolvers.keySet
if (unknownSolvers.nonEmpty) {
reporter.error("Unknown solver(s): "+unknownSolvers.mkString(", ")+" (Available: "+allSolvers.keys.mkString(", ")+")")
}
selectedSolvers = Set() ++ ss
case v @ LeonValueOption("timeout", _) =>
timeout = v.asInt(ctx)
case _ =>
}
val reporter = ctx.reporter
// Solvers selection and validation
val solversToUse = allSolvers.filterKeys(selectedSolvers)
val baseFactories = Seq(
SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver)
)
val entrySolver = if (solversToUse.isEmpty) {
reporter.fatalError("No solver selected. Aborting")
} else if (solversToUse.size == 1) {
solversToUse.values.head
} else {
SolverFactory( () => new PortfolioSolver(ctx, solversToUse.values.toSeq) with TimeoutSolver)
}
val solverFactories = timeout match {
val mainSolver = timeout match {
case Some(sec) =>
baseFactories.map { sf =>
new TimeoutSolverFactory(sf, sec*1000L)
}
new TimeoutSolverFactory(entrySolver, sec*1000L)
case None =>
baseFactories
entrySolver
}
val vctx = VerificationContext(ctx, program, solverFactories, reporter)
val vctx = VerificationContext(ctx, program, mainSolver, reporter)
reporter.debug("Running verification condition generation...")
val vcs = generateVerificationConditions(vctx, functionsToAnalyse)
......
......@@ -11,6 +11,6 @@ import java.util.concurrent.atomic.AtomicBoolean
case class VerificationContext (
context: LeonContext,
program: Program,
solvers: Seq[SolverFactory[Solver]],
solverFactory: SolverFactory[Solver],
reporter: Reporter
)
/* Copyright 2009-2013 EPFL, Lausanne */
package leon.test
package solvers
import leon._
import leon.utils.Interruptible
import leon.solvers._
import leon.solvers.combinators._
import leon.purescala.Common._
import leon.purescala.Definitions._
import leon.purescala.Trees._
import leon.purescala.TypeTrees._
class EnumerationSolverTests extends LeonTestSuite {
private def check(sf: SolverFactory[Solver], e: Expr): Option[Boolean] = {
val s = sf.getNewSolver
s.assertCnstr(e)
s.check
}
private def getSolver = {
SolverFactory(() => new EnumerationSolver(testContext, Program.empty))
}
test("EnumerationSolver 1 (true)") {
val sf = getSolver
assert(check(sf, BooleanLiteral(true)) === Some(true))
}
test("EnumerationSolver 2 (x == 1)") {
val sf = getSolver
val x = Variable(FreshIdentifier("x").setType(Int32Type))
val o = IntLiteral(1)
assert(check(sf, Equals(x, o)) === Some(true))
}
test("EnumerationSolver 3 (Limited range for ints)") {
val sf = getSolver
val x = Variable(FreshIdentifier("x").setType(Int32Type))
val o = IntLiteral(42)
assert(check(sf, Equals(x, o)) === None)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment