diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index 67241adeffd6adebee90741c3db3835983cb1f00..ec69ff46e3c2f646fb298040be7a64d6d53ccea7 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -19,14 +19,14 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb // EXPRESSIONS // all expressions are printed in-line override def pp(tree: Expr, lvl: Int): Unit = tree match { - case Variable(id) => sb.append(id) + case Variable(id) => sb.append(idToString(id)) case DeBruijnIndex(idx) => sys.error("Not Valid Scala") case LetTuple(ids,d,e) => sb.append("locally {\n") ind(lvl+1) sb.append("val (" ) for (((id, tpe), i) <- ids.map(id => (id, id.getType)).zipWithIndex) { - sb.append(id.toString+": ") + sb.append(idToString(id)+": ") pp(tpe, lvl) if (i != ids.size-1) { sb.append(", ") @@ -85,7 +85,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb sb.append("._" + i) case CaseClass(cd, args) => - sb.append(cd.id) + sb.append(idToString(cd.id)) if (cd.isCaseObject) { ppNary(args, "", "", "", lvl) } else { @@ -94,14 +94,14 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb case CaseClassInstanceOf(cd, e) => pp(e, lvl) - sb.append(".isInstanceOf[" + cd.id + "]") + sb.append(".isInstanceOf[" + idToString(cd.id) + "]") case CaseClassSelector(_, cc, id) => pp(cc, lvl) - sb.append("." + id) + sb.append("." + idToString(id)) case FunctionInvocation(fd, args) => - sb.append(fd.id) + sb.append(idToString(fd.id)) ppNary(args, "(", ", ", ")", lvl) case Plus(l,r) => ppBinary(l, r, " + ", lvl) @@ -210,7 +210,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb case Choose(ids, pred) => sb.append("(choose { (") for (((id, tpe), i) <- ids.map(id => (id, id.getType)).zipWithIndex) { - sb.append(id.toString+": ") + sb.append(idToString(id)+": ") pp(tpe, lvl) if (i != ids.size-1) { sb.append(", ") @@ -229,7 +229,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb //case InstanceOfPattern(Some(id), ctd) => case CaseClassPattern(bndr, ccd, subps) => { bndr.foreach(b => sb.append(b + " @ ")) - sb.append(ccd.id).append("(") + sb.append(idToString(ccd.id)).append("(") var c = 0 val sz = subps.size subps.foreach(sp => { @@ -241,10 +241,10 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb sb.append(")") } case WildcardPattern(None) => sb.append("_") - case WildcardPattern(Some(id)) => sb.append(id) + case WildcardPattern(Some(id)) => sb.append(idToString(id)) case InstanceOfPattern(bndr, ccd) => { bndr.foreach(b => sb.append(b + " : ")) - sb.append(ccd.id) + sb.append(idToString(ccd.id)) } case TuplePattern(bndr, subPatterns) => { bndr.foreach(b => sb.append(b + " @ ")) @@ -333,7 +333,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb } sb.append(" => ") pp(tt, lvl) - case c: ClassType => sb.append(c.classDef.id) + case c: ClassType => sb.append(idToString(c.classDef.id)) case _ => sb.append("Type?") } @@ -350,7 +350,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb case ObjectDef(id, defs, invs) => { ind(lvl) sb.append("object ") - sb.append(id) + sb.append(idToString(id)) sb.append(" {\n") var c = 0 @@ -372,19 +372,19 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb case AbstractClassDef(id, parent) => ind(lvl) sb.append("sealed abstract class ") - sb.append(id) + sb.append(idToString(id)) parent.foreach(p => sb.append(" extends " + p.id)) case CaseClassDef(id, parent, varDecls) => ind(lvl) sb.append("case class ") - sb.append(id) + sb.append(idToString(id)) sb.append("(") var c = 0 val sz = varDecls.size varDecls.foreach(vd => { - sb.append(vd.id) + sb.append(idToString(vd.id)) sb.append(": ") pp(vd.tpe, lvl) if(c < sz - 1) { @@ -393,20 +393,20 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb c = c + 1 }) sb.append(")") - parent.foreach(p => sb.append(" extends " + p.id)) + parent.foreach(p => sb.append(" extends " + idToString(p.id))) case fd: FunDef => ind(lvl) sb.append("def ") - sb.append(fd.id) + sb.append(idToString(fd.id)) sb.append("(") val sz = fd.args.size var c = 0 fd.args.foreach(arg => { - sb.append(arg.id) + sb.append(idToString(arg.id)) sb.append(" : ") pp(arg.tpe, lvl) diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 3479c8192576e12dc7de3036f30dcf7fb6c66793..2ac6b102e03b7b7d1284fd38356f394df579f358 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1449,7 +1449,7 @@ object TreeOps { case (id, post) => val nid = genId(id, newScope) val postScope = newScope.register(id -> nid) - (id, rec(post, postScope)) + (nid, rec(post, postScope)) } LetDef(newFd, rec(body, newScope)) @@ -1583,6 +1583,7 @@ object TreeOps { case Untyped | AnyType | BottomType | BooleanType | Int32Type | UnitType => None } + var idMap = Map[Identifier, Identifier]() var funDefMap = Map.empty[FunDef,FunDef] def fd2fd(funDef : FunDef) : FunDef = funDefMap.get(funDef) match { @@ -1598,7 +1599,14 @@ object TreeOps { // These will be taken care of in the recursive traversal. fd.body = funDef.body fd.precondition = funDef.precondition - fd.postcondition = funDef.postcondition + funDef.postcondition match { + case Some((id, post)) => + val freshId = FreshIdentifier(id.name, true).setType(rt) + idMap += id -> freshId + fd.postcondition = Some((freshId, post)) + case None => + fd.postcondition = None + } fd } funDefMap = funDefMap.updated(funDef, newFD) @@ -1607,6 +1615,7 @@ object TreeOps { def pre(e : Expr) : Expr = e match { case Tuple(Seq()) => UnitLiteral + case Variable(id) if idMap contains id => Variable(idMap(id)) case Tuple(Seq(s)) => pre(s) diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index d7bd3e9689bc856486d9aa975a6653c183c88021..65b25f25e23343a6f6b1ec319a159bd217379d01 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -490,9 +490,13 @@ object Trees { leastUpperBound(Seq(set1, set2).map(_.getType)).foreach(setType _) } case class SetMin(set: Expr) extends Expr with FixedType { + typeCheck(set, SetType(Int32Type)) + val fixedType = Int32Type } case class SetMax(set: Expr) extends Expr with FixedType { + typeCheck(set, SetType(Int32Type)) + val fixedType = Int32Type } diff --git a/src/main/scala/leon/solvers/AssumptionSolver.scala b/src/main/scala/leon/solvers/AssumptionSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..e906687bee4329f82ff1ecac50d8c4bfeb41ff23 --- /dev/null +++ b/src/main/scala/leon/solvers/AssumptionSolver.scala @@ -0,0 +1,11 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import purescala.Trees.Expr + +trait AssumptionSolver extends Solver { + def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] + def getUnsatCore: Set[Expr] +} diff --git a/src/main/scala/leon/solvers/IncrementalSolver.scala b/src/main/scala/leon/solvers/IncrementalSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..79c8d04d46697318ac5d3db380f482fdcffde0f5 --- /dev/null +++ b/src/main/scala/leon/solvers/IncrementalSolver.scala @@ -0,0 +1,10 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +trait IncrementalSolver extends Solver { + def push(): Unit + def pop(lvl: Int = 1): Unit +} + diff --git a/src/main/scala/leon/solvers/SimpleAssumptionSolverAPI.scala b/src/main/scala/leon/solvers/SimpleAssumptionSolverAPI.scala new file mode 100644 index 0000000000000000000000000000000000000000..0ce1eb8d292b0df54fa813c235b108e7c0222fe9 --- /dev/null +++ b/src/main/scala/leon/solvers/SimpleAssumptionSolverAPI.scala @@ -0,0 +1,27 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import purescala.Common._ +import purescala.Trees._ + +class SimpleAssumptionSolverAPI(sf: SolverFactory[AssumptionSolver]) extends SimpleSolverAPI(sf) { + + def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { + val s = sf.getNewSolver + try { + s.assertCnstr(expression) + s.checkAssumptions(assumptions) match { + case Some(true) => + (Some(true), s.getModel, Set()) + case Some(false) => + (Some(false), Map(), s.getUnsatCore) + case None => + (None, Map(), Set()) + } + } finally { + s.free() + } + } +} diff --git a/src/main/scala/leon/solvers/SimpleSolverAPI.scala b/src/main/scala/leon/solvers/SimpleSolverAPI.scala index 6565a2148c7959ae2a19c402ad18c5aef93dcb14..59b20f8f9011657f9cd62a8be504d0acfea63b9b 100644 --- a/src/main/scala/leon/solvers/SimpleSolverAPI.scala +++ b/src/main/scala/leon/solvers/SimpleSolverAPI.scala @@ -6,41 +6,43 @@ package solvers import purescala.Common._ import purescala.Trees._ -case class SimpleSolverAPI[S <: Solver](sf: SolverFactory[S]) { +class SimpleSolverAPI(sf: SolverFactory[Solver]) { def solveVALID(expression: Expr): Option[Boolean] = { - val s = sf.getNewSolver() - s.assertCnstr(Not(expression)) - s.check.map(r => !r) - } - - def free() { - sf.free() + val s = sf.getNewSolver + try { + s.assertCnstr(Not(expression)) + s.check.map(r => !r) + } finally { + s.free() + } } def solveSAT(expression: Expr): (Option[Boolean], Map[Identifier, Expr]) = { - val s = sf.getNewSolver() - s.assertCnstr(expression) - s.check match { - case Some(true) => - (Some(true), s.getModel) - case Some(false) => - (Some(false), Map()) - case None => - (None, Map()) + val s = sf.getNewSolver + try { + s.assertCnstr(expression) + s.check match { + case Some(true) => + (Some(true), s.getModel) + case Some(false) => + (Some(false), Map()) + case None => + (None, Map()) + } + } finally { + s.free() } } +} + +object SimpleSolverAPI { + def apply(sf: SolverFactory[Solver]) = { + new SimpleSolverAPI(sf) + } - def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { - val s = sf.getNewSolver() - s.assertCnstr(expression) - s.checkAssumptions(assumptions) match { - case Some(true) => - (Some(true), s.getModel, Set()) - case Some(false) => - (Some(false), Map(), s.getUnsatCore) - case None => - (None, Map(), Set()) - } + // Wrapping an AssumptionSolver will automatically provide an extended + // interface + def apply(sf: SolverFactory[AssumptionSolver]) = { + new SimpleAssumptionSolverAPI(sf) } } - diff --git a/src/main/scala/leon/solvers/Solver.scala b/src/main/scala/leon/solvers/Solver.scala index cb257d2369d532836e9b77a25da23432af453ef3..b4edadedbd96833932dd3ee8c79b4f4487dea3b3 100644 --- a/src/main/scala/leon/solvers/Solver.scala +++ b/src/main/scala/leon/solvers/Solver.scala @@ -4,18 +4,24 @@ package leon package solvers import utils._ -import purescala.Common._ -import purescala.Trees._ +import purescala.Trees.Expr +import purescala.Common.Identifier + + +trait Solver { + def name: String + val context: LeonContext -trait Solver extends Interruptible { - def push(): Unit - def pop(lvl: Int = 1): Unit def assertCnstr(expression: Expr): Unit def check: Option[Boolean] - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] def getModel: Map[Identifier, Expr] - def getUnsatCore: Set[Expr] + + def free() implicit val debugSection = DebugSectionSolver + + private[solvers] def debugS(msg: String) = { + context.reporter.debug("["+name+"] "+msg) + } } diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index aa343136abe8927b2077c7b220c164ee4fd5a8f8..147d681d97fd90ea17790d785a42d94f8d6bc1f6 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -3,41 +3,18 @@ package leon package solvers -import solvers.combinators._ - -import utils._ -import purescala.Common._ -import purescala.Definitions._ -import purescala.TreeOps._ -import purescala.Trees._ - -trait SolverFactory[S <: Solver] extends Interruptible with LeonComponent { - val context: LeonContext - val program: Program - - def free() {} - - var interrupted = false - - override def interrupt() { - interrupted = true - } - - override def recoverInterrupt() { - interrupted = false - } +import scala.reflect.runtime.universe._ +abstract class SolverFactory[+S <: Solver : TypeTag] { def getNewSolver(): S - def withTimeout(ms: Long): TimeoutSolverFactory[S] = { - this match { - case tsf: TimeoutSolverFactory[S] => - // Unwrap/Rewrap to take only new timeout into account - new TimeoutSolverFactory[S](tsf.sf, ms) - case _ => - new TimeoutSolverFactory[S](this, ms) + val name = "SFact("+typeOf[S].toString+")" +} + +object SolverFactory { + def apply[S <: Solver : TypeTag](builder: () => S): SolverFactory[S] = { + new SolverFactory[S] { + def getNewSolver() = builder() } } - - implicit val debugSection = DebugSectionSolver } diff --git a/src/main/scala/leon/solvers/TimeoutAssumptionSolver.scala b/src/main/scala/leon/solvers/TimeoutAssumptionSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..ae59400daba40d71cd5cea4d2aa448cc8ab4974f --- /dev/null +++ b/src/main/scala/leon/solvers/TimeoutAssumptionSolver.scala @@ -0,0 +1,23 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import purescala.Trees.Expr + +trait TimeoutAssumptionSolver extends TimeoutSolver with AssumptionSolver { + + protected def innerCheckAssumptions(assumptions: Set[Expr]): Option[Boolean] + + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + optTimeout match { + case Some(to) => + interruptAfter(to) { + innerCheckAssumptions(assumptions) + } + case None => + innerCheckAssumptions(assumptions) + } + } +} + diff --git a/src/main/scala/leon/solvers/TimeoutSolver.scala b/src/main/scala/leon/solvers/TimeoutSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..ab98755c24aaf59204edc4f9dd64efec9eb102d1 --- /dev/null +++ b/src/main/scala/leon/solvers/TimeoutSolver.scala @@ -0,0 +1,71 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import utils._ + +trait TimeoutSolver extends Solver with Interruptible { + + private class Countdown(timeout: Long, onTimeout: => Unit) extends Thread { + private var keepRunning = true + override def run : Unit = { + val startTime : Long = System.currentTimeMillis + var exceeded : Boolean = false + + while(!exceeded && keepRunning) { + if(timeout < (System.currentTimeMillis - startTime)) { + exceeded = true + } + Thread.sleep(10) + } + + if(exceeded && keepRunning) { + onTimeout + } + } + + def finishedRunning : Unit = { + keepRunning = false + } + } + + protected var optTimeout: Option[Long] = None; + protected def interruptAfter[T](timeout: Long)(body: => T): T = { + var reachedTimeout = false + + val timer = new Countdown(timeout, { + interrupt() + reachedTimeout = true + }) + + timer.start + val res = body + timer.finishedRunning + + if (reachedTimeout) { + recoverInterrupt() + } + + res + } + + def setTimeout(timeout: Long): this.type = { + optTimeout = Some(timeout) + this + } + + protected def innerCheck: Option[Boolean] + + override def check: Option[Boolean] = { + optTimeout match { + case Some(to) => + interruptAfter(to) { + innerCheck + } + case None => + innerCheck + } + } + +} diff --git a/src/main/scala/leon/solvers/TimeoutSolverFactory.scala b/src/main/scala/leon/solvers/TimeoutSolverFactory.scala new file mode 100644 index 0000000000000000000000000000000000000000..02bf21c8a41b6c7db50e6fe041206c50478cfd4f --- /dev/null +++ b/src/main/scala/leon/solvers/TimeoutSolverFactory.scala @@ -0,0 +1,13 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers + +import scala.reflect.runtime.universe._ + + +class TimeoutSolverFactory[+S <: TimeoutSolver : TypeTag](sf: SolverFactory[S], to: Long) extends SolverFactory[S] { + override def getNewSolver() = sf.getNewSolver().setTimeout(to) + + override val name = "SFact("+typeOf[S].toString+") with t.o" +} diff --git a/src/main/scala/leon/solvers/combinators/DNFSolver.scala b/src/main/scala/leon/solvers/combinators/DNFSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..507106bd174a276e4bf9780d0bb37105382b722b --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/DNFSolver.scala @@ -0,0 +1,141 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ + +class DNFSolver(val context: LeonContext, + underlyings: SolverFactory[Solver]) extends Solver { + + def name = "DNF("+underlyings.name+")" + + def free {} + + private var theConstraint : Option[Expr] = None + private var theModel : Option[Map[Identifier,Expr]] = None + + import context.reporter._ + + def assertCnstr(expression : Expr) { + if(!theConstraint.isEmpty) { fatalError("Multiple assertCnstr(...).") } + theConstraint = Some(expression) + } + + def check : Option[Boolean] = theConstraint.map { expr => + + val simpleSolver = SimpleSolverAPI(underlyings) + + var result : Option[Boolean] = None + + debugS("Before NNF:\n" + expr) + + val nnfed = nnf(expr, false) + + debugS("After NNF:\n" + nnfed) + + val dnfed = dnf(nnfed) + + debugS("After DNF:\n" + dnfed) + + val candidates : Seq[Expr] = dnfed match { + case Or(es) => es + case elze => Seq(elze) + } + + debugS("# conjuncts : " + candidates.size) + + var done : Boolean = false + + for(candidate <- candidates if !done) { + simpleSolver.solveSAT(candidate) match { + case (Some(false), _) => + result = Some(false) + + case (Some(true), m) => + result = Some(true) + theModel = Some(m) + done = true + + case (None, m) => + result = None + theModel = Some(m) + done = true + } + } + result + } getOrElse { + Some(true) + } + + def getModel : Map[Identifier,Expr] = { + val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) + theModel.getOrElse(Map.empty).filter(p => vs(p._1)) + } + + private def nnf(expr : Expr, flip : Boolean) : Expr = expr match { + case _ : Let | _ : IfExpr => throw new Exception("Can't NNF *everything*, sorry.") + case Not(Implies(l,r)) => nnf(And(l, Not(r)), flip) + case Implies(l, r) => nnf(Or(Not(l), r), flip) + case Not(Iff(l, r)) => nnf(Or(And(l, Not(r)), And(Not(l), r)), flip) + case Iff(l, r) => nnf(Or(And(l, r), And(Not(l), Not(r))), flip) + case And(es) if flip => Or(es.map(e => nnf(e, true))) + case And(es) => And(es.map(e => nnf(e, false))) + case Or(es) if flip => And(es.map(e => nnf(e, true))) + case Or(es) => Or(es.map(e => nnf(e, false))) + case Not(e) if flip => nnf(e, false) + case Not(e) => nnf(e, true) + case LessThan(l,r) if flip => GreaterEquals(l,r) + case GreaterThan(l,r) if flip => LessEquals(l,r) + case LessEquals(l,r) if flip => GreaterThan(l,r) + case GreaterEquals(l,r) if flip => LessThan(l,r) + case elze if flip => Not(elze) + case elze => elze + } + + // fun pushC (And(p,Or(q,r))) = Or(pushC(And(p,q)),pushC(And(p,r))) + // | pushC (And(Or(q,r),p)) = Or(pushC(And(p,q)),pushC(And(p,r))) + // | pushC (And(p,q)) = And(pushC(p),pushC(q)) + // | pushC (Literal(l)) = Literal(l) + // | pushC (Or(p,q)) = Or(pushC(p),pushC(q)) + + private def dnf(expr : Expr) : Expr = expr match { + case And(es) => + val (ors, lits) = es.partition(_.isInstanceOf[Or]) + if(!ors.isEmpty) { + val orHead = ors.head.asInstanceOf[Or] + val orTail = ors.tail + Or(orHead.exprs.map(oe => dnf(And(filterObvious(lits ++ (oe +: orTail)))))) + } else { + expr + } + + case Or(es) => + Or(es.map(dnf(_))) + + case _ => expr + } + + private def filterObvious(exprs : Seq[Expr]) : Seq[Expr] = { + var pos : List[Identifier] = Nil + var neg : List[Identifier] = Nil + + for(e <- exprs) e match { + case Variable(id) => pos = id :: pos + case Not(Variable(id)) => neg = id :: neg + case _ => ; + } + + val both : Set[Identifier] = pos.toSet intersect neg.toSet + if(!both.isEmpty) { + Seq(BooleanLiteral(false)) + } else { + exprs + } + } +} diff --git a/src/main/scala/leon/solvers/combinators/DNFSolverFactory.scala b/src/main/scala/leon/solvers/combinators/DNFSolverFactory.scala deleted file mode 100644 index 6ef2752d900bdc78922a6ffcad83c48581c35ca3..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/combinators/DNFSolverFactory.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers -package combinators - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Trees._ -import purescala.TypeTrees._ -import purescala.TreeOps._ - -import scala.collection.mutable.{Map=>MutableMap} - -class DNFSolverFactory[S <: Solver](val sf : SolverFactory[S]) extends SolverFactory[Solver] { - val description = "DNF around a base solver" - val name = sf.name + "!" - - val context = sf.context - val program = sf.program - - private val thisFactory = this - - override def free() { - sf.free() - } - - override def recoverInterrupt() { - sf.recoverInterrupt() - } - - def getNewSolver() : Solver = { - new Solver { - private var theConstraint : Option[Expr] = None - private var theModel : Option[Map[Identifier,Expr]] = None - - private def fail(because : String) : Nothing = { throw new Exception("Not supported in DNFSolvers : " + because) } - - def push() : Unit = fail("push()") - def pop(lvl : Int = 1) : Unit = fail("pop(lvl)") - - def assertCnstr(expression : Expr) { - if(!theConstraint.isEmpty) { fail("Multiple assertCnstr(...).") } - theConstraint = Some(expression) - } - - def interrupt() { fail("interrupt()") } - - def recoverInterrupt() { fail("recoverInterrupt()") } - - def check : Option[Boolean] = theConstraint.map { expr => - import context.reporter - - val simpleSolver = SimpleSolverAPI(sf) - - var result : Option[Boolean] = None - - def info(msg : String) { reporter.info("In " + thisFactory.name + ": " + msg) } - - // info("Before NNF:\n" + expr) - - val nnfed = nnf(expr, false) - - // info("After NNF:\n" + nnfed) - - val dnfed = dnf(nnfed) - - // info("After DNF:\n" + dnfed) - - val candidates : Seq[Expr] = dnfed match { - case Or(es) => es - case elze => Seq(elze) - } - - info("# conjuncts : " + candidates.size) - - var done : Boolean = false - - for(candidate <- candidates if !done) { - simpleSolver.solveSAT(candidate) match { - case (Some(false), _) => - result = Some(false) - - case (Some(true), m) => - result = Some(true) - theModel = Some(m) - done = true - - case (None, m) => - result = None - theModel = Some(m) - done = true - } - } - result - } getOrElse { - Some(true) - } - - def checkAssumptions(assumptions : Set[Expr]) : Option[Boolean] = { - fail("checkAssumptions(assumptions)") - } - - def getModel : Map[Identifier,Expr] = { - val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) - theModel.getOrElse(Map.empty).filter(p => vs(p._1)) - } - - def getUnsatCore : Set[Expr] = { fail("getUnsatCore") } - } - } - - private def nnf(expr : Expr, flip : Boolean) : Expr = expr match { - case _ : Let | _ : IfExpr => throw new Exception("Can't NNF *everything*, sorry.") - case Not(Implies(l,r)) => nnf(And(l, Not(r)), flip) - case Implies(l, r) => nnf(Or(Not(l), r), flip) - case Not(Iff(l, r)) => nnf(Or(And(l, Not(r)), And(Not(l), r)), flip) - case Iff(l, r) => nnf(Or(And(l, r), And(Not(l), Not(r))), flip) - case And(es) if flip => Or(es.map(e => nnf(e, true))) - case And(es) => And(es.map(e => nnf(e, false))) - case Or(es) if flip => And(es.map(e => nnf(e, true))) - case Or(es) => Or(es.map(e => nnf(e, false))) - case Not(e) if flip => nnf(e, false) - case Not(e) => nnf(e, true) - case LessThan(l,r) if flip => GreaterEquals(l,r) - case GreaterThan(l,r) if flip => LessEquals(l,r) - case LessEquals(l,r) if flip => GreaterThan(l,r) - case GreaterEquals(l,r) if flip => LessThan(l,r) - case elze if flip => Not(elze) - case elze => elze - } - - // fun pushC (And(p,Or(q,r))) = Or(pushC(And(p,q)),pushC(And(p,r))) - // | pushC (And(Or(q,r),p)) = Or(pushC(And(p,q)),pushC(And(p,r))) - // | pushC (And(p,q)) = And(pushC(p),pushC(q)) - // | pushC (Literal(l)) = Literal(l) - // | pushC (Or(p,q)) = Or(pushC(p),pushC(q)) - - private def dnf(expr : Expr) : Expr = expr match { - case And(es) => - val (ors, lits) = es.partition(_.isInstanceOf[Or]) - if(!ors.isEmpty) { - val orHead = ors.head.asInstanceOf[Or] - val orTail = ors.tail - Or(orHead.exprs.map(oe => dnf(And(filterObvious(lits ++ (oe +: orTail)))))) - } else { - expr - } - - case Or(es) => - Or(es.map(dnf(_))) - - case _ => expr - } - - private def filterObvious(exprs : Seq[Expr]) : Seq[Expr] = { - var pos : List[Identifier] = Nil - var neg : List[Identifier] = Nil - - for(e <- exprs) e match { - case Variable(id) => pos = id :: pos - case Not(Variable(id)) => neg = id :: neg - case _ => ; - } - - val both : Set[Identifier] = pos.toSet intersect neg.toSet - if(!both.isEmpty) { - Seq(BooleanLiteral(false)) - } else { - exprs - } - } -} diff --git a/src/main/scala/leon/solvers/combinators/RewritingSolver.scala b/src/main/scala/leon/solvers/combinators/RewritingSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..44e442b26739e7484629874c4ffc656152b39e06 --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/RewritingSolver.scala @@ -0,0 +1,36 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ + +abstract class RewritingSolver[+S <: Solver, T](underlying: S) { + val context = underlying.context + + /** The type T is used to encode any meta information useful, for instance, to reconstruct + * models. */ + def rewriteCnstr(expression : Expr) : (Expr,T) + + def reconstructModel(model : Map[Identifier,Expr], meta : T) : Map[Identifier,Expr] + + private var storedMeta : List[T] = Nil + + def assertCnstr(expression : Expr) { + val (rewritten, meta) = rewriteCnstr(expression) + storedMeta = meta :: storedMeta + underlying.assertCnstr(rewritten) + } + + def getModel : Map[Identifier,Expr] = { + storedMeta match { + case Nil => underlying.getModel + case m :: _ => reconstructModel(underlying.getModel, m) + } + } +} diff --git a/src/main/scala/leon/solvers/combinators/RewritingSolverFactory.scala b/src/main/scala/leon/solvers/combinators/RewritingSolverFactory.scala deleted file mode 100644 index 2a8fbb37a2301dcb9bef9518ecb7d4b5c87f961b..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/combinators/RewritingSolverFactory.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers -package combinators - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Trees._ -import purescala.TypeTrees._ - -/** This is for solvers that operate by rewriting formulas into equisatisfiable ones. - * They are essentially defined by two methods, one for preprocessing of the expressions, - * and one for reconstructing the models. */ -abstract class RewritingSolverFactory[S <: Solver,T](val sf : SolverFactory[S]) extends SolverFactory[Solver] { - val context = sf.context - val program = sf.program - - override def free() { - sf.free() - } - - override def recoverInterrupt() { - sf.recoverInterrupt() - } - - /** The type T is used to encode any meta information useful, for instance, to reconstruct - * models. */ - def rewriteCnstr(expression : Expr) : (Expr,T) - - def reconstructModel(model : Map[Identifier,Expr], meta : T) : Map[Identifier,Expr] - - def getNewSolver() : Solver = { - new Solver { - val underlying : Solver = sf.getNewSolver() - - private def fail(because : String) : Nothing = { - throw new Exception("Not supported in RewritingSolvers : " + because) - } - - def push() : Unit = fail("push()") - def pop(lvl : Int = 1) : Unit = fail("pop(lvl)") - - private var storedMeta : List[T] = Nil - - def assertCnstr(expression : Expr) { - context.reporter.info("Asked to solve this in BAPA<:\n" + expression) - val (rewritten, meta) = rewriteCnstr(expression) - storedMeta = meta :: storedMeta - underlying.assertCnstr(rewritten) - } - - def interrupt() { - underlying.interrupt() - } - - def recoverInterrupt() { - underlying.recoverInterrupt() - } - - def check : Option[Boolean] = { - underlying.check - } - - def checkAssumptions(assumptions : Set[Expr]) : Option[Boolean] = { - fail("checkAssumptions(assumptions)") - } - - def getModel : Map[Identifier,Expr] = { - storedMeta match { - case Nil => fail("reconstructing model without meta-information.") - case m :: _ => reconstructModel(underlying.getModel, m) - } - } - - def getUnsatCore : Set[Expr] = { - fail("getUnsatCore") - } - } - } -} diff --git a/src/main/scala/leon/solvers/combinators/TimeoutSolverFactory.scala b/src/main/scala/leon/solvers/combinators/TimeoutSolverFactory.scala deleted file mode 100644 index dc670bfbe578c467786acd6f9675a708af10608b..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/combinators/TimeoutSolverFactory.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers -package combinators - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Trees._ -import purescala.TypeTrees._ - -class TimeoutSolverFactory[S <: Solver](val sf: SolverFactory[S], val timeoutMs: Long) extends SolverFactory[Solver] { - val description = sf.description + ", with "+timeoutMs+"ms timeout" - val name = sf.name + "+to" - - val context = sf.context - val program = sf.program - - private class Countdown(onTimeout: => Unit) extends Thread { - private var keepRunning = true - private val asMillis : Long = timeoutMs - - override def run : Unit = { - val startTime : Long = System.currentTimeMillis - var exceeded : Boolean = false - - while(!exceeded && keepRunning) { - if(asMillis < (System.currentTimeMillis - startTime)) { - exceeded = true - } - Thread.sleep(10) - } - if(exceeded && keepRunning) { - onTimeout - } - } - - def finishedRunning : Unit = { - keepRunning = false - } - } - - override def free() { - sf.free() - } - - def withTimeout[T](solver: S)(body: => T): T = { - val timer = new Countdown(timeout(solver)) - timer.start - val res = body - timer.finishedRunning - recoverFromTimeout(solver) - res - } - - var reachedTimeout = false - def timeout(solver: S) { - solver.interrupt() - reachedTimeout = true - } - - def recoverFromTimeout(solver: S) { - if (reachedTimeout) { - solver.recoverInterrupt() - reachedTimeout = false - } - } - - def getNewSolver = new Solver { - val solver = sf.getNewSolver - - def push(): Unit = { - solver.push() - } - - def pop(lvl: Int = 1): Unit = { - solver.pop(lvl) - } - - def assertCnstr(expression: Expr): Unit = { - solver.assertCnstr(expression) - } - - def interrupt() { - solver.interrupt() - } - - def recoverInterrupt() { - solver.recoverInterrupt() - } - - def check: Option[Boolean] = { - withTimeout(solver){ - solver.check - } - } - - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - withTimeout(solver){ - solver.checkAssumptions(assumptions) - } - } - - def getModel: Map[Identifier, Expr] = { - solver.getModel - } - - def getUnsatCore: Set[Expr] = { - solver.getUnsatCore - } - } -} diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..ad4eb78bcef56fb46aa3e8f4ec2f2a225a1f2785 --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -0,0 +1,155 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ + +import scala.collection.mutable.{Map=>MutableMap} + +class UnrollingSolver(val context: LeonContext, + underlyings: SolverFactory[Solver], + maxUnrollings: Int = 3) extends Solver { + + private var theConstraint : Option[Expr] = None + private var theModel : Option[Map[Identifier,Expr]] = None + + def name = "Unr("+underlyings.name+")" + + def free {} + + import context.reporter._ + + def assertCnstr(expression : Expr) { + if(!theConstraint.isEmpty) { + fatalError("Multiple assertCnstr(...).") + } + theConstraint = Some(expression) + } + + def check : Option[Boolean] = theConstraint.map { expr => + val simpleSolver = SimpleSolverAPI(underlyings) + + debugS("Check called on " + expr + "...") + + val template = getTemplate(expr) + + val aVar : Identifier = template.activatingBool + var allClauses : Seq[Expr] = Nil + var allBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty + + def fullOpenExpr : Expr = { + // And(Variable(aVar), And(allClauses.reverse)) + // Let's help the poor underlying guy a little bit... + // Note that I keep aVar around, because it may negate one of the blockers, and we can't miss that! + And(Variable(aVar), replace(Map(Variable(aVar) -> BooleanLiteral(true)), And(allClauses.reverse))) + } + + def fullClosedExpr : Expr = { + val blockedVars : Seq[Expr] = allBlockers.toSeq.map(p => Variable(p._1)) + + And( + replace(blockedVars.map(v => (v -> BooleanLiteral(false))).toMap, fullOpenExpr), + And(blockedVars.map(Not(_))) + ) + } + + def unrollOneStep() { + val blockersBefore = allBlockers + + var newClauses : List[Seq[Expr]] = Nil + var newBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty + + for(blocker <- allBlockers.keySet; FunctionInvocation(funDef, args) <- allBlockers(blocker)) { + val (nc, nb) = getTemplate(funDef).instantiate(blocker, args) + newClauses = nc :: newClauses + newBlockers = newBlockers ++ nb + } + + allClauses = newClauses.flatten ++ allClauses + allBlockers = newBlockers + } + + val (nc, nb) = template.instantiate(aVar, template.funDef.args.map(a => Variable(a.id))) + + allClauses = nc.reverse + allBlockers = nb + + var unrollingCount : Int = 0 + var done : Boolean = false + var result : Option[Boolean] = None + + // We're now past the initial step. + while(!done && unrollingCount < maxUnrollings) { + debugS("At lvl : " + unrollingCount) + val closed : Expr = fullClosedExpr + + debugS("Going for SAT with this:\n" + closed) + + simpleSolver.solveSAT(closed) match { + case (Some(false), _) => + val open = fullOpenExpr + debugS("Was UNSAT... Going for UNSAT with this:\n" + open) + simpleSolver.solveSAT(open) match { + case (Some(false), _) => + debugS("Was UNSAT... Done !") + done = true + result = Some(false) + + case _ => + debugS("Was SAT or UNKNOWN. Let's unroll !") + unrollingCount += 1 + unrollOneStep() + } + + case (Some(true), model) => + debugS("WAS SAT ! We're DONE !") + done = true + result = Some(true) + theModel = Some(model) + + case (None, model) => + debugS("WAS UNKNOWN ! We're DONE !") + done = true + result = Some(true) + theModel = Some(model) + } + } + result + + } getOrElse { + Some(true) + } + + def getModel : Map[Identifier,Expr] = { + val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) + theModel.getOrElse(Map.empty).filter(p => vs(p._1)) + } + + private val funDefTemplateCache : MutableMap[FunDef, FunctionTemplate] = MutableMap.empty + private val exprTemplateCache : MutableMap[Expr, FunctionTemplate] = MutableMap.empty + + private def getTemplate(funDef: FunDef): FunctionTemplate = { + funDefTemplateCache.getOrElse(funDef, { + val res = FunctionTemplate.mkTemplate(funDef, true) + funDefTemplateCache += funDef -> res + res + }) + } + + private def getTemplate(body: Expr): FunctionTemplate = { + exprTemplateCache.getOrElse(body, { + val fakeFunDef = new FunDef(FreshIdentifier("fake", true), body.getType, variablesOf(body).toSeq.map(id => VarDecl(id, id.getType))) + fakeFunDef.body = Some(body) + + val res = FunctionTemplate.mkTemplate(fakeFunDef, false) + exprTemplateCache += body -> res + res + }) + } +} diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolverFactory.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolverFactory.scala deleted file mode 100644 index b2f8862f6a15d0cea2237fe9276c99730da2f7f5..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolverFactory.scala +++ /dev/null @@ -1,190 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package solvers -package combinators - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Trees._ -import purescala.TypeTrees._ -import purescala.TreeOps._ - -import scala.collection.mutable.{Map=>MutableMap} - -class UnrollingSolverFactory[S <: Solver](val sf : SolverFactory[S]) extends SolverFactory[Solver] { - val description = "Unrolling loop around a base solver." - val name = sf.name + "*" - - val context = sf.context - val program = sf.program - - // Yes, a hardcoded constant. Sue me. - val MAXUNROLLINGS : Int = 3 - - private val thisFactory = this - - override def free() { - sf.free() - } - - override def recoverInterrupt() { - sf.recoverInterrupt() - } - - def getNewSolver() : Solver = { - new Solver { - private var theConstraint : Option[Expr] = None - private var theModel : Option[Map[Identifier,Expr]] = None - - private def fail(because : String) : Nothing = { - throw new Exception("Not supported in UnrollingSolvers : " + because) - } - - def push() : Unit = fail("push()") - def pop(lvl : Int = 1) : Unit = fail("pop(lvl)") - - def assertCnstr(expression : Expr) { - if(!theConstraint.isEmpty) { - fail("Multiple assertCnstr(...).") - } - theConstraint = Some(expression) - } - - def interrupt() { fail("interrupt()") } - - def recoverInterrupt() { fail("recoverInterrupt()") } - - def check : Option[Boolean] = theConstraint.map { expr => - import context.reporter - - val simpleSolver = SimpleSolverAPI(sf) - - def info(msg : String) { reporter.info("In " + thisFactory.name + ": " + msg) } - - info("Check called on " + expr + "...") - - val template = getTemplate(expr) - - val aVar : Identifier = template.activatingBool - var allClauses : Seq[Expr] = Nil - var allBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty - - def fullOpenExpr : Expr = { - // And(Variable(aVar), And(allClauses.reverse)) - // Let's help the poor underlying guy a little bit... - // Note that I keep aVar around, because it may negate one of the blockers, and we can't miss that! - And(Variable(aVar), replace(Map(Variable(aVar) -> BooleanLiteral(true)), And(allClauses.reverse))) - } - - def fullClosedExpr : Expr = { - val blockedVars : Seq[Expr] = allBlockers.toSeq.map(p => Variable(p._1)) - - And( - replace(blockedVars.map(v => (v -> BooleanLiteral(false))).toMap, fullOpenExpr), - And(blockedVars.map(Not(_))) - ) - } - - def unrollOneStep() { - val blockersBefore = allBlockers - - var newClauses : List[Seq[Expr]] = Nil - var newBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty - - for(blocker <- allBlockers.keySet; FunctionInvocation(funDef, args) <- allBlockers(blocker)) { - val (nc, nb) = getTemplate(funDef).instantiate(blocker, args) - newClauses = nc :: newClauses - newBlockers = newBlockers ++ nb - } - - allClauses = newClauses.flatten ++ allClauses - allBlockers = newBlockers - } - - val (nc, nb) = template.instantiate(aVar, template.funDef.args.map(a => Variable(a.id))) - - allClauses = nc.reverse - allBlockers = nb - - var unrollingCount : Int = 0 - var done : Boolean = false - var result : Option[Boolean] = None - - // We're now past the initial step. - while(!done && unrollingCount < MAXUNROLLINGS) { - info("At lvl : " + unrollingCount) - val closed : Expr = fullClosedExpr - - info("Going for SAT with this:\n" + closed) - - simpleSolver.solveSAT(closed) match { - case (Some(false), _) => - val open = fullOpenExpr - info("Was UNSAT... Going for UNSAT with this:\n" + open) - simpleSolver.solveSAT(open) match { - case (Some(false), _) => - info("Was UNSAT... Done !") - done = true - result = Some(false) - - case _ => - info("Was SAT or UNKNOWN. Let's unroll !") - unrollingCount += 1 - unrollOneStep() - } - - case (Some(true), model) => - info("WAS SAT ! We're DONE !") - done = true - result = Some(true) - theModel = Some(model) - - case (None, model) => - info("WAS UNKNOWN ! We're DONE !") - done = true - result = Some(true) - theModel = Some(model) - } - } - result - - } getOrElse { - Some(true) - } - - def checkAssumptions(assumptions : Set[Expr]) : Option[Boolean] = { - fail("checkAssumptions(assumptions)") - } - - def getModel : Map[Identifier,Expr] = { - val vs : Set[Identifier] = theConstraint.map(variablesOf(_)).getOrElse(Set.empty) - theModel.getOrElse(Map.empty).filter(p => vs(p._1)) - } - - def getUnsatCore : Set[Expr] = { fail("getUnsatCore") } - } - } - - private val funDefTemplateCache : MutableMap[FunDef, FunctionTemplate] = MutableMap.empty - private val exprTemplateCache : MutableMap[Expr, FunctionTemplate] = MutableMap.empty - - private def getTemplate(funDef: FunDef): FunctionTemplate = { - funDefTemplateCache.getOrElse(funDef, { - val res = FunctionTemplate.mkTemplate(funDef, true) - funDefTemplateCache += funDef -> res - res - }) - } - - private def getTemplate(body: Expr): FunctionTemplate = { - exprTemplateCache.getOrElse(body, { - val fakeFunDef = new FunDef(FreshIdentifier("fake", true), body.getType, variablesOf(body).toSeq.map(id => VarDecl(id, id.getType))) - fakeFunDef.body = Some(body) - - val res = FunctionTemplate.mkTemplate(fakeFunDef, false) - exprTemplateCache += body -> res - res - }) - } -} diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 3bc56101590feed24655d46d55296605d7a8b3c2..62ee881093fcf5965ebed21bccf000334aeb0510 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -19,7 +19,12 @@ import scala.collection.mutable.{Set => MutableSet} // This is just to factor out the things that are common in "classes that deal // with a Z3 instance" -trait AbstractZ3Solver extends SolverFactory[Solver] { +trait AbstractZ3Solver + extends Solver + with TimeoutAssumptionSolver + with AssumptionSolver + with IncrementalSolver { + val context : LeonContext val program : Program @@ -45,20 +50,25 @@ trait AbstractZ3Solver extends SolverFactory[Solver] { override def free() { freed = true - super.free() if (z3 ne null) { z3.delete() z3 = null; } } + protected[z3] var interrupted = false; + override def interrupt() { - super.interrupt() + interrupted = true if(z3 ne null) { z3.interrupt } } + override def recoverInterrupt() { + interrupted = false + } + protected[leon] def prepareFunctions : Unit protected[leon] def functionDefToDecl(funDef: FunDef) : Z3FuncDecl protected[leon] def functionDeclToDef(decl: Z3FuncDecl) : FunDef @@ -455,13 +465,7 @@ trait AbstractZ3Solver extends SolverFactory[Solver] { case IntLiteral(v) => z3.mkInt(v, intSort) case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() case UnitLiteral => unitValue - case Equals(l, r) => { - //if(l.getType != r.getType) - // println("Warning : wrong types in equality for " + l + " == " + r) - z3.mkEq(rec( l ), rec( r ) ) - } - - //case Equals(l, r) => z3.mkEq(rec(l), rec(r)) + case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) case Minus(l, r) => z3.mkSub(rec(l), rec(r)) case Times(l, r) => z3.mkMul(rec(l), rec(r)) diff --git a/src/main/scala/leon/solvers/z3/FairZ3SolverFactory.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala similarity index 58% rename from src/main/scala/leon/solvers/z3/FairZ3SolverFactory.scala rename to src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 7f50b3e1ced575831178936d18a94b9dfac58faf..63abcb42aee7a04212fcb97e5dea7175038d715e 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3SolverFactory.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -7,7 +7,7 @@ import leon.utils._ import z3.scala._ -import leon.solvers.Solver +import leon.solvers.{Solver, IncrementalSolver} import purescala.Common._ import purescala.Definitions._ @@ -23,7 +23,7 @@ import termination._ import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable.{Set => MutableSet} -class FairZ3SolverFactory(val context : LeonContext, val program: Program) +class FairZ3Solver(val context : LeonContext, val program: Program) extends AbstractZ3Solver with Z3ModelReconstruction with FairZ3Component { @@ -327,263 +327,240 @@ class FairZ3SolverFactory(val context : LeonContext, val program: Program) } } - def getNewSolver = new Solver { - private val evaluator = enclosing.evaluator - private val feelingLucky = enclosing.feelingLucky - private val checkModels = enclosing.checkModels - private val useCodeGen = enclosing.useCodeGen + initZ3 - initZ3 + val solver = z3.mkSolver - val solver = z3.mkSolver - - for(funDef <- program.definedFunctions) { - if (funDef.annotations.contains("axiomatize") && !axiomatizedFunctions(funDef)) { - reporter.warning("Function " + funDef.id + " was marked for axiomatization but could not be handled.") - } + for(funDef <- program.definedFunctions) { + if (funDef.annotations.contains("axiomatize") && !axiomatizedFunctions(funDef)) { + reporter.warning("Function " + funDef.id + " was marked for axiomatization but could not be handled.") } + } - private var varsInVC = Set[Identifier]() - - private var frameExpressions = List[List[Expr]](Nil) - - val unrollingBank = new UnrollingBank() + private var varsInVC = Set[Identifier]() - def push() { - solver.push() - unrollingBank.push() - frameExpressions = Nil :: frameExpressions - } + private var frameExpressions = List[List[Expr]](Nil) - override def recoverInterrupt() { - enclosing.recoverInterrupt() - } + val unrollingBank = new UnrollingBank() - override def interrupt() { - enclosing.interrupt() - } + def push() { + solver.push() + unrollingBank.push() + frameExpressions = Nil :: frameExpressions + } - def pop(lvl: Int = 1) { - solver.pop(lvl) - unrollingBank.pop(lvl) - frameExpressions = frameExpressions.drop(lvl) - } + def pop(lvl: Int = 1) { + solver.pop(lvl) + unrollingBank.pop(lvl) + frameExpressions = frameExpressions.drop(lvl) + } - def check: Option[Boolean] = { - fairCheck(Set()) - } + def innerCheck: Option[Boolean] = { + fairCheck(Set()) + } - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - fairCheck(assumptions) - } + def innerCheckAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + fairCheck(assumptions) + } - var foundDefinitiveAnswer = false - var definitiveAnswer : Option[Boolean] = None - var definitiveModel : Map[Identifier,Expr] = Map.empty - var definitiveCore : Set[Expr] = Set.empty + var foundDefinitiveAnswer = false + var definitiveAnswer : Option[Boolean] = None + var definitiveModel : Map[Identifier,Expr] = Map.empty + var definitiveCore : Set[Expr] = Set.empty - def assertCnstr(expression: Expr) { - varsInVC ++= variablesOf(expression) + def assertCnstr(expression: Expr) { + varsInVC ++= variablesOf(expression) - frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail + frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail - val newClauses = unrollingBank.scanForNewTemplates(expression) + val newClauses = unrollingBank.scanForNewTemplates(expression) - for (cl <- newClauses) { - solver.assertCnstr(cl) - } + for (cl <- newClauses) { + solver.assertCnstr(cl) } + } - def getModel = { - definitiveModel - } + def getModel = { + definitiveModel + } - def getUnsatCore = { - definitiveCore - } + def getUnsatCore = { + definitiveCore + } - def fairCheck(assumptions: Set[Expr]): Option[Boolean] = { - foundDefinitiveAnswer = false + def fairCheck(assumptions: Set[Expr]): Option[Boolean] = { + foundDefinitiveAnswer = false - def entireFormula = And(assumptions.toSeq ++ frameExpressions.flatten) + def entireFormula = And(assumptions.toSeq ++ frameExpressions.flatten) - def foundAnswer(answer : Option[Boolean], model : Map[Identifier,Expr] = Map.empty, core: Set[Expr] = Set.empty) : Unit = { - foundDefinitiveAnswer = true - definitiveAnswer = answer - definitiveModel = model - definitiveCore = core - } + def foundAnswer(answer : Option[Boolean], model : Map[Identifier,Expr] = Map.empty, core: Set[Expr] = Set.empty) : Unit = { + foundDefinitiveAnswer = true + definitiveAnswer = answer + definitiveModel = model + definitiveCore = core + } - // these are the optional sequence of assumption literals - val assumptionsAsZ3: Seq[Z3AST] = assumptions.flatMap(toZ3Formula(_)).toSeq - val assumptionsAsZ3Set: Set[Z3AST] = assumptionsAsZ3.toSet + // these are the optional sequence of assumption literals + val assumptionsAsZ3: Seq[Z3AST] = assumptions.flatMap(toZ3Formula(_)).toSeq + val assumptionsAsZ3Set: Set[Z3AST] = assumptionsAsZ3.toSet - def z3CoreToCore(core: Seq[Z3AST]): Set[Expr] = { - core.filter(assumptionsAsZ3Set).map(ast => fromZ3Formula(null, ast, None) match { - case n @ Not(Variable(_)) => n - case v @ Variable(_) => v - case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") - }).toSet - } + def z3CoreToCore(core: Seq[Z3AST]): Set[Expr] = { + core.filter(assumptionsAsZ3Set).map(ast => fromZ3Formula(null, ast, None) match { + case n @ Not(Variable(_)) => n + case v @ Variable(_) => v + case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") + }).toSet + } - while(!foundDefinitiveAnswer && !interrupted) { + while(!foundDefinitiveAnswer && !interrupted) { - //val blockingSetAsZ3 : Seq[Z3AST] = blockingSet.toSeq.map(toZ3Formula(_).get) - // println("Blocking set : " + blockingSet) + //val blockingSetAsZ3 : Seq[Z3AST] = blockingSet.toSeq.map(toZ3Formula(_).get) + // println("Blocking set : " + blockingSet) - reporter.debug(" - Running Z3 search...") + reporter.debug(" - Running Z3 search...") - // reporter.debug("Searching in:\n"+solver.getAssertions.toSeq.mkString("\nAND\n")) - // reporter.debug("Unroll. Assumptions:\n"+unrollingBank.z3CurrentZ3Blockers.mkString(" && ")) - // reporter.debug("Userland Assumptions:\n"+assumptionsAsZ3.mkString(" && ")) + // reporter.debug("Searching in:\n"+solver.getAssertions.toSeq.mkString("\nAND\n")) + // reporter.debug("Unroll. Assumptions:\n"+unrollingBank.z3CurrentZ3Blockers.mkString(" && ")) + // reporter.debug("Userland Assumptions:\n"+assumptionsAsZ3.mkString(" && ")) - solver.push() // FIXME: remove when z3 bug is fixed - val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.z3CurrentZ3Blockers) :_*) - solver.pop() // FIXME: remove when z3 bug is fixed + solver.push() // FIXME: remove when z3 bug is fixed + val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.z3CurrentZ3Blockers) :_*) + solver.pop() // FIXME: remove when z3 bug is fixed - reporter.debug(" - Finished search with blocked literals") + reporter.debug(" - Finished search with blocked literals") - res match { - case None => - // reporter.warning("Z3 doesn't know because: " + z3.getSearchFailure.message) - reporter.warning("Z3 doesn't know because ??") - foundAnswer(None) + res match { + case None => + // reporter.warning("Z3 doesn't know because: " + z3.getSearchFailure.message) + reporter.warning("Z3 doesn't know because ??") + foundAnswer(None) - case Some(true) => // SAT + case Some(true) => // SAT - val z3model = solver.getModel + val z3model = solver.getModel - if (this.checkModels) { - val (isValid, model) = validateModel(z3model, entireFormula, varsInVC, silenceErrors = false) + if (this.checkModels) { + val (isValid, model) = validateModel(z3model, entireFormula, varsInVC, silenceErrors = false) - if (isValid) { - foundAnswer(Some(true), model) - } else { - reporter.error("Something went wrong. The model should have been valid, yet we got this : ") - reporter.error(model) - foundAnswer(None, model) - } + if (isValid) { + foundAnswer(Some(true), model) } else { - val model = modelToMap(z3model, varsInVC) + reporter.error("Something went wrong. The model should have been valid, yet we got this : ") + reporter.error(model) + foundAnswer(None, model) + } + } else { + val model = modelToMap(z3model, varsInVC) - //lazy val modelAsString = model.toList.map(p => p._1 + " -> " + p._2).mkString("\n") - //reporter.debug("- Found a model:") - //reporter.debug(modelAsString) + //lazy val modelAsString = model.toList.map(p => p._1 + " -> " + p._2).mkString("\n") + //reporter.debug("- Found a model:") + //reporter.debug(modelAsString) - foundAnswer(Some(true), model) - } + foundAnswer(Some(true), model) + } - case Some(false) if !unrollingBank.canUnroll => + case Some(false) if !unrollingBank.canUnroll => - val core = z3CoreToCore(solver.getUnsatCore) + val core = z3CoreToCore(solver.getUnsatCore) - foundAnswer(Some(false), core = core) + foundAnswer(Some(false), core = core) - // This branch is both for with and without unsat cores. The - // distinction is made inside. - case Some(false) => + // This branch is both for with and without unsat cores. The + // distinction is made inside. + case Some(false) => - val z3Core = solver.getUnsatCore + val z3Core = solver.getUnsatCore - def coreElemToBlocker(c: Z3AST): (Z3AST, Boolean) = { - z3.getASTKind(c) match { - case Z3AppAST(decl, args) => - z3.getDeclKind(decl) match { - case Z3DeclKind.OpNot => - (args(0), true) - case Z3DeclKind.OpUninterpreted => - (c, false) - } + def coreElemToBlocker(c: Z3AST): (Z3AST, Boolean) = { + z3.getASTKind(c) match { + case Z3AppAST(decl, args) => + z3.getDeclKind(decl) match { + case Z3DeclKind.OpNot => + (args(0), true) + case Z3DeclKind.OpUninterpreted => + (c, false) + } - case ast => - (c, false) - } + case ast => + (c, false) } + } - if (unrollUnsatCores) { - unrollingBank.decreaseAllGenerations() - - for (c <- solver.getUnsatCore) { - val (z3ast, pol) = coreElemToBlocker(c) - assert(pol == true) + if (unrollUnsatCores) { + unrollingBank.decreaseAllGenerations() - unrollingBank.promoteBlocker(z3ast) - } + for (c <- solver.getUnsatCore) { + val (z3ast, pol) = coreElemToBlocker(c) + assert(pol == true) + unrollingBank.promoteBlocker(z3ast) } - //debug("UNSAT BECAUSE: "+solver.getUnsatCore.mkString("\n AND \n")) - //debug("UNSAT BECAUSE: "+core.mkString(" AND ")) + } - if (!interrupted) { - if (this.feelingLucky) { - // we need the model to perform the additional test - reporter.debug(" - Running search without blocked literals (w/ lucky test)") - } else { - reporter.debug(" - Running search without blocked literals (w/o lucky test)") - } + //debug("UNSAT BECAUSE: "+solver.getUnsatCore.mkString("\n AND \n")) + //debug("UNSAT BECAUSE: "+core.mkString(" AND ")) + + if (!interrupted) { + if (this.feelingLucky) { + // we need the model to perform the additional test + reporter.debug(" - Running search without blocked literals (w/ lucky test)") + } else { + reporter.debug(" - Running search without blocked literals (w/o lucky test)") + } - solver.push() // FIXME: remove when z3 bug is fixed - val res2 = solver.checkAssumptions(assumptionsAsZ3 : _*) - solver.pop() // FIXME: remove when z3 bug is fixed - - res2 match { - case Some(false) => - //reporter.debug("UNSAT WITHOUT Blockers") - foundAnswer(Some(false), core = z3CoreToCore(solver.getUnsatCore)) - case Some(true) => - //reporter.debug("SAT WITHOUT Blockers") - if (this.feelingLucky && !interrupted) { - // we might have been lucky :D - val (wereWeLucky, cleanModel) = validateModel(solver.getModel, entireFormula, varsInVC, silenceErrors = true) - - if(wereWeLucky) { - foundAnswer(Some(true), cleanModel) - } + solver.push() // FIXME: remove when z3 bug is fixed + val res2 = solver.checkAssumptions(assumptionsAsZ3 : _*) + solver.pop() // FIXME: remove when z3 bug is fixed + + res2 match { + case Some(false) => + //reporter.debug("UNSAT WITHOUT Blockers") + foundAnswer(Some(false), core = z3CoreToCore(solver.getUnsatCore)) + case Some(true) => + //reporter.debug("SAT WITHOUT Blockers") + if (this.feelingLucky && !interrupted) { + // we might have been lucky :D + val (wereWeLucky, cleanModel) = validateModel(solver.getModel, entireFormula, varsInVC, silenceErrors = true) + + if(wereWeLucky) { + foundAnswer(Some(true), cleanModel) } + } - case None => - foundAnswer(None) - } + case None => + foundAnswer(None) } + } - if(interrupted) { - foundAnswer(None) - } + if(interrupted) { + foundAnswer(None) + } - if(!foundDefinitiveAnswer) { - reporter.debug("- We need to keep going.") + if(!foundDefinitiveAnswer) { + reporter.debug("- We need to keep going.") - val toRelease = unrollingBank.getZ3BlockersToUnlock + val toRelease = unrollingBank.getZ3BlockersToUnlock - reporter.debug(" - more unrollings") + reporter.debug(" - more unrollings") - for(id <- toRelease) { - val newClauses = unrollingBank.unlock(id) + for(id <- toRelease) { + val newClauses = unrollingBank.unlock(id) - for(ncl <- newClauses) { - solver.assertCnstr(ncl) - } + for(ncl <- newClauses) { + solver.assertCnstr(ncl) } - - reporter.debug(" - finished unrolling") } - } - } - //reporter.debug(" !! DONE !! ") - - if(interrupted) { - None - } else { - definitiveAnswer + reporter.debug(" - finished unrolling") + } } } - if (program == null) { - reporter.error("Z3 Solver was not initialized with a PureScala Program.") + if(interrupted) { None + } else { + definitiveAnswer } } - } diff --git a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala index 4e97ddadadaa2ff4dc8dd856b3135457ca700710..5a3f2e2f4700dd0f118184377c63fdfa0c3313f0 100644 --- a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala +++ b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala @@ -19,7 +19,7 @@ import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap} case class Z3FunctionInvocation(funDef: FunDef, args: Seq[Z3AST]) class FunctionTemplate private( - solver: FairZ3SolverFactory, + solver: FairZ3Solver, val funDef : FunDef, activatingBool : Identifier, condVars : Set[Identifier], @@ -152,7 +152,7 @@ class FunctionTemplate private( object FunctionTemplate { val splitAndOrImplies = false - def mkTemplate(solver: FairZ3SolverFactory, funDef: FunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { + def mkTemplate(solver: FairZ3Solver, funDef: FunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { val condVars : MutableSet[Identifier] = MutableSet.empty val exprVars : MutableSet[Identifier] = MutableSet.empty diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3SolverFactory.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala similarity index 56% rename from src/main/scala/leon/solvers/z3/UninterpretedZ3SolverFactory.scala rename to src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala index 6589e3b504cbd6a304cb89a20c22b8bb5390a785..78b75ee5eefdf13907ea1e7d068c9341bc0744fc 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3SolverFactory.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala @@ -21,12 +21,10 @@ import purescala.TypeTrees._ * - otherwise it returns UNKNOWN * Results should come back very quickly. */ -class UninterpretedZ3SolverFactory(val context : LeonContext, val program: Program) +class UninterpretedZ3Solver(val context : LeonContext, val program: Program) extends AbstractZ3Solver with Z3ModelReconstruction { - enclosing => - val name = "Z3-u" val description = "Uninterpreted Z3 Solver" @@ -57,65 +55,56 @@ class UninterpretedZ3SolverFactory(val context : LeonContext, val program: Progr protected[leon] def functionDeclToDef(decl: Z3FuncDecl) : FunDef = reverseFunctionMap(decl) protected[leon] def isKnownDecl(decl: Z3FuncDecl) : Boolean = reverseFunctionMap.isDefinedAt(decl) - def getNewSolver = new Solver { - initZ3 - - val solver = z3.mkSolver + initZ3 - def push() { - solver.push - } + val solver = z3.mkSolver - def interrupt() { - enclosing.interrupt() - } + def push() { + solver.push + } - def recoverInterrupt() { - enclosing.recoverInterrupt() - } - def pop(lvl: Int = 1) { - solver.pop(lvl) - } + def pop(lvl: Int = 1) { + solver.pop(lvl) + } - private var variables = Set[Identifier]() - private var containsFunCalls = false + private var variables = Set[Identifier]() + private var containsFunCalls = false - def assertCnstr(expression: Expr) { - variables ++= variablesOf(expression) - containsFunCalls ||= containsFunctionCalls(expression) - solver.assertCnstr(toZ3Formula(expression).get) - } + def assertCnstr(expression: Expr) { + variables ++= variablesOf(expression) + containsFunCalls ||= containsFunctionCalls(expression) + solver.assertCnstr(toZ3Formula(expression).get) + } - def check: Option[Boolean] = { - solver.check match { - case Some(true) => - if (containsFunCalls) { - None - } else { - Some(true) - } - - case r => - r - } + def innerCheck: Option[Boolean] = { + solver.check match { + case Some(true) => + if (containsFunCalls) { + None + } else { + Some(true) + } + + case r => + r } + } - def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - variables ++= assumptions.flatMap(variablesOf(_)) - solver.checkAssumptions(assumptions.toSeq.map(toZ3Formula(_).get) : _*) - } + def innerCheckAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + variables ++= assumptions.flatMap(variablesOf(_)) + solver.checkAssumptions(assumptions.toSeq.map(toZ3Formula(_).get) : _*) + } - def getModel = { - modelToMap(solver.getModel, variables) - } + def getModel = { + modelToMap(solver.getModel, variables) + } - def getUnsatCore = { - solver.getUnsatCore.map(ast => fromZ3Formula(null, ast, None) match { - case n @ Not(Variable(_)) => n - case v @ Variable(_) => v - case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") - }).toSet - } + def getUnsatCore = { + solver.getUnsatCore.map(ast => fromZ3Formula(null, ast, None) match { + case n @ Not(Variable(_)) => n + case v @ Variable(_) => v + case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") + }).toSet } } diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 69d83cd556ff49daca856bd573d01b82ca59afdf..782a72049d84cfc0c8c09e8ab477e983434fd1ac 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -8,6 +8,7 @@ import purescala.TypeTrees.{TypeTree,TupleType} import purescala.Definitions._ import purescala.TreeOps._ import solvers.z3._ +import solvers._ // Defines a synthesis solution of the form: // ⟨ P | T ⟩ @@ -30,7 +31,7 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) { defs.foldLeft(term){ case (t, fd) => LetDef(fd, t) } def toSimplifiedExpr(ctx: LeonContext, p: Program): Expr = { - val uninterpretedZ3 = new UninterpretedZ3SolverFactory(ctx, p) + val uninterpretedZ3 = SolverFactory(() => new UninterpretedZ3Solver(ctx, p)) val simplifiers = List[Expr => Expr]( simplifyTautologies(uninterpretedZ3)(_), diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala index 4269f4175cbda5a8d54f721d952cb76308a4faf5..adf89f1cc3d441af80e10755adf744e0d3fc6a7e 100644 --- a/src/main/scala/leon/synthesis/SynthesisContext.scala +++ b/src/main/scala/leon/synthesis/SynthesisContext.scala @@ -20,17 +20,22 @@ case class SynthesisContext( reporter: Reporter ) { - def solverFactory: SolverFactory[Solver] = { - new FairZ3SolverFactory(context, program) + def newSolver: SynthesisContext.SynthesisSolver = { + new FairZ3Solver(context, program) } - def fastSolverFactory: SolverFactory[Solver] = { - new UninterpretedZ3SolverFactory(context, program) + def newFastSolver: SynthesisContext.SynthesisSolver = { + new UninterpretedZ3Solver(context, program) } + val solverFactory = SolverFactory(() => newSolver) + val fastSolverFactory = SolverFactory(() => newFastSolver) + } object SynthesisContext { + type SynthesisSolver = TimeoutAssumptionSolver with IncrementalSolver + def fromSynthesizer(synth: Synthesizer) = { SynthesisContext( synth.context, diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index b138fcc8aac6b1bbcdc56a934aedbba721a6772b..78e5f20a251ab3a4ed98248e7f772ebddf3a9379 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -81,10 +81,10 @@ class Synthesizer(val context : LeonContext, val (npr, fds) = solutionToProgram(sol) - val tsolver = new TimeoutSolverFactory(new FairZ3SolverFactory(context, npr), timeoutMs) + val solverf = SolverFactory(() => new FairZ3Solver(context, npr).setTimeout(timeoutMs)) val vcs = generateVerificationConditions(reporter, npr, fds.map(_.id.name)) - val vctx = VerificationContext(context, Seq(tsolver), context.reporter) + val vctx = VerificationContext(context, Seq(solverf), context.reporter) val vcreport = checkVerificationConditions(vctx, vcs) if (vcreport.totalValid == vcreport.totalConditions) { diff --git a/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala b/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala index beea692e621f44d5f553052786ccaead7e1a2f0f..dc557e56fdea53937638f4a2a9e420223fd2a87c 100755 --- a/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala +++ b/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala @@ -37,13 +37,14 @@ import verification._ import SynthesisInfo._ import SynthesisInfo.Action._ +import SynthesisContext.SynthesisSolver // enable postfix operations import scala.language.postfixOps class SynthesizerForRuleExamples( // some synthesis instance information - val mainSolver: SolverFactory[Solver], + val mainSolver: SolverFactory[SynthesisSolver], val program: Program, val desiredType: LeonType, val holeFunDef: FunDef, diff --git a/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala b/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala index 49cacd1f311df3aee16e1b5e422f964f9b1624e8..e6a486fa70cc9d8446d0d1500be4820485c23ea8 100755 --- a/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala +++ b/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala @@ -24,7 +24,7 @@ case object ConditionAbductionSynthesisTwoPhase extends Rule("Condition abductio List(new RuleInstantiation(p, this, SolutionBuilder.none, "Condition abduction") { def apply(sctx: SynthesisContext): RuleApplicationResult = { try { - val solver = sctx.solverFactory.withTimeout(500L) + val solver = SolverFactory(() => sctx.newSolver.setTimeout(500L)) val program = sctx.program val reporter = sctx.reporter diff --git a/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala b/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala index b9922e25505f024ebdce30d4f0457505ce72e6e0..cf0625a5ab1e57fffb5cb62668e6def3cce4f8cf 100644 --- a/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala +++ b/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala @@ -13,7 +13,7 @@ import purescala.Definitions._ import _root_.insynth.util.logging._ -abstract class AbstractVerifier(solverf: SolverFactory[Solver], p: Problem, synthInfo: SynthesisInfo) +abstract class AbstractVerifier(solverf: SolverFactory[Solver with IncrementalSolver with TimeoutSolver], p: Problem, synthInfo: SynthesisInfo) extends HasLogger { val solver = solverf.getNewSolver diff --git a/src/main/scala/leon/synthesis/condabd/verification/RelaxedVerifier.scala b/src/main/scala/leon/synthesis/condabd/verification/RelaxedVerifier.scala index 2dfa8784349995a2f2dfba43ef1993ee683dc75e..2b382091591790e554047ffa815807ee7619fd72 100644 --- a/src/main/scala/leon/synthesis/condabd/verification/RelaxedVerifier.scala +++ b/src/main/scala/leon/synthesis/condabd/verification/RelaxedVerifier.scala @@ -10,10 +10,11 @@ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ +import SynthesisContext.SynthesisSolver import _root_.insynth.util.logging._ -class RelaxedVerifier(solverf: SolverFactory[Solver], p: Problem, synthInfo: SynthesisInfo = new SynthesisInfo) +class RelaxedVerifier(solverf: SolverFactory[SynthesisSolver], p: Problem, synthInfo: SynthesisInfo = new SynthesisInfo) extends AbstractVerifier(solverf, p, synthInfo) with HasLogger { var _isTimeoutUsed = false diff --git a/src/main/scala/leon/synthesis/condabd/verification/Verifier.scala b/src/main/scala/leon/synthesis/condabd/verification/Verifier.scala index c1c74a38a86a1fd6ccd6619422de6606ec07a5a1..54ff321540453aaf32ef03b305740c833c3a9120 100644 --- a/src/main/scala/leon/synthesis/condabd/verification/Verifier.scala +++ b/src/main/scala/leon/synthesis/condabd/verification/Verifier.scala @@ -10,10 +10,11 @@ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ +import SynthesisContext.SynthesisSolver import _root_.insynth.util.logging._ -class Verifier(solverf: SolverFactory[Solver], p: Problem, synthInfo: SynthesisInfo = new SynthesisInfo) +class Verifier(solverf: SolverFactory[SynthesisSolver], p: Problem, synthInfo: SynthesisInfo = new SynthesisInfo) extends AbstractVerifier(solverf, p, synthInfo) with HasLogger { import SynthesisInfo.Action._ diff --git a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala b/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala index 35cad0762f3da847e253ed41e71bdd6b27937759..724191dd0e45624b7f2a18aa4e4bb5eb4f0b9d90 100644 --- a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala +++ b/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala @@ -14,14 +14,10 @@ import purescala.Definitions._ case object ADTInduction extends Rule("ADT Induction") with Heuristic { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val tsolver = sctx.solverFactory.withTimeout(500L) - val candidates = p.as.collect { - case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(tsolver)(p.pc, origId) => (origId, cd) + case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(sctx.solverFactory)(p.pc, origId) => (origId, cd) } - tsolver.free() - val instances = for (candidate <- candidates) yield { val (origId, cd) = candidate val oas = p.as.filterNot(_ == origId) diff --git a/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala b/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala index e11f09448fa31606b2cd4030d7eef35e12a073d4..0305450d8d53a53d5e26fb1bfd4f013f1074af79 100644 --- a/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala +++ b/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala @@ -14,13 +14,10 @@ import purescala.Definitions._ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val tsolver = sctx.solverFactory.withTimeout(500L) val candidates = p.as.collect { - case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(tsolver)(p.pc, origId) => (origId, cd) + case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(sctx.solverFactory)(p.pc, origId) => (origId, cd) } - tsolver.free() - val instances = for (candidate <- candidates) yield { val (origId, cd) = candidate val oas = p.as.filterNot(_ == origId) diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 261b371433db98fca18afe3892b44be94c528774..93f584ccda6279c97a9c27586bfd86aa98ebbe45 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -14,7 +14,7 @@ import solvers._ case object ADTSplit extends Rule("ADT Split.") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation]= { - val solver = SimpleSolverAPI(sctx.solverFactory.withTimeout(200L)) + val solver = SimpleSolverAPI(new TimeoutSolverFactory(sctx.solverFactory, 200L)) val candidates = p.as.collect { case IsTyped(id, AbstractClassType(cd)) => @@ -46,8 +46,6 @@ case object ADTSplit extends Rule("ADT Split.") { } } - solver.free() - candidates.collect{ _ match { case Some((id, cases)) => val oas = p.as.filter(_ != id) diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index c1c53fc00f9558c64b2aa46d01c9fd455a7fd0e4..b62c45e9788cb2831b0f3e37110fc8efa193bf6c 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -441,20 +441,20 @@ case object CEGIS extends Rule("CEGIS") { var unrolings = 0 val maxUnrolings = 3 - val exSolver = sctx.solverFactory.withTimeout(3000L) // 3sec - val cexSolver = sctx.solverFactory.withTimeout(3000L) // 3sec + val exSolverTo = 3000L + val cexSolverTo = 3000L - try { - var baseExampleInputs: Seq[Seq[Expr]] = Seq() + var baseExampleInputs: Seq[Seq[Expr]] = Seq() - // We populate the list of examples with a predefined one - if (p.pc == BooleanLiteral(true)) { - baseExampleInputs = p.as.map(a => simplestValue(a.getType)) +: baseExampleInputs - } else { - val solver = exSolver.getNewSolver + // We populate the list of examples with a predefined one + if (p.pc == BooleanLiteral(true)) { + baseExampleInputs = p.as.map(a => simplestValue(a.getType)) +: baseExampleInputs + } else { + val solver = sctx.newSolver.setTimeout(exSolverTo) - solver.assertCnstr(p.pc) + solver.assertCnstr(p.pc) + try { solver.check match { case Some(true) => val model = solver.getModel @@ -467,36 +467,39 @@ case object CEGIS extends Rule("CEGIS") { sctx.reporter.warning("Solver could not solve path-condition") return RuleApplicationImpossible // This is not necessary though, but probably wanted } - - } - - val inputIterator: Iterator[Seq[Expr]] = if (useVanuatoo) { - new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, 20, 3000) - } else { - new NaiveDataGen(sctx.context, sctx.program, evaluator).generateFor(p.as, p.pc, 20, 1000) + } finally { + solver.free() } + } - val cachedInputIterator = new Iterator[Seq[Expr]] { - def next() = { - val i = inputIterator.next() - baseExampleInputs = i +: baseExampleInputs - i - } + val inputIterator: Iterator[Seq[Expr]] = if (useVanuatoo) { + new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, 20, 3000) + } else { + new NaiveDataGen(sctx.context, sctx.program, evaluator).generateFor(p.as, p.pc, 20, 1000) + } - def hasNext() = inputIterator.hasNext + val cachedInputIterator = new Iterator[Seq[Expr]] { + def next() = { + val i = inputIterator.next() + baseExampleInputs = i +: baseExampleInputs + i } - def hasInputExamples() = baseExampleInputs.size > 0 || cachedInputIterator.hasNext + def hasNext() = inputIterator.hasNext + } + + def hasInputExamples() = baseExampleInputs.size > 0 || cachedInputIterator.hasNext - def allInputExamples() = baseExampleInputs.iterator ++ cachedInputIterator + def allInputExamples() = baseExampleInputs.iterator ++ cachedInputIterator - def checkForPrograms(programs: Set[Set[Identifier]]): RuleApplicationResult = { - for (prog <- programs) { - val expr = ndProgram.determinize(prog) - val res = Equals(Tuple(p.xs.map(Variable(_))), expr) - val solver3 = cexSolver.getNewSolver - solver3.assertCnstr(And(p.pc :: res :: Not(p.phi) :: Nil)) + def checkForPrograms(programs: Set[Set[Identifier]]): RuleApplicationResult = { + for (prog <- programs) { + val expr = ndProgram.determinize(prog) + val res = Equals(Tuple(p.xs.map(Variable(_))), expr) + val solver3 = sctx.newSolver.setTimeout(cexSolverTo) + solver3.assertCnstr(And(p.pc :: res :: Not(p.phi) :: Nil)) + try { solver3.check match { case Some(false) => return RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = true) @@ -505,249 +508,254 @@ case object CEGIS extends Rule("CEGIS") { case Some(true) => // invalid program, we skip } + } finally { + solver3.free() } - - RuleApplicationImpossible } - // Keep track of collected cores to filter programs to test - var collectedCores = Set[Set[Identifier]]() - - val initExClause = And(p.pc :: p.phi :: Variable(initGuard) :: Nil) - val initCExClause = And(p.pc :: Not(p.phi) :: Variable(initGuard) :: Nil) + RuleApplicationImpossible + } - // solver1 is used for the initial SAT queries - var solver1 = exSolver.getNewSolver - solver1.assertCnstr(initExClause) + // Keep track of collected cores to filter programs to test + var collectedCores = Set[Set[Identifier]]() - // solver2 is used for validating a candidate program, or finding new inputs - var solver2 = cexSolver.getNewSolver - solver2.assertCnstr(initCExClause) + val initExClause = And(p.pc :: p.phi :: Variable(initGuard) :: Nil) + val initCExClause = And(p.pc :: Not(p.phi) :: Variable(initGuard) :: Nil) - var didFilterAlready = false + // solver1 is used for the initial SAT queries + var solver1 = sctx.newSolver.setTimeout(exSolverTo) + solver1.assertCnstr(initExClause) - val tpe = TupleType(p.xs.map(_.getType)) + // solver2 is used for validating a candidate program, or finding new inputs + var solver2 = sctx.newSolver.setTimeout(cexSolverTo) + solver2.assertCnstr(initCExClause) - try { - do { - var needMoreUnrolling = false + var didFilterAlready = false - var bssAssumptions = Set[Identifier]() + val tpe = TupleType(p.xs.map(_.getType)) - if (!didFilterAlready) { - val (clauses, closedBs) = ndProgram.unroll + try { + do { + var needMoreUnrolling = false - bssAssumptions = closedBs + var bssAssumptions = Set[Identifier]() - sctx.reporter.ifDebug { debug => - debug("UNROLLING: ") - for (c <- clauses) { - debug(" - " + c) - } - debug("CLOSED Bs "+closedBs) - } + if (!didFilterAlready) { + val (clauses, closedBs) = ndProgram.unroll - val clause = And(clauses) + bssAssumptions = closedBs - solver1.assertCnstr(clause) - solver2.assertCnstr(clause) + sctx.reporter.ifDebug { debug => + debug("UNROLLING: ") + for (c <- clauses) { + debug(" - " + c) + } + debug("CLOSED Bs "+closedBs) } - // Compute all programs that have not been excluded yet - var prunedPrograms: Set[Set[Identifier]] = if (useCEPruning) { - ndProgram.allPrograms.filterNot(p => collectedCores.exists(c => c.subsetOf(p))) - } else { - Set() - } + val clause = And(clauses) - val allPrograms = prunedPrograms.size + solver1.assertCnstr(clause) + solver2.assertCnstr(clause) + } - sctx.reporter.debug("#Programs: "+prunedPrograms.size) + // Compute all programs that have not been excluded yet + var prunedPrograms: Set[Set[Identifier]] = if (useCEPruning) { + ndProgram.allPrograms.filterNot(p => collectedCores.exists(c => c.subsetOf(p))) + } else { + Set() + } - // We further filter the set of working programs to remove those that fail on known examples - if (useCEPruning && hasInputExamples() && ndProgram.canTest()) { + val allPrograms = prunedPrograms.size - for (p <- prunedPrograms) { - if (!allInputExamples().forall(ndProgram.testForProgram(p))) { - // This program failed on at least one example - solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq))) - prunedPrograms -= p - } - } + sctx.reporter.debug("#Programs: "+prunedPrograms.size) - if (prunedPrograms.isEmpty) { - needMoreUnrolling = true - } + // We further filter the set of working programs to remove those that fail on known examples + if (useCEPruning && hasInputExamples() && ndProgram.canTest()) { - //println("Passing tests: "+prunedPrograms.size) + for (p <- prunedPrograms) { + if (!allInputExamples().forall(ndProgram.testForProgram(p))) { + // This program failed on at least one example + solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq))) + prunedPrograms -= p + } } - val nPassing = prunedPrograms.size - - sctx.reporter.debug("#Programs passing tests: "+nPassing) - - if (nPassing == 0) { - needMoreUnrolling = true; - } else if (nPassing <= testUpTo) { - // Immediate Test - result = Some(checkForPrograms(prunedPrograms)) - } else if (((nPassing < allPrograms*filterThreshold) || didFilterAlready) && useBssFiltering) { - // We filter the Bss so that the formula we give to z3 is much smalled - val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _) - - // Cannot unroll normally after having filtered, so we need to - // repeat the filtering procedure at next unrolling. - didFilterAlready = true - - // Freshening solvers - solver1 = exSolver.getNewSolver - solver1.assertCnstr(initExClause) - solver2 = cexSolver.getNewSolver - solver2.assertCnstr(initCExClause) - - val clauses = ndProgram.filterFor(bssToKeep) - val clause = And(clauses) - - solver1.assertCnstr(clause) - solver2.assertCnstr(clause) + if (prunedPrograms.isEmpty) { + needMoreUnrolling = true } - val bss = ndProgram.bss + //println("Passing tests: "+prunedPrograms.size) + } - while (result.isEmpty && !needMoreUnrolling && !interruptManager.isInterrupted()) { + val nPassing = prunedPrograms.size + + sctx.reporter.debug("#Programs passing tests: "+nPassing) + + if (nPassing == 0) { + needMoreUnrolling = true; + } else if (nPassing <= testUpTo) { + // Immediate Test + result = Some(checkForPrograms(prunedPrograms)) + } else if (((nPassing < allPrograms*filterThreshold) || didFilterAlready) && useBssFiltering) { + // We filter the Bss so that the formula we give to z3 is much smalled + val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _) + + // Cannot unroll normally after having filtered, so we need to + // repeat the filtering procedure at next unrolling. + didFilterAlready = true + + // Freshening solvers + solver1.free() + solver1 = sctx.newSolver.setTimeout(exSolverTo) + solver1.assertCnstr(initExClause) + + solver2.free() + solver2 = sctx.newSolver.setTimeout(cexSolverTo) + solver2.assertCnstr(initCExClause) + + val clauses = ndProgram.filterFor(bssToKeep) + val clause = And(clauses) + + solver1.assertCnstr(clause) + solver2.assertCnstr(clause) + } - solver1.checkAssumptions(bssAssumptions.map(id => Not(Variable(id)))) match { - case Some(true) => - val satModel = solver1.getModel + val bss = ndProgram.bss - val bssAssumptions: Set[Expr] = bss.map(b => satModel(b) match { - case BooleanLiteral(true) => Variable(b) - case BooleanLiteral(false) => Not(Variable(b)) - }) + while (result.isEmpty && !needMoreUnrolling && !interruptManager.isInterrupted()) { - val validateWithZ3 = if (useCETests && hasInputExamples() && ndProgram.canTest()) { + solver1.checkAssumptions(bssAssumptions.map(id => Not(Variable(id)))) match { + case Some(true) => + val satModel = solver1.getModel - val p = bssAssumptions.collect { case Variable(b) => b } + val bssAssumptions: Set[Expr] = bss.map(b => satModel(b) match { + case BooleanLiteral(true) => Variable(b) + case BooleanLiteral(false) => Not(Variable(b)) + }) - if (allInputExamples().forall(ndProgram.testForProgram(p))) { - // All valid inputs also work with this, we need to - // make sure by validating this candidate with z3 - true - } else { - // One valid input failed with this candidate, we can skip - solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq))) - false - } - } else { - // No inputs or capability to test, we need to ask Z3 + val validateWithZ3 = if (useCETests && hasInputExamples() && ndProgram.canTest()) { + + val p = bssAssumptions.collect { case Variable(b) => b } + + if (allInputExamples().forall(ndProgram.testForProgram(p))) { + // All valid inputs also work with this, we need to + // make sure by validating this candidate with z3 true + } else { + // One valid input failed with this candidate, we can skip + solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq))) + false } + } else { + // No inputs or capability to test, we need to ask Z3 + true + } - if (validateWithZ3) { - solver2.checkAssumptions(bssAssumptions) match { - case Some(true) => - val invalidModel = solver2.getModel + if (validateWithZ3) { + solver2.checkAssumptions(bssAssumptions) match { + case Some(true) => + val invalidModel = solver2.getModel - val fixedAss = And(ass.collect { - case a if invalidModel contains a => Equals(Variable(a), invalidModel(a)) - }.toSeq) + val fixedAss = And(ass.collect { + case a if invalidModel contains a => Equals(Variable(a), invalidModel(a)) + }.toSeq) - val newCE = p.as.map(valuateWithModel(invalidModel)) + val newCE = p.as.map(valuateWithModel(invalidModel)) - baseExampleInputs = newCE +: baseExampleInputs + baseExampleInputs = newCE +: baseExampleInputs - // Retest whether the newly found C-E invalidates all programs - if (useCEPruning && ndProgram.canTest) { - if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(newCE))) { - needMoreUnrolling = true - } + // Retest whether the newly found C-E invalidates all programs + if (useCEPruning && ndProgram.canTest) { + if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(newCE))) { + needMoreUnrolling = true } + } - val unsatCore = if (useUnsatCores) { - solver1.push() - solver1.assertCnstr(fixedAss) + val unsatCore = if (useUnsatCores) { + solver1.push() + solver1.assertCnstr(fixedAss) - val core = solver1.checkAssumptions(bssAssumptions) match { - case Some(false) => - // Core might be empty if unrolling level is - // insufficient, it becomes unsat no matter what - // the assumptions are. - solver1.getUnsatCore + val core = solver1.checkAssumptions(bssAssumptions) match { + case Some(false) => + // Core might be empty if unrolling level is + // insufficient, it becomes unsat no matter what + // the assumptions are. + solver1.getUnsatCore - case Some(true) => - // Can't be! - bssAssumptions + case Some(true) => + // Can't be! + bssAssumptions - case None => - return RuleApplicationImpossible - } + case None => + return RuleApplicationImpossible + } - solver1.pop() + solver1.pop() - collectedCores += core.collect{ case Variable(id) => id } + collectedCores += core.collect{ case Variable(id) => id } - core - } else { - bssAssumptions - } + core + } else { + bssAssumptions + } - if (unsatCore.isEmpty) { - needMoreUnrolling = true - } else { - solver1.assertCnstr(Not(And(unsatCore.toSeq))) - } + if (unsatCore.isEmpty) { + needMoreUnrolling = true + } else { + solver1.assertCnstr(Not(And(unsatCore.toSeq))) + } - case Some(false) => + case Some(false) => - val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) + val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) - result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr))) + result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr))) - case _ => - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) - result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = false)) - } else { - return RuleApplicationImpossible - } - } + case _ => + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") + val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) + result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = false)) + } else { + return RuleApplicationImpossible + } } + } - case Some(false) => - if (useUninterpretedProbe) { - solver1.check match { - case Some(false) => - // Unsat even without blockers (under which fcalls are then uninterpreted) - return RuleApplicationImpossible + case Some(false) => + if (useUninterpretedProbe) { + solver1.check match { + case Some(false) => + // Unsat even without blockers (under which fcalls are then uninterpreted) + return RuleApplicationImpossible - case _ => - } + case _ => } + } - needMoreUnrolling = true + needMoreUnrolling = true - case _ => - // Last chance, we test first few programs - return checkForPrograms(prunedPrograms.take(testUpTo)) - } + case _ => + // Last chance, we test first few programs + return checkForPrograms(prunedPrograms.take(testUpTo)) } + } - unrolings += 1 - } while(unrolings < maxUnrolings && result.isEmpty && !interruptManager.isInterrupted()) + unrolings += 1 + } while(unrolings < maxUnrolings && result.isEmpty && !interruptManager.isInterrupted()) - result.getOrElse(RuleApplicationImpossible) + result.getOrElse(RuleApplicationImpossible) - } catch { - case e: Throwable => - sctx.reporter.warning("CEGIS crashed: "+e.getMessage) - RuleApplicationImpossible - } + } catch { + case e: Throwable => + sctx.reporter.warning("CEGIS crashed: "+e.getMessage) + RuleApplicationImpossible } finally { - exSolver.free() - cexSolver.free() + solver1.free() + solver2.free() } } }) diff --git a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala index 26867fd76f9162569cbb7f96f3f98efbd2b45e47..6e867ce1df2e34ec0c16df55f6867ff207c16fa6 100644 --- a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala @@ -40,8 +40,6 @@ case object EqualitySplit extends Rule("Eq. Split") { case _ => false }).values.flatten - solver.free() - candidates.flatMap(_ match { case List(a1, a2) => diff --git a/src/main/scala/leon/synthesis/rules/Ground.scala b/src/main/scala/leon/synthesis/rules/Ground.scala index 31bfcefad9dad1ec9cd59b4acaff2fffaa80f05e..da7c0d6514b3094b9460cbc81604f331c0426eff 100644 --- a/src/main/scala/leon/synthesis/rules/Ground.scala +++ b/src/main/scala/leon/synthesis/rules/Ground.scala @@ -14,7 +14,7 @@ case object Ground extends Rule("Ground") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { if (p.as.isEmpty) { - val solver = SimpleSolverAPI(sctx.solverFactory.withTimeout(5000L)) // We give that 5s + val solver = SimpleSolverAPI(new TimeoutSolverFactory(sctx.solverFactory, 5000L)) val tpe = TupleType(p.xs.map(_.getType)) @@ -29,8 +29,6 @@ case object Ground extends Rule("Ground") { None } - solver.free() - result } else { None diff --git a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala index a2a2952165755d2bcd0013331a105e0667d5b9f7..b64dc5335789bb1c9789479871537b6d35bdf940 100644 --- a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala @@ -52,8 +52,6 @@ case object InequalitySplit extends Rule("Ineq. Split.") { case _ => false } - solver.free() - candidates.flatMap(_ match { case List(a1, a2) => diff --git a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala index 6085007e313b78afd8774245556d9ac1d2c52792..3e2fcfcca8c9dfd24fd0a88cc0d942723e74504e 100644 --- a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala +++ b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala @@ -69,8 +69,6 @@ case object OptimisticGround extends Rule("Optimistic Ground") { i += 1 } - solver.free() - result.getOrElse(RuleApplicationImpossible) } } diff --git a/src/main/scala/leon/testgen/CallGraph.scala b/src/main/scala/leon/testgen/CallGraph.scala index 3c22b2334ae7f773b4df3f5c36f93b422002062f..a8e4f367e26989e7c2351f3a9caeedd8d949d71a 100644 --- a/src/main/scala/leon/testgen/CallGraph.scala +++ b/src/main/scala/leon/testgen/CallGraph.scala @@ -168,7 +168,7 @@ class CallGraph(val program: Program) { fd.annotations.exists(_ == "main") } - def findAllPaths(z3Solver: FairZ3SolverFactory): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + def findAllPaths(z3Solverf: SolverFactory[Solver]): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { val waypoints: Set[ProgramPoint] = programPoints.filter{ case ExpressionPoint(Waypoint(_, _), _) => true case _ => false } val sortedWaypoints: Seq[ProgramPoint] = waypoints.toSeq.sortWith((p1, p2) => { val (ExpressionPoint(Waypoint(i1, _), _), ExpressionPoint(Waypoint(i2, _), _)) = (p1, p2) @@ -183,7 +183,7 @@ class CallGraph(val program: Program) { if(sortedWaypoints.size == 0) { findSimplePaths(mainPoint.get) } else { - visitAllWaypoints(mainPoint.get :: sortedWaypoints.toList, z3Solver) match { + visitAllWaypoints(mainPoint.get :: sortedWaypoints.toList, z3Solverf) match { case None => Set() case Some(p) => Set(p) } @@ -194,7 +194,7 @@ class CallGraph(val program: Program) { } } - def visitAllWaypoints(waypoints: List[ProgramPoint], z3Solver: FairZ3SolverFactory): Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + def visitAllWaypoints(waypoints: List[ProgramPoint], z3Solverf: SolverFactory[Solver]): Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { def rec(head: ProgramPoint, tail: List[ProgramPoint], path: Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]): Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { tail match { @@ -204,11 +204,10 @@ class CallGraph(val program: Program) { var completePath: Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = None allPaths.find(intermediatePath => { val pc = pathConstraint(path ++ intermediatePath) - z3Solver.restartZ3 var testcase: Option[Map[Identifier, Expr]] = None - val (solverResult, model) = SimpleSolverAPI(z3Solver).solveSAT(pc) + val (solverResult, model) = SimpleSolverAPI(z3Solverf).solveSAT(pc) solverResult match { case None => { false diff --git a/src/main/scala/leon/testgen/TestGeneration.scala b/src/main/scala/leon/testgen/TestGeneration.scala index 321388e67743d4b8af33fc92bd1448bb3cff58c6..e08e7b1721e714b093079910e3ea0704cadb5245 100644 --- a/src/main/scala/leon/testgen/TestGeneration.scala +++ b/src/main/scala/leon/testgen/TestGeneration.scala @@ -27,7 +27,6 @@ class TestGeneration(context : LeonContext) { private val reporter = context.reporter def analyse(program: Program) { - val z3Solver = new FairZ3SolverFactory(context, program) reporter.info("Running test generation") val testcases = generateTestCases(program) @@ -60,11 +59,11 @@ class TestGeneration(context : LeonContext) { } def generatePathConditions(program: Program): Set[Expr] = { - val z3Solver = new FairZ3SolverFactory(context, program) + val z3Solverf = SolverFactory( () => new FairZ3Solver(context, program)) val callGraph = new CallGraph(program) callGraph.writeDotFile("testgen.dot") - val constraints = callGraph.findAllPaths(z3Solver).map(path => { + val constraints = callGraph.findAllPaths(z3Solverf).map(path => { println("Path is: " + path) val cnstr = callGraph.pathConstraint(path) println("constraint is: " + cnstr) @@ -75,15 +74,14 @@ class TestGeneration(context : LeonContext) { private def generateTestCases(program: Program): Set[Map[Identifier, Expr]] = { val allPaths = generatePathConditions(program) - val z3Solver = new FairZ3SolverFactory(context, program) + val z3Solverf = SolverFactory( () => new FairZ3Solver(context, program)) allPaths.flatMap(pathCond => { reporter.info("Now considering path condition: " + pathCond) var testcase: Option[Map[Identifier, Expr]] = None - z3Solver.restartZ3 - val (solverResult, model) = SimpleSolverAPI(z3Solver).solveSAT(pathCond) + val (solverResult, model) = SimpleSolverAPI(z3Solverf).solveSAT(pathCond) solverResult match { case None => Seq() diff --git a/src/main/scala/leon/verification/AnalysisPhase.scala b/src/main/scala/leon/verification/AnalysisPhase.scala index 3a9d16e32aee4c0faca94c1592a2ff790fe862ec..76a7ea3c2de1c33fc526217cec8110f02e2de2f0 100644 --- a/src/main/scala/leon/verification/AnalysisPhase.scala +++ b/src/main/scala/leon/verification/AnalysisPhase.scala @@ -11,7 +11,6 @@ import purescala.TypeTrees._ import solvers._ import solvers.z3._ -import solvers.bapaminmax._ import solvers.combinators._ import scala.collection.mutable.{Set => MutableSet} @@ -70,62 +69,65 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { for((funDef, vcs) <- vcs.toSeq.sortWith((a,b) => a._1 < b._1); vcInfo <- vcs if !interruptManager.isInterrupted()) { val funDef = vcInfo.funDef val vc = vcInfo.condition - - val time0 : Long = System.currentTimeMillis - val time1 = System.currentTimeMillis - reporter.info("Now considering '" + vcInfo.kind + "' VC for " + funDef.id + "...") reporter.debug("Verification condition (" + vcInfo.kind + ") for ==== " + funDef.id + " ====") reporter.debug(simplifyLets(vc)) // try all solvers until one returns a meaningful answer - solvers.find(se => { - reporter.debug("Trying with solver: " + se.name) - val t1 = System.nanoTime - val (satResult, counterexample) = SimpleSolverAPI(se).solveSAT(Not(vc)) - val solverResult = satResult.map(!_) - - val t2 = System.nanoTime - val dt = ((t2 - t1) / 1000000) / 1000.0 - - solverResult match { - case _ if interruptManager.isInterrupted() => - reporter.info("=== CANCELLED ===") - vcInfo.time = Some(dt) - false - - case None => - vcInfo.time = Some(dt) - false - - case Some(true) => - reporter.info("==== VALID ====") - - vcInfo.hasValue = true - vcInfo.value = Some(true) - vcInfo.solvedWith = Some(se) - vcInfo.time = Some(dt) - true - - case Some(false) => - reporter.error("Found counter-example : ") - reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n")) - reporter.error("==== INVALID ====") + solvers.find(sf => { + val s = sf.getNewSolver + try { + reporter.debug("Trying with solver: " + s.name) + val t1 = System.nanoTime + s.assertCnstr(Not(vc)) + + val satResult = s.check + val counterexample: Map[Identifier, Expr] = if (satResult == Some(true)) s.getModel else Map() + val solverResult = satResult.map(!_) + + val t2 = System.nanoTime + val dt = ((t2 - t1) / 1000000) / 1000.0 + + solverResult match { + case _ if interruptManager.isInterrupted() => + reporter.info("=== CANCELLED ===") + vcInfo.time = Some(dt) + false + + case None => + vcInfo.time = Some(dt) + false + + case Some(true) => + reporter.info("==== VALID ====") + + vcInfo.hasValue = true + vcInfo.value = Some(true) + vcInfo.solvedWith = Some(s) + vcInfo.time = Some(dt) + true + + case Some(false) => + reporter.error("Found counter-example : ") + reporter.error(counterexample.toSeq.sortBy(_._1.name).map(p => p._1 + " -> " + p._2).mkString("\n")) + reporter.error("==== INVALID ====") + vcInfo.hasValue = true + vcInfo.value = Some(false) + vcInfo.solvedWith = Some(s) + vcInfo.counterExample = Some(counterexample) + vcInfo.time = Some(dt) + true + } + } finally { + s.free() + }}) match { + case None => { vcInfo.hasValue = true - vcInfo.value = Some(false) - vcInfo.solvedWith = Some(se) - vcInfo.counterExample = Some(counterexample) - vcInfo.time = Some(dt) - true + reporter.warning("==== UNKNOWN ====") + } + case _ => } - }) match { - case None => { - vcInfo.hasValue = true - reporter.warning("==== UNKNOWN ====") - } - case _ => - } } val report = new VerificationReport(vcs) @@ -148,33 +150,23 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] { val reporter = ctx.reporter - lazy val fairZ3 = new FairZ3SolverFactory(ctx, program) - - val baseSolvers : Seq[SolverFactory[Solver]] = { - fairZ3 :: Nil - } - - val solvers: Seq[SolverFactory[Solver]] = timeout match { - case Some(t) => - baseSolvers.map(_.withTimeout(100L*t)) + val baseFactories = Seq( + SolverFactory(() => new FairZ3Solver(ctx, program)) + ) + val solverFactories = timeout match { + case Some(sec) => + baseFactories.map { sf => + new TimeoutSolverFactory(sf, sec*1000L) + } case None => - baseSolvers + baseFactories } - val vctx = VerificationContext(ctx, solvers, reporter) + val vctx = VerificationContext(ctx, solverFactories, reporter) - val report = if(solvers.size >= 1) { - reporter.debug("Running verification condition generation...") - val vcs = generateVerificationConditions(reporter, program, functionsToAnalyse) - checkVerificationConditions(vctx, vcs) - } else { - reporter.warning("No solver specified. Cannot test verification conditions.") - VerificationReport.emptyReport - } - - solvers.foreach(_.free()) - - report + reporter.debug("Running verification condition generation...") + val vcs = generateVerificationConditions(reporter, program, functionsToAnalyse) + checkVerificationConditions(vctx, vcs) } } diff --git a/src/main/scala/leon/verification/VerificationCondition.scala b/src/main/scala/leon/verification/VerificationCondition.scala index 776a4c004724c2506bc8d7bb56467308f6922bb3..e82d61cf123a9f08f09adce8a2bac57478634c06 100644 --- a/src/main/scala/leon/verification/VerificationCondition.scala +++ b/src/main/scala/leon/verification/VerificationCondition.scala @@ -15,7 +15,7 @@ class VerificationCondition(val condition: Expr, val funDef: FunDef, val kind: V // Some(false) = valid var hasValue = false var value : Option[Boolean] = None - var solvedWith : Option[SolverFactory[Solver]] = None + var solvedWith : Option[Solver] = None var time : Option[Double] = None var counterExample : Option[Map[Identifier, Expr]] = None diff --git a/src/test/scala/leon/test/condabd/VariableSolverRefinerTest.scala b/src/test/scala/leon/test/condabd/VariableSolverRefinerTest.scala index dc1d945ca04ebbabe9914fe30415764a82aba6a5..5fe37005b69230ef3dd3a03d09be6f34a35367c6 100644 --- a/src/test/scala/leon/test/condabd/VariableSolverRefinerTest.scala +++ b/src/test/scala/leon/test/condabd/VariableSolverRefinerTest.scala @@ -86,8 +86,6 @@ class VariableSolverRefinerTest extends FunSpec with GivenWhenThen { ) (res2._1, res2._2) } - - solver.free() } } @@ -144,7 +142,6 @@ class VariableSolverRefinerTest extends FunSpec with GivenWhenThen { ) (res2._1, res2._2) } - solver.free() } } @@ -204,8 +201,6 @@ class VariableSolverRefinerTest extends FunSpec with GivenWhenThen { ) (res2._1, res2._2) } - - solver.free() } } diff --git a/src/test/scala/leon/test/condabd/VerifierTest.scala b/src/test/scala/leon/test/condabd/VerifierTest.scala index de88e347cfaad3b941bc4b9bc38fdd1ffc6f3c12..b5d5c3fdac7759f5c53bf5c2347f11675939451c 100644 --- a/src/test/scala/leon/test/condabd/VerifierTest.scala +++ b/src/test/scala/leon/test/condabd/VerifierTest.scala @@ -23,133 +23,135 @@ class VerifierTest extends FunSpec { import Utils._ import Scaffold._ - val lesynthTestDir = "testcases/condabd/test/lesynth" - + val lesynthTestDir = "testcases/condabd/test/lesynth" + def getPostconditionFunction(problem: Problem) = { (list: Iterable[Identifier]) => { (problem.phi /: list) { case ((res, newId)) => - (res /: problem.as.find(_.name == newId.name)) { - case ((res, oldId)) => - TreeOps.replace(Map(Variable(oldId) -> Variable(newId)), res) - } + (res /: problem.as.find(_.name == newId.name)) { + case ((res, oldId)) => + TreeOps.replace(Map(Variable(oldId) -> Variable(newId)), res) + } } } } - - describe("Concrete verifier: ") { + + describe("Concrete verifier: ") { val testCaseFileName = lesynthTestDir + "/ListConcatVerifierTest.scala" val problems = forFile(testCaseFileName) assert(problems.size == 1) - for ((sctx, funDef, problem) <- problems) { - - val timeoutSolver = sctx.solverFactory.withTimeout(2000L) - - val getNewPostcondition = getPostconditionFunction(problem) - - describe("A Verifier") { - it("should verify first correct concat body") { - val newFunDef = getFunDefByName(sctx.program, "goodConcat1") - funDef.body = newFunDef.body + for ((sctx, funDef, problem) <- problems) { + + val timeoutSolver = SolverFactory(() => sctx.newSolver.setTimeout(2000L)) + + + val getNewPostcondition = getPostconditionFunction(problem) + + describe("A Verifier") { + it("should verify first correct concat body") { + val newFunDef = getFunDefByName(sctx.program, "goodConcat1") + funDef.body = newFunDef.body - expectResult(1) { problem.xs.size } - funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) - funDef.precondition = Some(BooleanLiteral(true)) - - val verifier = new Verifier(timeoutSolver, problem) - - assert( verifier.analyzeFunction(funDef)._1 ) - } - - it("should verify 2nd correct concat body") { - val newFunDef = getFunDefByName(sctx.program, "goodConcat2") - funDef.body = newFunDef.body + expectResult(1) { problem.xs.size } + funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) + funDef.precondition = Some(BooleanLiteral(true)) + + val verifier = new Verifier(timeoutSolver, problem) + + assert( verifier.analyzeFunction(funDef)._1 ) + verifier.solver.free() + } + + it("should verify 2nd correct concat body") { + val newFunDef = getFunDefByName(sctx.program, "goodConcat2") + funDef.body = newFunDef.body - expectResult(1) { problem.xs.size } - funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) - funDef.precondition = Some(BooleanLiteral(true)) - - val verifier = new Verifier(timeoutSolver, problem) - - assert( verifier.analyzeFunction(funDef)._1 ) - } - - it("should not verify an incorrect concat body") { - val newFunDef = getFunDefByName(sctx.program, "badConcat1") - funDef.body = newFunDef.body + expectResult(1) { problem.xs.size } + funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) + funDef.precondition = Some(BooleanLiteral(true)) + + val verifier = new Verifier(timeoutSolver, problem) + + assert( verifier.analyzeFunction(funDef)._1 ) + verifier.solver.free() + } + + it("should not verify an incorrect concat body") { + val newFunDef = getFunDefByName(sctx.program, "badConcat1") + funDef.body = newFunDef.body - expectResult(1) { problem.xs.size } - funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) - funDef.precondition = Some(BooleanLiteral(true)) - - val verifier = new Verifier(timeoutSolver, problem) - - assert( ! verifier.analyzeFunction(funDef)._1 ) - } - } - - timeoutSolver.free - } - } + expectResult(1) { problem.xs.size } + funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) + funDef.precondition = Some(BooleanLiteral(true)) + + val verifier = new Verifier(timeoutSolver, problem) + + assert( ! verifier.analyzeFunction(funDef)._1 ) + verifier.solver.free() + } + } + } + } def getPreconditionFunction(problem: Problem) = { (list: Iterable[Identifier]) => { (problem.pc /: list) { case ((res, newId)) => - (res /: problem.as.find(_.name == newId.name)) { - case ((res, oldId)) => - TreeOps.replace(Map(Variable(oldId) -> Variable(newId)), res) - } + (res /: problem.as.find(_.name == newId.name)) { + case ((res, oldId)) => + TreeOps.replace(Map(Variable(oldId) -> Variable(newId)), res) + } } } } - describe("Relaxed verifier: ") { + describe("Relaxed verifier: ") { val testCaseFileName = lesynthTestDir + "/BinarySearchTree.scala" val problems = forFile(testCaseFileName) assert(problems.size == 1) - for ((sctx, funDef, problem) <- problems) { - - val timeoutSolver = sctx.solverFactory.withTimeout(1000L) - - val getNewPostcondition = getPostconditionFunction(problem) - val getNewPrecondition = getPreconditionFunction(problem) - - describe("A RelaxedVerifier on BST") { - it("should verify a correct member body") { - val newFunDef = getFunDefByName(sctx.program, "goodMember") - funDef.body = newFunDef.body + for ((sctx, funDef, problem) <- problems) { + + val timeoutSolver = SolverFactory(() => sctx.newSolver.setTimeout(1000L)) + + val getNewPostcondition = getPostconditionFunction(problem) + val getNewPrecondition = getPreconditionFunction(problem) + + describe("A RelaxedVerifier on BST") { + it("should verify a correct member body") { + val newFunDef = getFunDefByName(sctx.program, "goodMember") + funDef.body = newFunDef.body - expectResult(1) { problem.xs.size } - funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) - funDef.precondition = Some(getNewPrecondition(newFunDef.args.map(_.id))) - - val verifier = new RelaxedVerifier(timeoutSolver, problem) - - assert( verifier.analyzeFunction(funDef)._1 ) - } - - it("should not verify an incorrect member body") { - val newFunDef = getFunDefByName(sctx.program, "badMember") - funDef.body = newFunDef.body + expectResult(1) { problem.xs.size } + funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) + funDef.precondition = Some(getNewPrecondition(newFunDef.args.map(_.id))) + + val verifier = new RelaxedVerifier(timeoutSolver, problem) + + assert( verifier.analyzeFunction(funDef)._1 ) + verifier.solver.free() + } + + it("should not verify an incorrect member body") { + val newFunDef = getFunDefByName(sctx.program, "badMember") + funDef.body = newFunDef.body - expectResult(1) { problem.xs.size } - funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) - funDef.precondition = Some(getNewPrecondition(newFunDef.args.map(_.id))) - - val verifier = new Verifier(timeoutSolver, problem) - - assert( verifier.analyzeFunction(funDef)._1 ) - } - } - - timeoutSolver.free - } + expectResult(1) { problem.xs.size } + funDef.postcondition = Some((problem.xs.head, getNewPostcondition(newFunDef.args.map(_.id)))) + funDef.precondition = Some(getNewPrecondition(newFunDef.args.map(_.id))) + + val verifier = new Verifier(timeoutSolver, problem) + + assert( verifier.analyzeFunction(funDef)._1 ) + verifier.solver.free() + } + } + } - } + } } diff --git a/src/test/scala/leon/test/purescala/TreeOpsTests.scala b/src/test/scala/leon/test/purescala/TreeOpsTests.scala index a1d804171d499f8ca59a2d24a23686802f68d7ec..f261d246b52e286038ff68ff6a285f5fb7598d27 100644 --- a/src/test/scala/leon/test/purescala/TreeOpsTests.scala +++ b/src/test/scala/leon/test/purescala/TreeOpsTests.scala @@ -16,13 +16,9 @@ import leon.solvers.z3._ class TreeOpsTests extends LeonTestSuite { test("Path-aware simplifications") { - val solver = new UninterpretedZ3SolverFactory(testContext, Program.empty) - // TODO actually testing something here would be better, sorry // PS - solver.free() - assert(true) } diff --git a/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala b/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala index 9745d3a123815d9d28eea90bdc444e9fe67b3208..edd96fdc19763e4d33f4eb1943fe06f753ab3950 100644 --- a/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala +++ b/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala @@ -12,38 +12,39 @@ import leon.purescala.Trees._ import leon.purescala.TypeTrees._ class TimeoutSolverTests extends LeonTestSuite { - private class IdioticSolver(val context : LeonContext, val program: Program) extends SolverFactory[Solver] { - enclosing => - + private class IdioticSolver(val context : LeonContext, val program: Program) extends Solver with TimeoutSolver { val name = "Idiotic" val description = "Loops" - def getNewSolver = new Solver { - def check = { - while(!interrupted) { - Thread.sleep(100) - } - None - } + var interrupted = false - def assertCnstr(e: Expr) {} + def innerCheck = { + while(!interrupted) { + Thread.sleep(100) + } + None + } - def checkAssumptions(assump: Set[Expr]) = ??? - def getModel = ??? - def getUnsatCore = ??? - def push() = ??? - def pop(lvl: Int) = ??? + def recoverInterrupt() { + interrupted = false + } - def interrupt() = enclosing.interrupt() - def recoverInterrupt() = enclosing.recoverInterrupt() + def interrupt() { + interrupted = true } + + def assertCnstr(e: Expr) = {} + + def free() {} + + def getModel = ??? } - private def getTOSolver : TimeoutSolverFactory[Solver] = { - new IdioticSolver(testContext, Program.empty).withTimeout(1000L) + private def getTOSolver : SolverFactory[Solver] = { + SolverFactory(() => new IdioticSolver(testContext, Program.empty).setTimeout(1000L)) } - private def check(sf: TimeoutSolverFactory[Solver], e: Expr): Option[Boolean] = { + private def check(sf: SolverFactory[Solver], e: Expr): Option[Boolean] = { val s = sf.getNewSolver s.assertCnstr(e) s.check diff --git a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala index cdc28175ee16514a64f8b6c1c78dd86f2aa3ad7f..affcad2d2e83d7237fe6606492387e99626fbc94 100644 --- a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala +++ b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala @@ -53,7 +53,7 @@ class FairZ3SolverTests extends LeonTestSuite { private val y : Expr = Variable(FreshIdentifier("y").setType(Int32Type)) private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) - private val solver = SimpleSolverAPI(new FairZ3SolverFactory(testContext, minimalProgram)) + private val solver = SimpleSolverAPI(SolverFactory(() => new FairZ3Solver(testContext, minimalProgram))) private val tautology1 : Expr = BooleanLiteral(true) assertValid(solver, tautology1) @@ -102,6 +102,4 @@ class FairZ3SolverTests extends LeonTestSuite { assert(core === Set(b2)) } } - - solver.free() } diff --git a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala index e44e3697f05b8079ffd0bea263c077bc0c5bd84d..4d9f46185f99da00b460f46e01a366ede24194c9 100644 --- a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala +++ b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala @@ -20,9 +20,13 @@ class FairZ3SolverTestsNewAPI extends LeonTestSuite { test("Solver test #" + testCounter) { val sub = solver.getNewSolver - sub.assertCnstr(Not(expr)) + try { + sub.assertCnstr(Not(expr)) - assert(sub.check === expected.map(!_), msg) + assert(sub.check === expected.map(!_), msg) + } finally { + sub.free() + } } } @@ -57,7 +61,7 @@ class FairZ3SolverTestsNewAPI extends LeonTestSuite { private val y : Expr = Variable(FreshIdentifier("y").setType(Int32Type)) private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) - private val solver = new FairZ3SolverFactory(testContext, minimalProgram) + private val solver = SolverFactory(() => new FairZ3Solver(testContext, minimalProgram)) private val tautology1 : Expr = BooleanLiteral(true) assertValid(solver, tautology1) @@ -91,28 +95,38 @@ class FairZ3SolverTestsNewAPI extends LeonTestSuite { locally { val sub = solver.getNewSolver - sub.assertCnstr(f) - assert(sub.check === Some(true)) + try { + sub.assertCnstr(f) + assert(sub.check === Some(true)) + } finally { + sub.free() + } } locally { val sub = solver.getNewSolver - sub.assertCnstr(f) - val result = sub.checkAssumptions(Set(b1)) - - assert(result === Some(true)) - assert(sub.getUnsatCore.isEmpty) + try { + sub.assertCnstr(f) + val result = sub.checkAssumptions(Set(b1)) + + assert(result === Some(true)) + assert(sub.getUnsatCore.isEmpty) + } finally { + sub.free() + } } locally { val sub = solver.getNewSolver - sub.assertCnstr(f) - - val result = sub.checkAssumptions(Set(b1, b2)) - assert(result === Some(false)) - assert(sub.getUnsatCore === Set(b2)) + try { + sub.assertCnstr(f) + + val result = sub.checkAssumptions(Set(b1, b2)) + assert(result === Some(false)) + assert(sub.getUnsatCore === Set(b2)) + } finally { + sub.free() + } } } - - solver.free() } diff --git a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala index 050c5a8efc63ec582baf7d1216615cdfd4e84889..84c7589e8f0272058164e550ee126705298413c8 100644 --- a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala +++ b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala @@ -58,7 +58,7 @@ class UninterpretedZ3SolverTests extends LeonTestSuite { private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) private def g(e : Expr) : Expr = FunctionInvocation(gDef, e :: Nil) - private val solver = SimpleSolverAPI(new UninterpretedZ3SolverFactory(testContext, minimalProgram)) + private val solver = SimpleSolverAPI(SolverFactory(() => new UninterpretedZ3Solver(testContext, minimalProgram))) private val tautology1 : Expr = BooleanLiteral(true) assertValid(solver, tautology1) @@ -89,6 +89,4 @@ class UninterpretedZ3SolverTests extends LeonTestSuite { solver.solveVALID(Equals(g(x), g(x))) } } - - solver.free() }