From f374e736205cd5206399d6effcb8229d0cf8b042 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <colder@php.net> Date: Thu, 5 Sep 2013 16:34:46 +0200 Subject: [PATCH] Refactor Solvers - We now explicitly create them from SolverFactories - SolveSAT/solve/solveWithModel/etc.. is not only available through the SimpleSolverAPI() wrapper.o - Remove mostly unused/useless solvers --- .../scala/leon/codegen/CompilationUnit.scala | 2 +- src/main/scala/leon/purescala/TreeOps.scala | 30 ++-- .../leon/solvers/IncrementalSolver.scala | 82 ---------- .../leon/solvers/InterruptibleSolver.scala | 9 -- .../scala/leon/solvers/RandomSolver.scala | 147 ------------------ .../scala/leon/solvers/SimpleSolverAPI.scala | 40 +++++ src/main/scala/leon/solvers/Solver.scala | 56 ++----- .../scala/leon/solvers/SolverFactory.scala | 52 +++++++ .../scala/leon/solvers/TimeoutSolver.scala | 69 +++----- .../scala/leon/solvers/TrivialSolver.scala | 23 --- .../leon/solvers/z3/AbstractZ3Solver.scala | 26 ++-- ...Solver.scala => FairZ3SolverFactory.scala} | 77 +++------ .../leon/solvers/z3/FunctionTemplate.scala | 4 +- ...ala => UninterpretedZ3SolverFactory.scala} | 62 ++++---- .../scala/leon/synthesis/ParallelSearch.scala | 15 +- .../scala/leon/synthesis/SimpleSearch.scala | 6 +- src/main/scala/leon/synthesis/Solution.scala | 5 +- .../leon/synthesis/SynthesisContext.scala | 10 +- .../scala/leon/synthesis/SynthesisPhase.scala | 3 +- .../scala/leon/synthesis/Synthesizer.scala | 16 +- .../synthesis/heuristics/ADTInduction.scala | 4 +- .../heuristics/ADTLongInduction.scala | 4 +- .../scala/leon/synthesis/rules/ADTSplit.scala | 4 +- .../scala/leon/synthesis/rules/Cegis.scala | 9 +- .../leon/synthesis/rules/EqualitySplit.scala | 4 +- .../scala/leon/synthesis/rules/Ground.scala | 4 +- .../synthesis/rules/InequalitySplit.scala | 4 +- .../synthesis/rules/OptimisticGround.scala | 4 +- .../leon/synthesis/utils/Benchmarks.scala | 29 ++-- src/main/scala/leon/testgen/CallGraph.scala | 11 +- .../scala/leon/testgen/TestGeneration.scala | 64 +------- .../scala/leon/utils/InterruptManager.scala | 10 ++ src/main/scala/leon/utils/Interruptible.scala | 1 + .../leon/verification/AnalysisPhase.scala | 105 ++++++------- .../verification/VerificationCondition.scala | 4 +- .../verification/VerificationContext.scala | 4 +- .../leon/test/purescala/TreeOpsTests.scala | 9 +- .../test/solvers/TimeoutSolverTests.scala | 57 ++++--- .../test/solvers/z3/FairZ3SolverTests.scala | 17 +- .../solvers/z3/FairZ3SolverTestsNewAPI.scala | 16 +- .../z3/UninterpretedZ3SolverTests.scala | 19 ++- .../leon/test/synthesis/SynthesisSuite.scala | 6 +- 42 files changed, 413 insertions(+), 710 deletions(-) delete mode 100644 src/main/scala/leon/solvers/IncrementalSolver.scala delete mode 100644 src/main/scala/leon/solvers/InterruptibleSolver.scala delete mode 100644 src/main/scala/leon/solvers/RandomSolver.scala create mode 100644 src/main/scala/leon/solvers/SimpleSolverAPI.scala create mode 100644 src/main/scala/leon/solvers/SolverFactory.scala delete mode 100644 src/main/scala/leon/solvers/TrivialSolver.scala rename src/main/scala/leon/solvers/z3/{FairZ3Solver.scala => FairZ3SolverFactory.scala} (92%) rename src/main/scala/leon/solvers/z3/{UninterpretedZ3Solver.scala => UninterpretedZ3SolverFactory.scala} (74%) diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index bbb79d1b3..99976d4c7 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -169,7 +169,7 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi case Int32Type | BooleanType => ch << IRETURN - case UnitType | TupleType(_) | SetType(_) | MapType(_, _) | AbstractClassType(_) | CaseClassType(_) | ArrayType(_) => + case UnitType | _: TupleType | _: SetType | _: MapType | _: AbstractClassType | _: CaseClassType | _: ArrayType => ch << ARETURN case other => diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 32e377b4a..3479c8192 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -3,7 +3,7 @@ package leon package purescala -import leon.solvers.Solver +import leon.solvers._ import scala.collection.concurrent.TrieMap @@ -1133,17 +1133,19 @@ object TreeOps { simplePreTransform(pre)(e) } - def simplifyTautologies(solver : Solver)(expr : Expr) : Expr = { + def simplifyTautologies(sf: SolverFactory[Solver])(expr : Expr) : Expr = { + val solver = SimpleSolverAPI(sf) + def pre(e : Expr) = e match { case LetDef(fd, expr) if fd.hasPrecondition => val pre = fd.precondition.get - solver.solve(pre) match { + solver.solveVALID(pre) match { case Some(true) => fd.precondition = None - case Some(false) => solver.solve(Not(pre)) match { + case Some(false) => solver.solveVALID(Not(pre)) match { case Some(true) => fd.precondition = Some(BooleanLiteral(false)) case _ => @@ -1155,9 +1157,9 @@ object TreeOps { case IfExpr(cond, thenn, elze) => try { - solver.solve(cond) match { + solver.solveVALID(cond) match { case Some(true) => thenn - case Some(false) => solver.solve(Not(cond)) match { + case Some(false) => solver.solveVALID(Not(cond)) match { case Some(true) => elze case _ => e } @@ -1174,8 +1176,8 @@ object TreeOps { simplePreTransform(pre)(expr) } - def simplifyPaths(solver : Solver): Expr => Expr = { - new SimplifierWithPaths(solver).transform _ + def simplifyPaths(sf: SolverFactory[Solver]): Expr => Expr = { + new SimplifierWithPaths(sf).transform _ } trait Transformer { @@ -1267,15 +1269,17 @@ object TreeOps { } } - class SimplifierWithPaths(solver: Solver) extends TransformerWithPC { + class SimplifierWithPaths(sf: SolverFactory[Solver]) extends TransformerWithPC { type C = List[Expr] val initC = Nil + val solver = SimpleSolverAPI(sf) + protected def register(e: Expr, c: C) = e :: c def impliedBy(e : Expr, path : Seq[Expr]) : Boolean = try { - solver.solve(Implies(And(path), e)) match { + solver.solveVALID(Implies(And(path), e)) match { case Some(true) => true case _ => false } @@ -1284,7 +1288,7 @@ object TreeOps { } def contradictedBy(e : Expr, path : Seq[Expr]) : Boolean = try { - solver.solve(Implies(And(path), Not(e))) match { + solver.solveVALID(Implies(And(path), Not(e))) match { case Some(true) => true case _ => false } @@ -1766,7 +1770,7 @@ object TreeOps { case e => (None, e) } - def isInductiveOn(solver: Solver)(expr: Expr, on: Identifier): Boolean = on match { + def isInductiveOn(sf: SolverFactory[Solver])(expr: Expr, on: Identifier): Boolean = on match { case IsTyped(origId, AbstractClassType(cd)) => def isAlternativeRecursive(cd: CaseClassDef): Boolean = { cd.fieldsIds.exists(_.getType == origId.getType) @@ -1789,6 +1793,8 @@ object TreeOps { } }.flatten + val solver = SimpleSolverAPI(sf) + toCheck.forall { cond => solver.solveSAT(cond) match { case (Some(false), _) => diff --git a/src/main/scala/leon/solvers/IncrementalSolver.scala b/src/main/scala/leon/solvers/IncrementalSolver.scala deleted file mode 100644 index 700699e53..000000000 --- a/src/main/scala/leon/solvers/IncrementalSolver.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers - -import purescala.Common._ -import purescala.Definitions._ -import purescala.TreeOps._ -import purescala.Trees._ - -trait IncrementalSolverBuilder { - def getNewSolver: IncrementalSolver -} - -trait IncrementalSolver extends InterruptibleSolver { - // New Solver API - // Moslty for z3 solvers since z3 4.3 - - def push(): Unit - def pop(lvl: Int = 1): Unit - def assertCnstr(expression: Expr): Unit - - def halt(): Unit - def init(): Unit = {} - - def check: Option[Boolean] - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] - def getModel: Map[Identifier, Expr] - def getUnsatCore: Set[Expr] -} - -trait NaiveIncrementalSolver extends IncrementalSolverBuilder { - def halt(): Unit - def solveSAT(e: Expr): (Option[Boolean], Map[Identifier, Expr]) - - def getNewSolver = new IncrementalSolver { - private var stack = List[List[Expr]]() - - def push() { - stack = Nil :: stack - } - - def pop(lvl: Int = 1) { - stack = stack.drop(lvl) - } - - def halt() { - NaiveIncrementalSolver.this.halt() - } - - def assertCnstr(expression: Expr) { - stack = (expression :: stack.head) :: stack.tail - } - - private def allConstraints() = stack.flatten - - private var unsatCore = Set[Expr]() - - def check: Option[Boolean] = { - solveSAT(And(allConstraints()))._1 - } - - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - solveSAT(And((assumptions ++ allConstraints()).toSeq))._1 match { - case Some(true) => - unsatCore = Set[Expr]() - Some(true) - case r => - unsatCore = assumptions.toSet - r - } - } - - def getModel: Map[Identifier, Expr] = { - Map[Identifier, Expr]() - } - - def getUnsatCore: Set[Expr] = { - unsatCore - } - } -} diff --git a/src/main/scala/leon/solvers/InterruptibleSolver.scala b/src/main/scala/leon/solvers/InterruptibleSolver.scala deleted file mode 100644 index e10e8ded9..000000000 --- a/src/main/scala/leon/solvers/InterruptibleSolver.scala +++ /dev/null @@ -1,9 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers - -trait InterruptibleSolver { - def halt(): Unit - def init(): Unit -} diff --git a/src/main/scala/leon/solvers/RandomSolver.scala b/src/main/scala/leon/solvers/RandomSolver.scala deleted file mode 100644 index 750270625..000000000 --- a/src/main/scala/leon/solvers/RandomSolver.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Trees._ -import purescala.TreeOps._ -import purescala.TypeTrees._ - -import evaluators._ - -import scala.util.Random - -@deprecated("Unused, Untested, Unmaintained", "") -class RandomSolver(context: LeonContext, val nbTrial: Option[Int] = None) extends Solver(context) with NaiveIncrementalSolver { - require(nbTrial.forall(i => i >= 0)) - - private val reporter = context.reporter - private var evaluator : Evaluator = null - - override def setProgram(program : Program) : Unit = { - evaluator = new DefaultEvaluator(context, program) - } - - val name = "QC" - val description = "Solver applying random testing (QuickCheck-like)" - - private val random = new Random() - - private def randomType(): TypeTree = { - random.nextInt(2) match { - case 0 => Int32Type - case 1 => BooleanType - } - } - - private def randomValue(t: TypeTree, size: Int): Expr = t match { - case Int32Type => { - val s = if(size < Int.MaxValue) size + 1 else size - IntLiteral(random.nextInt(s)) - } - case BooleanType => BooleanLiteral(random.nextBoolean()) - case AbstractClassType(acd) => { - val children = acd.knownChildren - if(size <= 0 || random.nextInt(size) == 0) { - val terminalChildren = children.filter{ - case CaseClassDef(_, _, fields) => fields.isEmpty - case _ => false - } - if(terminalChildren.isEmpty) { //Then we need to filter children with no adt as fields - val terminalChildren2 = children.filter{ - case CaseClassDef(_, _, fields) => fields.forall(f => !f.getType.isInstanceOf[AbstractClassType]) - case _ => false - } - CaseClass(terminalChildren2(random.nextInt(terminalChildren2.size)).asInstanceOf[CaseClassDef], Seq()) - } else - CaseClass(terminalChildren(random.nextInt(terminalChildren.size)).asInstanceOf[CaseClassDef], Seq()) - } else { - val nonTerminalChildren = children.filter{ - case CaseClassDef(_, _, fields) => !fields.isEmpty - case _ => false - } - if(nonTerminalChildren.isEmpty) { - randomValue(classDefToClassType(children(random.nextInt(children.size))), size) - } else - randomValue(classDefToClassType( - nonTerminalChildren( - random.nextInt(nonTerminalChildren.size))), size) - } - } - case CaseClassType(cd) => { - val nbFields = cd.fields.size - CaseClass(cd, cd.fields.map(f => randomValue(f.getType, size / nbFields))) - } - case AnyType => randomValue(randomType(), size) - case SetType(base) => FiniteSet(Seq()) - case MultisetType(base) => EmptyMultiset(base) - case Untyped => sys.error("I don't know what to do") - case BottomType => sys.error("I don't know what to do") - case ListType(base) => sys.error("I don't know what to do") - case TupleType(bases) => sys.error("I don't know what to do") - case MapType(from, to) => sys.error("I don't know what to do") - case _ => sys.error("Unexpected type: " + t) - } - - def solve(expression: Expr) : Option[Boolean] = { - val vars = variablesOf(expression) - val nbVars = vars.size - - var stop = false - //bound starts at 1 since it allows to test value like 0, 1, and Leaf of class hierarchy - var bound = 1 - val maxBound = Int.MaxValue - //the threashold depends on the number of variable and the actual range given by the bound - val thresholdStep = nbVars * 4 - var threshold = thresholdStep - - var result: Option[Boolean] = None - var iteration = 0 - while(!forceStop && !stop) { - - nbTrial match { - case Some(n) => stop &&= (iteration < n) - case None => () - } - - if(iteration > threshold && bound != maxBound) { - if(bound * 4 < bound) //this is an overflow - bound = maxBound - else - bound *= 2 //exponential growth - threshold += thresholdStep - } - - val var2val: Map[Identifier, Expr] = Map(vars.map(v => (v, randomValue(v.getType, bound))).toList: _*) - //reporter.info("Trying with: " + var2val) - - val evalResult = evaluator.eval(expression, var2val) - evalResult match { - case EvaluationResults.Successful(BooleanLiteral(true)) => { - //continue trying - } - - case EvaluationResults.Successful(BooleanLiteral(false)) => { - reporter.info("Found counter example to formula: " + var2val) - result = Some(false) - stop = true - } - - case EvaluationResults.RuntimeError(_) => { - reporter.info("Input leads to runtime error: " + var2val) - result = Some(false) - stop = true - } - - // otherwise, simply continue with another assignement - case EvaluationResults.EvaluatorError(_) => ; - } - - iteration += 1 - } - result - } - -} diff --git a/src/main/scala/leon/solvers/SimpleSolverAPI.scala b/src/main/scala/leon/solvers/SimpleSolverAPI.scala new file mode 100644 index 000000000..40c1ac5a6 --- /dev/null +++ b/src/main/scala/leon/solvers/SimpleSolverAPI.scala @@ -0,0 +1,40 @@ +package leon +package solvers + +import purescala.Common._ +import purescala.Trees._ + +case class SimpleSolverAPI(sf: SolverFactory[Solver]) { + def solveVALID(expression: Expr): Option[Boolean] = { + val s = sf.getNewSolver() + s.assertCnstr(Not(expression)) + s.check.map(r => !r) + } + + def solveSAT(expression: Expr): (Option[Boolean], Map[Identifier, Expr]) = { + val s = sf.getNewSolver() + s.assertCnstr(expression) + s.check match { + case Some(true) => + (Some(true), s.getModel) + case Some(false) => + (Some(false), Map()) + case None => + (None, Map()) + } + } + + def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { + val s = sf.getNewSolver() + s.assertCnstr(expression) + s.checkAssumptions(assumptions) match { + case Some(true) => + (Some(true), s.getModel, Set()) + case Some(false) => + (Some(false), Map(), s.getUnsatCore) + case None => + (None, Map(), Set()) + } + } +} + diff --git a/src/main/scala/leon/solvers/Solver.scala b/src/main/scala/leon/solvers/Solver.scala index 5cc012e74..3fc361a9c 100644 --- a/src/main/scala/leon/solvers/Solver.scala +++ b/src/main/scala/leon/solvers/Solver.scala @@ -3,55 +3,17 @@ package leon package solvers +import utils._ import purescala.Common._ -import purescala.Definitions._ -import purescala.TreeOps._ import purescala.Trees._ -abstract class Solver(val context : LeonContext) extends IncrementalSolverBuilder with InterruptibleSolver with LeonComponent { - // This can be used by solvers to "see" the programs from which the - // formulas come. (e.g. to set up some datastructures for the defined - // ADTs, etc.) - // Ideally, we would pass it at construction time and not change it later. - def setProgram(program: Program) : Unit = {} +trait Solver extends Interruptible { + def push(): Unit + def pop(lvl: Int = 1): Unit + def assertCnstr(expression: Expr): Unit - // Returns Some(true) if valid, Some(false) if invalid, - // None if unknown. - // should halt as soon as possible with any result (Unknown is ok) as soon as forceStop is true - def solve(expression: Expr) : Option[Boolean] - - def solveSAT(expression: Expr): (Option[Boolean], Map[Identifier, Expr]) = { - solve(Not(expression)) match { - case Some(true) => - (Some(false), Map()) - case Some(false) => - (Some(true), Map()) - case None => - (None, Map()) - } - } - - def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { - solveSAT(And(expression +: assumptions.toSeq)) match { - case (Some(false), _) => - (Some(false), Map(), assumptions) - case (r, m) => - (r, m, Set()) - } - } - - def superseeds : Seq[String] = Nil - - private var _forceStop = false - - def halt() : Unit = { - _forceStop = true - } - - def init() : Unit = { - _forceStop = false - } - - protected def forceStop = _forceStop + def check: Option[Boolean] + def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] + def getModel: Map[Identifier, Expr] + def getUnsatCore: Set[Expr] } - diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala new file mode 100644 index 000000000..6e848e99d --- /dev/null +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -0,0 +1,52 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import utils._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.TreeOps._ +import purescala.Trees._ + +trait SolverFactory[S <: Solver] extends Interruptible with LeonComponent { + val context: LeonContext + val program: Program + + var freed = false + val traceE = new Exception() + + override def finalize() { + if (!freed) { + //println("!! Solver not freed properly prior to GC:") + //traceE.printStackTrace() + free() + } + } + + def free() { + freed = true + } + + var interrupted = false + + override def interrupt() { + interrupted = true + } + + override def recoverInterrupt() { + interrupted = false + } + + def getNewSolver(): S + + def withTimeout(ms: Long): TimeoutSolverFactory[S] = { + this match { + case tsf: TimeoutSolverFactory[S] => + // Unwrap/Rewrap to take only new timeout into account + new TimeoutSolverFactory[S](tsf.sf, ms) + case _ => + new TimeoutSolverFactory[S](this, ms) + } + } +} diff --git a/src/main/scala/leon/solvers/TimeoutSolver.scala b/src/main/scala/leon/solvers/TimeoutSolver.scala index 207b3118d..ba406948d 100644 --- a/src/main/scala/leon/solvers/TimeoutSolver.scala +++ b/src/main/scala/leon/solvers/TimeoutSolver.scala @@ -10,9 +10,13 @@ import purescala.TypeTrees._ import scala.sys.error -class TimeoutSolver(solver : Solver with IncrementalSolverBuilder, timeoutMs : Long) extends Solver(solver.context) with IncrementalSolverBuilder { - // I'm making this an inner class to fight the temptation of using it for anything meaningful. - // We have Akka, these days, which whould be better in any respect for non-trivial things. +class TimeoutSolverFactory[S <: Solver](val sf: SolverFactory[S], val timeoutMs: Long) extends SolverFactory[Solver] { + val description = sf.description + ", with "+timeoutMs+"ms timeout" + val name = sf.name + "+to" + + val context = sf.context + val program = sf.program + private class Timer(onTimeout: => Unit) extends Thread { private var keepRunning = true private val asMillis : Long = timeoutMs @@ -32,60 +36,35 @@ class TimeoutSolver(solver : Solver with IncrementalSolverBuilder, timeoutMs : } } - def halt : Unit = { + def finishedRunning : Unit = { keepRunning = false } } - def withTimeout[T](solver: InterruptibleSolver)(body: => T): T = { + def withTimeout[T](solver: S)(body: => T): T = { val timer = new Timer(timeout(solver)) timer.start val res = body - timer.halt + timer.finishedRunning recoverFromTimeout(solver) res } var reachedTimeout = false - def timeout(solver: InterruptibleSolver) { - solver.halt + def timeout(solver: S) { + solver.interrupt() reachedTimeout = true } - def recoverFromTimeout(solver: InterruptibleSolver) { + def recoverFromTimeout(solver: S) { if (reachedTimeout) { - solver.init + solver.recoverInterrupt() reachedTimeout = false } } - val description = solver.description + ", with timeout" - val name = solver.name + "+to" - - override def setProgram(prog: Program): Unit = { - solver.setProgram(prog) - } - - def solve(expression: Expr) : Option[Boolean] = { - withTimeout(solver) { - solver.solve(expression) - } - } - - override def solveSAT(expression: Expr): (Option[Boolean], Map[Identifier, Expr]) = { - withTimeout(solver) { - solver.solveSAT(expression) - } - } - - override def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { - withTimeout(solver) { - solver.solveSATWithCores(expression, assumptions) - } - } - - def getNewSolver = new IncrementalSolver { - val solver = TimeoutSolver.this.solver.getNewSolver + def getNewSolver = new Solver { + val solver = sf.getNewSolver def push(): Unit = { solver.push() @@ -99,8 +78,12 @@ class TimeoutSolver(solver : Solver with IncrementalSolverBuilder, timeoutMs : solver.assertCnstr(expression) } - def halt(): Unit = { - solver.halt() + def interrupt() { + solver.interrupt() + } + + def recoverInterrupt() { + solver.recoverInterrupt() } def check: Option[Boolean] = { @@ -123,12 +106,4 @@ class TimeoutSolver(solver : Solver with IncrementalSolverBuilder, timeoutMs : solver.getUnsatCore } } - - override def init() { - solver.init - } - - override def halt() { - solver.halt - } } diff --git a/src/main/scala/leon/solvers/TrivialSolver.scala b/src/main/scala/leon/solvers/TrivialSolver.scala deleted file mode 100644 index c1ae27d26..000000000 --- a/src/main/scala/leon/solvers/TrivialSolver.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Trees._ -import purescala.TypeTrees._ - -class TrivialSolver(context: LeonContext) extends Solver(context) with NaiveIncrementalSolver { - val name = "trivial" - val description = "Solver for syntactically trivial formulas" - - def solve(expression: Expr) : Option[Boolean] = expression match { - case BooleanLiteral(v) => Some(v) - case Not(BooleanLiteral(v)) => Some(!v) - case Or(exs) if exs.contains(BooleanLiteral(true)) => Some(true) - case And(exs) if exs.contains(BooleanLiteral(false)) => Some(false) - case Equals(l,r) if l == r => Some(true) - case _ => None - } -} diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 0003f17cf..2f83d9272 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -6,6 +6,7 @@ package solvers.z3 import leon.utils._ import z3.scala._ +import solvers._ import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ @@ -18,25 +19,32 @@ import scala.collection.mutable.{Set => MutableSet} // This is just to factor out the things that are common in "classes that deal // with a Z3 instance" -trait AbstractZ3Solver extends solvers.IncrementalSolverBuilder with Interruptible { - self: leon.solvers.Solver => - +trait AbstractZ3Solver extends SolverFactory[Solver] { val context : LeonContext + val program : Program + protected[z3] val reporter : Reporter = context.reporter context.interruptManager.registerForInterrupts(this) - def interrupt() { - halt() - } class CantTranslateException(t: Z3AST) extends Exception("Can't translate from Z3 tree: " + t) protected[leon] val z3cfg : Z3Config protected[leon] var z3 : Z3Context = null - protected[leon] var program : Program = null - override def setProgram(prog: Program): Unit = { - program = prog + override def free() { + super.free() + if (z3 ne null) { + z3.delete() + z3 = null; + } + } + + override def interrupt() { + super.interrupt() + if(z3 ne null) { + z3.interrupt + } } protected[leon] def prepareFunctions : Unit diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3SolverFactory.scala similarity index 92% rename from src/main/scala/leon/solvers/z3/FairZ3Solver.scala rename to src/main/scala/leon/solvers/z3/FairZ3SolverFactory.scala index 2930067d0..ca49ca533 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3SolverFactory.scala @@ -23,9 +23,8 @@ import termination._ import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable.{Set => MutableSet} -class FairZ3Solver(context : LeonContext) - extends Solver(context) - with AbstractZ3Solver +class FairZ3SolverFactory(val context : LeonContext, val program: Program) + extends AbstractZ3Solver with Z3ModelReconstruction with FairZ3Component { @@ -53,25 +52,19 @@ class FairZ3Solver(context : LeonContext) (lucky, check, codegen, evalground, unrollUnsatCores) } - private var evaluator : Evaluator = null - protected[z3] def getEvaluator : Evaluator = evaluator - - private var terminator : TerminationChecker = null - protected[z3] def getTerminator : TerminationChecker = terminator - - override def setProgram(prog : Program) { - super.setProgram(prog) - - evaluator = if(useCodeGen) { + private val evaluator : Evaluator = if(useCodeGen) { // TODO If somehow we could not recompile each time we create a solver, // that would be good? - new CodeGenEvaluator(context, prog) + new CodeGenEvaluator(context, program) } else { - new DefaultEvaluator(context, prog) + new DefaultEvaluator(context, program) } - terminator = new SimpleTerminationChecker(context, prog) - } + protected[z3] def getEvaluator : Evaluator = evaluator + + private val terminator : TerminationChecker = new SimpleTerminationChecker(context, program) + + protected[z3] def getTerminator : TerminationChecker = terminator // This is fixed. protected[leon] val z3cfg = new Z3Config( @@ -109,33 +102,8 @@ class FairZ3Solver(context : LeonContext) } } - override def solve(vc: Expr) = { - val solver = getNewSolver - solver.assertCnstr(Not(vc)) - solver.check.map(!_) - } - - override def solveSAT(vc : Expr) : (Option[Boolean],Map[Identifier,Expr]) = { - val solver = getNewSolver - solver.assertCnstr(vc) - (solver.check, solver.getModel) - } - - override def halt() { - super.halt - if(z3 ne null) { - z3.interrupt - } - } - - override def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { - val solver = getNewSolver - solver.assertCnstr(expression) - (solver.checkAssumptions(assumptions), solver.getModel, solver.getUnsatCore) - } - private def validateModel(model: Z3Model, formula: Expr, variables: Set[Identifier], silenceErrors: Boolean) : (Boolean, Map[Identifier,Expr]) = { - if(!forceStop) { + if(!interrupted) { val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap val functionsAsMap: Map[Identifier, Expr] = functionsModel.flatMap(p => { @@ -362,7 +330,7 @@ class FairZ3Solver(context : LeonContext) } } - def getNewSolver = new solvers.IncrementalSolver { + def getNewSolver = new Solver { private val evaluator = enclosing.evaluator private val feelingLucky = enclosing.feelingLucky private val checkModels = enclosing.checkModels @@ -390,15 +358,12 @@ class FairZ3Solver(context : LeonContext) frameExpressions = Nil :: frameExpressions } - override def init() { - FairZ3Solver.super.init + override def recoverInterrupt() { + enclosing.recoverInterrupt() } - def halt() { - FairZ3Solver.super.halt - if(z3 ne null) { - z3.interrupt - } + override def interrupt() { + enclosing.interrupt() } def pop(lvl: Int = 1) { @@ -471,7 +436,7 @@ class FairZ3Solver(context : LeonContext) }).toSet } - while(!foundDefinitiveAnswer && !forceStop) { + while(!foundDefinitiveAnswer && !interrupted) { //val blockingSetAsZ3 : Seq[Z3AST] = blockingSet.toSeq.map(toZ3Formula(_).get) // println("Blocking set : " + blockingSet) @@ -564,7 +529,7 @@ class FairZ3Solver(context : LeonContext) //debug("UNSAT BECAUSE: "+solver.getUnsatCore.mkString("\n AND \n")) //debug("UNSAT BECAUSE: "+core.mkString(" AND ")) - if (!forceStop) { + if (!interrupted) { if (this.feelingLucky) { // we need the model to perform the additional test debug(" - Running search without blocked literals (w/ lucky test)") @@ -584,7 +549,7 @@ class FairZ3Solver(context : LeonContext) foundAnswer(Some(false), core = z3CoreToCore(solver.getUnsatCore)) case Some(true) => //debug("SAT WITHOUT Blockers") - if (this.feelingLucky && !forceStop) { + if (this.feelingLucky && !interrupted) { // we might have been lucky :D luckyTime.start val (wereWeLucky, cleanModel) = validateModel(solver.getModel, entireFormula, varsInVC, silenceErrors = true) @@ -600,7 +565,7 @@ class FairZ3Solver(context : LeonContext) } } - if(forceStop) { + if(interrupted) { foundAnswer(None) } @@ -638,7 +603,7 @@ class FairZ3Solver(context : LeonContext) //debug(" !! DONE !! ") - if(forceStop) { + if(interrupted) { None } else { definitiveAnswer diff --git a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala index 5a3f2e2f4..4e97ddada 100644 --- a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala +++ b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala @@ -19,7 +19,7 @@ import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap} case class Z3FunctionInvocation(funDef: FunDef, args: Seq[Z3AST]) class FunctionTemplate private( - solver: FairZ3Solver, + solver: FairZ3SolverFactory, val funDef : FunDef, activatingBool : Identifier, condVars : Set[Identifier], @@ -152,7 +152,7 @@ class FunctionTemplate private( object FunctionTemplate { val splitAndOrImplies = false - def mkTemplate(solver: FairZ3Solver, funDef: FunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { + def mkTemplate(solver: FairZ3SolverFactory, funDef: FunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { val condVars : MutableSet[Identifier] = MutableSet.empty val exprVars : MutableSet[Identifier] = MutableSet.empty diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3SolverFactory.scala similarity index 74% rename from src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala rename to src/main/scala/leon/solvers/z3/UninterpretedZ3SolverFactory.scala index f8d1bd492..6589e3b50 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3SolverFactory.scala @@ -5,7 +5,7 @@ package solvers.z3 import z3.scala._ -import leon.solvers.Solver +import leon.solvers._ import purescala.Common._ import purescala.Definitions._ @@ -21,7 +21,12 @@ import purescala.TypeTrees._ * - otherwise it returns UNKNOWN * Results should come back very quickly. */ -class UninterpretedZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ3Solver with Z3ModelReconstruction { +class UninterpretedZ3SolverFactory(val context : LeonContext, val program: Program) + extends AbstractZ3Solver + with Z3ModelReconstruction { + + enclosing => + val name = "Z3-u" val description = "Uninterpreted Z3 Solver" @@ -52,34 +57,7 @@ class UninterpretedZ3Solver(context : LeonContext) extends Solver(context) with protected[leon] def functionDeclToDef(decl: Z3FuncDecl) : FunDef = reverseFunctionMap(decl) protected[leon] def isKnownDecl(decl: Z3FuncDecl) : Boolean = reverseFunctionMap.isDefinedAt(decl) - override def solve(expression: Expr) : Option[Boolean] = solveSAT(Not(expression))._1.map(!_) - - // Where the solving occurs - override def solveSAT(expression : Expr) : (Option[Boolean],Map[Identifier,Expr]) = { - val solver = getNewSolver - - val emptyModel = Map.empty[Identifier,Expr] - val unknownResult = (None, emptyModel) - val unsatResult = (Some(false), emptyModel) - - solver.assertCnstr(expression) - - val result = solver.check match { - case Some(false) => unsatResult - case Some(true) => { - if(containsFunctionCalls(expression)) { - unknownResult - } else { - (Some(true), solver.getModel) - } - } - case _ => unknownResult - } - - result - } - - def getNewSolver = new solvers.IncrementalSolver { + def getNewSolver = new Solver { initZ3 val solver = z3.mkSolver @@ -88,8 +66,12 @@ class UninterpretedZ3Solver(context : LeonContext) extends Solver(context) with solver.push } - def halt() { - z3.interrupt + def interrupt() { + enclosing.interrupt() + } + + def recoverInterrupt() { + enclosing.recoverInterrupt() } def pop(lvl: Int = 1) { @@ -97,14 +79,26 @@ class UninterpretedZ3Solver(context : LeonContext) extends Solver(context) with } private var variables = Set[Identifier]() + private var containsFunCalls = false def assertCnstr(expression: Expr) { variables ++= variablesOf(expression) + containsFunCalls ||= containsFunctionCalls(expression) solver.assertCnstr(toZ3Formula(expression).get) } def check: Option[Boolean] = { - solver.check + solver.check match { + case Some(true) => + if (containsFunCalls) { + None + } else { + Some(true) + } + + case r => + r + } } def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { @@ -123,7 +117,5 @@ class UninterpretedZ3Solver(context : LeonContext) extends Solver(context) with case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") }).toSet } - - } } diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala index 78e30e87f..e4bf309d4 100644 --- a/src/main/scala/leon/synthesis/ParallelSearch.scala +++ b/src/main/scala/leon/synthesis/ParallelSearch.scala @@ -5,8 +5,7 @@ package synthesis import synthesis.search._ import akka.actor._ -import solvers.z3.{FairZ3Solver,UninterpretedZ3Solver} -import solvers.TrivialSolver +import solvers.z3._ class ParallelSearch(synth: Synthesizer, problem: Problem, @@ -24,15 +23,13 @@ class ParallelSearch(synth: Synthesizer, private[this] var contexts = List[SynthesisContext]() def initWorkerContext(wr: ActorRef) = { - val solver = new FairZ3Solver(synth.context) - solver.setProgram(synth.program) - solver.initZ3 + val solverf = new FairZ3SolverFactory(synth.context, synth.program) + solverf.initZ3 - val simpleSolver = new UninterpretedZ3Solver(synth.context) - simpleSolver.setProgram(synth.program) - simpleSolver.initZ3 + val fastSolverf = new UninterpretedZ3SolverFactory(synth.context, synth.program) + fastSolverf.initZ3 - val ctx = SynthesisContext.fromSynthesizer(synth).copy(solver = solver, simpleSolver = simpleSolver) + val ctx = SynthesisContext.fromSynthesizer(synth).copy(solverf = solverf, fastSolverf = fastSolverf) synchronized { contexts = ctx :: contexts diff --git a/src/main/scala/leon/synthesis/SimpleSearch.scala b/src/main/scala/leon/synthesis/SimpleSearch.scala index 0a0f96d27..543731c23 100644 --- a/src/main/scala/leon/synthesis/SimpleSearch.scala +++ b/src/main/scala/leon/synthesis/SimpleSearch.scala @@ -182,6 +182,10 @@ class SimpleSearch(synth: Synthesizer, stop() } + def recoverInterrupt() { + shouldStop = false + } + private var shouldStop = false override def stop() { @@ -190,8 +194,6 @@ class SimpleSearch(synth: Synthesizer, } def search(): Option[(Solution, Boolean)] = { - sctx.solver.init() - shouldStop = false while (!g.tree.isSolved && !shouldStop) { diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index e6a5620cf..69d83cd55 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -7,7 +7,7 @@ import purescala.Trees._ import purescala.TypeTrees.{TypeTree,TupleType} import purescala.Definitions._ import purescala.TreeOps._ -import solvers.z3.UninterpretedZ3Solver +import solvers.z3._ // Defines a synthesis solution of the form: // ⟨ P | T ⟩ @@ -30,8 +30,7 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) { defs.foldLeft(term){ case (t, fd) => LetDef(fd, t) } def toSimplifiedExpr(ctx: LeonContext, p: Program): Expr = { - val uninterpretedZ3 = new UninterpretedZ3Solver(ctx) - uninterpretedZ3.setProgram(p) + val uninterpretedZ3 = new UninterpretedZ3SolverFactory(ctx, p) val simplifiers = List[Expr => Expr]( simplifyTautologies(uninterpretedZ3)(_), diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala index 7c1c29b1b..1609064ce 100644 --- a/src/main/scala/leon/synthesis/SynthesisContext.scala +++ b/src/main/scala/leon/synthesis/SynthesisContext.scala @@ -3,7 +3,7 @@ package leon package synthesis -import solvers.Solver +import solvers._ import purescala.Trees._ import purescala.Definitions.{Program, FunDef} import purescala.Common.Identifier @@ -15,8 +15,8 @@ case class SynthesisContext( options: SynthesisOptions, functionContext: Option[FunDef], program: Program, - solver: Solver, - simpleSolver: Solver, + solverf: SolverFactory[Solver], + fastSolverf: SolverFactory[Solver], reporter: Reporter ) @@ -27,8 +27,8 @@ object SynthesisContext { synth.options, synth.functionContext, synth.program, - synth.solver, - synth.simpleSolver, + synth.solverf, + synth.fastSolverf, synth.reporter) } } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 57dca4236..f430ceef6 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -4,8 +4,7 @@ package leon package synthesis import purescala.TreeOps._ -import solvers.TrivialSolver -import solvers.z3.{FairZ3Solver,UninterpretedZ3Solver} +import solvers.z3._ import purescala.Trees._ import purescala.ScalaPrinter diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index e77cf0069..3418c14fe 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -8,14 +8,12 @@ import purescala.Definitions.{Program, FunDef} import purescala.TreeOps._ import purescala.Trees._ import purescala.ScalaPrinter + +import solvers._ import solvers.z3._ -import solvers.TimeoutSolver -import solvers.Solver import java.io.File -import collection.mutable.PriorityQueue - import synthesis.search._ class Synthesizer(val context : LeonContext, @@ -26,11 +24,8 @@ class Synthesizer(val context : LeonContext, val rules: Seq[Rule] = options.rules - val solver: FairZ3Solver = new FairZ3Solver(context) - solver.setProgram(program) - - val simpleSolver: Solver = new UninterpretedZ3Solver(context) - simpleSolver.setProgram(program) + val solverf = new FairZ3SolverFactory(context, program) + val fastSolverf = new UninterpretedZ3SolverFactory(context, program) val reporter = context.reporter @@ -81,8 +76,7 @@ class Synthesizer(val context : LeonContext, val (npr, fds) = solutionToProgram(sol) - val tsolver = new TimeoutSolver(new FairZ3Solver(context), timeoutMs) - tsolver.setProgram(npr) + val tsolver = new TimeoutSolverFactory(new FairZ3SolverFactory(context, npr), timeoutMs) val vcs = generateVerificationConditions(reporter, npr, fds.map(_.id.name)) val vctx = VerificationContext(context, Seq(tsolver), context.reporter) diff --git a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala b/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala index 4b8786987..594eecb76 100644 --- a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala +++ b/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala @@ -4,7 +4,7 @@ package leon package synthesis package heuristics -import solvers.TimeoutSolver +import solvers._ import purescala.Common._ import purescala.Trees._ import purescala.Extractors._ @@ -14,7 +14,7 @@ import purescala.Definitions._ case object ADTInduction extends Rule("ADT Induction") with Heuristic { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val tsolver = new TimeoutSolver(sctx.solver, 500L) + val tsolver = new TimeoutSolverFactory(sctx.solverf, 500L) val candidates = p.as.collect { case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(tsolver)(p.pc, origId) => (origId, cd) } diff --git a/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala b/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala index db5af8c78..6bf7cc92e 100644 --- a/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala +++ b/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala @@ -4,7 +4,7 @@ package leon package synthesis package heuristics -import solvers.TimeoutSolver +import solvers._ import purescala.Common._ import purescala.Trees._ import purescala.Extractors._ @@ -14,7 +14,7 @@ import purescala.Definitions._ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val tsolver = new TimeoutSolver(sctx.solver, 500L) + val tsolver = sctx.solverf.withTimeout(500L) val candidates = p.as.collect { case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(tsolver)(p.pc, origId) => (origId, cd) } diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 5203bd60f..18a48edbb 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -10,11 +10,11 @@ import purescala.TypeTrees._ import purescala.TreeOps._ import purescala.Extractors._ import purescala.Definitions._ -import solvers.TimeoutSolver +import solvers._ case object ADTSplit extends Rule("ADT Split.") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation]= { - val solver = new TimeoutSolver(sctx.solver, 200L) + val solver = SimpleSolverAPI(sctx.solverf.withTimeout(200L)) val candidates = p.as.collect { case IsTyped(id, AbstractClassType(cd)) => diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index e5a459fc6..9af12ce3d 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -4,7 +4,9 @@ package leon package synthesis package rules -import solvers.TimeoutSolver +import solvers._ +import solvers.z3._ + import purescala.Trees._ import purescala.Common._ import purescala.Definitions._ @@ -18,7 +20,6 @@ import scala.collection.mutable.{Map=>MutableMap} import evaluators._ import datagen._ -import solvers.z3.FairZ3Solver case object CEGIS extends Rule("CEGIS") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { @@ -450,8 +451,8 @@ case object CEGIS extends Rule("CEGIS") { var unrolings = 0 val maxUnrolings = 3 - val exSolver = new TimeoutSolver(sctx.solver, 3000L) // 3sec - val cexSolver = new TimeoutSolver(sctx.solver, 3000L) // 3sec + val exSolver = sctx.solverf.withTimeout(3000L) // 3sec + val cexSolver = sctx.solverf.withTimeout(3000L) // 3sec var baseExampleInputs: Seq[Seq[Expr]] = Seq() diff --git a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala index 5d643659a..df5d2dd8c 100644 --- a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala @@ -10,9 +10,11 @@ import purescala.TypeTrees._ import purescala.TreeOps._ import purescala.Extractors._ +import solvers._ + case object EqualitySplit extends Rule("Eq. Split") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val solver = sctx.simpleSolver + val solver = SimpleSolverAPI(sctx.fastSolverf) val candidates = p.as.groupBy(_.getType).mapValues(_.combinations(2).filter { case List(a1, a2) => diff --git a/src/main/scala/leon/synthesis/rules/Ground.scala b/src/main/scala/leon/synthesis/rules/Ground.scala index af1f48622..440f6e78d 100644 --- a/src/main/scala/leon/synthesis/rules/Ground.scala +++ b/src/main/scala/leon/synthesis/rules/Ground.scala @@ -4,7 +4,7 @@ package leon package synthesis package rules -import solvers.TimeoutSolver +import solvers._ import purescala.Trees._ import purescala.TypeTrees._ import purescala.TreeOps._ @@ -14,7 +14,7 @@ case object Ground extends Rule("Ground") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { if (p.as.isEmpty) { - val solver = new TimeoutSolver(sctx.solver, 5000L) // We give that 1s + val solver = SimpleSolverAPI(sctx.solverf.withTimeout(5000L)) // We give that 5s val tpe = TupleType(p.xs.map(_.getType)) diff --git a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala index 443040f7f..8fddc9099 100644 --- a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala @@ -11,9 +11,11 @@ import purescala.TypeTrees._ import purescala.TreeOps._ import purescala.Extractors._ +import solvers._ + case object InequalitySplit extends Rule("Ineq. Split.") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val solver = sctx.simpleSolver + val solver = SimpleSolverAPI(sctx.fastSolverf) val candidates = p.as.filter(_.getType == Int32Type).combinations(2).toList.filter { case List(a1, a2) => diff --git a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala index e0935fefc..a3c9cbf61 100644 --- a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala +++ b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala @@ -9,13 +9,15 @@ import purescala.TypeTrees._ import purescala.TreeOps._ import purescala.Extractors._ +import solvers._ + case object OptimisticGround extends Rule("Optimistic Ground") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { if (!p.as.isEmpty && !p.xs.isEmpty) { val res = new RuleInstantiation(p, this, SolutionBuilder.none, this.name) { def apply(sctx: SynthesisContext) = { - val solver = sctx.simpleSolver // Optimistic ground is given a simple solver (uninterpreted) + val solver = SimpleSolverAPI(sctx.fastSolverf) // Optimistic ground is given a simple solver (uninterpreted) val xss = p.xs.toSet val ass = p.as.toSet diff --git a/src/main/scala/leon/synthesis/utils/Benchmarks.scala b/src/main/scala/leon/synthesis/utils/Benchmarks.scala index 24023d0ea..0e6ae51b1 100644 --- a/src/main/scala/leon/synthesis/utils/Benchmarks.scala +++ b/src/main/scala/leon/synthesis/utils/Benchmarks.scala @@ -8,7 +8,7 @@ import leon.purescala.Definitions._ import leon.purescala.Trees._ import leon.purescala.TreeOps._ import leon.solvers.z3._ -import leon.solvers.Solver +import leon.solvers._ import leon.synthesis._ import java.util.Date @@ -54,19 +54,22 @@ object Benchmarks extends App { println("# Using rule: "+rule.name) - val infoSep : String = "╟" + ("┄" * 86) + "╢" - val infoFooter : String = "╚" + ("═" * 86) + "╝" + val infoSep : String = "╟" + ("┄" * 100) + "╢" + val infoFooter : String = "╚" + ("═" * 100) + "╝" val infoHeader : String = " ┌────────────┐\n" + - "╔═╡ Benchmarks ╞" + ("═" * 71) + "╗\n" + - "║ └────────────┘" + (" " * 71) + "║" + "╔═╡ Benchmarks ╞" + ("═" * 85) + "╗\n" + + "║ └────────────┘" + (" " * 85) + "║" + + val runtime = Runtime.getRuntime() def infoLine(file: String, f: String, ts: Long, nAlt: Int, nSuccess: Int, nInnap: Int, nDecomp: Int) : String = { - "║ %-30s %-24s %3d %10s %10s ms ║".format( + "║ %-30s %-24s %3d %10s %10s ms %10d Mb ║".format( file, f, nAlt, nSuccess+"/"+nInnap+"/"+nDecomp, - ts) + ts, + (runtime.totalMemory()-runtime.freeMemory())/(1024*1024)) } println(infoHeader) @@ -77,6 +80,8 @@ object Benchmarks extends App { val ctx = leon.Main.processOptions(others ++ newOptions) for (file <- ctx.files) { + Thread.sleep(10*1000); + val innerCtx = ctx.copy(files = List(file)) val opts = SynthesisOptions() @@ -85,11 +90,9 @@ object Benchmarks extends App { val (program, results) = pipeline.run(innerCtx)(file.getPath :: Nil) - val solver = new FairZ3Solver(ctx) - solver.setProgram(program) + val solverf = new FairZ3SolverFactory(ctx, program) - val simpleSolver = new UninterpretedZ3Solver(ctx) - simpleSolver.setProgram(program) + val fastSolverf = new UninterpretedZ3SolverFactory(ctx, program) for ((f, ps) <- results.toSeq.sortBy(_._1.id.toString); p <- ps) { val sctx = SynthesisContext( @@ -97,8 +100,8 @@ object Benchmarks extends App { options = opts, functionContext = Some(f), program = program, - solver = solver, - simpleSolver = simpleSolver, + solverf = solverf, + fastSolverf = fastSolverf, reporter = ctx.reporter ) diff --git a/src/main/scala/leon/testgen/CallGraph.scala b/src/main/scala/leon/testgen/CallGraph.scala index 62de1f42b..3c22b2334 100644 --- a/src/main/scala/leon/testgen/CallGraph.scala +++ b/src/main/scala/leon/testgen/CallGraph.scala @@ -9,7 +9,9 @@ import leon.purescala.TreeOps._ import leon.purescala.Extractors._ import leon.purescala.TypeTrees._ import leon.purescala.Common._ -import leon.solvers.z3.FairZ3Solver + +import leon.solvers.z3._ +import leon.solvers._ class CallGraph(val program: Program) { @@ -166,7 +168,7 @@ class CallGraph(val program: Program) { fd.annotations.exists(_ == "main") } - def findAllPaths(z3Solver: FairZ3Solver): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + def findAllPaths(z3Solver: FairZ3SolverFactory): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { val waypoints: Set[ProgramPoint] = programPoints.filter{ case ExpressionPoint(Waypoint(_, _), _) => true case _ => false } val sortedWaypoints: Seq[ProgramPoint] = waypoints.toSeq.sortWith((p1, p2) => { val (ExpressionPoint(Waypoint(i1, _), _), ExpressionPoint(Waypoint(i2, _), _)) = (p1, p2) @@ -192,7 +194,7 @@ class CallGraph(val program: Program) { } } - def visitAllWaypoints(waypoints: List[ProgramPoint], z3Solver: FairZ3Solver): Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + def visitAllWaypoints(waypoints: List[ProgramPoint], z3Solver: FairZ3SolverFactory): Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { def rec(head: ProgramPoint, tail: List[ProgramPoint], path: Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]): Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { tail match { @@ -202,12 +204,11 @@ class CallGraph(val program: Program) { var completePath: Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = None allPaths.find(intermediatePath => { val pc = pathConstraint(path ++ intermediatePath) - z3Solver.init() z3Solver.restartZ3 var testcase: Option[Map[Identifier, Expr]] = None - val (solverResult, model) = z3Solver.solveSAT(pc) + val (solverResult, model) = SimpleSolverAPI(z3Solver).solveSAT(pc) solverResult match { case None => { false diff --git a/src/main/scala/leon/testgen/TestGeneration.scala b/src/main/scala/leon/testgen/TestGeneration.scala index d8a716ba5..321388e67 100644 --- a/src/main/scala/leon/testgen/TestGeneration.scala +++ b/src/main/scala/leon/testgen/TestGeneration.scala @@ -10,9 +10,11 @@ import leon.xlang.Trees._ import leon.purescala.TreeOps._ import leon.purescala.TypeTrees._ import leon.purescala.ScalaPrinter -import leon.solvers.z3.FairZ3Solver import leon.Reporter +import leon.solvers._ +import leon.solvers.z3._ + import scala.collection.mutable.{Set => MutableSet} // TODO FIXME if this class is to be resurrected, make it a proper LeonPhase. @@ -23,19 +25,15 @@ class TestGeneration(context : LeonContext) { def shortDescription: String = "test" private val reporter = context.reporter - private val z3Solver = new FairZ3Solver(context) def analyse(program: Program) { - z3Solver.setProgram(program) + val z3Solver = new FairZ3SolverFactory(context, program) reporter.info("Running test generation") val testcases = generateTestCases(program) val topFunDef = program.definedFunctions.find(fd => isMain(fd)).get -//fd.body.exists(body => body match { -// case Waypoint(1, _) => true -// case _ => false -// }) + val testFun = new FunDef(FreshIdentifier("test"), UnitType, Seq()) val funInvocs = testcases.map(testcase => { val params = topFunDef.args @@ -62,6 +60,7 @@ class TestGeneration(context : LeonContext) { } def generatePathConditions(program: Program): Set[Expr] = { + val z3Solver = new FairZ3SolverFactory(context, program) val callGraph = new CallGraph(program) callGraph.writeDotFile("testgen.dot") @@ -76,16 +75,15 @@ class TestGeneration(context : LeonContext) { private def generateTestCases(program: Program): Set[Map[Identifier, Expr]] = { val allPaths = generatePathConditions(program) + val z3Solver = new FairZ3SolverFactory(context, program) allPaths.flatMap(pathCond => { reporter.info("Now considering path condition: " + pathCond) var testcase: Option[Map[Identifier, Expr]] = None - //val z3Solver: FairZ3Solver = loadedSolverExtensions.find(se => se.isInstanceOf[FairZ3Solver]).get.asInstanceOf[FairZ3Solver] - z3Solver.init() z3Solver.restartZ3 - val (solverResult, model) = z3Solver.solveSAT(pathCond) + val (solverResult, model) = SimpleSolverAPI(z3Solver).solveSAT(pathCond) solverResult match { case None => Seq() @@ -100,52 +98,6 @@ class TestGeneration(context : LeonContext) { } }) } - - //private def generatePathConditions(funDef: FunDef): Seq[Expr] = if(!funDef.hasImplementation) Seq() else { - // val body = funDef.body.get - // val cleanBody = hoistIte(expandLets(matchToIfThenElse(body))) - // collectWithPathCondition(cleanBody) - //} - - //private def generateTestCases(funDef: FunDef): Seq[Map[Identifier, Expr]] = { - // val allPaths = generatePathConditions(funDef) - - // allPaths.flatMap(pathCond => { - // reporter.info("Now considering path condition: " + pathCond) - - // var testcase: Option[Map[Identifier, Expr]] = None - // //val z3Solver: FairZ3Solver = loadedSolverExtensions.find(se => se.isInstanceOf[FairZ3Solver]).get.asInstanceOf[FairZ3Solver] - // - // z3Solver.init() - // z3Solver.restartZ3 - // val (solverResult, model) = z3Solver.decideWithModel(pathCond, false) - - // solverResult match { - // case None => Seq() - // case Some(true) => { - // reporter.info("The path is unreachable") - // Seq() - // } - // case Some(false) => { - // reporter.info("The model should be used as the testcase") - // Seq(model) - // } - // } - // }) - //} - - //prec: ite are hoisted and no lets nor match occurs - //private def collectWithPathCondition(expression: Expr): Seq[Expr] = { - // var allPaths: Seq[Expr] = Seq() - - // def rec(expr: Expr, path: List[Expr]): Seq[Expr] = expr match { - // case IfExpr(cond, thenn, elze) => rec(thenn, cond :: path) ++ rec(elze, Not(cond) :: path) - // case _ => Seq(And(path.toSeq)) - // } - - // rec(expression, List()) - //} - } diff --git a/src/main/scala/leon/utils/InterruptManager.scala b/src/main/scala/leon/utils/InterruptManager.scala index 23f2221ee..d69786ca1 100644 --- a/src/main/scala/leon/utils/InterruptManager.scala +++ b/src/main/scala/leon/utils/InterruptManager.scala @@ -27,6 +27,16 @@ class InterruptManager(reporter: Reporter) { } } + def recoverInterrupt() = synchronized { + if (interrupted.get()) { + interrupted.set(false) + + interruptibles.keySet.foreach(_.recoverInterrupt()) + } else { + reporter.warning("Not interrupted!") + } + } + def registerForInterrupts(i: Interruptible) { interruptibles.put(i, true) } diff --git a/src/main/scala/leon/utils/Interruptible.scala b/src/main/scala/leon/utils/Interruptible.scala index 1fe7b7390..b9667d6bd 100644 --- a/src/main/scala/leon/utils/Interruptible.scala +++ b/src/main/scala/leon/utils/Interruptible.scala @@ -3,4 +3,5 @@ package utils trait Interruptible { def interrupt(): Unit + def recoverInterrupt(): Unit } diff --git a/src/main/scala/leon/verification/AnalysisPhase.scala b/src/main/scala/leon/verification/AnalysisPhase.scala index 1113daca8..653d9f021 100644 --- a/src/main/scala/leon/verification/AnalysisPhase.scala +++ b/src/main/scala/leon/verification/AnalysisPhase.scala @@ -9,8 +9,8 @@ import purescala.Trees._ import purescala.TreeOps._ import purescala.TypeTrees._ -import solvers.{Solver,TrivialSolver,TimeoutSolver} -import solvers.z3.FairZ3Solver +import solvers._ +import solvers.z3._ import scala.collection.mutable.{Set => MutableSet} @@ -72,54 +72,44 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { reporter.info(simplifyLets(vc)) // try all solvers until one returns a meaningful answer - var superseeded : Set[String] = Set.empty[String] solvers.find(se => { reporter.info("Trying with solver: " + se.name) - if(superseeded(se.name) || superseeded(se.description)) { - reporter.info("Solver was superseeded. Skipping.") - false - } else { - superseeded = superseeded ++ Set(se.superseeds: _*) - - val t1 = System.nanoTime - se.init() - val (satResult, counterexample) = se.solveSAT(Not(vc)) - val solverResult = satResult.map(!_) - - val t2 = System.nanoTime - val dt = ((t2 - t1) / 1000000) / 1000.0 - - solverResult match { - case _ if interruptManager.isInterrupted() => - reporter.info("=== CANCELLED ===") - vcInfo.time = Some(dt) - false - - case None => - vcInfo.time = Some(dt) - false - - case Some(true) => - reporter.info("==== VALID ====") - - vcInfo.hasValue = true - vcInfo.value = Some(true) - vcInfo.solvedWith = Some(se) - vcInfo.time = Some(dt) - true - - case Some(false) => - reporter.error("Found counter-example : ") - reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n")) - reporter.error("==== INVALID ====") - vcInfo.hasValue = true - vcInfo.value = Some(false) - vcInfo.solvedWith = Some(se) - vcInfo.counterExample = Some(counterexample) - vcInfo.time = Some(dt) - true - - } + val t1 = System.nanoTime + val (satResult, counterexample) = SimpleSolverAPI(se).solveSAT(Not(vc)) + val solverResult = satResult.map(!_) + + val t2 = System.nanoTime + val dt = ((t2 - t1) / 1000000) / 1000.0 + + solverResult match { + case _ if interruptManager.isInterrupted() => + reporter.info("=== CANCELLED ===") + vcInfo.time = Some(dt) + false + + case None => + vcInfo.time = Some(dt) + false + + case Some(true) => + reporter.info("==== VALID ====") + + vcInfo.hasValue = true + vcInfo.value = Some(true) + vcInfo.solvedWith = Some(se) + vcInfo.time = Some(dt) + true + + case Some(false) => + reporter.error("Found counter-example : ") + reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n")) + reporter.error("==== INVALID ====") + vcInfo.hasValue = true + vcInfo.value = Some(false) + vcInfo.solvedWith = Some(se) + vcInfo.counterExample = Some(counterexample) + vcInfo.time = Some(dt) + true } }) match { case None => { @@ -150,20 +140,21 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { val reporter = ctx.reporter - val trivialSolver = new TrivialSolver(ctx) - val fairZ3 = new FairZ3Solver(ctx) + val fairZ3 = new FairZ3SolverFactory(ctx, program) - val solvers0 : Seq[Solver] = trivialSolver :: fairZ3 :: Nil - val solvers: Seq[Solver] = timeout match { - case Some(t) => solvers0.map(s => new TimeoutSolver(s, 1000L * t)) - case None => solvers0 - } + val baseSolvers : Seq[SolverFactory[Solver]] = fairZ3 :: Nil - solvers.foreach(_.setProgram(program)) + val solvers: Seq[SolverFactory[Solver]] = timeout match { + case Some(t) => + baseSolvers.map(_.withTimeout(100L*t)) + + case None => + baseSolvers + } val vctx = VerificationContext(ctx, solvers, reporter) - val report = if(solvers.size > 1) { + val report = if(solvers.size >= 1) { reporter.info("Running verification condition generation...") val vcs = generateVerificationConditions(reporter, program, functionsToAnalyse) checkVerificationConditions(vctx, vcs) diff --git a/src/main/scala/leon/verification/VerificationCondition.scala b/src/main/scala/leon/verification/VerificationCondition.scala index 48b6a2e22..776a4c004 100644 --- a/src/main/scala/leon/verification/VerificationCondition.scala +++ b/src/main/scala/leon/verification/VerificationCondition.scala @@ -6,7 +6,7 @@ import leon.purescala.Trees._ import leon.purescala.Definitions._ import leon.purescala.Common._ -import leon.solvers.Solver +import leon.solvers._ /** This is just to hold some history information. */ class VerificationCondition(val condition: Expr, val funDef: FunDef, val kind: VCKind.Value, val tactic: Tactic, val info: String = "") extends ScalacPositional { @@ -15,7 +15,7 @@ class VerificationCondition(val condition: Expr, val funDef: FunDef, val kind: V // Some(false) = valid var hasValue = false var value : Option[Boolean] = None - var solvedWith : Option[Solver] = None + var solvedWith : Option[SolverFactory[Solver]] = None var time : Option[Double] = None var counterExample : Option[Map[Identifier, Expr]] = None diff --git a/src/main/scala/leon/verification/VerificationContext.scala b/src/main/scala/leon/verification/VerificationContext.scala index f4141f201..ff2f60a4b 100644 --- a/src/main/scala/leon/verification/VerificationContext.scala +++ b/src/main/scala/leon/verification/VerificationContext.scala @@ -3,12 +3,12 @@ package leon package verification -import solvers.Solver +import solvers._ import java.util.concurrent.atomic.AtomicBoolean case class VerificationContext ( context: LeonContext, - solvers: Seq[Solver], + solvers: Seq[SolverFactory[Solver]], reporter: Reporter ) diff --git a/src/test/scala/leon/test/purescala/TreeOpsTests.scala b/src/test/scala/leon/test/purescala/TreeOpsTests.scala index 6372a364a..9878b66db 100644 --- a/src/test/scala/leon/test/purescala/TreeOpsTests.scala +++ b/src/test/scala/leon/test/purescala/TreeOpsTests.scala @@ -11,17 +11,16 @@ import leon.purescala.Trees._ import leon.purescala.TypeTrees._ import leon.purescala.TreeOps._ +import leon.solvers.z3._ + class TreeOpsTests extends LeonTestSuite { test("Path-aware simplifications") { - import leon.solvers.z3.UninterpretedZ3Solver - val solver = new UninterpretedZ3Solver(testContext) - solver.setProgram(Program.empty) - + val solver = new UninterpretedZ3SolverFactory(testContext, Program.empty) // TODO actually testing something here would be better, sorry // PS - assert(true) + assert(true) } diff --git a/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala b/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala index a8f1ddd08..988a514e9 100644 --- a/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala +++ b/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala @@ -11,42 +11,57 @@ import leon.purescala.Trees._ import leon.purescala.TypeTrees._ class TimeoutSolverTests extends LeonTestSuite { - private class IdioticSolver(ctx : LeonContext) extends Solver(ctx) with NaiveIncrementalSolver { + private class IdioticSolver(val context : LeonContext, val program: Program) extends SolverFactory[Solver] { + enclosing => + val name = "Idiotic" - val description = "Loops when it doesn't know" - - def solve(expression : Expr) : Option[Boolean] = expression match { - case BooleanLiteral(true) => Some(true) - case BooleanLiteral(false) => Some(false) - case Equals(x, y) if x == y => Some(true) - case _ => - while(!forceStop) { - Thread.sleep(1) + val description = "Loops" + + def getNewSolver = new Solver { + def check = { + while(!interrupted) { + Thread.sleep(100) } None + } + + def assertCnstr(e: Expr) {} + + def checkAssumptions(assump: Set[Expr]) = ??? + def getModel = ??? + def getUnsatCore = ??? + def push() = ??? + def pop(lvl: Int) = ??? + + def interrupt() = enclosing.interrupt() + def recoverInterrupt() = enclosing.recoverInterrupt() } } - private def getTOSolver : Solver = { - val s = new TimeoutSolver(new IdioticSolver(testContext), 1000L) - s.setProgram(Program.empty) - s + private def getTOSolver : TimeoutSolverFactory[Solver] = { + new IdioticSolver(testContext, Program.empty).withTimeout(1000L) + } + + private def check(sf: TimeoutSolverFactory[Solver], e: Expr): Option[Boolean] = { + val s = sf.getNewSolver + s.assertCnstr(e) + s.check } test("TimeoutSolver 1") { - val s = getTOSolver - assert(s.solve(BooleanLiteral(true)) === Some(true)) - assert(s.solve(BooleanLiteral(false)) === Some(false)) + val sf = getTOSolver + assert(check(sf, BooleanLiteral(true)) === None) + assert(check(sf, BooleanLiteral(false)) === None) val x = Variable(FreshIdentifier("x").setType(Int32Type)) - assert(s.solve(Equals(x, x)) === Some(true)) + assert(check(sf, Equals(x, x)) === None) } test("TimeoutSolver 2") { - val s = getTOSolver + val sf = getTOSolver val x = Variable(FreshIdentifier("x").setType(Int32Type)) val o = IntLiteral(1) - assert(s.solve(Equals(Plus(x, o), Plus(o, x))) === None) - assert(s.solve(Equals(Plus(x, o), x)) === None) + assert(check(sf, Equals(Plus(x, o), Plus(o, x))) === None) + assert(check(sf, Equals(Plus(x, o), x)) === None) } } diff --git a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala index 075522597..3dd9579f8 100644 --- a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala +++ b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala @@ -9,30 +9,30 @@ import leon.purescala.Trees._ import leon.purescala.TreeOps._ import leon.purescala.TypeTrees._ -import leon.solvers.Solver -import leon.solvers.z3.FairZ3Solver +import leon.solvers._ +import leon.solvers.z3._ class FairZ3SolverTests extends LeonTestSuite { private var testCounter : Int = 0 - private def solverCheck(solver : Solver, expr : Expr, expected : Option[Boolean], msg : String) = { + private def solverCheck(solver : SimpleSolverAPI, expr : Expr, expected : Option[Boolean], msg : String) = { testCounter += 1 test("Solver test #" + testCounter) { - assert(solver.solve(expr) === expected, msg) + assert(solver.solveVALID(expr) === expected, msg) } } - private def assertValid(solver : Solver, expr : Expr) = solverCheck( + private def assertValid(solver : SimpleSolverAPI, expr : Expr) = solverCheck( solver, expr, Some(true), "Solver should prove the formula " + expr + " valid." ) - private def assertInvalid(solver : Solver, expr : Expr) = solverCheck( + private def assertInvalid(solver : SimpleSolverAPI, expr : Expr) = solverCheck( solver, expr, Some(false), "Solver should prove the formula " + expr + " invalid." ) - private def assertUnknown(solver : Solver, expr : Expr) = solverCheck( + private def assertUnknown(solver : SimpleSolverAPI, expr : Expr) = solverCheck( solver, expr, None, "Solver should not be able to decide the formula " + expr + "." ) @@ -53,8 +53,7 @@ class FairZ3SolverTests extends LeonTestSuite { private val y : Expr = Variable(FreshIdentifier("y").setType(Int32Type)) private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) - private val solver = new FairZ3Solver(testContext) - solver.setProgram(minimalProgram) + private val solver = SimpleSolverAPI(new FairZ3SolverFactory(testContext, minimalProgram)) private val tautology1 : Expr = BooleanLiteral(true) assertValid(solver, tautology1) diff --git a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala index 385784d9a..4d9036d86 100644 --- a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala +++ b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala @@ -9,12 +9,12 @@ import leon.purescala.Trees._ import leon.purescala.TreeOps._ import leon.purescala.TypeTrees._ -import leon.solvers.Solver -import leon.solvers.z3.FairZ3Solver +import leon.solvers._ +import leon.solvers.z3._ class FairZ3SolverTestsNewAPI extends LeonTestSuite { private var testCounter : Int = 0 - private def solverCheck(solver : Solver, expr : Expr, expected : Option[Boolean], msg : String) = { + private def solverCheck(solver : SolverFactory[Solver], expr : Expr, expected : Option[Boolean], msg : String) = { testCounter += 1 test("Solver test #" + testCounter) { @@ -26,17 +26,17 @@ class FairZ3SolverTestsNewAPI extends LeonTestSuite { } } - private def assertValid(solver : Solver, expr : Expr) = solverCheck( + private def assertValid(solver : SolverFactory[Solver], expr : Expr) = solverCheck( solver, expr, Some(true), "Solver should prove the formula " + expr + " valid." ) - private def assertInvalid(solver : Solver, expr : Expr) = solverCheck( + private def assertInvalid(solver : SolverFactory[Solver], expr : Expr) = solverCheck( solver, expr, Some(false), "Solver should prove the formula " + expr + " invalid." ) - private def assertUnknown(solver : Solver, expr : Expr) = solverCheck( + private def assertUnknown(solver : SolverFactory[Solver], expr : Expr) = solverCheck( solver, expr, None, "Solver should not be able to decide the formula " + expr + "." ) @@ -57,9 +57,7 @@ class FairZ3SolverTestsNewAPI extends LeonTestSuite { private val y : Expr = Variable(FreshIdentifier("y").setType(Int32Type)) private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) - private val solver = new FairZ3Solver(testContext) - solver.setProgram(minimalProgram) - solver.restartZ3 + private val solver = new FairZ3SolverFactory(testContext, minimalProgram) private val tautology1 : Expr = BooleanLiteral(true) assertValid(solver, tautology1) diff --git a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala index 8f4cd8466..ed3d54198 100644 --- a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala +++ b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala @@ -9,30 +9,30 @@ import leon.purescala.Trees._ import leon.purescala.TreeOps._ import leon.purescala.TypeTrees._ -import leon.solvers.Solver -import leon.solvers.z3.UninterpretedZ3Solver +import leon.solvers._ +import leon.solvers.z3._ class UninterpretedZ3SolverTests extends LeonTestSuite { private var testCounter : Int = 0 - private def solverCheck(solver : Solver, expr : Expr, expected : Option[Boolean], msg : String) = { + private def solverCheck(solver : SimpleSolverAPI, expr : Expr, expected : Option[Boolean], msg : String) = { testCounter += 1 test("Solver test #" + testCounter) { - assert(solver.solve(expr) === expected, msg) + assert(solver.solveVALID(expr) === expected, msg) } } - private def assertValid(solver : Solver, expr : Expr) = solverCheck( + private def assertValid(solver : SimpleSolverAPI, expr : Expr) = solverCheck( solver, expr, Some(true), "Solver should prove the formula " + expr + " valid." ) - private def assertInvalid(solver : Solver, expr : Expr) = solverCheck( + private def assertInvalid(solver : SimpleSolverAPI, expr : Expr) = solverCheck( solver, expr, Some(false), "Solver should prove the formula " + expr + " invalid." ) - private def assertUnknown(solver : Solver, expr : Expr) = solverCheck( + private def assertUnknown(solver : SimpleSolverAPI, expr : Expr) = solverCheck( solver, expr, None, "Solver should not be able to decide the formula " + expr + "." ) @@ -58,8 +58,7 @@ class UninterpretedZ3SolverTests extends LeonTestSuite { private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) private def g(e : Expr) : Expr = FunctionInvocation(gDef, e :: Nil) - private val solver = new UninterpretedZ3Solver(testContext) - solver.setProgram(minimalProgram) + private val solver = SimpleSolverAPI(new UninterpretedZ3SolverFactory(testContext, minimalProgram)) private val tautology1 : Expr = BooleanLiteral(true) assertValid(solver, tautology1) @@ -87,7 +86,7 @@ class UninterpretedZ3SolverTests extends LeonTestSuite { test("Expected crash on undefined functions.") { intercept[Exception] { - solver.solve(Equals(g(x), g(x))) + solver.solveVALID(Equals(g(x), g(x))) } } } diff --git a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala index 41e0780a4..49651699c 100644 --- a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala +++ b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala @@ -37,11 +37,9 @@ class SynthesisSuite extends LeonTestSuite { val (program, results) = pipeline.run(ctx)((content, Nil)) - val solver = new FairZ3Solver(ctx) - solver.setProgram(program) + val solver = new FairZ3SolverFactory(ctx, program) - val simpleSolver = new UninterpretedZ3Solver(ctx) - simpleSolver.setProgram(program) + val simpleSolver = new UninterpretedZ3SolverFactory(ctx, program) for ((f, ps) <- results; p <- ps) { test("Synthesizing %3d: %-20s [%s]".format(nextInt(), f.id.toString, title)) { -- GitLab