From b86116f0c0548a98d3e7816b8640b2b8e03da95f Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <colder@php.net> Date: Wed, 18 Sep 2013 17:04:33 +0200 Subject: [PATCH] Re-introduce type hierarchy for solvers, simplify factories Solvers wrap solvers or factories, depending on the needs. Factories no longer wrap factories, except for the special case of timeoutsolverfactories (it does it in a typesafe way though). Fix TupleRewrite with new posts, fix ScopeSimplified, Fix pretty printer --- .../scala/leon/purescala/ScalaPrinter.scala | 36 +- src/main/scala/leon/purescala/TreeOps.scala | 13 +- src/main/scala/leon/purescala/Trees.scala | 4 + .../scala/leon/solvers/AssumptionSolver.scala | 11 + .../leon/solvers/IncrementalSolver.scala | 10 + .../solvers/SimpleAssumptionSolverAPI.scala | 27 ++ .../scala/leon/solvers/SimpleSolverAPI.scala | 60 +-- src/main/scala/leon/solvers/Solver.scala | 20 +- .../scala/leon/solvers/SolverFactory.scala | 41 +- .../solvers/TimeoutAssumptionSolver.scala | 23 + .../scala/leon/solvers/TimeoutSolver.scala | 71 +++ .../leon/solvers/TimeoutSolverFactory.scala | 13 + .../leon/solvers/combinators/DNFSolver.scala | 141 ++++++ .../combinators/DNFSolverFactory.scala | 173 ------- .../solvers/combinators/RewritingSolver.scala | 36 ++ .../combinators/RewritingSolverFactory.scala | 81 ---- .../combinators/TimeoutSolverFactory.scala | 112 ----- .../solvers/combinators/UnrollingSolver.scala | 155 +++++++ .../combinators/UnrollingSolverFactory.scala | 190 -------- .../leon/solvers/z3/AbstractZ3Solver.scala | 24 +- ...SolverFactory.scala => FairZ3Solver.scala} | 359 +++++++-------- .../leon/solvers/z3/FunctionTemplate.scala | 4 +- ...tory.scala => UninterpretedZ3Solver.scala} | 93 ++-- src/main/scala/leon/synthesis/Solution.scala | 3 +- .../leon/synthesis/SynthesisContext.scala | 13 +- .../scala/leon/synthesis/Synthesizer.scala | 4 +- .../condabd/SynthesizerExamples.scala | 3 +- .../ConditionAbductionSynthesisTwoPhase.scala | 2 +- .../verification/AbstractVerifier.scala | 2 +- .../verification/RelaxedVerifier.scala | 3 +- .../condabd/verification/Verifier.scala | 3 +- .../synthesis/heuristics/ADTInduction.scala | 6 +- .../heuristics/ADTLongInduction.scala | 5 +- .../scala/leon/synthesis/rules/ADTSplit.scala | 4 +- .../scala/leon/synthesis/rules/Cegis.scala | 430 +++++++++--------- .../leon/synthesis/rules/EqualitySplit.scala | 2 - .../scala/leon/synthesis/rules/Ground.scala | 4 +- .../synthesis/rules/InequalitySplit.scala | 2 - .../synthesis/rules/OptimisticGround.scala | 2 - src/main/scala/leon/testgen/CallGraph.scala | 9 +- .../scala/leon/testgen/TestGeneration.scala | 10 +- .../leon/verification/AnalysisPhase.scala | 136 +++--- .../verification/VerificationCondition.scala | 2 +- .../condabd/VariableSolverRefinerTest.scala | 5 - .../leon/test/condabd/VerifierTest.scala | 194 ++++---- .../leon/test/purescala/TreeOpsTests.scala | 4 - .../test/solvers/TimeoutSolverTests.scala | 43 +- .../test/solvers/z3/FairZ3SolverTests.scala | 4 +- .../solvers/z3/FairZ3SolverTestsNewAPI.scala | 48 +- .../z3/UninterpretedZ3SolverTests.scala | 4 +- 50 files changed, 1268 insertions(+), 1376 deletions(-) create mode 100644 src/main/scala/leon/solvers/AssumptionSolver.scala create mode 100644 src/main/scala/leon/solvers/IncrementalSolver.scala create mode 100644 src/main/scala/leon/solvers/SimpleAssumptionSolverAPI.scala create mode 100644 src/main/scala/leon/solvers/TimeoutAssumptionSolver.scala create mode 100644 src/main/scala/leon/solvers/TimeoutSolver.scala create mode 100644 src/main/scala/leon/solvers/TimeoutSolverFactory.scala create mode 100644 src/main/scala/leon/solvers/combinators/DNFSolver.scala delete mode 100644 src/main/scala/leon/solvers/combinators/DNFSolverFactory.scala create mode 100644 src/main/scala/leon/solvers/combinators/RewritingSolver.scala delete mode 100644 src/main/scala/leon/solvers/combinators/RewritingSolverFactory.scala delete mode 100644 src/main/scala/leon/solvers/combinators/TimeoutSolverFactory.scala create mode 100644 src/main/scala/leon/solvers/combinators/UnrollingSolver.scala delete mode 100644 src/main/scala/leon/solvers/combinators/UnrollingSolverFactory.scala rename src/main/scala/leon/solvers/z3/{FairZ3SolverFactory.scala => FairZ3Solver.scala} (58%) rename src/main/scala/leon/solvers/z3/{UninterpretedZ3SolverFactory.scala => UninterpretedZ3Solver.scala} (56%) diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index 67241adef..ec69ff46e 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -19,14 +19,14 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb // EXPRESSIONS // all expressions are printed in-line override def pp(tree: Expr, lvl: Int): Unit = tree match { - case Variable(id) => sb.append(id) + case Variable(id) => sb.append(idToString(id)) case DeBruijnIndex(idx) => sys.error("Not Valid Scala") case LetTuple(ids,d,e) => sb.append("locally {\n") ind(lvl+1) sb.append("val (" ) for (((id, tpe), i) <- ids.map(id => (id, id.getType)).zipWithIndex) { - sb.append(id.toString+": ") + sb.append(idToString(id)+": ") pp(tpe, lvl) if (i != ids.size-1) { sb.append(", ") @@ -85,7 +85,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb sb.append("._" + i) case CaseClass(cd, args) => - sb.append(cd.id) + sb.append(idToString(cd.id)) if (cd.isCaseObject) { ppNary(args, "", "", "", lvl) } else { @@ -94,14 +94,14 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb case CaseClassInstanceOf(cd, e) => pp(e, lvl) - sb.append(".isInstanceOf[" + cd.id + "]") + sb.append(".isInstanceOf[" + idToString(cd.id) + "]") case CaseClassSelector(_, cc, id) => pp(cc, lvl) - sb.append("." + id) + sb.append("." + idToString(id)) case FunctionInvocation(fd, args) => - sb.append(fd.id) + sb.append(idToString(fd.id)) ppNary(args, "(", ", ", ")", lvl) case Plus(l,r) => ppBinary(l, r, " + ", lvl) @@ -210,7 +210,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb case Choose(ids, pred) => sb.append("(choose { (") for (((id, tpe), i) <- ids.map(id => (id, id.getType)).zipWithIndex) { - sb.append(id.toString+": ") + sb.append(idToString(id)+": ") pp(tpe, lvl) if (i != ids.size-1) { sb.append(", ") @@ -229,7 +229,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb //case InstanceOfPattern(Some(id), ctd) => case CaseClassPattern(bndr, ccd, subps) => { bndr.foreach(b => sb.append(b + " @ ")) - sb.append(ccd.id).append("(") + sb.append(idToString(ccd.id)).append("(") var c = 0 val sz = subps.size subps.foreach(sp => { @@ -241,10 +241,10 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb sb.append(")") } case WildcardPattern(None) => sb.append("_") - case WildcardPattern(Some(id)) => sb.append(id) + case WildcardPattern(Some(id)) => sb.append(idToString(id)) case InstanceOfPattern(bndr, ccd) => { bndr.foreach(b => sb.append(b + " : ")) - sb.append(ccd.id) + sb.append(idToString(ccd.id)) } case TuplePattern(bndr, subPatterns) => { bndr.foreach(b => sb.append(b + " @ ")) @@ -333,7 +333,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb } sb.append(" => ") pp(tt, lvl) - case c: ClassType => sb.append(c.classDef.id) + case c: ClassType => sb.append(idToString(c.classDef.id)) case _ => sb.append("Type?") } @@ -350,7 +350,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb case ObjectDef(id, defs, invs) => { ind(lvl) sb.append("object ") - sb.append(id) + sb.append(idToString(id)) sb.append(" {\n") var c = 0 @@ -372,19 +372,19 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb case AbstractClassDef(id, parent) => ind(lvl) sb.append("sealed abstract class ") - sb.append(id) + sb.append(idToString(id)) parent.foreach(p => sb.append(" extends " + p.id)) case CaseClassDef(id, parent, varDecls) => ind(lvl) sb.append("case class ") - sb.append(id) + sb.append(idToString(id)) sb.append("(") var c = 0 val sz = varDecls.size varDecls.foreach(vd => { - sb.append(vd.id) + sb.append(idToString(vd.id)) sb.append(": ") pp(vd.tpe, lvl) if(c < sz - 1) { @@ -393,20 +393,20 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb c = c + 1 }) sb.append(")") - parent.foreach(p => sb.append(" extends " + p.id)) + parent.foreach(p => sb.append(" extends " + idToString(p.id))) case fd: FunDef => ind(lvl) sb.append("def ") - sb.append(fd.id) + sb.append(idToString(fd.id)) sb.append("(") val sz = fd.args.size var c = 0 fd.args.foreach(arg => { - sb.append(arg.id) + sb.append(idToString(arg.id)) sb.append(" : ") pp(arg.tpe, lvl) diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 3479c8192..2ac6b102e 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1449,7 +1449,7 @@ object TreeOps { case (id, post) => val nid = genId(id, newScope) val postScope = newScope.register(id -> nid) - (id, rec(post, postScope)) + (nid, rec(post, postScope)) } LetDef(newFd, rec(body, newScope)) @@ -1583,6 +1583,7 @@ object TreeOps { case Untyped | AnyType | BottomType | BooleanType | Int32Type | UnitType => None } + var idMap = Map[Identifier, Identifier]() var funDefMap = Map.empty[FunDef,FunDef] def fd2fd(funDef : FunDef) : FunDef = funDefMap.get(funDef) match { @@ -1598,7 +1599,14 @@ object TreeOps { // These will be taken care of in the recursive traversal. fd.body = funDef.body fd.precondition = funDef.precondition - fd.postcondition = funDef.postcondition + funDef.postcondition match { + case Some((id, post)) => + val freshId = FreshIdentifier(id.name, true).setType(rt) + idMap += id -> freshId + fd.postcondition = Some((freshId, post)) + case None => + fd.postcondition = None + } fd } funDefMap = funDefMap.updated(funDef, newFD) @@ -1607,6 +1615,7 @@ object TreeOps { def pre(e : Expr) : Expr = e match { case Tuple(Seq()) => UnitLiteral + case Variable(id) if idMap contains id => Variable(idMap(id)) case Tuple(Seq(s)) => pre(s) diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index d7bd3e968..65b25f25e 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -490,9 +490,13 @@ object Trees { leastUpperBound(Seq(set1, set2).map(_.getType)).foreach(setType _) } case class SetMin(set: Expr) extends Expr with FixedType { + typeCheck(set, SetType(Int32Type)) + val fixedType = Int32Type } case class SetMax(set: Expr) extends Expr with FixedType { + typeCheck(set, SetType(Int32Type)) + val fixedType = Int32Type } diff --git a/src/main/scala/leon/solvers/AssumptionSolver.scala b/src/main/scala/leon/solvers/AssumptionSolver.scala new file mode 100644 index 000000000..e906687be --- /dev/null +++ b/src/main/scala/leon/solvers/AssumptionSolver.scala @@ -0,0 +1,11 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import purescala.Trees.Expr + +trait AssumptionSolver extends Solver { + def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] + def getUnsatCore: Set[Expr] +} diff --git a/src/main/scala/leon/solvers/IncrementalSolver.scala b/src/main/scala/leon/solvers/IncrementalSolver.scala new file mode 100644 index 000000000..79c8d04d4 --- /dev/null +++ b/src/main/scala/leon/solvers/IncrementalSolver.scala @@ -0,0 +1,10 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +trait IncrementalSolver extends Solver { + def push(): Unit + def pop(lvl: Int = 1): Unit +} + diff --git a/src/main/scala/leon/solvers/SimpleAssumptionSolverAPI.scala b/src/main/scala/leon/solvers/SimpleAssumptionSolverAPI.scala new file mode 100644 index 000000000..0ce1eb8d2 --- /dev/null +++ b/src/main/scala/leon/solvers/SimpleAssumptionSolverAPI.scala @@ -0,0 +1,27 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import purescala.Common._ +import purescala.Trees._ + +class SimpleAssumptionSolverAPI(sf: SolverFactory[AssumptionSolver]) extends SimpleSolverAPI(sf) { + + def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { + val s = sf.getNewSolver + try { + s.assertCnstr(expression) + s.checkAssumptions(assumptions) match { + case Some(true) => + (Some(true), s.getModel, Set()) + case Some(false) => + (Some(false), Map(), s.getUnsatCore) + case None => + (None, Map(), Set()) + } + } finally { + s.free() + } + } +} diff --git a/src/main/scala/leon/solvers/SimpleSolverAPI.scala b/src/main/scala/leon/solvers/SimpleSolverAPI.scala index 6565a2148..59b20f8f9 100644 --- a/src/main/scala/leon/solvers/SimpleSolverAPI.scala +++ b/src/main/scala/leon/solvers/SimpleSolverAPI.scala @@ -6,41 +6,43 @@ package solvers import purescala.Common._ import purescala.Trees._ -case class SimpleSolverAPI[S <: Solver](sf: SolverFactory[S]) { +class SimpleSolverAPI(sf: SolverFactory[Solver]) { def solveVALID(expression: Expr): Option[Boolean] = { - val s = sf.getNewSolver() - s.assertCnstr(Not(expression)) - s.check.map(r => !r) - } - - def free() { - sf.free() + val s = sf.getNewSolver + try { + s.assertCnstr(Not(expression)) + s.check.map(r => !r) + } finally { + s.free() + } } def solveSAT(expression: Expr): (Option[Boolean], Map[Identifier, Expr]) = { - val s = sf.getNewSolver() - s.assertCnstr(expression) - s.check match { - case Some(true) => - (Some(true), s.getModel) - case Some(false) => - (Some(false), Map()) - case None => - (None, Map()) + val s = sf.getNewSolver + try { + s.assertCnstr(expression) + s.check match { + case Some(true) => + (Some(true), s.getModel) + case Some(false) => + (Some(false), Map()) + case None => + (None, Map()) + } + } finally { + s.free() } } +} + +object SimpleSolverAPI { + def apply(sf: SolverFactory[Solver]) = { + new SimpleSolverAPI(sf) + } - def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { - val s = sf.getNewSolver() - s.assertCnstr(expression) - s.checkAssumptions(assumptions) match { - case Some(true) => - (Some(true), s.getModel, Set()) - case Some(false) => - (Some(false), Map(), s.getUnsatCore) - case None => - (None, Map(), Set()) - } + // Wrapping an AssumptionSolver will automatically provide an extended + // interface + def apply(sf: SolverFactory[AssumptionSolver]) = { + new SimpleAssumptionSolverAPI(sf) } } - diff --git a/src/main/scala/leon/solvers/Solver.scala b/src/main/scala/leon/solvers/Solver.scala index cb257d236..b4edadedb 100644 --- a/src/main/scala/leon/solvers/Solver.scala +++ b/src/main/scala/leon/solvers/Solver.scala @@ -4,18 +4,24 @@ package leon package solvers import utils._ -import purescala.Common._ -import purescala.Trees._ +import purescala.Trees.Expr +import purescala.Common.Identifier + + +trait Solver { + def name: String + val context: LeonContext -trait Solver extends Interruptible { - def push(): Unit - def pop(lvl: Int = 1): Unit def assertCnstr(expression: Expr): Unit def check: Option[Boolean] - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] def getModel: Map[Identifier, Expr] - def getUnsatCore: Set[Expr] + + def free() implicit val debugSection = DebugSectionSolver + + private[solvers] def debugS(msg: String) = { + context.reporter.debug("["+name+"] "+msg) + } } diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index aa343136a..147d681d9 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -3,41 +3,18 @@ package leon package solvers -import solvers.combinators._ - -import utils._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.TreeOps._ -import purescala.Trees._ - -trait SolverFactory[S <: Solver] extends Interruptible with LeonComponent { - val context: LeonContext - val program: Program - - def free() {} - - var interrupted = false - - override def interrupt() { - interrupted = true - } - - override def recoverInterrupt() { - interrupted = false - } +import scala.reflect.runtime.universe._ +abstract class SolverFactory[+S <: Solver : TypeTag] { def getNewSolver(): S - def withTimeout(ms: Long): TimeoutSolverFactory[S] = { - this match { - case tsf: TimeoutSolverFactory[S] => - // Unwrap/Rewrap to take only new timeout into account - new TimeoutSolverFactory[S](tsf.sf, ms) - case _ => - new TimeoutSolverFactory[S](this, ms) + val name = "SFact("+typeOf[S].toString+")" +} + +object SolverFactory { + def apply[S <: Solver : TypeTag](builder: () => S): SolverFactory[S] = { + new SolverFactory[S] { + def getNewSolver() = builder() } } - - implicit val debugSection = DebugSectionSolver } diff --git a/src/main/scala/leon/solvers/TimeoutAssumptionSolver.scala b/src/main/scala/leon/solvers/TimeoutAssumptionSolver.scala new file mode 100644 index 000000000..ae59400da --- /dev/null +++ b/src/main/scala/leon/solvers/TimeoutAssumptionSolver.scala @@ -0,0 +1,23 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import purescala.Trees.Expr + +trait TimeoutAssumptionSolver extends TimeoutSolver with AssumptionSolver { + + protected def innerCheckAssumptions(assumptions: Set[Expr]): Option[Boolean] + + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + optTimeout match { + case Some(to) => + interruptAfter(to) { + innerCheckAssumptions(assumptions) + } + case None => + innerCheckAssumptions(assumptions) + } + } +} + diff --git a/src/main/scala/leon/solvers/TimeoutSolver.scala b/src/main/scala/leon/solvers/TimeoutSolver.scala new file mode 100644 index 000000000..ab98755c2 --- /dev/null +++ b/src/main/scala/leon/solvers/TimeoutSolver.scala @@ -0,0 +1,71 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import utils._ + +trait TimeoutSolver extends Solver with Interruptible { + + private class Countdown(timeout: Long, onTimeout: => Unit) extends Thread { + private var keepRunning = true + override def run : Unit = { + val startTime : Long = System.currentTimeMillis + var exceeded : Boolean = false + + while(!exceeded && keepRunning) { + if(timeout < (System.currentTimeMillis - startTime)) { + exceeded = true + } + Thread.sleep(10) + } + + if(exceeded && keepRunning) { + onTimeout + } + } + + def finishedRunning : Unit = { + keepRunning = false + } + } + + protected var optTimeout: Option[Long] = None; + protected def interruptAfter[T](timeout: Long)(body: => T): T = { + var reachedTimeout = false + + val timer = new Countdown(timeout, { + interrupt() + reachedTimeout = true + }) + + timer.start + val res = body + timer.finishedRunning + + if (reachedTimeout) { + recoverInterrupt() + } + + res + } + + def setTimeout(timeout: Long): this.type = { + optTimeout = Some(timeout) + this + } + + protected def innerCheck: Option[Boolean] + + override def check: Option[Boolean] = { + optTimeout match { + case Some(to) => + interruptAfter(to) { + innerCheck + } + case None => + innerCheck + } + } + +} diff --git a/src/main/scala/leon/solvers/TimeoutSolverFactory.scala b/src/main/scala/leon/solvers/TimeoutSolverFactory.scala new file mode 100644 index 000000000..02bf21c8a --- /dev/null +++ b/src/main/scala/leon/solvers/TimeoutSolverFactory.scala @@ -0,0 +1,13 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import scala.reflect.runtime.universe._ + + +class TimeoutSolverFactory[+S <: TimeoutSolver : TypeTag](sf: SolverFactory[S], to: Long) extends SolverFactory[S] { + override def getNewSolver() = sf.getNewSolver().setTimeout(to) + + override val name = "SFact("+typeOf[S].toString+") with t.o" +} diff --git a/src/main/scala/leon/solvers/combinators/DNFSolver.scala b/src/main/scala/leon/solvers/combinators/DNFSolver.scala new file mode 100644 index 000000000..507106bd1 --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/DNFSolver.scala @@ -0,0 +1,141 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ + +class DNFSolver(val context: LeonContext, + underlyings: SolverFactory[Solver]) extends Solver { + + def name = "DNF("+underlyings.name+")" + + def free {} + + private var theConstraint : Option[Expr] = None + private var theModel : Option[Map[Identifier,Expr]] = None + + import context.reporter._ + + def assertCnstr(expression : Expr) { + if(!theConstraint.isEmpty) { fatalError("Multiple assertCnstr(...).") } + theConstraint = Some(expression) + } + + def check : Option[Boolean] = theConstraint.map { expr => + + val simpleSolver = SimpleSolverAPI(underlyings) + + var result : Option[Boolean] = None + + debugS("Before NNF:\n" + expr) + + val nnfed = nnf(expr, false) + + debugS("After NNF:\n" + nnfed) + + val dnfed = dnf(nnfed) + + debugS("After DNF:\n" + dnfed) + + val candidates : Seq[Expr] = dnfed match { + case Or(es) => es + case elze => Seq(elze) + } + + debugS("# conjuncts : " + candidates.size) + + var done : Boolean = false + + for(candidate <- candidates if !done) { + simpleSolver.solveSAT(candidate) match { + case (Some(false), _) => + result = Some(false) + + case (Some(true), m) => + result = Some(true) + theModel = Some(m) + done = true + + case (None, m) => + result = None + theModel = Some(m) + done = true + } + } + result + } getOrElse { + Some(true) + } + + def getModel : Map[Identifier,Expr] = { + val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) + theModel.getOrElse(Map.empty).filter(p => vs(p._1)) + } + + private def nnf(expr : Expr, flip : Boolean) : Expr = expr match { + case _ : Let | _ : IfExpr => throw new Exception("Can't NNF *everything*, sorry.") + case Not(Implies(l,r)) => nnf(And(l, Not(r)), flip) + case Implies(l, r) => nnf(Or(Not(l), r), flip) + case Not(Iff(l, r)) => nnf(Or(And(l, Not(r)), And(Not(l), r)), flip) + case Iff(l, r) => nnf(Or(And(l, r), And(Not(l), Not(r))), flip) + case And(es) if flip => Or(es.map(e => nnf(e, true))) + case And(es) => And(es.map(e => nnf(e, false))) + case Or(es) if flip => And(es.map(e => nnf(e, true))) + case Or(es) => Or(es.map(e => nnf(e, false))) + case Not(e) if flip => nnf(e, false) + case Not(e) => nnf(e, true) + case LessThan(l,r) if flip => GreaterEquals(l,r) + case GreaterThan(l,r) if flip => LessEquals(l,r) + case LessEquals(l,r) if flip => GreaterThan(l,r) + case GreaterEquals(l,r) if flip => LessThan(l,r) + case elze if flip => Not(elze) + case elze => elze + } + + // fun pushC (And(p,Or(q,r))) = Or(pushC(And(p,q)),pushC(And(p,r))) + // | pushC (And(Or(q,r),p)) = Or(pushC(And(p,q)),pushC(And(p,r))) + // | pushC (And(p,q)) = And(pushC(p),pushC(q)) + // | pushC (Literal(l)) = Literal(l) + // | pushC (Or(p,q)) = Or(pushC(p),pushC(q)) + + private def dnf(expr : Expr) : Expr = expr match { + case And(es) => + val (ors, lits) = es.partition(_.isInstanceOf[Or]) + if(!ors.isEmpty) { + val orHead = ors.head.asInstanceOf[Or] + val orTail = ors.tail + Or(orHead.exprs.map(oe => dnf(And(filterObvious(lits ++ (oe +: orTail)))))) + } else { + expr + } + + case Or(es) => + Or(es.map(dnf(_))) + + case _ => expr + } + + private def filterObvious(exprs : Seq[Expr]) : Seq[Expr] = { + var pos : List[Identifier] = Nil + var neg : List[Identifier] = Nil + + for(e <- exprs) e match { + case Variable(id) => pos = id :: pos + case Not(Variable(id)) => neg = id :: neg + case _ => ; + } + + val both : Set[Identifier] = pos.toSet intersect neg.toSet + if(!both.isEmpty) { + Seq(BooleanLiteral(false)) + } else { + exprs + } + } +} diff --git a/src/main/scala/leon/solvers/combinators/DNFSolverFactory.scala b/src/main/scala/leon/solvers/combinators/DNFSolverFactory.scala deleted file mode 100644 index 6ef2752d9..000000000 --- a/src/main/scala/leon/solvers/combinators/DNFSolverFactory.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers -package combinators - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Trees._ -import purescala.TypeTrees._ -import purescala.TreeOps._ - -import scala.collection.mutable.{Map=>MutableMap} - -class DNFSolverFactory[S <: Solver](val sf : SolverFactory[S]) extends SolverFactory[Solver] { - val description = "DNF around a base solver" - val name = sf.name + "!" - - val context = sf.context - val program = sf.program - - private val thisFactory = this - - override def free() { - sf.free() - } - - override def recoverInterrupt() { - sf.recoverInterrupt() - } - - def getNewSolver() : Solver = { - new Solver { - private var theConstraint : Option[Expr] = None - private var theModel : Option[Map[Identifier,Expr]] = None - - private def fail(because : String) : Nothing = { throw new Exception("Not supported in DNFSolvers : " + because) } - - def push() : Unit = fail("push()") - def pop(lvl : Int = 1) : Unit = fail("pop(lvl)") - - def assertCnstr(expression : Expr) { - if(!theConstraint.isEmpty) { fail("Multiple assertCnstr(...).") } - theConstraint = Some(expression) - } - - def interrupt() { fail("interrupt()") } - - def recoverInterrupt() { fail("recoverInterrupt()") } - - def check : Option[Boolean] = theConstraint.map { expr => - import context.reporter - - val simpleSolver = SimpleSolverAPI(sf) - - var result : Option[Boolean] = None - - def info(msg : String) { reporter.info("In " + thisFactory.name + ": " + msg) } - - // info("Before NNF:\n" + expr) - - val nnfed = nnf(expr, false) - - // info("After NNF:\n" + nnfed) - - val dnfed = dnf(nnfed) - - // info("After DNF:\n" + dnfed) - - val candidates : Seq[Expr] = dnfed match { - case Or(es) => es - case elze => Seq(elze) - } - - info("# conjuncts : " + candidates.size) - - var done : Boolean = false - - for(candidate <- candidates if !done) { - simpleSolver.solveSAT(candidate) match { - case (Some(false), _) => - result = Some(false) - - case (Some(true), m) => - result = Some(true) - theModel = Some(m) - done = true - - case (None, m) => - result = None - theModel = Some(m) - done = true - } - } - result - } getOrElse { - Some(true) - } - - def checkAssumptions(assumptions : Set[Expr]) : Option[Boolean] = { - fail("checkAssumptions(assumptions)") - } - - def getModel : Map[Identifier,Expr] = { - val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) - theModel.getOrElse(Map.empty).filter(p => vs(p._1)) - } - - def getUnsatCore : Set[Expr] = { fail("getUnsatCore") } - } - } - - private def nnf(expr : Expr, flip : Boolean) : Expr = expr match { - case _ : Let | _ : IfExpr => throw new Exception("Can't NNF *everything*, sorry.") - case Not(Implies(l,r)) => nnf(And(l, Not(r)), flip) - case Implies(l, r) => nnf(Or(Not(l), r), flip) - case Not(Iff(l, r)) => nnf(Or(And(l, Not(r)), And(Not(l), r)), flip) - case Iff(l, r) => nnf(Or(And(l, r), And(Not(l), Not(r))), flip) - case And(es) if flip => Or(es.map(e => nnf(e, true))) - case And(es) => And(es.map(e => nnf(e, false))) - case Or(es) if flip => And(es.map(e => nnf(e, true))) - case Or(es) => Or(es.map(e => nnf(e, false))) - case Not(e) if flip => nnf(e, false) - case Not(e) => nnf(e, true) - case LessThan(l,r) if flip => GreaterEquals(l,r) - case GreaterThan(l,r) if flip => LessEquals(l,r) - case LessEquals(l,r) if flip => GreaterThan(l,r) - case GreaterEquals(l,r) if flip => LessThan(l,r) - case elze if flip => Not(elze) - case elze => elze - } - - // fun pushC (And(p,Or(q,r))) = Or(pushC(And(p,q)),pushC(And(p,r))) - // | pushC (And(Or(q,r),p)) = Or(pushC(And(p,q)),pushC(And(p,r))) - // | pushC (And(p,q)) = And(pushC(p),pushC(q)) - // | pushC (Literal(l)) = Literal(l) - // | pushC (Or(p,q)) = Or(pushC(p),pushC(q)) - - private def dnf(expr : Expr) : Expr = expr match { - case And(es) => - val (ors, lits) = es.partition(_.isInstanceOf[Or]) - if(!ors.isEmpty) { - val orHead = ors.head.asInstanceOf[Or] - val orTail = ors.tail - Or(orHead.exprs.map(oe => dnf(And(filterObvious(lits ++ (oe +: orTail)))))) - } else { - expr - } - - case Or(es) => - Or(es.map(dnf(_))) - - case _ => expr - } - - private def filterObvious(exprs : Seq[Expr]) : Seq[Expr] = { - var pos : List[Identifier] = Nil - var neg : List[Identifier] = Nil - - for(e <- exprs) e match { - case Variable(id) => pos = id :: pos - case Not(Variable(id)) => neg = id :: neg - case _ => ; - } - - val both : Set[Identifier] = pos.toSet intersect neg.toSet - if(!both.isEmpty) { - Seq(BooleanLiteral(false)) - } else { - exprs - } - } -} diff --git a/src/main/scala/leon/solvers/combinators/RewritingSolver.scala b/src/main/scala/leon/solvers/combinators/RewritingSolver.scala new file mode 100644 index 000000000..44e442b26 --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/RewritingSolver.scala @@ -0,0 +1,36 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ + +abstract class RewritingSolver[+S <: Solver, T](underlying: S) { + val context = underlying.context + + /** The type T is used to encode any meta information useful, for instance, to reconstruct + * models. */ + def rewriteCnstr(expression : Expr) : (Expr,T) + + def reconstructModel(model : Map[Identifier,Expr], meta : T) : Map[Identifier,Expr] + + private var storedMeta : List[T] = Nil + + def assertCnstr(expression : Expr) { + val (rewritten, meta) = rewriteCnstr(expression) + storedMeta = meta :: storedMeta + underlying.assertCnstr(rewritten) + } + + def getModel : Map[Identifier,Expr] = { + storedMeta match { + case Nil => underlying.getModel + case m :: _ => reconstructModel(underlying.getModel, m) + } + } +} diff --git a/src/main/scala/leon/solvers/combinators/RewritingSolverFactory.scala b/src/main/scala/leon/solvers/combinators/RewritingSolverFactory.scala deleted file mode 100644 index 2a8fbb37a..000000000 --- a/src/main/scala/leon/solvers/combinators/RewritingSolverFactory.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers -package combinators - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Trees._ -import purescala.TypeTrees._ - -/** This is for solvers that operate by rewriting formulas into equisatisfiable ones. - * They are essentially defined by two methods, one for preprocessing of the expressions, - * and one for reconstructing the models. */ -abstract class RewritingSolverFactory[S <: Solver,T](val sf : SolverFactory[S]) extends SolverFactory[Solver] { - val context = sf.context - val program = sf.program - - override def free() { - sf.free() - } - - override def recoverInterrupt() { - sf.recoverInterrupt() - } - - /** The type T is used to encode any meta information useful, for instance, to reconstruct - * models. */ - def rewriteCnstr(expression : Expr) : (Expr,T) - - def reconstructModel(model : Map[Identifier,Expr], meta : T) : Map[Identifier,Expr] - - def getNewSolver() : Solver = { - new Solver { - val underlying : Solver = sf.getNewSolver() - - private def fail(because : String) : Nothing = { - throw new Exception("Not supported in RewritingSolvers : " + because) - } - - def push() : Unit = fail("push()") - def pop(lvl : Int = 1) : Unit = fail("pop(lvl)") - - private var storedMeta : List[T] = Nil - - def assertCnstr(expression : Expr) { - context.reporter.info("Asked to solve this in BAPA<:\n" + expression) - val (rewritten, meta) = rewriteCnstr(expression) - storedMeta = meta :: storedMeta - underlying.assertCnstr(rewritten) - } - - def interrupt() { - underlying.interrupt() - } - - def recoverInterrupt() { - underlying.recoverInterrupt() - } - - def check : Option[Boolean] = { - underlying.check - } - - def checkAssumptions(assumptions : Set[Expr]) : Option[Boolean] = { - fail("checkAssumptions(assumptions)") - } - - def getModel : Map[Identifier,Expr] = { - storedMeta match { - case Nil => fail("reconstructing model without meta-information.") - case m :: _ => reconstructModel(underlying.getModel, m) - } - } - - def getUnsatCore : Set[Expr] = { - fail("getUnsatCore") - } - } - } -} diff --git a/src/main/scala/leon/solvers/combinators/TimeoutSolverFactory.scala b/src/main/scala/leon/solvers/combinators/TimeoutSolverFactory.scala deleted file mode 100644 index dc670bfbe..000000000 --- a/src/main/scala/leon/solvers/combinators/TimeoutSolverFactory.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers -package combinators - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Trees._ -import purescala.TypeTrees._ - -class TimeoutSolverFactory[S <: Solver](val sf: SolverFactory[S], val timeoutMs: Long) extends SolverFactory[Solver] { - val description = sf.description + ", with "+timeoutMs+"ms timeout" - val name = sf.name + "+to" - - val context = sf.context - val program = sf.program - - private class Countdown(onTimeout: => Unit) extends Thread { - private var keepRunning = true - private val asMillis : Long = timeoutMs - - override def run : Unit = { - val startTime : Long = System.currentTimeMillis - var exceeded : Boolean = false - - while(!exceeded && keepRunning) { - if(asMillis < (System.currentTimeMillis - startTime)) { - exceeded = true - } - Thread.sleep(10) - } - if(exceeded && keepRunning) { - onTimeout - } - } - - def finishedRunning : Unit = { - keepRunning = false - } - } - - override def free() { - sf.free() - } - - def withTimeout[T](solver: S)(body: => T): T = { - val timer = new Countdown(timeout(solver)) - timer.start - val res = body - timer.finishedRunning - recoverFromTimeout(solver) - res - } - - var reachedTimeout = false - def timeout(solver: S) { - solver.interrupt() - reachedTimeout = true - } - - def recoverFromTimeout(solver: S) { - if (reachedTimeout) { - solver.recoverInterrupt() - reachedTimeout = false - } - } - - def getNewSolver = new Solver { - val solver = sf.getNewSolver - - def push(): Unit = { - solver.push() - } - - def pop(lvl: Int = 1): Unit = { - solver.pop(lvl) - } - - def assertCnstr(expression: Expr): Unit = { - solver.assertCnstr(expression) - } - - def interrupt() { - solver.interrupt() - } - - def recoverInterrupt() { - solver.recoverInterrupt() - } - - def check: Option[Boolean] = { - withTimeout(solver){ - solver.check - } - } - - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - withTimeout(solver){ - solver.checkAssumptions(assumptions) - } - } - - def getModel: Map[Identifier, Expr] = { - solver.getModel - } - - def getUnsatCore: Set[Expr] = { - solver.getUnsatCore - } - } -} diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala new file mode 100644 index 000000000..ad4eb78bc --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -0,0 +1,155 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ + +import scala.collection.mutable.{Map=>MutableMap} + +class UnrollingSolver(val context: LeonContext, + underlyings: SolverFactory[Solver], + maxUnrollings: Int = 3) extends Solver { + + private var theConstraint : Option[Expr] = None + private var theModel : Option[Map[Identifier,Expr]] = None + + def name = "Unr("+underlyings.name+")" + + def free {} + + import context.reporter._ + + def assertCnstr(expression : Expr) { + if(!theConstraint.isEmpty) { + fatalError("Multiple assertCnstr(...).") + } + theConstraint = Some(expression) + } + + def check : Option[Boolean] = theConstraint.map { expr => + val simpleSolver = SimpleSolverAPI(underlyings) + + debugS("Check called on " + expr + "...") + + val template = getTemplate(expr) + + val aVar : Identifier = template.activatingBool + var allClauses : Seq[Expr] = Nil + var allBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty + + def fullOpenExpr : Expr = { + // And(Variable(aVar), And(allClauses.reverse)) + // Let's help the poor underlying guy a little bit... + // Note that I keep aVar around, because it may negate one of the blockers, and we can't miss that! + And(Variable(aVar), replace(Map(Variable(aVar) -> BooleanLiteral(true)), And(allClauses.reverse))) + } + + def fullClosedExpr : Expr = { + val blockedVars : Seq[Expr] = allBlockers.toSeq.map(p => Variable(p._1)) + + And( + replace(blockedVars.map(v => (v -> BooleanLiteral(false))).toMap, fullOpenExpr), + And(blockedVars.map(Not(_))) + ) + } + + def unrollOneStep() { + val blockersBefore = allBlockers + + var newClauses : List[Seq[Expr]] = Nil + var newBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty + + for(blocker <- allBlockers.keySet; FunctionInvocation(funDef, args) <- allBlockers(blocker)) { + val (nc, nb) = getTemplate(funDef).instantiate(blocker, args) + newClauses = nc :: newClauses + newBlockers = newBlockers ++ nb + } + + allClauses = newClauses.flatten ++ allClauses + allBlockers = newBlockers + } + + val (nc, nb) = template.instantiate(aVar, template.funDef.args.map(a => Variable(a.id))) + + allClauses = nc.reverse + allBlockers = nb + + var unrollingCount : Int = 0 + var done : Boolean = false + var result : Option[Boolean] = None + + // We're now past the initial step. + while(!done && unrollingCount < maxUnrollings) { + debugS("At lvl : " + unrollingCount) + val closed : Expr = fullClosedExpr + + debugS("Going for SAT with this:\n" + closed) + + simpleSolver.solveSAT(closed) match { + case (Some(false), _) => + val open = fullOpenExpr + debugS("Was UNSAT... Going for UNSAT with this:\n" + open) + simpleSolver.solveSAT(open) match { + case (Some(false), _) => + debugS("Was UNSAT... Done !") + done = true + result = Some(false) + + case _ => + debugS("Was SAT or UNKNOWN. Let's unroll !") + unrollingCount += 1 + unrollOneStep() + } + + case (Some(true), model) => + debugS("WAS SAT ! We're DONE !") + done = true + result = Some(true) + theModel = Some(model) + + case (None, model) => + debugS("WAS UNKNOWN ! We're DONE !") + done = true + result = Some(true) + theModel = Some(model) + } + } + result + + } getOrElse { + Some(true) + } + + def getModel : Map[Identifier,Expr] = { + val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) + theModel.getOrElse(Map.empty).filter(p => vs(p._1)) + } + + private val funDefTemplateCache : MutableMap[FunDef, FunctionTemplate] = MutableMap.empty + private val exprTemplateCache : MutableMap[Expr, FunctionTemplate] = MutableMap.empty + + private def getTemplate(funDef: FunDef): FunctionTemplate = { + funDefTemplateCache.getOrElse(funDef, { + val res = FunctionTemplate.mkTemplate(funDef, true) + funDefTemplateCache += funDef -> res + res + }) + } + + private def getTemplate(body: Expr): FunctionTemplate = { + exprTemplateCache.getOrElse(body, { + val fakeFunDef = new FunDef(FreshIdentifier("fake", true), body.getType, variablesOf(body).toSeq.map(id => VarDecl(id, id.getType))) + fakeFunDef.body = Some(body) + + val res = FunctionTemplate.mkTemplate(fakeFunDef, false) + exprTemplateCache += body -> res + res + }) + } +} diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolverFactory.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolverFactory.scala deleted file mode 100644 index b2f8862f6..000000000 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolverFactory.scala +++ /dev/null @@ -1,190 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers -package combinators - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Trees._ -import purescala.TypeTrees._ -import purescala.TreeOps._ - -import scala.collection.mutable.{Map=>MutableMap} - -class UnrollingSolverFactory[S <: Solver](val sf : SolverFactory[S]) extends SolverFactory[Solver] { - val description = "Unrolling loop around a base solver." - val name = sf.name + "*" - - val context = sf.context - val program = sf.program - - // Yes, a hardcoded constant. Sue me. - val MAXUNROLLINGS : Int = 3 - - private val thisFactory = this - - override def free() { - sf.free() - } - - override def recoverInterrupt() { - sf.recoverInterrupt() - } - - def getNewSolver() : Solver = { - new Solver { - private var theConstraint : Option[Expr] = None - private var theModel : Option[Map[Identifier,Expr]] = None - - private def fail(because : String) : Nothing = { - throw new Exception("Not supported in UnrollingSolvers : " + because) - } - - def push() : Unit = fail("push()") - def pop(lvl : Int = 1) : Unit = fail("pop(lvl)") - - def assertCnstr(expression : Expr) { - if(!theConstraint.isEmpty) { - fail("Multiple assertCnstr(...).") - } - theConstraint = Some(expression) - } - - def interrupt() { fail("interrupt()") } - - def recoverInterrupt() { fail("recoverInterrupt()") } - - def check : Option[Boolean] = theConstraint.map { expr => - import context.reporter - - val simpleSolver = SimpleSolverAPI(sf) - - def info(msg : String) { reporter.info("In " + thisFactory.name + ": " + msg) } - - info("Check called on " + expr + "...") - - val template = getTemplate(expr) - - val aVar : Identifier = template.activatingBool - var allClauses : Seq[Expr] = Nil - var allBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty - - def fullOpenExpr : Expr = { - // And(Variable(aVar), And(allClauses.reverse)) - // Let's help the poor underlying guy a little bit... - // Note that I keep aVar around, because it may negate one of the blockers, and we can't miss that! - And(Variable(aVar), replace(Map(Variable(aVar) -> BooleanLiteral(true)), And(allClauses.reverse))) - } - - def fullClosedExpr : Expr = { - val blockedVars : Seq[Expr] = allBlockers.toSeq.map(p => Variable(p._1)) - - And( - replace(blockedVars.map(v => (v -> BooleanLiteral(false))).toMap, fullOpenExpr), - And(blockedVars.map(Not(_))) - ) - } - - def unrollOneStep() { - val blockersBefore = allBlockers - - var newClauses : List[Seq[Expr]] = Nil - var newBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty - - for(blocker <- allBlockers.keySet; FunctionInvocation(funDef, args) <- allBlockers(blocker)) { - val (nc, nb) = getTemplate(funDef).instantiate(blocker, args) - newClauses = nc :: newClauses - newBlockers = newBlockers ++ nb - } - - allClauses = newClauses.flatten ++ allClauses - allBlockers = newBlockers - } - - val (nc, nb) = template.instantiate(aVar, template.funDef.args.map(a => Variable(a.id))) - - allClauses = nc.reverse - allBlockers = nb - - var unrollingCount : Int = 0 - var done : Boolean = false - var result : Option[Boolean] = None - - // We're now past the initial step. - while(!done && unrollingCount < MAXUNROLLINGS) { - info("At lvl : " + unrollingCount) - val closed : Expr = fullClosedExpr - - info("Going for SAT with this:\n" + closed) - - simpleSolver.solveSAT(closed) match { - case (Some(false), _) => - val open = fullOpenExpr - info("Was UNSAT... Going for UNSAT with this:\n" + open) - simpleSolver.solveSAT(open) match { - case (Some(false), _) => - info("Was UNSAT... Done !") - done = true - result = Some(false) - - case _ => - info("Was SAT or UNKNOWN. Let's unroll !") - unrollingCount += 1 - unrollOneStep() - } - - case (Some(true), model) => - info("WAS SAT ! We're DONE !") - done = true - result = Some(true) - theModel = Some(model) - - case (None, model) => - info("WAS UNKNOWN ! We're DONE !") - done = true - result = Some(true) - theModel = Some(model) - } - } - result - - } getOrElse { - Some(true) - } - - def checkAssumptions(assumptions : Set[Expr]) : Option[Boolean] = { - fail("checkAssumptions(assumptions)") - } - - def getModel : Map[Identifier,Expr] = { - val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) - theModel.getOrElse(Map.empty).filter(p => vs(p._1)) - } - - def getUnsatCore : Set[Expr] = { fail("getUnsatCore") } - } - } - - private val funDefTemplateCache : MutableMap[FunDef, FunctionTemplate] = MutableMap.empty - private val exprTemplateCache : MutableMap[Expr, FunctionTemplate] = MutableMap.empty - - private def getTemplate(funDef: FunDef): FunctionTemplate = { - funDefTemplateCache.getOrElse(funDef, { - val res = FunctionTemplate.mkTemplate(funDef, true) - funDefTemplateCache += funDef -> res - res - }) - } - - private def getTemplate(body: Expr): FunctionTemplate = { - exprTemplateCache.getOrElse(body, { - val fakeFunDef = new FunDef(FreshIdentifier("fake", true), body.getType, variablesOf(body).toSeq.map(id => VarDecl(id, id.getType))) - fakeFunDef.body = Some(body) - - val res = FunctionTemplate.mkTemplate(fakeFunDef, false) - exprTemplateCache += body -> res - res - }) - } -} diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 3bc561015..62ee88109 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -19,7 +19,12 @@ import scala.collection.mutable.{Set => MutableSet} // This is just to factor out the things that are common in "classes that deal // with a Z3 instance" -trait AbstractZ3Solver extends SolverFactory[Solver] { +trait AbstractZ3Solver + extends Solver + with TimeoutAssumptionSolver + with AssumptionSolver + with IncrementalSolver { + val context : LeonContext val program : Program @@ -45,20 +50,25 @@ trait AbstractZ3Solver extends SolverFactory[Solver] { override def free() { freed = true - super.free() if (z3 ne null) { z3.delete() z3 = null; } } + protected[z3] var interrupted = false; + override def interrupt() { - super.interrupt() + interrupted = true if(z3 ne null) { z3.interrupt } } + override def recoverInterrupt() { + interrupted = false + } + protected[leon] def prepareFunctions : Unit protected[leon] def functionDefToDecl(funDef: FunDef) : Z3FuncDecl protected[leon] def functionDeclToDef(decl: Z3FuncDecl) : FunDef @@ -455,13 +465,7 @@ trait AbstractZ3Solver extends SolverFactory[Solver] { case IntLiteral(v) => z3.mkInt(v, intSort) case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() case UnitLiteral => unitValue - case Equals(l, r) => { - //if(l.getType != r.getType) - // println("Warning : wrong types in equality for " + l + " == " + r) - z3.mkEq(rec( l ), rec( r ) ) - } - - //case Equals(l, r) => z3.mkEq(rec(l), rec(r)) + case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) case Minus(l, r) => z3.mkSub(rec(l), rec(r)) case Times(l, r) => z3.mkMul(rec(l), rec(r)) diff --git a/src/main/scala/leon/solvers/z3/FairZ3SolverFactory.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala similarity index 58% rename from src/main/scala/leon/solvers/z3/FairZ3SolverFactory.scala rename to src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 7f50b3e1c..63abcb42a 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3SolverFactory.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -7,7 +7,7 @@ import leon.utils._ import z3.scala._ -import leon.solvers.Solver +import leon.solvers.{Solver, IncrementalSolver} import purescala.Common._ import purescala.Definitions._ @@ -23,7 +23,7 @@ import termination._ import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable.{Set => MutableSet} -class FairZ3SolverFactory(val context : LeonContext, val program: Program) +class FairZ3Solver(val context : LeonContext, val program: Program) extends AbstractZ3Solver with Z3ModelReconstruction with FairZ3Component { @@ -327,263 +327,240 @@ class FairZ3SolverFactory(val context : LeonContext, val program: Program) } } - def getNewSolver = new Solver { - private val evaluator = enclosing.evaluator - private val feelingLucky = enclosing.feelingLucky - private val checkModels = enclosing.checkModels - private val useCodeGen = enclosing.useCodeGen + initZ3 - initZ3 + val solver = z3.mkSolver - val solver = z3.mkSolver - - for(funDef <- program.definedFunctions) { - if (funDef.annotations.contains("axiomatize") && !axiomatizedFunctions(funDef)) { - reporter.warning("Function " + funDef.id + " was marked for axiomatization but could not be handled.") - } + for(funDef <- program.definedFunctions) { + if (funDef.annotations.contains("axiomatize") && !axiomatizedFunctions(funDef)) { + reporter.warning("Function " + funDef.id + " was marked for axiomatization but could not be handled.") } + } - private var varsInVC = Set[Identifier]() - - private var frameExpressions = List[List[Expr]](Nil) - - val unrollingBank = new UnrollingBank() + private var varsInVC = Set[Identifier]() - def push() { - solver.push() - unrollingBank.push() - frameExpressions = Nil :: frameExpressions - } + private var frameExpressions = List[List[Expr]](Nil) - override def recoverInterrupt() { - enclosing.recoverInterrupt() - } + val unrollingBank = new UnrollingBank() - override def interrupt() { - enclosing.interrupt() - } + def push() { + solver.push() + unrollingBank.push() + frameExpressions = Nil :: frameExpressions + } - def pop(lvl: Int = 1) { - solver.pop(lvl) - unrollingBank.pop(lvl) - frameExpressions = frameExpressions.drop(lvl) - } + def pop(lvl: Int = 1) { + solver.pop(lvl) + unrollingBank.pop(lvl) + frameExpressions = frameExpressions.drop(lvl) + } - def check: Option[Boolean] = { - fairCheck(Set()) - } + def innerCheck: Option[Boolean] = { + fairCheck(Set()) + } - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - fairCheck(assumptions) - } + def innerCheckAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + fairCheck(assumptions) + } - var foundDefinitiveAnswer = false - var definitiveAnswer : Option[Boolean] = None - var definitiveModel : Map[Identifier,Expr] = Map.empty - var definitiveCore : Set[Expr] = Set.empty + var foundDefinitiveAnswer = false + var definitiveAnswer : Option[Boolean] = None + var definitiveModel : Map[Identifier,Expr] = Map.empty + var definitiveCore : Set[Expr] = Set.empty - def assertCnstr(expression: Expr) { - varsInVC ++= variablesOf(expression) + def assertCnstr(expression: Expr) { + varsInVC ++= variablesOf(expression) - frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail + frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail - val newClauses = unrollingBank.scanForNewTemplates(expression) + val newClauses = unrollingBank.scanForNewTemplates(expression) - for (cl <- newClauses) { - solver.assertCnstr(cl) - } + for (cl <- newClauses) { + solver.assertCnstr(cl) } + } - def getModel = { - definitiveModel - } + def getModel = { + definitiveModel + } - def getUnsatCore = { - definitiveCore - } + def getUnsatCore = { + definitiveCore + } - def fairCheck(assumptions: Set[Expr]): Option[Boolean] = { - foundDefinitiveAnswer = false + def fairCheck(assumptions: Set[Expr]): Option[Boolean] = { + foundDefinitiveAnswer = false - def entireFormula = And(assumptions.toSeq ++ frameExpressions.flatten) + def entireFormula = And(assumptions.toSeq ++ frameExpressions.flatten) - def foundAnswer(answer : Option[Boolean], model : Map[Identifier,Expr] = Map.empty, core: Set[Expr] = Set.empty) : Unit = { - foundDefinitiveAnswer = true - definitiveAnswer = answer - definitiveModel = model - definitiveCore = core - } + def foundAnswer(answer : Option[Boolean], model : Map[Identifier,Expr] = Map.empty, core: Set[Expr] = Set.empty) : Unit = { + foundDefinitiveAnswer = true + definitiveAnswer = answer + definitiveModel = model + definitiveCore = core + } - // these are the optional sequence of assumption literals - val assumptionsAsZ3: Seq[Z3AST] = assumptions.flatMap(toZ3Formula(_)).toSeq - val assumptionsAsZ3Set: Set[Z3AST] = assumptionsAsZ3.toSet + // these are the optional sequence of assumption literals + val assumptionsAsZ3: Seq[Z3AST] = assumptions.flatMap(toZ3Formula(_)).toSeq + val assumptionsAsZ3Set: Set[Z3AST] = assumptionsAsZ3.toSet - def z3CoreToCore(core: Seq[Z3AST]): Set[Expr] = { - core.filter(assumptionsAsZ3Set).map(ast => fromZ3Formula(null, ast, None) match { - case n @ Not(Variable(_)) => n - case v @ Variable(_) => v - case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") - }).toSet - } + def z3CoreToCore(core: Seq[Z3AST]): Set[Expr] = { + core.filter(assumptionsAsZ3Set).map(ast => fromZ3Formula(null, ast, None) match { + case n @ Not(Variable(_)) => n + case v @ Variable(_) => v + case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") + }).toSet + } - while(!foundDefinitiveAnswer && !interrupted) { + while(!foundDefinitiveAnswer && !interrupted) { - //val blockingSetAsZ3 : Seq[Z3AST] = blockingSet.toSeq.map(toZ3Formula(_).get) - // println("Blocking set : " + blockingSet) + //val blockingSetAsZ3 : Seq[Z3AST] = blockingSet.toSeq.map(toZ3Formula(_).get) + // println("Blocking set : " + blockingSet) - reporter.debug(" - Running Z3 search...") + reporter.debug(" - Running Z3 search...") - // reporter.debug("Searching in:\n"+solver.getAssertions.toSeq.mkString("\nAND\n")) - // reporter.debug("Unroll. Assumptions:\n"+unrollingBank.z3CurrentZ3Blockers.mkString(" && ")) - // reporter.debug("Userland Assumptions:\n"+assumptionsAsZ3.mkString(" && ")) + // reporter.debug("Searching in:\n"+solver.getAssertions.toSeq.mkString("\nAND\n")) + // reporter.debug("Unroll. Assumptions:\n"+unrollingBank.z3CurrentZ3Blockers.mkString(" && ")) + // reporter.debug("Userland Assumptions:\n"+assumptionsAsZ3.mkString(" && ")) - solver.push() // FIXME: remove when z3 bug is fixed - val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.z3CurrentZ3Blockers) :_*) - solver.pop() // FIXME: remove when z3 bug is fixed + solver.push() // FIXME: remove when z3 bug is fixed + val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.z3CurrentZ3Blockers) :_*) + solver.pop() // FIXME: remove when z3 bug is fixed - reporter.debug(" - Finished search with blocked literals") + reporter.debug(" - Finished search with blocked literals") - res match { - case None => - // reporter.warning("Z3 doesn't know because: " + z3.getSearchFailure.message) - reporter.warning("Z3 doesn't know because ??") - foundAnswer(None) + res match { + case None => + // reporter.warning("Z3 doesn't know because: " + z3.getSearchFailure.message) + reporter.warning("Z3 doesn't know because ??") + foundAnswer(None) - case Some(true) => // SAT + case Some(true) => // SAT - val z3model = solver.getModel + val z3model = solver.getModel - if (this.checkModels) { - val (isValid, model) = validateModel(z3model, entireFormula, varsInVC, silenceErrors = false) + if (this.checkModels) { + val (isValid, model) = validateModel(z3model, entireFormula, varsInVC, silenceErrors = false) - if (isValid) { - foundAnswer(Some(true), model) - } else { - reporter.error("Something went wrong. The model should have been valid, yet we got this : ") - reporter.error(model) - foundAnswer(None, model) - } + if (isValid) { + foundAnswer(Some(true), model) } else { - val model = modelToMap(z3model, varsInVC) + reporter.error("Something went wrong. The model should have been valid, yet we got this : ") + reporter.error(model) + foundAnswer(None, model) + } + } else { + val model = modelToMap(z3model, varsInVC) - //lazy val modelAsString = model.toList.map(p => p._1 + " -> " + p._2).mkString("\n") - //reporter.debug("- Found a model:") - //reporter.debug(modelAsString) + //lazy val modelAsString = model.toList.map(p => p._1 + " -> " + p._2).mkString("\n") + //reporter.debug("- Found a model:") + //reporter.debug(modelAsString) - foundAnswer(Some(true), model) - } + foundAnswer(Some(true), model) + } - case Some(false) if !unrollingBank.canUnroll => + case Some(false) if !unrollingBank.canUnroll => - val core = z3CoreToCore(solver.getUnsatCore) + val core = z3CoreToCore(solver.getUnsatCore) - foundAnswer(Some(false), core = core) + foundAnswer(Some(false), core = core) - // This branch is both for with and without unsat cores. The - // distinction is made inside. - case Some(false) => + // This branch is both for with and without unsat cores. The + // distinction is made inside. + case Some(false) => - val z3Core = solver.getUnsatCore + val z3Core = solver.getUnsatCore - def coreElemToBlocker(c: Z3AST): (Z3AST, Boolean) = { - z3.getASTKind(c) match { - case Z3AppAST(decl, args) => - z3.getDeclKind(decl) match { - case Z3DeclKind.OpNot => - (args(0), true) - case Z3DeclKind.OpUninterpreted => - (c, false) - } + def coreElemToBlocker(c: Z3AST): (Z3AST, Boolean) = { + z3.getASTKind(c) match { + case Z3AppAST(decl, args) => + z3.getDeclKind(decl) match { + case Z3DeclKind.OpNot => + (args(0), true) + case Z3DeclKind.OpUninterpreted => + (c, false) + } - case ast => - (c, false) - } + case ast => + (c, false) } + } - if (unrollUnsatCores) { - unrollingBank.decreaseAllGenerations() - - for (c <- solver.getUnsatCore) { - val (z3ast, pol) = coreElemToBlocker(c) - assert(pol == true) + if (unrollUnsatCores) { + unrollingBank.decreaseAllGenerations() - unrollingBank.promoteBlocker(z3ast) - } + for (c <- solver.getUnsatCore) { + val (z3ast, pol) = coreElemToBlocker(c) + assert(pol == true) + unrollingBank.promoteBlocker(z3ast) } - //debug("UNSAT BECAUSE: "+solver.getUnsatCore.mkString("\n AND \n")) - //debug("UNSAT BECAUSE: "+core.mkString(" AND ")) + } - if (!interrupted) { - if (this.feelingLucky) { - // we need the model to perform the additional test - reporter.debug(" - Running search without blocked literals (w/ lucky test)") - } else { - reporter.debug(" - Running search without blocked literals (w/o lucky test)") - } + //debug("UNSAT BECAUSE: "+solver.getUnsatCore.mkString("\n AND \n")) + //debug("UNSAT BECAUSE: "+core.mkString(" AND ")) + + if (!interrupted) { + if (this.feelingLucky) { + // we need the model to perform the additional test + reporter.debug(" - Running search without blocked literals (w/ lucky test)") + } else { + reporter.debug(" - Running search without blocked literals (w/o lucky test)") + } - solver.push() // FIXME: remove when z3 bug is fixed - val res2 = solver.checkAssumptions(assumptionsAsZ3 : _*) - solver.pop() // FIXME: remove when z3 bug is fixed - - res2 match { - case Some(false) => - //reporter.debug("UNSAT WITHOUT Blockers") - foundAnswer(Some(false), core = z3CoreToCore(solver.getUnsatCore)) - case Some(true) => - //reporter.debug("SAT WITHOUT Blockers") - if (this.feelingLucky && !interrupted) { - // we might have been lucky :D - val (wereWeLucky, cleanModel) = validateModel(solver.getModel, entireFormula, varsInVC, silenceErrors = true) - - if(wereWeLucky) { - foundAnswer(Some(true), cleanModel) - } + solver.push() // FIXME: remove when z3 bug is fixed + val res2 = solver.checkAssumptions(assumptionsAsZ3 : _*) + solver.pop() // FIXME: remove when z3 bug is fixed + + res2 match { + case Some(false) => + //reporter.debug("UNSAT WITHOUT Blockers") + foundAnswer(Some(false), core = z3CoreToCore(solver.getUnsatCore)) + case Some(true) => + //reporter.debug("SAT WITHOUT Blockers") + if (this.feelingLucky && !interrupted) { + // we might have been lucky :D + val (wereWeLucky, cleanModel) = validateModel(solver.getModel, entireFormula, varsInVC, silenceErrors = true) + + if(wereWeLucky) { + foundAnswer(Some(true), cleanModel) } + } - case None => - foundAnswer(None) - } + case None => + foundAnswer(None) } + } - if(interrupted) { - foundAnswer(None) - } + if(interrupted) { + foundAnswer(None) + } - if(!foundDefinitiveAnswer) { - reporter.debug("- We need to keep going.") + if(!foundDefinitiveAnswer) { + reporter.debug("- We need to keep going.") - val toRelease = unrollingBank.getZ3BlockersToUnlock + val toRelease = unrollingBank.getZ3BlockersToUnlock - reporter.debug(" - more unrollings") + reporter.debug(" - more unrollings") - for(id <- toRelease) { - val newClauses = unrollingBank.unlock(id) + for(id <- toRelease) { + val newClauses = unrollingBank.unlock(id) - for(ncl <- newClauses) { - solver.assertCnstr(ncl) - } + for(ncl <- newClauses) { + solver.assertCnstr(ncl) } - - reporter.debug(" - finished unrolling") } - } - } - //reporter.debug(" !! DONE !! ") - - if(interrupted) { - None - } else { - definitiveAnswer + reporter.debug(" - finished unrolling") + } } } - if (program == null) { - reporter.error("Z3 Solver was not initialized with a PureScala Program.") + if(interrupted) { None + } else { + definitiveAnswer } } - } diff --git a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala index 4e97ddada..5a3f2e2f4 100644 --- a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala +++ b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala @@ -19,7 +19,7 @@ import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap} case class Z3FunctionInvocation(funDef: FunDef, args: Seq[Z3AST]) class FunctionTemplate private( - solver: FairZ3SolverFactory, + solver: FairZ3Solver, val funDef : FunDef, activatingBool : Identifier, condVars : Set[Identifier], @@ -152,7 +152,7 @@ class FunctionTemplate private( object FunctionTemplate { val splitAndOrImplies = false - def mkTemplate(solver: FairZ3SolverFactory, funDef: FunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { + def mkTemplate(solver: FairZ3Solver, funDef: FunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { val condVars : MutableSet[Identifier] = MutableSet.empty val exprVars : MutableSet[Identifier] = MutableSet.empty diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3SolverFactory.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala similarity index 56% rename from src/main/scala/leon/solvers/z3/UninterpretedZ3SolverFactory.scala rename to src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala index 6589e3b50..78b75ee5e 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3SolverFactory.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala @@ -21,12 +21,10 @@ import purescala.TypeTrees._ * - otherwise it returns UNKNOWN * Results should come back very quickly. */ -class UninterpretedZ3SolverFactory(val context : LeonContext, val program: Program) +class UninterpretedZ3Solver(val context : LeonContext, val program: Program) extends AbstractZ3Solver with Z3ModelReconstruction { - enclosing => - val name = "Z3-u" val description = "Uninterpreted Z3 Solver" @@ -57,65 +55,56 @@ class UninterpretedZ3SolverFactory(val context : LeonContext, val program: Progr protected[leon] def functionDeclToDef(decl: Z3FuncDecl) : FunDef = reverseFunctionMap(decl) protected[leon] def isKnownDecl(decl: Z3FuncDecl) : Boolean = reverseFunctionMap.isDefinedAt(decl) - def getNewSolver = new Solver { - initZ3 - - val solver = z3.mkSolver + initZ3 - def push() { - solver.push - } + val solver = z3.mkSolver - def interrupt() { - enclosing.interrupt() - } + def push() { + solver.push + } - def recoverInterrupt() { - enclosing.recoverInterrupt() - } - def pop(lvl: Int = 1) { - solver.pop(lvl) - } + def pop(lvl: Int = 1) { + solver.pop(lvl) + } - private var variables = Set[Identifier]() - private var containsFunCalls = false + private var variables = Set[Identifier]() + private var containsFunCalls = false - def assertCnstr(expression: Expr) { - variables ++= variablesOf(expression) - containsFunCalls ||= containsFunctionCalls(expression) - solver.assertCnstr(toZ3Formula(expression).get) - } + def assertCnstr(expression: Expr) { + variables ++= variablesOf(expression) + containsFunCalls ||= containsFunctionCalls(expression) + solver.assertCnstr(toZ3Formula(expression).get) + } - def check: Option[Boolean] = { - solver.check match { - case Some(true) => - if (containsFunCalls) { - None - } else { - Some(true) - } - - case r => - r - } + def innerCheck: Option[Boolean] = { + solver.check match { + case Some(true) => + if (containsFunCalls) { + None + } else { + Some(true) + } + + case r => + r } + } - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - variables ++= assumptions.flatMap(variablesOf(_)) - solver.checkAssumptions(assumptions.toSeq.map(toZ3Formula(_).get) : _*) - } + def innerCheckAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + variables ++= assumptions.flatMap(variablesOf(_)) + solver.checkAssumptions(assumptions.toSeq.map(toZ3Formula(_).get) : _*) + } - def getModel = { - modelToMap(solver.getModel, variables) - } + def getModel = { + modelToMap(solver.getModel, variables) + } - def getUnsatCore = { - solver.getUnsatCore.map(ast => fromZ3Formula(null, ast, None) match { - case n @ Not(Variable(_)) => n - case v @ Variable(_) => v - case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") - }).toSet - } + def getUnsatCore = { + solver.getUnsatCore.map(ast => fromZ3Formula(null, ast, None) match { + case n @ Not(Variable(_)) => n + case v @ Variable(_) => v + case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") + }).toSet } } diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 69d83cd55..782a72049 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -8,6 +8,7 @@ import purescala.TypeTrees.{TypeTree,TupleType} import purescala.Definitions._ import purescala.TreeOps._ import solvers.z3._ +import solvers._ // Defines a synthesis solution of the form: // ⟨ P | T ⟩ @@ -30,7 +31,7 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) { defs.foldLeft(term){ case (t, fd) => LetDef(fd, t) } def toSimplifiedExpr(ctx: LeonContext, p: Program): Expr = { - val uninterpretedZ3 = new UninterpretedZ3SolverFactory(ctx, p) + val uninterpretedZ3 = SolverFactory(() => new UninterpretedZ3Solver(ctx, p)) val simplifiers = List[Expr => Expr]( simplifyTautologies(uninterpretedZ3)(_), diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala index 4269f4175..adf89f1cc 100644 --- a/src/main/scala/leon/synthesis/SynthesisContext.scala +++ b/src/main/scala/leon/synthesis/SynthesisContext.scala @@ -20,17 +20,22 @@ case class SynthesisContext( reporter: Reporter ) { - def solverFactory: SolverFactory[Solver] = { - new FairZ3SolverFactory(context, program) + def newSolver: SynthesisContext.SynthesisSolver = { + new FairZ3Solver(context, program) } - def fastSolverFactory: SolverFactory[Solver] = { - new UninterpretedZ3SolverFactory(context, program) + def newFastSolver: SynthesisContext.SynthesisSolver = { + new UninterpretedZ3Solver(context, program) } + val solverFactory = SolverFactory(() => newSolver) + val fastSolverFactory = SolverFactory(() => newFastSolver) + } object SynthesisContext { + type SynthesisSolver = TimeoutAssumptionSolver with IncrementalSolver + def fromSynthesizer(synth: Synthesizer) = { SynthesisContext( synth.context, diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index b138fcc8a..78e5f20a2 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -81,10 +81,10 @@ class Synthesizer(val context : LeonContext, val (npr, fds) = solutionToProgram(sol) - val tsolver = new TimeoutSolverFactory(new FairZ3SolverFactory(context, npr), timeoutMs) + val solverf = SolverFactory(() => new FairZ3Solver(context, npr).setTimeout(timeoutMs)) val vcs = generateVerificationConditions(reporter, npr, fds.map(_.id.name)) - val vctx = VerificationContext(context, Seq(tsolver), context.reporter) + val vctx = VerificationContext(context, Seq(solverf), context.reporter) val vcreport = checkVerificationConditions(vctx, vcs) if (vcreport.totalValid == vcreport.totalConditions) { diff --git a/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala b/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala index beea692e6..dc557e56f 100755 --- a/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala +++ b/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala @@ -37,13 +37,14 @@ import verification._ import SynthesisInfo._ import SynthesisInfo.Action._ +import SynthesisContext.SynthesisSolver // enable postfix operations import scala.language.postfixOps class SynthesizerForRuleExamples( // some synthesis instance information - val mainSolver: SolverFactory[Solver], + val mainSolver: SolverFactory[SynthesisSolver], val program: Program, val desiredType: LeonType, val holeFunDef: FunDef, diff --git a/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala b/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala index 49cacd1f3..e6a486fa7 100755 --- a/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala +++ b/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala @@ -24,7 +24,7 @@ case object ConditionAbductionSynthesisTwoPhase extends Rule("Condition abductio List(new RuleInstantiation(p, this, SolutionBuilder.none, "Condition abduction") { def apply(sctx: SynthesisContext): RuleApplicationResult = { try { - val solver = sctx.solverFactory.withTimeout(500L) + val solver = SolverFactory(() => sctx.newSolver.setTimeout(500L)) val program = sctx.program val reporter = sctx.reporter diff --git a/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala b/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala index b9922e255..cf0625a5a 100644 --- a/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala +++ b/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala @@ -13,7 +13,7 @@ import purescala.Definitions._ import _root_.insynth.util.logging._ -abstract class AbstractVerifier(solverf: SolverFactory[Solver], p: Problem, synthInfo: SynthesisInfo) +abstract class AbstractVerifier(solverf: SolverFactory[Solver with IncrementalSolver with TimeoutSolver], p: Problem, synthInfo: SynthesisInfo) extends HasLogger { val solver = solverf.getNewSolver diff --git a/src/main/scala/leon/synthesis/condabd/verification/RelaxedVerifier.scala b/src/main/scala/leon/synthesis/condabd/verification/RelaxedVerifier.scala index 2dfa87843..2b3820915 100644 --- a/src/main/scala/leon/synthesis/condabd/verification/RelaxedVerifier.scala +++ b/src/main/scala/leon/synthesis/condabd/verification/RelaxedVerifier.scala @@ -10,10 +10,11 @@ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ +import SynthesisContext.SynthesisSolver import _root_.insynth.util.logging._ -class RelaxedVerifier(solverf: SolverFactory[Solver], p: Problem, synthInfo: SynthesisInfo = new SynthesisInfo) +class RelaxedVerifier(solverf: SolverFactory[SynthesisSolver], p: Problem, synthInfo: SynthesisInfo = new SynthesisInfo) extends AbstractVerifier(solverf, p, synthInfo) with HasLogger { var _isTimeoutUsed = false diff --git a/src/main/scala/leon/synthesis/condabd/verification/Verifier.scala b/src/main/scala/leon/synthesis/condabd/verification/Verifier.scala index c1c74a38a..54ff32154 100644 --- a/src/main/scala/leon/synthesis/condabd/verification/Verifier.scala +++ b/src/main/scala/leon/synthesis/condabd/verification/Verifier.scala @@ -10,10 +10,11 @@ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ +import SynthesisContext.SynthesisSolver import _root_.insynth.util.logging._ -class Verifier(solverf: SolverFactory[Solver], p: Problem, synthInfo: SynthesisInfo = new SynthesisInfo) +class Verifier(solverf: SolverFactory[SynthesisSolver], p: Problem, synthInfo: SynthesisInfo = new SynthesisInfo) extends AbstractVerifier(solverf, p, synthInfo) with HasLogger { import SynthesisInfo.Action._ diff --git a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala b/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala index 35cad0762..724191dd0 100644 --- a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala +++ b/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala @@ -14,14 +14,10 @@ import purescala.Definitions._ case object ADTInduction extends Rule("ADT Induction") with Heuristic { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val tsolver = sctx.solverFactory.withTimeout(500L) - val candidates = p.as.collect { - case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(tsolver)(p.pc, origId) => (origId, cd) + case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(sctx.solverFactory)(p.pc, origId) => (origId, cd) } - tsolver.free() - val instances = for (candidate <- candidates) yield { val (origId, cd) = candidate val oas = p.as.filterNot(_ == origId) diff --git a/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala b/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala index e11f09448..0305450d8 100644 --- a/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala +++ b/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala @@ -14,13 +14,10 @@ import purescala.Definitions._ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val tsolver = sctx.solverFactory.withTimeout(500L) val candidates = p.as.collect { - case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(tsolver)(p.pc, origId) => (origId, cd) + case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(sctx.solverFactory)(p.pc, origId) => (origId, cd) } - tsolver.free() - val instances = for (candidate <- candidates) yield { val (origId, cd) = candidate val oas = p.as.filterNot(_ == origId) diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 261b37143..93f584ccd 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -14,7 +14,7 @@ import solvers._ case object ADTSplit extends Rule("ADT Split.") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation]= { - val solver = SimpleSolverAPI(sctx.solverFactory.withTimeout(200L)) + val solver = SimpleSolverAPI(new TimeoutSolverFactory(sctx.solverFactory, 200L)) val candidates = p.as.collect { case IsTyped(id, AbstractClassType(cd)) => @@ -46,8 +46,6 @@ case object ADTSplit extends Rule("ADT Split.") { } } - solver.free() - candidates.collect{ _ match { case Some((id, cases)) => val oas = p.as.filter(_ != id) diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index c1c53fc00..b62c45e97 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -441,20 +441,20 @@ case object CEGIS extends Rule("CEGIS") { var unrolings = 0 val maxUnrolings = 3 - val exSolver = sctx.solverFactory.withTimeout(3000L) // 3sec - val cexSolver = sctx.solverFactory.withTimeout(3000L) // 3sec + val exSolverTo = 3000L + val cexSolverTo = 3000L - try { - var baseExampleInputs: Seq[Seq[Expr]] = Seq() + var baseExampleInputs: Seq[Seq[Expr]] = Seq() - // We populate the list of examples with a predefined one - if (p.pc == BooleanLiteral(true)) { - baseExampleInputs = p.as.map(a => simplestValue(a.getType)) +: baseExampleInputs - } else { - val solver = exSolver.getNewSolver + // We populate the list of examples with a predefined one + if (p.pc == BooleanLiteral(true)) { + baseExampleInputs = p.as.map(a => simplestValue(a.getType)) +: baseExampleInputs + } else { + val solver = sctx.newSolver.setTimeout(exSolverTo) - solver.assertCnstr(p.pc) + solver.assertCnstr(p.pc) + try { solver.check match { case Some(true) => val model = solver.getModel @@ -467,36 +467,39 @@ case object CEGIS extends Rule("CEGIS") { sctx.reporter.warning("Solver could not solve path-condition") return RuleApplicationImpossible // This is not necessary though, but probably wanted } - - } - - val inputIterator: Iterator[Seq[Expr]] = if (useVanuatoo) { - new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, 20, 3000) - } else { - new NaiveDataGen(sctx.context, sctx.program, evaluator).generateFor(p.as, p.pc, 20, 1000) + } finally { + solver.free() } + } - val cachedInputIterator = new Iterator[Seq[Expr]] { - def next() = { - val i = inputIterator.next() - baseExampleInputs = i +: baseExampleInputs - i - } + val inputIterator: Iterator[Seq[Expr]] = if (useVanuatoo) { + new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, 20, 3000) + } else { + new NaiveDataGen(sctx.context, sctx.program, evaluator).generateFor(p.as, p.pc, 20, 1000) + } - def hasNext() = inputIterator.hasNext + val cachedInputIterator = new Iterator[Seq[Expr]] { + def next() = { + val i = inputIterator.next() + baseExampleInputs = i +: baseExampleInputs + i } - def hasInputExamples() = baseExampleInputs.size > 0 || cachedInputIterator.hasNext + def hasNext() = inputIterator.hasNext + } + + def hasInputExamples() = baseExampleInputs.size > 0 || cachedInputIterator.hasNext - def allInputExamples() = baseExampleInputs.iterator ++ cachedInputIterator + def allInputExamples() = baseExampleInputs.iterator ++ cachedInputIterator - def checkForPrograms(programs: Set[Set[Identifier]]): RuleApplicationResult = { - for (prog <- programs) { - val expr = ndProgram.determinize(prog) - val res = Equals(Tuple(p.xs.map(Variable(_))), expr) - val solver3 = cexSolver.getNewSolver - solver3.assertCnstr(And(p.pc :: res :: Not(p.phi) :: Nil)) + def checkForPrograms(programs: Set[Set[Identifier]]): RuleApplicationResult = { + for (prog <- programs) { + val expr = ndProgram.determinize(prog) + val res = Equals(Tuple(p.xs.map(Variable(_))), expr) + val solver3 = sctx.newSolver.setTimeout(cexSolverTo) + solver3.assertCnstr(And(p.pc :: res :: Not(p.phi) :: Nil)) + try { solver3.check match { case Some(false) => return RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = true) @@ -505,249 +508,254 @@ case object CEGIS extends Rule("CEGIS") { case Some(true) => // invalid program, we skip } + } finally { + solver3.free() } - - RuleApplicationImpossible } - // Keep track of collected cores to filter programs to test - var collectedCores = Set[Set[Identifier]]() - - val initExClause = And(p.pc :: p.phi :: Variable(initGuard) :: Nil) - val initCExClause = And(p.pc :: Not(p.phi) :: Variable(initGuard) :: Nil) + RuleApplicationImpossible + } - // solver1 is used for the initial SAT queries - var solver1 = exSolver.getNewSolver - solver1.assertCnstr(initExClause) + // Keep track of collected cores to filter programs to test + var collectedCores = Set[Set[Identifier]]() - // solver2 is used for validating a candidate program, or finding new inputs - var solver2 = cexSolver.getNewSolver - solver2.assertCnstr(initCExClause) + val initExClause = And(p.pc :: p.phi :: Variable(initGuard) :: Nil) + val initCExClause = And(p.pc :: Not(p.phi) :: Variable(initGuard) :: Nil) - var didFilterAlready = false + // solver1 is used for the initial SAT queries + var solver1 = sctx.newSolver.setTimeout(exSolverTo) + solver1.assertCnstr(initExClause) - val tpe = TupleType(p.xs.map(_.getType)) + // solver2 is used for validating a candidate program, or finding new inputs + var solver2 = sctx.newSolver.setTimeout(cexSolverTo) + solver2.assertCnstr(initCExClause) - try { - do { - var needMoreUnrolling = false + var didFilterAlready = false - var bssAssumptions = Set[Identifier]() + val tpe = TupleType(p.xs.map(_.getType)) - if (!didFilterAlready) { - val (clauses, closedBs) = ndProgram.unroll + try { + do { + var needMoreUnrolling = false - bssAssumptions = closedBs + var bssAssumptions = Set[Identifier]() - sctx.reporter.ifDebug { debug => - debug("UNROLLING: ") - for (c <- clauses) { - debug(" - " + c) - } - debug("CLOSED Bs "+closedBs) - } + if (!didFilterAlready) { + val (clauses, closedBs) = ndProgram.unroll - val clause = And(clauses) + bssAssumptions = closedBs - solver1.assertCnstr(clause) - solver2.assertCnstr(clause) + sctx.reporter.ifDebug { debug => + debug("UNROLLING: ") + for (c <- clauses) { + debug(" - " + c) + } + debug("CLOSED Bs "+closedBs) } - // Compute all programs that have not been excluded yet - var prunedPrograms: Set[Set[Identifier]] = if (useCEPruning) { - ndProgram.allPrograms.filterNot(p => collectedCores.exists(c => c.subsetOf(p))) - } else { - Set() - } + val clause = And(clauses) - val allPrograms = prunedPrograms.size + solver1.assertCnstr(clause) + solver2.assertCnstr(clause) + } - sctx.reporter.debug("#Programs: "+prunedPrograms.size) + // Compute all programs that have not been excluded yet + var prunedPrograms: Set[Set[Identifier]] = if (useCEPruning) { + ndProgram.allPrograms.filterNot(p => collectedCores.exists(c => c.subsetOf(p))) + } else { + Set() + } - // We further filter the set of working programs to remove those that fail on known examples - if (useCEPruning && hasInputExamples() && ndProgram.canTest()) { + val allPrograms = prunedPrograms.size - for (p <- prunedPrograms) { - if (!allInputExamples().forall(ndProgram.testForProgram(p))) { - // This program failed on at least one example - solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq))) - prunedPrograms -= p - } - } + sctx.reporter.debug("#Programs: "+prunedPrograms.size) - if (prunedPrograms.isEmpty) { - needMoreUnrolling = true - } + // We further filter the set of working programs to remove those that fail on known examples + if (useCEPruning && hasInputExamples() && ndProgram.canTest()) { - //println("Passing tests: "+prunedPrograms.size) + for (p <- prunedPrograms) { + if (!allInputExamples().forall(ndProgram.testForProgram(p))) { + // This program failed on at least one example + solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq))) + prunedPrograms -= p + } } - val nPassing = prunedPrograms.size - - sctx.reporter.debug("#Programs passing tests: "+nPassing) - - if (nPassing == 0) { - needMoreUnrolling = true; - } else if (nPassing <= testUpTo) { - // Immediate Test - result = Some(checkForPrograms(prunedPrograms)) - } else if (((nPassing < allPrograms*filterThreshold) || didFilterAlready) && useBssFiltering) { - // We filter the Bss so that the formula we give to z3 is much smalled - val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _) - - // Cannot unroll normally after having filtered, so we need to - // repeat the filtering procedure at next unrolling. - didFilterAlready = true - - // Freshening solvers - solver1 = exSolver.getNewSolver - solver1.assertCnstr(initExClause) - solver2 = cexSolver.getNewSolver - solver2.assertCnstr(initCExClause) - - val clauses = ndProgram.filterFor(bssToKeep) - val clause = And(clauses) - - solver1.assertCnstr(clause) - solver2.assertCnstr(clause) + if (prunedPrograms.isEmpty) { + needMoreUnrolling = true } - val bss = ndProgram.bss + //println("Passing tests: "+prunedPrograms.size) + } - while (result.isEmpty && !needMoreUnrolling && !interruptManager.isInterrupted()) { + val nPassing = prunedPrograms.size + + sctx.reporter.debug("#Programs passing tests: "+nPassing) + + if (nPassing == 0) { + needMoreUnrolling = true; + } else if (nPassing <= testUpTo) { + // Immediate Test + result = Some(checkForPrograms(prunedPrograms)) + } else if (((nPassing < allPrograms*filterThreshold) || didFilterAlready) && useBssFiltering) { + // We filter the Bss so that the formula we give to z3 is much smalled + val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _) + + // Cannot unroll normally after having filtered, so we need to + // repeat the filtering procedure at next unrolling. + didFilterAlready = true + + // Freshening solvers + solver1.free() + solver1 = sctx.newSolver.setTimeout(exSolverTo) + solver1.assertCnstr(initExClause) + + solver2.free() + solver2 = sctx.newSolver.setTimeout(cexSolverTo) + solver2.assertCnstr(initCExClause) + + val clauses = ndProgram.filterFor(bssToKeep) + val clause = And(clauses) + + solver1.assertCnstr(clause) + solver2.assertCnstr(clause) + } - solver1.checkAssumptions(bssAssumptions.map(id => Not(Variable(id)))) match { - case Some(true) => - val satModel = solver1.getModel + val bss = ndProgram.bss - val bssAssumptions: Set[Expr] = bss.map(b => satModel(b) match { - case BooleanLiteral(true) => Variable(b) - case BooleanLiteral(false) => Not(Variable(b)) - }) + while (result.isEmpty && !needMoreUnrolling && !interruptManager.isInterrupted()) { - val validateWithZ3 = if (useCETests && hasInputExamples() && ndProgram.canTest()) { + solver1.checkAssumptions(bssAssumptions.map(id => Not(Variable(id)))) match { + case Some(true) => + val satModel = solver1.getModel - val p = bssAssumptions.collect { case Variable(b) => b } + val bssAssumptions: Set[Expr] = bss.map(b => satModel(b) match { + case BooleanLiteral(true) => Variable(b) + case BooleanLiteral(false) => Not(Variable(b)) + }) - if (allInputExamples().forall(ndProgram.testForProgram(p))) { - // All valid inputs also work with this, we need to - // make sure by validating this candidate with z3 - true - } else { - // One valid input failed with this candidate, we can skip - solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq))) - false - } - } else { - // No inputs or capability to test, we need to ask Z3 + val validateWithZ3 = if (useCETests && hasInputExamples() && ndProgram.canTest()) { + + val p = bssAssumptions.collect { case Variable(b) => b } + + if (allInputExamples().forall(ndProgram.testForProgram(p))) { + // All valid inputs also work with this, we need to + // make sure by validating this candidate with z3 true + } else { + // One valid input failed with this candidate, we can skip + solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq))) + false } + } else { + // No inputs or capability to test, we need to ask Z3 + true + } - if (validateWithZ3) { - solver2.checkAssumptions(bssAssumptions) match { - case Some(true) => - val invalidModel = solver2.getModel + if (validateWithZ3) { + solver2.checkAssumptions(bssAssumptions) match { + case Some(true) => + val invalidModel = solver2.getModel - val fixedAss = And(ass.collect { - case a if invalidModel contains a => Equals(Variable(a), invalidModel(a)) - }.toSeq) + val fixedAss = And(ass.collect { + case a if invalidModel contains a => Equals(Variable(a), invalidModel(a)) + }.toSeq) - val newCE = p.as.map(valuateWithModel(invalidModel)) + val newCE = p.as.map(valuateWithModel(invalidModel)) - baseExampleInputs = newCE +: baseExampleInputs + baseExampleInputs = newCE +: baseExampleInputs - // Retest whether the newly found C-E invalidates all programs - if (useCEPruning && ndProgram.canTest) { - if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(newCE))) { - needMoreUnrolling = true - } + // Retest whether the newly found C-E invalidates all programs + if (useCEPruning && ndProgram.canTest) { + if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(newCE))) { + needMoreUnrolling = true } + } - val unsatCore = if (useUnsatCores) { - solver1.push() - solver1.assertCnstr(fixedAss) + val unsatCore = if (useUnsatCores) { + solver1.push() + solver1.assertCnstr(fixedAss) - val core = solver1.checkAssumptions(bssAssumptions) match { - case Some(false) => - // Core might be empty if unrolling level is - // insufficient, it becomes unsat no matter what - // the assumptions are. - solver1.getUnsatCore + val core = solver1.checkAssumptions(bssAssumptions) match { + case Some(false) => + // Core might be empty if unrolling level is + // insufficient, it becomes unsat no matter what + // the assumptions are. + solver1.getUnsatCore - case Some(true) => - // Can't be! - bssAssumptions + case Some(true) => + // Can't be! + bssAssumptions - case None => - return RuleApplicationImpossible - } + case None => + return RuleApplicationImpossible + } - solver1.pop() + solver1.pop() - collectedCores += core.collect{ case Variable(id) => id } + collectedCores += core.collect{ case Variable(id) => id } - core - } else { - bssAssumptions - } + core + } else { + bssAssumptions + } - if (unsatCore.isEmpty) { - needMoreUnrolling = true - } else { - solver1.assertCnstr(Not(And(unsatCore.toSeq))) - } + if (unsatCore.isEmpty) { + needMoreUnrolling = true + } else { + solver1.assertCnstr(Not(And(unsatCore.toSeq))) + } - case Some(false) => + case Some(false) => - val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) + val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) - result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr))) + result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr))) - case _ => - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) - result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = false)) - } else { - return RuleApplicationImpossible - } - } + case _ => + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") + val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) + result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = false)) + } else { + return RuleApplicationImpossible + } } + } - case Some(false) => - if (useUninterpretedProbe) { - solver1.check match { - case Some(false) => - // Unsat even without blockers (under which fcalls are then uninterpreted) - return RuleApplicationImpossible + case Some(false) => + if (useUninterpretedProbe) { + solver1.check match { + case Some(false) => + // Unsat even without blockers (under which fcalls are then uninterpreted) + return RuleApplicationImpossible - case _ => - } + case _ => } + } - needMoreUnrolling = true + needMoreUnrolling = true - case _ => - // Last chance, we test first few programs - return checkForPrograms(prunedPrograms.take(testUpTo)) - } + case _ => + // Last chance, we test first few programs + return checkForPrograms(prunedPrograms.take(testUpTo)) } + } - unrolings += 1 - } while(unrolings < maxUnrolings && result.isEmpty && !interruptManager.isInterrupted()) + unrolings += 1 + } while(unrolings < maxUnrolings && result.isEmpty && !interruptManager.isInterrupted()) - result.getOrElse(RuleApplicationImpossible) + result.getOrElse(RuleApplicationImpossible) - } catch { - case e: Throwable => - sctx.reporter.warning("CEGIS crashed: "+e.getMessage) - RuleApplicationImpossible - } + } catch { + case e: Throwable => + sctx.reporter.warning("CEGIS crashed: "+e.getMessage) + RuleApplicationImpossible } finally { - exSolver.free() - cexSolver.free() + solver1.free() + solver2.free() } } }) diff --git a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala index 26867fd76..6e867ce1d 100644 --- a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala @@ -40,8 +40,6 @@ case object EqualitySplit extends Rule("Eq. Split") { case _ => false }).values.flatten - solver.free() - candidates.flatMap(_ match { case List(a1, a2) => diff --git a/src/main/scala/leon/synthesis/rules/Ground.scala b/src/main/scala/leon/synthesis/rules/Ground.scala index 31bfcefad..da7c0d651 100644 --- a/src/main/scala/leon/synthesis/rules/Ground.scala +++ b/src/main/scala/leon/synthesis/rules/Ground.scala @@ -14,7 +14,7 @@ case object Ground extends Rule("Ground") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { if (p.as.isEmpty) { - val solver = SimpleSolverAPI(sctx.solverFactory.withTimeout(5000L)) // We give that 5s + val solver = SimpleSolverAPI(new TimeoutSolverFactory(sctx.solverFactory, 5000L)) val tpe = TupleType(p.xs.map(_.getType)) @@ -29,8 +29,6 @@ case object Ground extends Rule("Ground") { None } - solver.free() - result } else { None diff --git a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala index a2a295216..b64dc5335 100644 --- a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala @@ -52,8 +52,6 @@ case object InequalitySplit extends Rule("Ineq. Split.") { case _ => false } - solver.free() - candidates.flatMap(_ match { case List(a1, a2) => diff --git a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala index 6085007e3..3e2fcfcca 100644 --- a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala +++ b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala @@ -69,8 +69,6 @@ case object OptimisticGround extends Rule("Optimistic Ground") { i += 1 } - solver.free() - result.getOrElse(RuleApplicationImpossible) } } diff --git a/src/main/scala/leon/testgen/CallGraph.scala b/src/main/scala/leon/testgen/CallGraph.scala index 3c22b2334..a8e4f367e 100644 --- a/src/main/scala/leon/testgen/CallGraph.scala +++ b/src/main/scala/leon/testgen/CallGraph.scala @@ -168,7 +168,7 @@ class CallGraph(val program: Program) { fd.annotations.exists(_ == "main") } - def findAllPaths(z3Solver: FairZ3SolverFactory): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + def findAllPaths(z3Solverf: SolverFactory[Solver]): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { val waypoints: Set[ProgramPoint] = programPoints.filter{ case ExpressionPoint(Waypoint(_, _), _) => true case _ => false } val sortedWaypoints: Seq[ProgramPoint] = waypoints.toSeq.sortWith((p1, p2) => { val (ExpressionPoint(Waypoint(i1, _), _), ExpressionPoint(Waypoint(i2, _), _)) = (p1, p2) @@ -183,7 +183,7 @@ class CallGraph(val program: Program) { if(sortedWaypoints.size == 0) { findSimplePaths(mainPoint.get) } else { - visitAllWaypoints(mainPoint.get :: sortedWaypoints.toList, z3Solver) match { + visitAllWaypoints(mainPoint.get :: sortedWaypoints.toList, z3Solverf) match { case None => Set() case Some(p) => Set(p) } @@ -194,7 +194,7 @@ class CallGraph(val program: Program) { } } - def visitAllWaypoints(waypoints: List[ProgramPoint], z3Solver: FairZ3SolverFactory): Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + def visitAllWaypoints(waypoints: List[ProgramPoint], z3Solverf: SolverFactory[Solver]): Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { def rec(head: ProgramPoint, tail: List[ProgramPoint], path: Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]): Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { tail match { @@ -204,11 +204,10 @@ class CallGraph(val program: Program) { var completePath: Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = None allPaths.find(intermediatePath => { val pc = pathConstraint(path ++ intermediatePath) - z3Solver.restartZ3 var testcase: Option[Map[Identifier, Expr]] = None - val (solverResult, model) = SimpleSolverAPI(z3Solver).solveSAT(pc) + val (solverResult, model) = SimpleSolverAPI(z3Solverf).solveSAT(pc) solverResult match { case None => { false diff --git a/src/main/scala/leon/testgen/TestGeneration.scala b/src/main/scala/leon/testgen/TestGeneration.scala index 321388e67..e08e7b172 100644 --- a/src/main/scala/leon/testgen/TestGeneration.scala +++ b/src/main/scala/leon/testgen/TestGeneration.scala @@ -27,7 +27,6 @@ class TestGeneration(context : LeonContext) { private val reporter = context.reporter def analyse(program: Program) { - val z3Solver = new FairZ3SolverFactory(context, program) reporter.info("Running test generation") val testcases = generateTestCases(program) @@ -60,11 +59,11 @@ class TestGeneration(context : LeonContext) { } def generatePathConditions(program: Program): Set[Expr] = { - val z3Solver = new FairZ3SolverFactory(context, program) + val z3Solverf = SolverFactory( () => new FairZ3Solver(context, program)) val callGraph = new CallGraph(program) callGraph.writeDotFile("testgen.dot") - val constraints = callGraph.findAllPaths(z3Solver).map(path => { + val constraints = callGraph.findAllPaths(z3Solverf).map(path => { println("Path is: " + path) val cnstr = callGraph.pathConstraint(path) println("constraint is: " + cnstr) @@ -75,15 +74,14 @@ class TestGeneration(context : LeonContext) { private def generateTestCases(program: Program): Set[Map[Identifier, Expr]] = { val allPaths = generatePathConditions(program) - val z3Solver = new FairZ3SolverFactory(context, program) + val z3Solverf = SolverFactory( () => new FairZ3Solver(context, program)) allPaths.flatMap(pathCond => { reporter.info("Now considering path condition: " + pathCond) var testcase: Option[Map[Identifier, Expr]] = None - z3Solver.restartZ3 - val (solverResult, model) = SimpleSolverAPI(z3Solver).solveSAT(pathCond) + val (solverResult, model) = SimpleSolverAPI(z3Solverf).solveSAT(pathCond) solverResult match { case None => Seq() diff --git a/src/main/scala/leon/verification/AnalysisPhase.scala b/src/main/scala/leon/verification/AnalysisPhase.scala index 3a9d16e32..76a7ea3c2 100644 --- a/src/main/scala/leon/verification/AnalysisPhase.scala +++ b/src/main/scala/leon/verification/AnalysisPhase.scala @@ -11,7 +11,6 @@ import purescala.TypeTrees._ import solvers._ import solvers.z3._ -import solvers.bapaminmax._ import solvers.combinators._ import scala.collection.mutable.{Set => MutableSet} @@ -70,62 +69,65 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { for((funDef, vcs) <- vcs.toSeq.sortWith((a,b) => a._1 < b._1); vcInfo <- vcs if !interruptManager.isInterrupted()) { val funDef = vcInfo.funDef val vc = vcInfo.condition - - val time0 : Long = System.currentTimeMillis - val time1 = System.currentTimeMillis - reporter.info("Now considering '" + vcInfo.kind + "' VC for " + funDef.id + "...") reporter.debug("Verification condition (" + vcInfo.kind + ") for ==== " + funDef.id + " ====") reporter.debug(simplifyLets(vc)) // try all solvers until one returns a meaningful answer - solvers.find(se => { - reporter.debug("Trying with solver: " + se.name) - val t1 = System.nanoTime - val (satResult, counterexample) = SimpleSolverAPI(se).solveSAT(Not(vc)) - val solverResult = satResult.map(!_) - - val t2 = System.nanoTime - val dt = ((t2 - t1) / 1000000) / 1000.0 - - solverResult match { - case _ if interruptManager.isInterrupted() => - reporter.info("=== CANCELLED ===") - vcInfo.time = Some(dt) - false - - case None => - vcInfo.time = Some(dt) - false - - case Some(true) => - reporter.info("==== VALID ====") - - vcInfo.hasValue = true - vcInfo.value = Some(true) - vcInfo.solvedWith = Some(se) - vcInfo.time = Some(dt) - true - - case Some(false) => - reporter.error("Found counter-example : ") - reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n")) - reporter.error("==== INVALID ====") + solvers.find(sf => { + val s = sf.getNewSolver + try { + reporter.debug("Trying with solver: " + s.name) + val t1 = System.nanoTime + s.assertCnstr(Not(vc)) + + val satResult = s.check + val counterexample: Map[Identifier, Expr] = if (satResult == Some(true)) s.getModel else Map() + val solverResult = satResult.map(!_) + + val t2 = System.nanoTime + val dt = ((t2 - t1) / 1000000) / 1000.0 + + solverResult match { + case _ if interruptManager.isInterrupted() => + reporter.info("=== CANCELLED ===") + vcInfo.time = Some(dt) + false + + case None => + vcInfo.time = Some(dt) + false + + case Some(true) => + reporter.info("==== VALID ====") + + vcInfo.hasValue = true + vcInfo.value = Some(true) + vcInfo.solvedWith = Some(s) + vcInfo.time = Some(dt) + true + + case Some(false) => + reporter.error("Found counter-example : ") + reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n")) + reporter.error("==== INVALID ====") + vcInfo.hasValue = true + vcInfo.value = Some(false) + vcInfo.solvedWith = Some(s) + vcInfo.counterExample = Some(counterexample) + vcInfo.time = Some(dt) + true + } + } finally { + s.free() + }}) match { + case None => { vcInfo.hasValue = true - vcInfo.value = Some(false) - vcInfo.solvedWith = Some(se) - vcInfo.counterExample = Some(counterexample) - vcInfo.time = Some(dt) - true + reporter.warning("==== UNKNOWN ====") + } + case _ => } - }) match { - case None => { - vcInfo.hasValue = true - reporter.warning("==== UNKNOWN ====") - } - case _ => - } } val report = new VerificationReport(vcs) @@ -148,33 +150,23 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { val reporter = ctx.reporter - lazy val fairZ3 = new FairZ3SolverFactory(ctx, program) - - val baseSolvers : Seq[SolverFactory[Solver]] = { - fairZ3 :: Nil - } - - val solvers: Seq[SolverFactory[Solver]] = timeout match { - case Some(t) => - baseSolvers.map(_.withTimeout(100L*t)) + val baseFactories = Seq( + SolverFactory(() => new FairZ3Solver(ctx, program)) + ) + val solverFactories = timeout match { + case Some(sec) => + baseFactories.map { sf => + new TimeoutSolverFactory(sf, sec*1000L) + } case None => - baseSolvers + baseFactories } - val vctx = VerificationContext(ctx, solvers, reporter) + val vctx = VerificationContext(ctx, solverFactories, reporter) - val report = if(solvers.size >= 1) { - reporter.debug("Running verification condition generation...") - val vcs = generateVerificationConditions(reporter, program, functionsToAnalyse) - checkVerificationConditions(vctx, vcs) - } else { - reporter.warning("No solver specified. Cannot test verification conditions.") - VerificationReport.emptyReport - } - - solvers.foreach(_.free()) - - report + reporter.debug("Running verification condition generation...") + val vcs = generateVerificationConditions(reporter, program, functionsToAnalyse) + checkVerificationConditions(vctx, vcs) } } diff --git a/src/main/scala/leon/verification/VerificationCondition.scala b/src/main/scala/leon/verification/VerificationCondition.scala index 776a4c004..e82d61cf1 100644 --- a/src/main/scala/leon/verification/VerificationCondition.scala +++ b/src/main/scala/leon/verification/VerificationCondition.scala @@ -15,7 +15,7 @@ class VerificationCondition(val condition: Expr, val funDef: FunDef, val kind: V // Some(false) = valid var hasValue = false var value : Option[Boolean] = None - var solvedWith : Option[SolverFactory[Solver]] = None + var solvedWith : Option[Solver] = None var time : Option[Double] = None var counterExample : Option[Map[Identifier, Expr]] = None diff --git a/src/test/scala/leon/test/condabd/VariableSolverRefinerTest.scala b/src/test/scala/leon/test/condabd/VariableSolverRefinerTest.scala index dc1d945ca..5fe37005b 100644 --- a/src/test/scala/leon/test/condabd/VariableSolverRefinerTest.scala +++ b/src/test/scala/leon/test/condabd/VariableSolverRefinerTest.scala @@ -86,8 +86,6 @@ class VariableSolverRefinerTest extends FunSpec with GivenWhenThen { ) (res2._1, res2._2) } - - solver.free() } } @@ -144,7 +142,6 @@ class VariableSolverRefinerTest extends FunSpec with GivenWhenThen { ) (res2._1, res2._2) } - solver.free() } } @@ -204,8 +201,6 @@ class VariableSolverRefinerTest extends FunSpec with GivenWhenThen { ) (res2._1, res2._2) } - - solver.free() } } diff --git a/src/test/scala/leon/test/condabd/VerifierTest.scala b/src/test/scala/leon/test/condabd/VerifierTest.scala index de88e347c..b5d5c3fda 100644 --- a/src/test/scala/leon/test/condabd/VerifierTest.scala +++ b/src/test/scala/leon/test/condabd/VerifierTest.scala @@ -23,133 +23,135 @@ class VerifierTest extends FunSpec { import Utils._ import Scaffold._ - val lesynthTestDir = "testcases/condabd/test/lesynth" - + val lesynthTestDir = "testcases/condabd/test/lesynth" + def getPostconditionFunction(problem: Problem) = { (list: Iterable[Identifier]) => { (problem.phi /: list) { case ((res, newId)) => - (res /: problem.as.find(_.name == newId.name)) { - case ((res, oldId)) => - TreeOps.replace(Map(Variable(oldId) -> Variable(newId)), res) - } + (res /: problem.as.find(_.name == newId.name)) { + case ((res, oldId)) => + TreeOps.replace(Map(Variable(oldId) -> Variable(newId)), res) + } } } } - - describe("Concrete verifier: ") { + + describe("Concrete verifier: ") { val testCaseFileName = lesynthTestDir + "/ListConcatVerifierTest.scala" val problems = forFile(testCaseFileName) assert(problems.size == 1) - for ((sctx, funDef, problem) <- problems) { - - val timeoutSolver = sctx.solverFactory.withTimeout(2000L) - - val getNewPostcondition = getPostconditionFunction(problem) - - describe("A Verifier") { - it("should verify first correct concat body") { - val newFunDef = getFunDefByName(sctx.program, "goodConcat1") - funDef.body = newFunDef.body + for ((sctx, funDef, problem) <- problems) { + + val timeoutSolver = SolverFactory(() => sctx.newSolver.setTimeout(2000L)) + + + val getNewPostcondition = getPostconditionFunction(problem) + + describe("A Verifier") { + it("should verify first correct concat body") { + val newFunDef = getFunDefByName(sctx.program, "goodConcat1") + funDef.body = newFunDef.body - expectResult(1) { problem.xs.size } - funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) - funDef.precondition = Some(BooleanLiteral(true)) - - val verifier = new Verifier(timeoutSolver, problem) - - assert( verifier.analyzeFunction(funDef)._1 ) - } - - it("should verify 2nd correct concat body") { - val newFunDef = getFunDefByName(sctx.program, "goodConcat2") - funDef.body = newFunDef.body + expectResult(1) { problem.xs.size } + funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) + funDef.precondition = Some(BooleanLiteral(true)) + + val verifier = new Verifier(timeoutSolver, problem) + + assert( verifier.analyzeFunction(funDef)._1 ) + verifier.solver.free() + } + + it("should verify 2nd correct concat body") { + val newFunDef = getFunDefByName(sctx.program, "goodConcat2") + funDef.body = newFunDef.body - expectResult(1) { problem.xs.size } - funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) - funDef.precondition = Some(BooleanLiteral(true)) - - val verifier = new Verifier(timeoutSolver, problem) - - assert( verifier.analyzeFunction(funDef)._1 ) - } - - it("should not verify an incorrect concat body") { - val newFunDef = getFunDefByName(sctx.program, "badConcat1") - funDef.body = newFunDef.body + expectResult(1) { problem.xs.size } + funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) + funDef.precondition = Some(BooleanLiteral(true)) + + val verifier = new Verifier(timeoutSolver, problem) + + assert( verifier.analyzeFunction(funDef)._1 ) + verifier.solver.free() + } + + it("should not verify an incorrect concat body") { + val newFunDef = getFunDefByName(sctx.program, "badConcat1") + funDef.body = newFunDef.body - expectResult(1) { problem.xs.size } - funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) - funDef.precondition = Some(BooleanLiteral(true)) - - val verifier = new Verifier(timeoutSolver, problem) - - assert( ! verifier.analyzeFunction(funDef)._1 ) - } - } - - timeoutSolver.free - } - } + expectResult(1) { problem.xs.size } + funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) + funDef.precondition = Some(BooleanLiteral(true)) + + val verifier = new Verifier(timeoutSolver, problem) + + assert( ! verifier.analyzeFunction(funDef)._1 ) + verifier.solver.free() + } + } + } + } def getPreconditionFunction(problem: Problem) = { (list: Iterable[Identifier]) => { (problem.pc /: list) { case ((res, newId)) => - (res /: problem.as.find(_.name == newId.name)) { - case ((res, oldId)) => - TreeOps.replace(Map(Variable(oldId) -> Variable(newId)), res) - } + (res /: problem.as.find(_.name == newId.name)) { + case ((res, oldId)) => + TreeOps.replace(Map(Variable(oldId) -> Variable(newId)), res) + } } } } - describe("Relaxed verifier: ") { + describe("Relaxed verifier: ") { val testCaseFileName = lesynthTestDir + "/BinarySearchTree.scala" val problems = forFile(testCaseFileName) assert(problems.size == 1) - for ((sctx, funDef, problem) <- problems) { - - val timeoutSolver = sctx.solverFactory.withTimeout(1000L) - - val getNewPostcondition = getPostconditionFunction(problem) - val getNewPrecondition = getPreconditionFunction(problem) - - describe("A RelaxedVerifier on BST") { - it("should verify a correct member body") { - val newFunDef = getFunDefByName(sctx.program, "goodMember") - funDef.body = newFunDef.body + for ((sctx, funDef, problem) <- problems) { + + val timeoutSolver = SolverFactory(() => sctx.newSolver.setTimeout(1000L)) + + val getNewPostcondition = getPostconditionFunction(problem) + val getNewPrecondition = getPreconditionFunction(problem) + + describe("A RelaxedVerifier on BST") { + it("should verify a correct member body") { + val newFunDef = getFunDefByName(sctx.program, "goodMember") + funDef.body = newFunDef.body - expectResult(1) { problem.xs.size } - funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) - funDef.precondition = Some(getNewPrecondition(newFunDef.args.map(_.id))) - - val verifier = new RelaxedVerifier(timeoutSolver, problem) - - assert( verifier.analyzeFunction(funDef)._1 ) - } - - it("should not verify an incorrect member body") { - val newFunDef = getFunDefByName(sctx.program, "badMember") - funDef.body = newFunDef.body + expectResult(1) { problem.xs.size } + funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) + funDef.precondition = Some(getNewPrecondition(newFunDef.args.map(_.id))) + + val verifier = new RelaxedVerifier(timeoutSolver, problem) + + assert( verifier.analyzeFunction(funDef)._1 ) + verifier.solver.free() + } + + it("should not verify an incorrect member body") { + val newFunDef = getFunDefByName(sctx.program, "badMember") + funDef.body = newFunDef.body - expectResult(1) { problem.xs.size } - funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) - funDef.precondition = Some(getNewPrecondition(newFunDef.args.map(_.id))) - - val verifier = new Verifier(timeoutSolver, problem) - - assert( verifier.analyzeFunction(funDef)._1 ) - } - } - - timeoutSolver.free - } + expectResult(1) { problem.xs.size } + funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) + funDef.precondition = Some(getNewPrecondition(newFunDef.args.map(_.id))) + + val verifier = new Verifier(timeoutSolver, problem) + + assert( verifier.analyzeFunction(funDef)._1 ) + verifier.solver.free() + } + } + } - } + } } diff --git a/src/test/scala/leon/test/purescala/TreeOpsTests.scala b/src/test/scala/leon/test/purescala/TreeOpsTests.scala index a1d804171..f261d246b 100644 --- a/src/test/scala/leon/test/purescala/TreeOpsTests.scala +++ b/src/test/scala/leon/test/purescala/TreeOpsTests.scala @@ -16,13 +16,9 @@ import leon.solvers.z3._ class TreeOpsTests extends LeonTestSuite { test("Path-aware simplifications") { - val solver = new UninterpretedZ3SolverFactory(testContext, Program.empty) - // TODO actually testing something here would be better, sorry // PS - solver.free() - assert(true) } diff --git a/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala b/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala index 9745d3a12..edd96fdc1 100644 --- a/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala +++ b/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala @@ -12,38 +12,39 @@ import leon.purescala.Trees._ import leon.purescala.TypeTrees._ class TimeoutSolverTests extends LeonTestSuite { - private class IdioticSolver(val context : LeonContext, val program: Program) extends SolverFactory[Solver] { - enclosing => - + private class IdioticSolver(val context : LeonContext, val program: Program) extends Solver with TimeoutSolver { val name = "Idiotic" val description = "Loops" - def getNewSolver = new Solver { - def check = { - while(!interrupted) { - Thread.sleep(100) - } - None - } + var interrupted = false - def assertCnstr(e: Expr) {} + def innerCheck = { + while(!interrupted) { + Thread.sleep(100) + } + None + } - def checkAssumptions(assump: Set[Expr]) = ??? - def getModel = ??? - def getUnsatCore = ??? - def push() = ??? - def pop(lvl: Int) = ??? + def recoverInterrupt() { + interrupted = false + } - def interrupt() = enclosing.interrupt() - def recoverInterrupt() = enclosing.recoverInterrupt() + def interrupt() { + interrupted = true } + + def assertCnstr(e: Expr) = {} + + def free() {} + + def getModel = ??? } - private def getTOSolver : TimeoutSolverFactory[Solver] = { - new IdioticSolver(testContext, Program.empty).withTimeout(1000L) + private def getTOSolver : SolverFactory[Solver] = { + SolverFactory(() => new IdioticSolver(testContext, Program.empty).setTimeout(1000L)) } - private def check(sf: TimeoutSolverFactory[Solver], e: Expr): Option[Boolean] = { + private def check(sf: SolverFactory[Solver], e: Expr): Option[Boolean] = { val s = sf.getNewSolver s.assertCnstr(e) s.check diff --git a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala index cdc28175e..affcad2d2 100644 --- a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala +++ b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala @@ -53,7 +53,7 @@ class FairZ3SolverTests extends LeonTestSuite { private val y : Expr = Variable(FreshIdentifier("y").setType(Int32Type)) private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) - private val solver = SimpleSolverAPI(new FairZ3SolverFactory(testContext, minimalProgram)) + private val solver = SimpleSolverAPI(SolverFactory(() => new FairZ3Solver(testContext, minimalProgram))) private val tautology1 : Expr = BooleanLiteral(true) assertValid(solver, tautology1) @@ -102,6 +102,4 @@ class FairZ3SolverTests extends LeonTestSuite { assert(core === Set(b2)) } } - - solver.free() } diff --git a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala index e44e3697f..4d9f46185 100644 --- a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala +++ b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala @@ -20,9 +20,13 @@ class FairZ3SolverTestsNewAPI extends LeonTestSuite { test("Solver test #" + testCounter) { val sub = solver.getNewSolver - sub.assertCnstr(Not(expr)) + try { + sub.assertCnstr(Not(expr)) - assert(sub.check === expected.map(!_), msg) + assert(sub.check === expected.map(!_), msg) + } finally { + sub.free() + } } } @@ -57,7 +61,7 @@ class FairZ3SolverTestsNewAPI extends LeonTestSuite { private val y : Expr = Variable(FreshIdentifier("y").setType(Int32Type)) private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) - private val solver = new FairZ3SolverFactory(testContext, minimalProgram) + private val solver = SolverFactory(() => new FairZ3Solver(testContext, minimalProgram)) private val tautology1 : Expr = BooleanLiteral(true) assertValid(solver, tautology1) @@ -91,28 +95,38 @@ class FairZ3SolverTestsNewAPI extends LeonTestSuite { locally { val sub = solver.getNewSolver - sub.assertCnstr(f) - assert(sub.check === Some(true)) + try { + sub.assertCnstr(f) + assert(sub.check === Some(true)) + } finally { + sub.free() + } } locally { val sub = solver.getNewSolver - sub.assertCnstr(f) - val result = sub.checkAssumptions(Set(b1)) - - assert(result === Some(true)) - assert(sub.getUnsatCore.isEmpty) + try { + sub.assertCnstr(f) + val result = sub.checkAssumptions(Set(b1)) + + assert(result === Some(true)) + assert(sub.getUnsatCore.isEmpty) + } finally { + sub.free() + } } locally { val sub = solver.getNewSolver - sub.assertCnstr(f) - - val result = sub.checkAssumptions(Set(b1, b2)) - assert(result === Some(false)) - assert(sub.getUnsatCore === Set(b2)) + try { + sub.assertCnstr(f) + + val result = sub.checkAssumptions(Set(b1, b2)) + assert(result === Some(false)) + assert(sub.getUnsatCore === Set(b2)) + } finally { + sub.free() + } } } - - solver.free() } diff --git a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala index 050c5a8ef..84c7589e8 100644 --- a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala +++ b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala @@ -58,7 +58,7 @@ class UninterpretedZ3SolverTests extends LeonTestSuite { private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) private def g(e : Expr) : Expr = FunctionInvocation(gDef, e :: Nil) - private val solver = SimpleSolverAPI(new UninterpretedZ3SolverFactory(testContext, minimalProgram)) + private val solver = SimpleSolverAPI(SolverFactory(() => new UninterpretedZ3Solver(testContext, minimalProgram))) private val tautology1 : Expr = BooleanLiteral(true) assertValid(solver, tautology1) @@ -89,6 +89,4 @@ class UninterpretedZ3SolverTests extends LeonTestSuite { solver.solveVALID(Equals(g(x), g(x))) } } - - solver.free() } -- GitLab