From a6724892e9962250f04e982f6db958c8bc3a9852 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Wed, 19 Aug 2015 01:13:34 +0200 Subject: [PATCH] Flatten & Simplify Solver API Solvers are no longer distinguished in 20 traits depending on what they implement. It turns out that most leon solvers already implemented everything: 1) Being interrupted 2) Push / Pop 3) checkAssertions/getUnsatCore (a naive implementation of these can be added by mixing NaiveAssumptionSolver in) --- .../solvers/ModelEnumerationSuite.scala | 3 +- .../solvers/TimeoutSolverSuite.scala | 5 +- .../scala/leon/solvers/AssumptionSolver.scala | 11 -- .../leon/solvers/EnumerationSolver.scala | 2 +- .../scala/leon/solvers/GroundSolver.scala | 2 +- .../leon/solvers/IncrementalSolver.scala | 10 -- .../leon/solvers/NaiveAssumptionSolver.scala | 4 +- .../solvers/SimpleAssumptionSolverAPI.scala | 27 ---- .../scala/leon/solvers/SimpleSolverAPI.scala | 23 ++- src/main/scala/leon/solvers/Solver.scala | 10 +- .../scala/leon/solvers/SolverFactory.scala | 4 +- .../solvers/TimeoutAssumptionSolver.scala | 21 --- .../scala/leon/solvers/TimeoutSolver.scala | 14 +- .../leon/solvers/combinators/DNFSolver.scala | 138 ------------------ .../solvers/combinators/PortfolioSolver.scala | 14 +- .../combinators/PortfolioSolverFactory.scala | 2 +- .../solvers/combinators/UnrollingSolver.scala | 2 +- .../leon/solvers/smtlib/SMTLIBSolver.scala | 10 +- .../leon/solvers/z3/AbstractZ3Solver.scala | 6 +- .../{solvers => utils}/ModelEnumerator.scala | 53 +++---- .../leon/test/solvers/SolverPoolSuite.scala | 6 +- 21 files changed, 99 insertions(+), 268 deletions(-) delete mode 100644 src/main/scala/leon/solvers/AssumptionSolver.scala delete mode 100644 src/main/scala/leon/solvers/IncrementalSolver.scala delete mode 100644 src/main/scala/leon/solvers/SimpleAssumptionSolverAPI.scala delete mode 100644 src/main/scala/leon/solvers/TimeoutAssumptionSolver.scala delete mode 100644 src/main/scala/leon/solvers/combinators/DNFSolver.scala rename src/main/scala/leon/{solvers => utils}/ModelEnumerator.scala (74%) diff --git a/src/integration/scala/leon/integration/solvers/ModelEnumerationSuite.scala b/src/integration/scala/leon/integration/solvers/ModelEnumerationSuite.scala index 41a236fa9..e17ce6d71 100644 --- a/src/integration/scala/leon/integration/solvers/ModelEnumerationSuite.scala +++ b/src/integration/scala/leon/integration/solvers/ModelEnumerationSuite.scala @@ -7,6 +7,7 @@ import leon.integration.helpers.ExpressionsDSL import leon.test._ import leon._ import leon.solvers._ +import leon.utils._ import leon.purescala.Definitions._ import leon.purescala.Common._ import leon.evaluators._ @@ -40,7 +41,7 @@ class ModelEnumeratorSuite extends LeonTestSuiteWithProgram with ExpressionsDSL ) def getModelEnum(implicit ctx: LeonContext, pgm: Program) = { - val sf = SolverFactory.default.asInstanceOf[SolverFactory[IncrementalSolver]] + val sf = SolverFactory.default new ModelEnumerator(ctx, pgm, sf) } diff --git a/src/integration/scala/leon/integration/solvers/TimeoutSolverSuite.scala b/src/integration/scala/leon/integration/solvers/TimeoutSolverSuite.scala index 6a44bc17c..cce647cc1 100644 --- a/src/integration/scala/leon/integration/solvers/TimeoutSolverSuite.scala +++ b/src/integration/scala/leon/integration/solvers/TimeoutSolverSuite.scala @@ -12,7 +12,7 @@ import leon.purescala.Expressions._ import leon.purescala.Types._ class TimeoutSolverSuite extends LeonTestSuite { - private class IdioticSolver(val context : LeonContext, val program: Program) extends Solver with Interruptible{ + private class IdioticSolver(val context : LeonContext, val program: Program) extends Solver with NaiveAssumptionSolver { val name = "Idiotic" val description = "Loops" @@ -35,6 +35,9 @@ class TimeoutSolverSuite extends LeonTestSuite { def assertCnstr(e: Expr) = {} + def push() {} + def pop() {} + def free() {} def reset() {} diff --git a/src/main/scala/leon/solvers/AssumptionSolver.scala b/src/main/scala/leon/solvers/AssumptionSolver.scala deleted file mode 100644 index e9f8a9313..000000000 --- a/src/main/scala/leon/solvers/AssumptionSolver.scala +++ /dev/null @@ -1,11 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package solvers - -import purescala.Expressions.Expr - -trait AssumptionSolver extends Solver { - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] - def getUnsatCore: Set[Expr] -} diff --git a/src/main/scala/leon/solvers/EnumerationSolver.scala b/src/main/scala/leon/solvers/EnumerationSolver.scala index e01d8f7ed..3a2db100e 100644 --- a/src/main/scala/leon/solvers/EnumerationSolver.scala +++ b/src/main/scala/leon/solvers/EnumerationSolver.scala @@ -12,7 +12,7 @@ import purescala.ExprOps._ import datagen._ -class EnumerationSolver(val context: LeonContext, val program: Program) extends Solver with Interruptible with IncrementalSolver with NaiveAssumptionSolver { +class EnumerationSolver(val context: LeonContext, val program: Program) extends Solver with NaiveAssumptionSolver { def name = "Enum" val maxTried = 10000 diff --git a/src/main/scala/leon/solvers/GroundSolver.scala b/src/main/scala/leon/solvers/GroundSolver.scala index ad4ab6dc7..29ee75238 100644 --- a/src/main/scala/leon/solvers/GroundSolver.scala +++ b/src/main/scala/leon/solvers/GroundSolver.scala @@ -14,7 +14,7 @@ import utils.Interruptible import utils.IncrementalSeq // This solver only "solves" ground terms by evaluating them -class GroundSolver(val context: LeonContext, val program: Program) extends IncrementalSolver with Interruptible { +class GroundSolver(val context: LeonContext, val program: Program) extends Solver with NaiveAssumptionSolver { context.interruptManager.registerForInterrupts(this) diff --git a/src/main/scala/leon/solvers/IncrementalSolver.scala b/src/main/scala/leon/solvers/IncrementalSolver.scala deleted file mode 100644 index c935558c0..000000000 --- a/src/main/scala/leon/solvers/IncrementalSolver.scala +++ /dev/null @@ -1,10 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package solvers - -trait IncrementalSolver extends Solver { - def push(): Unit - def pop(): Unit -} - diff --git a/src/main/scala/leon/solvers/NaiveAssumptionSolver.scala b/src/main/scala/leon/solvers/NaiveAssumptionSolver.scala index 7c1130b3d..da1dfa766 100644 --- a/src/main/scala/leon/solvers/NaiveAssumptionSolver.scala +++ b/src/main/scala/leon/solvers/NaiveAssumptionSolver.scala @@ -6,8 +6,8 @@ package solvers import purescala.Expressions._ import purescala.Constructors._ -trait NaiveAssumptionSolver extends AssumptionSolver { - self: IncrementalSolver => +trait NaiveAssumptionSolver { + self: Solver => var lastBs = Set[Expr]() def checkAssumptions(bs: Set[Expr]): Option[Boolean] = { diff --git a/src/main/scala/leon/solvers/SimpleAssumptionSolverAPI.scala b/src/main/scala/leon/solvers/SimpleAssumptionSolverAPI.scala deleted file mode 100644 index eb92c529c..000000000 --- a/src/main/scala/leon/solvers/SimpleAssumptionSolverAPI.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package solvers - -import purescala.Common.Identifier -import purescala.Expressions.Expr - -class SimpleAssumptionSolverAPI(sf: SolverFactory[AssumptionSolver]) extends SimpleSolverAPI(sf) { - - def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { - val s = sf.getNewSolver() - try { - s.assertCnstr(expression) - s.checkAssumptions(assumptions) match { - case Some(true) => - (Some(true), s.getModel, Set()) - case Some(false) => - (Some(false), Map(), s.getUnsatCore) - case None => - (None, Map(), Set()) - } - } finally { - sf.reclaim(s) - } - } -} diff --git a/src/main/scala/leon/solvers/SimpleSolverAPI.scala b/src/main/scala/leon/solvers/SimpleSolverAPI.scala index 907f903df..37aefc1b5 100644 --- a/src/main/scala/leon/solvers/SimpleSolverAPI.scala +++ b/src/main/scala/leon/solvers/SimpleSolverAPI.scala @@ -33,16 +33,27 @@ class SimpleSolverAPI(sf: SolverFactory[Solver]) { sf.reclaim(s) } } + + def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { + val s = sf.getNewSolver() + try { + s.assertCnstr(expression) + s.checkAssumptions(assumptions) match { + case Some(true) => + (Some(true), s.getModel, Set()) + case Some(false) => + (Some(false), Map(), s.getUnsatCore) + case None => + (None, Map(), Set()) + } + } finally { + sf.reclaim(s) + } + } } object SimpleSolverAPI { def apply(sf: SolverFactory[Solver]) = { new SimpleSolverAPI(sf) } - - // Wrapping an AssumptionSolver will automatically provide an extended - // interface - def apply(sf: SolverFactory[AssumptionSolver]) = { - new SimpleAssumptionSolverAPI(sf) - } } diff --git a/src/main/scala/leon/solvers/Solver.scala b/src/main/scala/leon/solvers/Solver.scala index 54879dcc4..9f8aaace4 100644 --- a/src/main/scala/leon/solvers/Solver.scala +++ b/src/main/scala/leon/solvers/Solver.scala @@ -3,12 +3,12 @@ package leon package solvers -import utils.DebugSectionSolver +import utils.{DebugSectionSolver, Interruptible} import purescala.Expressions._ import purescala.Common.Identifier import verification.VC -trait Solver { +trait Solver extends Interruptible { def name: String val context: LeonContext @@ -27,6 +27,12 @@ trait Solver { def reset() + def push(): Unit + def pop(): Unit + + def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] + def getUnsatCore: Set[Expr] + implicit val debugSection = DebugSectionSolver private[solvers] def debugS(msg: String) = { diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index 6e7273a11..27f4973e0 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -41,7 +41,7 @@ object SolverFactory { "ground" -> "Only solves ground verification conditions by evaluating them", "enum" -> "Enumeration-based counter-example-finder" ) - + val availableSolversPretty = "Available: " + solvers.SolverFactory.definedSolvers.toSeq.sortBy(_._1).map { case (name, desc) => f"\n $name%-14s : $desc" @@ -76,7 +76,7 @@ object SolverFactory { ctx.reporter.fatalError("Aborting Leon...") } - def getFromName(ctx: LeonContext, program: Program)(name: String): SolverFactory[IncrementalSolver with TimeoutSolver] = name match { + def getFromName(ctx: LeonContext, program: Program)(name: String): SolverFactory[TimeoutSolver] = name match { case "fairz3" => SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver) diff --git a/src/main/scala/leon/solvers/TimeoutAssumptionSolver.scala b/src/main/scala/leon/solvers/TimeoutAssumptionSolver.scala deleted file mode 100644 index 0294d8716..000000000 --- a/src/main/scala/leon/solvers/TimeoutAssumptionSolver.scala +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package solvers - -import purescala.Expressions.Expr - -trait TimeoutAssumptionSolver extends TimeoutSolver with AssumptionSolver { - - abstract override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - optTimeout match { - case Some(to) => - ti.interruptAfter(to) { - super.checkAssumptions(assumptions) - } - case None => - super.checkAssumptions(assumptions) - } - } -} - diff --git a/src/main/scala/leon/solvers/TimeoutSolver.scala b/src/main/scala/leon/solvers/TimeoutSolver.scala index a7c27f1a0..723b48886 100644 --- a/src/main/scala/leon/solvers/TimeoutSolver.scala +++ b/src/main/scala/leon/solvers/TimeoutSolver.scala @@ -4,10 +4,11 @@ package leon package solvers import utils._ +import purescala.Expressions.Expr import scala.concurrent.duration._ -trait TimeoutSolver extends Solver with Interruptible { +trait TimeoutSolver extends Solver { val ti = new TimeoutFor(this) @@ -34,4 +35,15 @@ trait TimeoutSolver extends Solver with Interruptible { } } + abstract override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + optTimeout match { + case Some(to) => + ti.interruptAfter(to) { + super.checkAssumptions(assumptions) + } + case None => + super.checkAssumptions(assumptions) + } + } + } diff --git a/src/main/scala/leon/solvers/combinators/DNFSolver.scala b/src/main/scala/leon/solvers/combinators/DNFSolver.scala deleted file mode 100644 index ad48ca2cd..000000000 --- a/src/main/scala/leon/solvers/combinators/DNFSolver.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package solvers -package combinators - -import purescala.Common._ -import purescala.Constructors._ -import purescala.Expressions._ -import purescala.ExprOps._ - -class DNFSolver(val context: LeonContext, - underlyings: SolverFactory[Solver]) extends Solver { - - def name = "DNF("+underlyings.name+")" - - def free() {} - - private var theConstraint : Option[Expr] = None - private var theModel : Option[Map[Identifier,Expr]] = None - - import context.reporter._ - - def assertCnstr(expression : Expr) { - if(theConstraint.isDefined) { fatalError("Multiple assertCnstr(...).") } - theConstraint = Some(expression) - } - - def reset() = { - throw new CantResetException(this) - } - - def check : Option[Boolean] = theConstraint.map { expr => - - val simpleSolver = SimpleSolverAPI(underlyings) - - var result : Option[Boolean] = None - - debugS("Before NNF:\n" + expr.asString) - - val nnfed = nnf(expr, false) - - debugS("After NNF:\n" + nnfed.asString) - - val dnfed = dnf(nnfed) - - debugS("After DNF:\n" + dnfed.asString) - - val candidates : Seq[Expr] = dnfed match { - case Or(es) => es - case elze => Seq(elze) - } - - debugS("# conjuncts : " + candidates.size) - - var done : Boolean = false - - for(candidate <- candidates if !done) { - simpleSolver.solveSAT(candidate) match { - case (Some(false), _) => - result = Some(false) - - case (Some(true), m) => - result = Some(true) - theModel = Some(m) - done = true - - case (None, m) => - result = None - theModel = Some(m) - done = true - } - } - result - } getOrElse { - Some(true) - } - - def getModel : Map[Identifier,Expr] = { - val vs : Set[Identifier] = theConstraint.map(variablesOf).getOrElse(Set.empty) - theModel.getOrElse(Map.empty).filter(p => vs(p._1)) - } - - private def nnf(expr : Expr, flip : Boolean) : Expr = expr match { - case _ : Let | _ : IfExpr => throw new Exception("Can't NNF *everything*, sorry.") - case Not(Implies(l,r)) => nnf(and(l, not(r)), flip) - case Implies(l, r) => nnf(or(not(l), r), flip) - case Not(Equals(l, r)) => nnf(or(and(l, not(r)), and(not(l), r)), flip) - case Equals(l, r) => nnf(or(and(l, r), and(not(l), not(r))), flip) - case And(es) if flip => orJoin(es.map(e => nnf(e, true))) - case And(es) => andJoin(es.map(e => nnf(e, false))) - case Or(es) if flip => andJoin(es.map(e => nnf(e, true))) - case Or(es) => orJoin(es.map(e => nnf(e, false))) - case Not(e) if flip => nnf(e, false) - case Not(e) => nnf(e, true) - case LessThan(l,r) if flip => GreaterEquals(l,r) - case GreaterThan(l,r) if flip => LessEquals(l,r) - case LessEquals(l,r) if flip => GreaterThan(l,r) - case GreaterEquals(l,r) if flip => LessThan(l,r) - case elze if flip => not(elze) - case elze => elze - } - - private def dnf(expr : Expr) : Expr = expr match { - case And(es) => - val (ors, lits) = es.partition(_.isInstanceOf[Or]) - if(ors.nonEmpty) { - val orHead = ors.head.asInstanceOf[Or] - val orTail = ors.tail - orJoin(orHead.exprs.map(oe => dnf(andJoin(filterObvious(lits ++ (oe +: orTail)))))) - } else { - expr - } - - case Or(es) => - orJoin(es.map(dnf)) - - case _ => expr - } - - private def filterObvious(exprs : Seq[Expr]) : Seq[Expr] = { - var pos : List[Identifier] = Nil - var neg : List[Identifier] = Nil - - for(e <- exprs) e match { - case Variable(id) => pos = id :: pos - case Not(Variable(id)) => neg = id :: neg - case _ => ; - } - - val both : Set[Identifier] = pos.toSet intersect neg.toSet - if(both.nonEmpty) { - Seq(BooleanLiteral(false)) - } else { - exprs - } - } -} diff --git a/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala b/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala index ea8331a4a..9997d5176 100644 --- a/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala +++ b/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala @@ -15,7 +15,7 @@ import scala.concurrent.duration._ import ExecutionContext.Implicits.global class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, val solvers: Seq[S]) - extends Solver with Interruptible { + extends Solver with NaiveAssumptionSolver { val name = "Pfolio" @@ -53,7 +53,7 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, val result = Future.find(fs)(_._2.isDefined) - val res = Await.result(result, 10.days) match { + val res = Await.result(result, Duration.Inf) match { case Some((s, r, m)) => modelMap = m resultSolver = s.getResultSolver @@ -66,10 +66,18 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, None } - fs map { Await.ready(_, 10.days) } + fs map { Await.ready(_, Duration.Inf) } res } + def push(): Unit = { + solvers.foreach(_.push()) + } + + def pop(): Unit = { + solvers.foreach(_.pop()) + } + def free() = { solvers.foreach(_.free) modelMap = Map() diff --git a/src/main/scala/leon/solvers/combinators/PortfolioSolverFactory.scala b/src/main/scala/leon/solvers/combinators/PortfolioSolverFactory.scala index 85c862d29..855911455 100644 --- a/src/main/scala/leon/solvers/combinators/PortfolioSolverFactory.scala +++ b/src/main/scala/leon/solvers/combinators/PortfolioSolverFactory.scala @@ -8,7 +8,7 @@ import utils.Interruptible import scala.collection.mutable.Queue import scala.reflect.runtime.universe._ -class PortfolioSolverFactory[S <: Solver with Interruptible](ctx: LeonContext, sfs: Seq[SolverFactory[S]])(implicit tag: TypeTag[S]) extends SolverFactory[PortfolioSolver[S] with TimeoutSolver] { +class PortfolioSolverFactory[S <: Solver](ctx: LeonContext, sfs: Seq[SolverFactory[S]])(implicit tag: TypeTag[S]) extends SolverFactory[PortfolioSolver[S] with TimeoutSolver] { def getNewSolver(): PortfolioSolver[S] with TimeoutSolver = { new PortfolioSolver[S](ctx, sfs.map(_.getNewSolver())) with TimeoutSolver diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index 5bb65cdb3..f4bd40885 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -16,7 +16,7 @@ import templates._ import utils.Interruptible import evaluators._ -class UnrollingSolver(val context: LeonContext, program: Program, underlying: IncrementalSolver with Interruptible) extends Solver with Interruptible { +class UnrollingSolver(val context: LeonContext, program: Program, underlying: Solver) extends Solver with NaiveAssumptionSolver { val feelingLucky = context.findOptionOrDefault(optFeelingLucky) val useCodeGen = context.findOptionOrDefault(optUseCodeGen) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 942faa108..3fcaacd44 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -30,9 +30,7 @@ import _root_.smtlib.{Interpreter => SMTInterpreter} abstract class SMTLIBSolver(val context: LeonContext, - val program: Program) - extends IncrementalSolver with Interruptible { - + val program: Program) extends Solver with NaiveAssumptionSolver { /* Solver name */ def targetName: String @@ -42,7 +40,6 @@ abstract class SMTLIBSolver(val context: LeonContext, protected val reporter = context.reporter /* Interface with Interpreter */ - def interpreterOps(ctx: LeonContext): Seq[String] def getNewInterpreter(ctx: LeonContext): SMTInterpreter @@ -51,7 +48,6 @@ abstract class SMTLIBSolver(val context: LeonContext, /* Printing VCs */ - protected lazy val out: Option[java.io.FileWriter] = if (reporter.isDebugEnabled) Some { val file = context.files.headOption.map(_.getName).getOrElse("NA") val n = VCNumbers.getNext(targetName+file) @@ -62,9 +58,9 @@ abstract class SMTLIBSolver(val context: LeonContext, dir.mkdir } - val fileName = s"vcs/$targetName-$file-$n.smt2" + val fileName = s"smt-sessions/$targetName-$file-$n.smt2" - reporter.debug(s"Outputting VC into $fileName" ) + reporter.debug(s"Outputting smt session into $fileName" ) val fw = new java.io.FileWriter(fileName, false) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 550e84f05..b2683eb58 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -21,11 +21,7 @@ import scala.collection.mutable.{Map => MutableMap} // This is just to factor out the things that are common in "classes that deal // with a Z3 instance" -trait AbstractZ3Solver - extends Solver - with AssumptionSolver - with IncrementalSolver - with Interruptible { +trait AbstractZ3Solver extends Solver { val context : LeonContext val program : Program diff --git a/src/main/scala/leon/solvers/ModelEnumerator.scala b/src/main/scala/leon/utils/ModelEnumerator.scala similarity index 74% rename from src/main/scala/leon/solvers/ModelEnumerator.scala rename to src/main/scala/leon/utils/ModelEnumerator.scala index 3a5fdc5df..e813d8b50 100644 --- a/src/main/scala/leon/solvers/ModelEnumerator.scala +++ b/src/main/scala/leon/utils/ModelEnumerator.scala @@ -1,5 +1,5 @@ package leon -package solvers +package utils import purescala.Definitions._ import purescala.Common._ @@ -8,14 +8,14 @@ import purescala.Constructors._ import purescala.ExprOps._ import purescala.Types._ import evaluators._ +import solvers._ - -class ModelEnumerator(ctx: LeonContext, pgm: Program, sf: SolverFactory[IncrementalSolver]) { +class ModelEnumerator(ctx: LeonContext, pgm: Program, sf: SolverFactory[Solver]) { private[this] var reclaimPool = List[Solver]() private[this] val evaluator = new DefaultEvaluator(ctx, pgm) - def enumSimple(ids: Seq[Identifier], cnstr: Expr): Iterator[Map[Identifier, Expr]] = { - enumVarying0(ids, cnstr, None, -1) + def enumSimple(ids: Seq[Identifier], satisfying: Expr): Iterator[Map[Identifier, Expr]] = { + enumVarying0(ids, satisfying, None, -1) } /** @@ -25,26 +25,26 @@ class ModelEnumerator(ctx: LeonContext, pgm: Program, sf: SolverFactory[Incremen * Note: there is no guarantee that the models enumerated consecutively share the * same `caracteristic`. */ - def enumVarying(ids: Seq[Identifier], cnstr: Expr, caracteristic: Expr, nPerCaracteristic: Int = 1) = { - enumVarying0(ids, cnstr, Some(caracteristic), nPerCaracteristic) + def enumVarying(ids: Seq[Identifier], satisfying: Expr, measure: Expr, nPerMeasure: Int = 1) = { + enumVarying0(ids, satisfying, Some(measure), nPerMeasure) } - private[this] def enumVarying0(ids: Seq[Identifier], cnstr: Expr, caracteristic: Option[Expr], nPerCaracteristic: Int = 1): Iterator[Map[Identifier, Expr]] = { + private[this] def enumVarying0(ids: Seq[Identifier], satisfying: Expr, measure: Option[Expr], nPerMeasure: Int = 1): Iterator[Map[Identifier, Expr]] = { val s = sf.getNewSolver reclaimPool ::= s - s.assertCnstr(cnstr) + s.assertCnstr(satisfying) - val c = caracteristic match { - case Some(car) => - val c = FreshIdentifier("car", car.getType) - s.assertCnstr(Equals(c.toVariable, car)) - c + val m = measure match { + case Some(ms) => + val m = FreshIdentifier("measure", ms.getType) + s.assertCnstr(Equals(m.toVariable, ms)) + m case None => FreshIdentifier("noop", BooleanType) } - var perCarRemaining = Map[Expr, Int]() + var perMeasureRem = Map[Expr, Int]().withDefaultValue(nPerMeasure) new Iterator[Map[Identifier, Expr]] { def hasNext = { @@ -53,27 +53,28 @@ class ModelEnumerator(ctx: LeonContext, pgm: Program, sf: SolverFactory[Incremen def next = { val sm = s.getModel - val m = (ids.map { id => + val model = (ids.map { id => id -> sm.getOrElse(id, simplestValue(id.getType)) }).toMap // Vary the model - s.assertCnstr(not(andJoin(m.toSeq.sortBy(_._1).map { case (k,v) => equality(k.toVariable, v) }))) + s.assertCnstr(not(andJoin(model.toSeq.sortBy(_._1).map { case (k,v) => equality(k.toVariable, v) }))) - caracteristic match { - case Some(car) => - val cValue = evaluator.eval(car, m).result.get + measure match { + case Some(ms) => + val mValue = evaluator.eval(ms, model).result.get - perCarRemaining += (cValue -> (perCarRemaining.getOrElse(cValue, nPerCaracteristic) - 1)) - if (perCarRemaining(cValue) == 0) { - s.assertCnstr(not(equality(c.toVariable, cValue))) + perMeasureRem += (mValue -> (perMeasureRem(mValue) - 1)) + + if (perMeasureRem(mValue) <= 0) { + s.assertCnstr(not(equality(m.toVariable, mValue))) } case None => } - m + model } } } @@ -90,13 +91,13 @@ class ModelEnumerator(ctx: LeonContext, pgm: Program, sf: SolverFactory[Incremen case object Up extends SearchDirection case object Down extends SearchDirection - private[this] def enumOptimizing(ids: Seq[Identifier], cnstr: Expr, measure: Expr, dir: SearchDirection): Iterator[Map[Identifier, Expr]] = { + private[this] def enumOptimizing(ids: Seq[Identifier], satisfying: Expr, measure: Expr, dir: SearchDirection): Iterator[Map[Identifier, Expr]] = { assert(measure.getType == IntegerType) val s = sf.getNewSolver reclaimPool ::= s - s.assertCnstr(cnstr) + s.assertCnstr(satisfying) val mId = FreshIdentifier("measure", measure.getType) s.assertCnstr(Equals(mId.toVariable, measure)) diff --git a/src/test/scala/leon/test/solvers/SolverPoolSuite.scala b/src/test/scala/leon/test/solvers/SolverPoolSuite.scala index 04d4d7723..b733ee193 100644 --- a/src/test/scala/leon/test/solvers/SolverPoolSuite.scala +++ b/src/test/scala/leon/test/solvers/SolverPoolSuite.scala @@ -12,7 +12,7 @@ import leon.purescala.Expressions._ class SolverPoolSuite extends LeonTestSuite { - private class DummySolver(val context : LeonContext, val program: Program) extends Solver { + private class DummySolver(val context : LeonContext, val program: Program) extends Solver with NaiveAssumptionSolver { val name = "Dummy" val description = "dummy" @@ -21,6 +21,10 @@ class SolverPoolSuite extends LeonTestSuite { def free() {} def reset() {} def getModel = ??? + def push() {} + def pop() {} + def interrupt() {} + def recoverInterrupt() {} } def sfactory(implicit ctx: LeonContext): SolverFactory[Solver] = { -- GitLab