From c2680c6618f5219481e1714d0be540900dc7ab47 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Mon, 14 Jan 2013 18:21:12 +0100
Subject: [PATCH] TimeoutSolver Improvements

Timeouts are now specified in milliseconds instead of seconds.

TimeoutSolvers that hit a timeout no longer makes the wrapped solver
useless for all subsequent invocations.
---
 .../leon/solvers/IncrementalSolver.scala      |  4 ++-
 .../leon/solvers/InterruptibleSolver.scala    |  7 ++++
 src/main/scala/leon/solvers/Solver.scala      |  2 +-
 .../scala/leon/solvers/TimeoutSolver.scala    | 34 +++++++++++++------
 .../leon/verification/AnalysisPhase.scala     |  2 +-
 .../test/solvers/TimeoutSolverTests.scala     |  2 +-
 6 files changed, 37 insertions(+), 14 deletions(-)
 create mode 100644 src/main/scala/leon/solvers/InterruptibleSolver.scala

diff --git a/src/main/scala/leon/solvers/IncrementalSolver.scala b/src/main/scala/leon/solvers/IncrementalSolver.scala
index cab715f4b..84215134e 100644
--- a/src/main/scala/leon/solvers/IncrementalSolver.scala
+++ b/src/main/scala/leon/solvers/IncrementalSolver.scala
@@ -10,7 +10,7 @@ trait IncrementalSolverBuilder {
   def getNewSolver: IncrementalSolver
 }
 
-trait IncrementalSolver {
+trait IncrementalSolver extends InterruptibleSolver {
   // New Solver API
   // Moslty for z3 solvers since z3 4.3
 
@@ -19,6 +19,8 @@ trait IncrementalSolver {
   def assertCnstr(expression: Expr): Unit
 
   def halt(): Unit
+  def init(): Unit = {}
+
   def check: Option[Boolean]
   def checkAssumptions(assumptions: Set[Expr]): Option[Boolean]
   def getModel: Map[Identifier, Expr]
diff --git a/src/main/scala/leon/solvers/InterruptibleSolver.scala b/src/main/scala/leon/solvers/InterruptibleSolver.scala
new file mode 100644
index 000000000..01620279c
--- /dev/null
+++ b/src/main/scala/leon/solvers/InterruptibleSolver.scala
@@ -0,0 +1,7 @@
+package leon
+package solvers
+
+trait InterruptibleSolver {
+  def halt(): Unit
+  def init(): Unit
+}
diff --git a/src/main/scala/leon/solvers/Solver.scala b/src/main/scala/leon/solvers/Solver.scala
index 777e4dc1c..01a1e7336 100644
--- a/src/main/scala/leon/solvers/Solver.scala
+++ b/src/main/scala/leon/solvers/Solver.scala
@@ -6,7 +6,7 @@ import purescala.Definitions._
 import purescala.TreeOps._
 import purescala.Trees._
 
-abstract class Solver(val context : LeonContext) extends IncrementalSolverBuilder with LeonComponent {
+abstract class Solver(val context : LeonContext) extends IncrementalSolverBuilder with InterruptibleSolver with LeonComponent {
   // This can be used by solvers to "see" the programs from which the
   // formulas come. (e.g. to set up some datastructures for the defined
   // ADTs, etc.) 
diff --git a/src/main/scala/leon/solvers/TimeoutSolver.scala b/src/main/scala/leon/solvers/TimeoutSolver.scala
index 17f992a77..22a257031 100644
--- a/src/main/scala/leon/solvers/TimeoutSolver.scala
+++ b/src/main/scala/leon/solvers/TimeoutSolver.scala
@@ -8,12 +8,12 @@ import purescala.TypeTrees._
 
 import scala.sys.error
 
-class TimeoutSolver(solver : Solver with  IncrementalSolverBuilder, timeout : Int) extends Solver(solver.context) with IncrementalSolverBuilder {
+class TimeoutSolver(solver : Solver with  IncrementalSolverBuilder, timeoutMs : Long) extends Solver(solver.context) with IncrementalSolverBuilder {
   // I'm making this an inner class to fight the temptation of using it for anything meaningful.
   // We have Akka, these days, which whould be better in any respect for non-trivial things.
   private class Timer(onTimeout: => Unit) extends Thread {
     private var keepRunning = true
-    private val asMillis : Long = 1000L * timeout
+    private val asMillis : Long = timeoutMs
 
     override def run : Unit = {
       val startTime : Long = System.currentTimeMillis
@@ -23,7 +23,7 @@ class TimeoutSolver(solver : Solver with  IncrementalSolverBuilder, timeout : In
         if(asMillis < (System.currentTimeMillis - startTime)) {
           exceeded = true
         }
-        Thread.sleep(10) 
+        Thread.sleep(10)
       }
       if(exceeded && keepRunning) {
         onTimeout
@@ -35,14 +35,28 @@ class TimeoutSolver(solver : Solver with  IncrementalSolverBuilder, timeout : In
     }
   }
 
-  def withTimeout[T](onTimeout: => Unit)(body: => T): T = {
-    val timer = new Timer(onTimeout)
+  def withTimeout[T](solver: InterruptibleSolver)(body: => T): T = {
+    val timer = new Timer(timeout(solver))
     timer.start
     val res = body
     timer.halt
+    recoverFromTimeout(solver)
     res
   }
 
+  var reachedTimeout = false
+  def timeout(solver: InterruptibleSolver) {
+    solver.halt
+    reachedTimeout = true
+  }
+
+  def recoverFromTimeout(solver: InterruptibleSolver) {
+    if (reachedTimeout) {
+      solver.init
+      reachedTimeout = false
+    }
+  }
+
   val description = solver.description + ", with timeout"
   val name = solver.name + "+to"
 
@@ -51,19 +65,19 @@ class TimeoutSolver(solver : Solver with  IncrementalSolverBuilder, timeout : In
   }
 
   def solve(expression: Expr) : Option[Boolean] = {
-    withTimeout(solver.halt) {
+    withTimeout(solver) {
       solver.solve(expression)
     }
   }
 
   override def solveSAT(expression: Expr): (Option[Boolean], Map[Identifier, Expr]) = {
-    withTimeout(solver.halt) {
+    withTimeout(solver) {
       solver.solveSAT(expression)
     }
   }
 
   override def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = {
-    withTimeout(solver.halt) {
+    withTimeout(solver) {
       solver.solveSATWithCores(expression, assumptions)
     }
   }
@@ -88,13 +102,13 @@ class TimeoutSolver(solver : Solver with  IncrementalSolverBuilder, timeout : In
     }
 
     def check: Option[Boolean] = {
-      withTimeout(solver.halt){
+      withTimeout(solver){
         solver.check
       }
     }
 
     def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = {
-      withTimeout(solver.halt){
+      withTimeout(solver){
         solver.checkAssumptions(assumptions)
       }
     }
diff --git a/src/main/scala/leon/verification/AnalysisPhase.scala b/src/main/scala/leon/verification/AnalysisPhase.scala
index 146f8bb2d..5ac64e8da 100644
--- a/src/main/scala/leon/verification/AnalysisPhase.scala
+++ b/src/main/scala/leon/verification/AnalysisPhase.scala
@@ -42,7 +42,7 @@ object AnalysisPhase extends LeonPhase[Program,VerificationReport] {
 
     val solvers0 : Seq[Solver] = trivialSolver :: fairZ3 :: Nil
     val solvers: Seq[Solver] = timeout match {
-      case Some(t) => solvers0.map(s => new TimeoutSolver(s, t))
+      case Some(t) => solvers0.map(s => new TimeoutSolver(s, 1000L * t))
       case None => solvers0
     }
 
diff --git a/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala b/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala
index e94e071bc..38713552c 100644
--- a/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala
+++ b/src/test/scala/leon/test/solvers/TimeoutSolverTests.scala
@@ -27,7 +27,7 @@ class TimeoutSolverTests extends FunSuite {
   }
 
   private def getTOSolver : Solver = {
-    val s = new TimeoutSolver(new IdioticSolver(LeonContext()), 1)
+    val s = new TimeoutSolver(new IdioticSolver(LeonContext()), 1000L)
     s.setProgram(Program.empty)
     s
   }
-- 
GitLab