From aab7b7f36f61f37f8cd699cea4f8b0311f276b13 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <colder@php.net> Date: Tue, 25 Feb 2014 13:29:36 +0100 Subject: [PATCH] EnumeratingSolver / PortfolioSolver Use a datagen-based solver to find simple counter-examples. Note that this solver returns Unknown most of the time, so it is best to combine it with a full-fledged solver. PortfolioSolver allows us to combine solvers and have them run in parallel. The first result (!= Unknown) is used. Solvers can be selected for verification using the --solvers option. --- .../scala/leon/datagen/DataGenerator.scala | 15 +- .../scala/leon/datagen/NaiveDataGen.scala | 1 + .../scala/leon/datagen/VanuatooDataGen.scala | 26 ++-- .../leon/solvers/EnumerationSolver.scala | 73 +++++++++ .../solvers/combinators/PortfolioSolver.scala | 84 +++++++++++ .../scala/leon/synthesis/Synthesizer.scala | 2 +- .../leon/verification/AnalysisPhase.scala | 138 ++++++++++-------- .../verification/VerificationContext.scala | 2 +- .../test/solvers/EnumerationSolverTests.scala | 44 ++++++ 9 files changed, 309 insertions(+), 76 deletions(-) create mode 100644 src/main/scala/leon/solvers/EnumerationSolver.scala create mode 100644 src/main/scala/leon/solvers/combinators/PortfolioSolver.scala create mode 100644 src/test/scala/leon/test/solvers/EnumerationSolverTests.scala diff --git a/src/main/scala/leon/datagen/DataGenerator.scala b/src/main/scala/leon/datagen/DataGenerator.scala index c15f07a2b..810083dd1 100644 --- a/src/main/scala/leon/datagen/DataGenerator.scala +++ b/src/main/scala/leon/datagen/DataGenerator.scala @@ -5,7 +5,20 @@ package datagen import purescala.Trees._ import purescala.Common._ +import utils._ -trait DataGenerator { +import java.util.concurrent.atomic.AtomicBoolean + +trait DataGenerator extends Interruptible { def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]]; + + protected val interrupted: AtomicBoolean = new AtomicBoolean(false) + + def interrupt(): Unit = { + interrupted.set(true) + } + + def recoverInterrupt(): Unit = { + interrupted.set(false) + } } diff --git a/src/main/scala/leon/datagen/NaiveDataGen.scala b/src/main/scala/leon/datagen/NaiveDataGen.scala index 5ec8577bd..5bd33cb85 100644 --- a/src/main/scala/leon/datagen/NaiveDataGen.scala +++ b/src/main/scala/leon/datagen/NaiveDataGen.scala @@ -118,6 +118,7 @@ class NaiveDataGen(ctx: LeonContext, p: Program, evaluator: Evaluator, _bounds : naryProduct(ins.map(id => generate(id.getType, bounds))) .take(maxEnumerated) + .takeWhile(s => !interrupted.get) .filter{s => evalFun(s) == sat } .take(maxValid) .iterator diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index f934ac4bd..e94415131 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -87,11 +87,12 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { unit.jvmClassToLeonClass(cc.getClass.getName) match { case Some(ccd: CaseClassDef) => + val cct = CaseClassType(ccd, ct.tps) val c = ct match { case act : AbstractClassType => - getConstructorFor(CaseClassType(ccd, ct.tps), act) + getConstructorFor(cct, act) case cct : CaseClassType => - getConstructors(CaseClassType(ccd, ct.tps))(0) + getConstructors(cct)(0) } val fields = cc.productElements() @@ -99,7 +100,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { val elems = for (i <- 0 until fields.length) yield { if (((r >> i) & 1) == 1) { // has been read - valueToPattern(fields(i), ct.fieldsTypes(i)) + valueToPattern(fields(i), cct.fieldsTypes(i)) } else { (AnyPattern[Expr, TypeTree](), false) } @@ -158,6 +159,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { (EvaluationResults.Successful(result), if (!pattern._2) Some(pattern._1) else None) } catch { + case e : ClassCastException => + (EvaluationResults.RuntimeError(e.getMessage), None) + case e : ArithmeticException => (EvaluationResults.RuntimeError(e.getMessage), None) @@ -228,7 +232,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { def computeNext(): Option[Seq[Expr]] = { - while(total < maxEnumerated && found < maxValid && it.hasNext) { + while(total < maxEnumerated && found < maxValid && it.hasNext && !interrupted.get) { val model = it.next.asInstanceOf[Tuple] if (model eq null) { @@ -250,10 +254,10 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { } if (!failed) { - println("Got model:") - for ((i, v) <- (ins zip model.exprs)) { - println(" - "+i+" -> "+v) - } + //println("Got model:") + //for ((i, v) <- (ins zip model.exprs)) { + // println(" - "+i+" -> "+v) + //} found += 1 @@ -264,9 +268,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { return Some(model.exprs); } - if (total % 1000 == 0) { - println("... "+total+" ...") - } + //if (total % 1000 == 0) { + // println("... "+total+" ...") + //} } } None diff --git a/src/main/scala/leon/solvers/EnumerationSolver.scala b/src/main/scala/leon/solvers/EnumerationSolver.scala new file mode 100644 index 000000000..4f5d71659 --- /dev/null +++ b/src/main/scala/leon/solvers/EnumerationSolver.scala @@ -0,0 +1,73 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import utils._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.Extractors._ +import purescala.TreeOps._ +import purescala.TypeTrees._ + +import datagen._ + + +class EnumerationSolver(val context: LeonContext, val program: Program) extends Solver with Interruptible { + def name = "Enum" + + val maxTried = 10000; + + var datagen: DataGenerator = _ + + var freeVars = List[Identifier]() + var constraints = List[Expr]() + + def assertCnstr(expression: Expr): Unit = { + constraints ::= expression + + val newFreeVars = (variablesOf(expression) -- freeVars).toList + freeVars = freeVars ::: newFreeVars + } + + private var modelMap = Map[Identifier, Expr]() + + def check: Option[Boolean] = { + try { + val muteContext = context.copy(reporter = new DefaultReporter(context.settings)) + datagen = new VanuatooDataGen(muteContext, program) + + modelMap = Map() + + val it = datagen.generateFor(freeVars, And(constraints.reverse), 1, maxTried) + + if (it.hasNext) { + val model = it.next + modelMap = (freeVars zip model).toMap + Some(true) + } else { + None + } + } catch { + case e: codegen.CompilationException => + None + } + } + + def getModel: Map[Identifier, Expr] = { + modelMap + } + + def free() = { + constraints = Nil + } + + def interrupt(): Unit = { + Option(datagen).foreach(_.interrupt) + } + + def recoverInterrupt(): Unit = { + Option(datagen).foreach(_.recoverInterrupt) + } +} diff --git a/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala b/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala new file mode 100644 index 000000000..1c75bf3b8 --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala @@ -0,0 +1,84 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ + +import utils.Interruptible +import scala.concurrent._ +import scala.concurrent.duration._ + +import scala.collection.mutable.{Map=>MutableMap} + +import ExecutionContext.Implicits.global + +class PortfolioSolver(val context: LeonContext, solvers: Seq[SolverFactory[Solver with Interruptible]]) + extends Solver with Interruptible { + + val name = "Pfolio" + + var constraints = List[Expr]() + + def assertCnstr(expression: Expr): Unit = { + constraints ::= expression + } + + private var modelMap = Map[Identifier, Expr]() + private var solversInsts = Seq[Solver with Interruptible]() + + def check: Option[Boolean] = { + modelMap = Map() + + // create fresh solvers + solversInsts = solvers.map(_.getNewSolver) + + // assert + solversInsts.foreach { s => + s.assertCnstr(And(constraints.reverse)) + } + + // solving + val fs = solversInsts.map { s => + Future { + (s, s.check, s.getModel) + } + } + + val result = Future.find(fs)(_._2.isDefined) + + val res = Await.result(result, 10.days) match { + case Some((s, r, m)) => + modelMap = m + solversInsts.foreach(_.interrupt) + r + case None => + None + } + + solversInsts.foreach(_.free) + + res + } + + def getModel: Map[Identifier, Expr] = { + modelMap + } + + def free() = { + constraints = Nil + } + + def interrupt(): Unit = { + solversInsts.foreach(_.interrupt()) + } + + def recoverInterrupt(): Unit = { + solversInsts.foreach(_.recoverInterrupt()) + } +} diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 766b4a42e..f33071f35 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -83,7 +83,7 @@ class Synthesizer(val context : LeonContext, val solverf = SolverFactory(() => (new FairZ3Solver(context, npr) with TimeoutSolver).setTimeout(timeoutMs)) - val vctx = VerificationContext(context, npr, Seq(solverf), context.reporter) + val vctx = VerificationContext(context, npr, solverf, context.reporter) val vcs = generateVerificationConditions(vctx, fds.map(_.id.name)) val vcreport = checkVerificationConditions(vctx, vcs) diff --git a/src/main/scala/leon/verification/AnalysisPhase.scala b/src/main/scala/leon/verification/AnalysisPhase.scala index a0bf5fe60..25ec22efc 100644 --- a/src/main/scala/leon/verification/AnalysisPhase.scala +++ b/src/main/scala/leon/verification/AnalysisPhase.scala @@ -23,6 +23,7 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { override val definedOptions : Set[LeonOptionDef] = Set( LeonValueOptionDef("functions", "--functions=f1:f2", "Limit verification to f1,f2,..."), + LeonValueOptionDef("solvers", "--solvers=s1,s2", "Use solvers s1 and s2 (fairz3,enum)", default = Some("fairz3")), LeonValueOptionDef("timeout", "--timeout=T", "Timeout after T seconds when trying to prove a verification condition.") ) @@ -83,7 +84,7 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { def checkVerificationConditions(vctx: VerificationContext, vcs: Map[FunDef, List[VerificationCondition]]) : VerificationReport = { import vctx.reporter - import vctx.solvers + import vctx.solverFactory import vctx.program val interruptManager = vctx.context.interruptManager @@ -100,60 +101,54 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { reporter.debug("Verification condition (" + vcInfo.kind + ") for ==== " + funDef.id + " ====") reporter.debug(simplifyLets(vc).asString(vctx.context)) - // try all solvers until one returns a meaningful answer - solvers.find(sf => { - val s = sf.getNewSolver - try { - reporter.debug("Trying with solver: " + s.name) - val t1 = System.nanoTime - s.assertCnstr(Not(vc)) - - val satResult = s.check - val counterexample: Map[Identifier, Expr] = if (satResult == Some(true)) s.getModel else Map() - val solverResult = satResult.map(!_) - - val t2 = System.nanoTime - val dt = ((t2 - t1) / 1000000) / 1000.0 - - solverResult match { - case _ if interruptManager.isInterrupted() => - reporter.info("=== CANCELLED ===") - vcInfo.time = Some(dt) - false - - case None => - vcInfo.time = Some(dt) - false - - case Some(true) => - reporter.info("==== VALID ====") - - vcInfo.hasValue = true - vcInfo.value = Some(true) - vcInfo.solvedWith = Some(s) - vcInfo.time = Some(dt) - true - - case Some(false) => - reporter.error("Found counter-example : ") - reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n")) - reporter.error("==== INVALID ====") - vcInfo.hasValue = true - vcInfo.value = Some(false) - vcInfo.solvedWith = Some(s) - vcInfo.counterExample = Some(counterexample) - vcInfo.time = Some(dt) - true - } - } finally { - s.free() - }}) match { - case None => { + val s = solverFactory.getNewSolver + try { + reporter.debug("Trying with solver: " + s.name) + val t1 = System.nanoTime + s.assertCnstr(Not(vc)) + + val satResult = s.check + val counterexample: Map[Identifier, Expr] = if (satResult == Some(true)) s.getModel else Map() + val solverResult = satResult.map(!_) + + val t2 = System.nanoTime + val dt = ((t2 - t1) / 1000000) / 1000.0 + + solverResult match { + case _ if interruptManager.isInterrupted() => + reporter.info("=== CANCELLED ===") + vcInfo.time = Some(dt) + false + + case None => vcInfo.hasValue = true reporter.warning("==== UNKNOWN ====") - } - case _ => + vcInfo.time = Some(dt) + false + + case Some(true) => + reporter.info("==== VALID ====") + + vcInfo.hasValue = true + vcInfo.value = Some(true) + vcInfo.solvedWith = Some(s) + vcInfo.time = Some(dt) + true + + case Some(false) => + reporter.error("Found counter-example : ") + reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n")) + reporter.error("==== INVALID ====") + vcInfo.hasValue = true + vcInfo.value = Some(false) + vcInfo.solvedWith = Some(s) + vcInfo.counterExample = Some(counterexample) + vcInfo.time = Some(dt) + true } + } finally { + s.free() + } } val report = new VerificationReport(vcs) @@ -163,33 +158,52 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { def run(ctx: LeonContext)(program: Program) : VerificationReport = { var functionsToAnalyse = Set[String]() var timeout: Option[Int] = None + var selectedSolvers = Set[String]("fairz3") + + val allSolvers = Map( + "fairz3" -> SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver), + "enum" -> SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver) + ) + + val reporter = ctx.reporter for(opt <- ctx.options) opt match { case LeonValueOption("functions", ListValue(fs)) => functionsToAnalyse = Set() ++ fs + case LeonValueOption("solvers", ListValue(ss)) => + val unknownSolvers = ss.toSet -- allSolvers.keySet + if (unknownSolvers.nonEmpty) { + reporter.error("Unknown solver(s): "+unknownSolvers.mkString(", ")+" (Available: "+allSolvers.keys.mkString(", ")+")") + } + selectedSolvers = Set() ++ ss + case v @ LeonValueOption("timeout", _) => timeout = v.asInt(ctx) case _ => } - val reporter = ctx.reporter + // Solvers selection and validation + val solversToUse = allSolvers.filterKeys(selectedSolvers) - val baseFactories = Seq( - SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver) - ) + val entrySolver = if (solversToUse.isEmpty) { + reporter.fatalError("No solver selected. Aborting") + } else if (solversToUse.size == 1) { + solversToUse.values.head + } else { + SolverFactory( () => new PortfolioSolver(ctx, solversToUse.values.toSeq) with TimeoutSolver) + } - val solverFactories = timeout match { + + val mainSolver = timeout match { case Some(sec) => - baseFactories.map { sf => - new TimeoutSolverFactory(sf, sec*1000L) - } + new TimeoutSolverFactory(entrySolver, sec*1000L) case None => - baseFactories + entrySolver } - val vctx = VerificationContext(ctx, program, solverFactories, reporter) + val vctx = VerificationContext(ctx, program, mainSolver, reporter) reporter.debug("Running verification condition generation...") val vcs = generateVerificationConditions(vctx, functionsToAnalyse) diff --git a/src/main/scala/leon/verification/VerificationContext.scala b/src/main/scala/leon/verification/VerificationContext.scala index 84692c3a7..50d3012ea 100644 --- a/src/main/scala/leon/verification/VerificationContext.scala +++ b/src/main/scala/leon/verification/VerificationContext.scala @@ -11,6 +11,6 @@ import java.util.concurrent.atomic.AtomicBoolean case class VerificationContext ( context: LeonContext, program: Program, - solvers: Seq[SolverFactory[Solver]], + solverFactory: SolverFactory[Solver], reporter: Reporter ) diff --git a/src/test/scala/leon/test/solvers/EnumerationSolverTests.scala b/src/test/scala/leon/test/solvers/EnumerationSolverTests.scala new file mode 100644 index 000000000..44ceb91e1 --- /dev/null +++ b/src/test/scala/leon/test/solvers/EnumerationSolverTests.scala @@ -0,0 +1,44 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon.test +package solvers + +import leon._ +import leon.utils.Interruptible +import leon.solvers._ +import leon.solvers.combinators._ +import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.purescala.Trees._ +import leon.purescala.TypeTrees._ + +class EnumerationSolverTests extends LeonTestSuite { + private def check(sf: SolverFactory[Solver], e: Expr): Option[Boolean] = { + val s = sf.getNewSolver + s.assertCnstr(e) + s.check + } + + private def getSolver = { + SolverFactory(() => new EnumerationSolver(testContext, Program.empty)) + } + + test("EnumerationSolver 1 (true)") { + val sf = getSolver + assert(check(sf, BooleanLiteral(true)) === Some(true)) + } + + test("EnumerationSolver 2 (x == 1)") { + val sf = getSolver + val x = Variable(FreshIdentifier("x").setType(Int32Type)) + val o = IntLiteral(1) + assert(check(sf, Equals(x, o)) === Some(true)) + } + + test("EnumerationSolver 3 (Limited range for ints)") { + val sf = getSolver + val x = Variable(FreshIdentifier("x").setType(Int32Type)) + val o = IntLiteral(42) + assert(check(sf, Equals(x, o)) === None) + } +} -- GitLab