Skip to content
Snippets Groups Projects
Commit c903fa2a authored by Etienne Kneuss's avatar Etienne Kneuss Committed by Philippe Suter
Browse files

Modify TimeoutSolver to wrap Incremental solvers as well

parent 7b9cd608
Branches
Tags
No related merge requests found
......@@ -18,6 +18,7 @@ trait IncrementalSolver {
def pop(lvl: Int = 1): Unit
def assertCnstr(expression: Expr): Unit
def halt(): Unit
def check: Option[Boolean]
def checkAssumptions(assumptions: Set[Expr]): Option[Boolean]
def getModel: Map[Identifier, Expr]
......@@ -25,6 +26,7 @@ trait IncrementalSolver {
}
trait NaiveIncrementalSolver extends IncrementalSolverBuilder {
def halt(): Unit
def solveSAT(e: Expr): (Option[Boolean], Map[Identifier, Expr])
def getNewSolver = new IncrementalSolver {
......@@ -38,6 +40,10 @@ trait NaiveIncrementalSolver extends IncrementalSolverBuilder {
stack = stack.drop(lvl)
}
def halt() {
NaiveIncrementalSolver.this.halt()
}
def assertCnstr(expression: Expr) {
stack = (expression :: stack.head) :: stack.tail
}
......
......@@ -8,17 +8,17 @@ import purescala.TypeTrees._
import scala.sys.error
class TimeoutSolver(solver : Solver, timeout : Int) extends Solver(solver.context) with NaiveIncrementalSolver {
class TimeoutSolver(solver : Solver with IncrementalSolverBuilder, timeout : Int) extends Solver(solver.context) with IncrementalSolverBuilder {
// I'm making this an inner class to fight the temptation of using it for anything meaningful.
// We have Akka, these days, which whould be better in any respect for non-trivial things.
private class Timer(callback : () => Unit, maxSecs : Int) extends Thread {
private class Timer(onTimeout: => Unit) extends Thread {
private var keepRunning = true
private val asMillis : Long = 1000L * maxSecs
private val asMillis : Long = 1000L * timeout
override def run : Unit = {
val startTime : Long = System.currentTimeMillis
var exceeded : Boolean = false
while(!exceeded && keepRunning) {
if(asMillis < (System.currentTimeMillis - startTime)) {
exceeded = true
......@@ -26,15 +26,23 @@ class TimeoutSolver(solver : Solver, timeout : Int) extends Solver(solver.contex
Thread.sleep(10)
}
if(exceeded && keepRunning) {
callback()
onTimeout
}
}
def halt : Unit = {
keepRunning = false
}
}
def withTimeout[T](onTimeout: => Unit)(body: => T): T = {
val timer = new Timer(onTimeout)
timer.start
val res = body
timer.halt
res
}
val description = solver.description + ", with timeout"
val name = solver.name + "+to"
......@@ -43,16 +51,67 @@ class TimeoutSolver(solver : Solver, timeout : Int) extends Solver(solver.contex
}
def solve(expression: Expr) : Option[Boolean] = {
val timer = new Timer(() => solver.halt, timeout)
timer.start
val res = solver.solve(expression)
timer.halt
res
withTimeout(solver.halt) {
solver.solve(expression)
}
}
override def solveSAT(expression: Expr): (Option[Boolean], Map[Identifier, Expr]) = {
withTimeout(solver.halt) {
solver.solveSAT(expression)
}
}
override def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = {
withTimeout(solver.halt) {
solver.solveSATWithCores(expression, assumptions)
}
}
def getNewSolver = new IncrementalSolver {
val solver = TimeoutSolver.this.solver.getNewSolver
def push(): Unit = {
solver.push()
}
def pop(lvl: Int = 1): Unit = {
solver.pop(lvl)
}
def assertCnstr(expression: Expr): Unit = {
solver.assertCnstr(expression)
}
def halt(): Unit = {
solver.halt()
}
def check: Option[Boolean] = {
withTimeout(solver.halt){
solver.check
}
}
def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = {
withTimeout(solver.halt){
solver.checkAssumptions(assumptions)
}
}
def getModel: Map[Identifier, Expr] = {
solver.getModel
}
def getUnsatCore: Set[Expr] = {
solver.getUnsatCore
}
}
override def init() {
solver.init
}
override def halt() {
solver.halt
}
......
......@@ -95,7 +95,7 @@ class FairZ3Solver(context : LeonContext)
override def halt() {
super.halt
if(z3 ne null) {
z3.softCheckCancel
z3.interrupt
}
}
......@@ -257,6 +257,10 @@ class FairZ3Solver(context : LeonContext)
frameExpressions = Nil :: frameExpressions
}
def halt() {
z3.interrupt
}
def pop(lvl: Int = 1) {
// We make sure we discard the expressions guarded by this frame
solver.assertCnstr(z3.mkNot(frameGuards.head))
......
......@@ -86,6 +86,10 @@ class UninterpretedZ3Solver(context : LeonContext) extends Solver(context) with
solver.push
}
def halt() {
z3.interrupt
}
def pop(lvl: Int = 1) {
solver.pop(lvl)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment