From 3114ffa29ac604a0a7dd1e9fdc60061c3b8f8da4 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Mon, 7 Nov 2016 08:33:40 +0100
Subject: [PATCH] Setup fine(r) grained interrupts in unrolling solver

---
 .../solvers/unrolling/FunctionTemplates.scala | 15 +++-
 .../solvers/unrolling/LambdaTemplates.scala   | 82 ++++++++++++-------
 .../unrolling/QuantificationTemplates.scala   | 31 ++++---
 .../inox/solvers/unrolling/Templates.scala    |  2 +
 .../solvers/unrolling/UnrollingSolver.scala   |  3 +-
 .../inox/solvers/z3/NativeZ3Solver.scala      |  1 +
 6 files changed, 93 insertions(+), 41 deletions(-)

diff --git a/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala b/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala
index 5c0b037c0..29965f7df 100644
--- a/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala
+++ b/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala
@@ -6,6 +6,8 @@ package unrolling
 
 import utils._
 
+import scala.collection.mutable.{Set => MutableSet}
+
 trait FunctionTemplates { self: Templates =>
   import program._
   import program.trees._
@@ -127,10 +129,14 @@ trait FunctionTemplates { self: Templates =>
 
       val newClauses = new scala.collection.mutable.ListBuffer[Encoded]
 
-      val newCallInfos = blockers.flatMap(id => callInfos.get(id).map(id -> _))
+      val thisCallInfos = blockers.flatMap(id => callInfos.get(id).map(id -> _))
       callInfos --= blockers
 
-      for ((blocker, (gen, _, _, calls)) <- newCallInfos; call @ Call(tfd, args) <- calls) {
+      val remainingBlockers = MutableSet.empty ++ blockers
+
+      for ((blocker, (gen, _, _, calls)) <- thisCallInfos if calls.nonEmpty && !interrupted;
+           _ = remainingBlockers -= blocker;
+           call @ Call(tfd, args) <- calls) {
         val newCls = new scala.collection.mutable.ListBuffer[Encoded]
 
         val defBlocker = defBlockers.get(call) match {
@@ -166,6 +172,11 @@ trait FunctionTemplates { self: Templates =>
         newClauses ++= newCls
       }
 
+      for ((b, (gen, origGen, notB, calls)) <- thisCallInfos if remainingBlockers(b)) callInfos.get(b) match {
+        case Some((newGen, _, _, newCalls)) => callInfos += b -> (gen min newGen, origGen, notB, calls ++ newCalls)
+        case None => callInfos += b -> (gen, origGen, notB, calls)
+      }
+
       ctx.reporter.debug(s"   - ${newClauses.size} new clauses")
 
       newClauses.toSeq
diff --git a/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala b/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala
index 0f6123363..e82f6c25c 100644
--- a/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala
+++ b/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala
@@ -506,50 +506,76 @@ trait LambdaTemplates { self: Templates =>
         (gen, infos)
       })
 
+      val remainingApps = MutableSet.empty ++ apps
+
       blockerToApps --= blockers
       appInfos --= apps
 
-      for ((app, (_, infos)) <- thisAppInfos if infos.nonEmpty) {
+      val newBlockers = (for ((app, (_, infos)) <- thisAppInfos if infos.nonEmpty) yield {
         val nextB = encodeSymbol(Variable(FreshIdentifier("b_lambda", true), BooleanType))
-        val extension = mkOr((infos.map(_.equals).toSeq :+ nextB) : _*)
-        val clause = mkEquals(appBlockers(app), extension)
+        val lastB = appBlockers(app)
 
-        appBlockers += app -> nextB
-        blockerToApps -= appBlockers(app)
         blockerToApps += nextB -> app
+        appBlockers += app -> nextB
 
-        ctx.reporter.debug(" -> extending lambda blocker: " + clause)
-        newClauses += clause
-      }
+        app -> ((lastB, nextB))
+      }).toMap
 
-      for ((app @ (b, _), (gen, infos)) <- thisAppInfos;
-           info @ TemplateAppInfo(tmpl, equals, args) <- infos;
-           template <- tmpl.left) {
-        val newCls = new scala.collection.mutable.ListBuffer[Encoded]
+      for ((app @ (b, _), (gen, infos)) <- thisAppInfos if infos.nonEmpty) {
+        val (lastB, nextB) = newBlockers(app)
+        if (interrupted) {
+          newClauses += mkEquals(lastB, nextB)
+        } else {
+          remainingApps -= app
 
-        val lambdaBlocker = lambdaBlockers.get(info) match {
-          case Some(lambdaBlocker) => lambdaBlocker
+          val extension = mkOr((infos.map(info => info.template match {
+            case Left(template) => mkAnd(template.start, info.equals)
+            case Right(_) => info.equals
+          }).toSeq :+ nextB) : _*)
 
-          case None =>
-            val lambdaBlocker = encodeSymbol(Variable(FreshIdentifier("d", true), BooleanType))
-            lambdaBlockers += info -> lambdaBlocker
+          val clause = mkEquals(lastB, extension)
+          ctx.reporter.debug(" -> extending lambda blocker: " + clause)
+          newClauses += clause
 
-            val instClauses: Clauses = template.instantiate(lambdaBlocker, args)
+          for (info @ TemplateAppInfo(tmpl, equals, args) <- infos; template <- tmpl.left) {
+            val newCls = new scala.collection.mutable.ListBuffer[Encoded]
 
-            newCls ++= instClauses
-            lambdaBlocker
-        }
+            val lambdaBlocker = lambdaBlockers.get(info) match {
+              case Some(lambdaBlocker) => lambdaBlocker
+
+              case None =>
+                val lambdaBlocker = encodeSymbol(Variable(FreshIdentifier("d", true), BooleanType))
+                lambdaBlockers += info -> lambdaBlocker
 
-        val enabler = if (equals == trueT) b else mkAnd(equals, b)
-        registerImplication(b, lambdaBlocker)
-        newCls += mkImplies(enabler, lambdaBlocker)
+                val instClauses: Clauses = template.instantiate(lambdaBlocker, args)
 
-        ctx.reporter.debug("Unrolling behind "+info+" ("+newCls.size+")")
-        for (cl <- newCls) {
-          ctx.reporter.debug("  . "+cl)
+                newCls ++= instClauses
+                lambdaBlocker
+            }
+
+            val enabler = if (equals == trueT) b else mkAnd(equals, b)
+            registerImplication(b, lambdaBlocker)
+            newCls += mkImplies(enabler, lambdaBlocker)
+
+            ctx.reporter.debug("Unrolling behind "+info+" ("+newCls.size+")")
+            for (cl <- newCls) {
+              ctx.reporter.debug("  . "+cl)
+            }
+
+            newClauses ++= newCls
+          }
         }
+      }
 
-        newClauses ++= newCls
+      val remainingInfos = thisAppInfos.filter { case (app, _) => remainingApps(app) }
+      for ((app, (gen, infos)) <- thisAppInfos if remainingApps(app)) appInfos.get(app) match {
+        case Some((newGen, origGen, b, notB, newInfos)) =>
+          appInfos += app -> (gen min newGen, origGen, b, notB, infos ++ newInfos)
+          
+        case None =>
+          val b = appBlockers(app)
+          val notB = mkNot(b)
+          appInfos += app -> (gen, gen, b, notB, infos)
       }
 
       ctx.reporter.debug(s"   - ${newClauses.size} new clauses")
diff --git a/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala
index 6eca9bbb0..fdb065d00 100644
--- a/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala
+++ b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala
@@ -247,19 +247,21 @@ trait QuantificationTemplates { self: Templates =>
     def promoteBlocker(b: Encoded): Boolean = false
 
     def unroll: Clauses = {
-      val imClauses = new scala.collection.mutable.ListBuffer[Encoded]
-      for (e @ (gen, bs, m) <- ignoredMatchers.toSeq if gen <= currentGeneration) {
-        imClauses ++= instantiateMatcher(bs, m, defer = true)
+      val clauses = new scala.collection.mutable.ListBuffer[Encoded]
+      for (e @ (gen, bs, m) <- ignoredMatchers.toSeq if gen <= currentGeneration && !interrupted) {
+        clauses ++= instantiateMatcher(bs, m, defer = true)
         ignoredMatchers -= e
       }
 
-      ctx.reporter.debug("Unrolling ignored matchers (" + imClauses.size + ")")
-      for (cl <- imClauses) {
+      ctx.reporter.debug("Unrolling ignored matchers (" + clauses.size + ")")
+      for (cl <- clauses) {
         ctx.reporter.debug("  . " + cl)
       }
 
+      if (interrupted) return clauses.toSeq
+
       val suClauses = new scala.collection.mutable.ListBuffer[Encoded]
-      for (q <- quantifications.toSeq if ignoredSubsts.isDefinedAt(q)) {
+      for (q <- quantifications.toSeq if ignoredSubsts.isDefinedAt(q) && !interrupted) {
         val (release, keep) = ignoredSubsts(q).partition(_._1 <= currentGeneration)
         ignoredSubsts += q -> keep
 
@@ -273,8 +275,12 @@ trait QuantificationTemplates { self: Templates =>
         ctx.reporter.debug("  . " + cl)
       }
 
+      clauses ++= suClauses
+
+      if (interrupted) return clauses.toSeq
+
       val grClauses = new scala.collection.mutable.ListBuffer[Encoded]
-      for ((gen, qs) <- ignoredGrounds.toSeq if gen <= currentGeneration; q <- qs) {
+      for ((gen, qs) <- ignoredGrounds.toSeq if gen <= currentGeneration && !interrupted; q <- qs) {
         grClauses ++= q.ensureGrounds
         val remaining = ignoredGrounds.getOrElse(gen, Set.empty) - q
         if (remaining.nonEmpty) {
@@ -289,7 +295,9 @@ trait QuantificationTemplates { self: Templates =>
         ctx.reporter.debug("  . " + cl)
       }
 
-      imClauses.toSeq ++ suClauses ++ grClauses
+      clauses ++= grClauses
+
+      clauses.toSeq
     }
   }
 
@@ -308,6 +316,9 @@ trait QuantificationTemplates { self: Templates =>
 
     if (handledMatchers(relevantBlockers -> matcher)) {
       Seq.empty
+    } else if (interrupted) {
+      ignoredMatchers += ((currentGeneration + 1, blockers, matcher))
+      Seq.empty
     } else {
       ctx.reporter.debug(" -> instantiating matcher " + blockers.mkString("{",",","}") + " ==> " + matcher)
       handledMatchers += relevantBlockers -> matcher
@@ -657,7 +668,7 @@ trait QuantificationTemplates { self: Templates =>
       val instantiation = new scala.collection.mutable.ListBuffer[Encoded]
 
       for (p @ (bs, subst, delay) <- substs if !handledSubsts.get(this).exists(_ contains (bs -> subst))) {
-        if (delay > 0) {
+        if (interrupted || delay > 0) {
           val gen = currentGeneration + delay + (if (getPolarity.isEmpty) 2 else 0)
           ignoredSubsts += this -> (ignoredSubsts.getOrElse(this, Set.empty) + ((gen, bs, subst)))
         } else {
@@ -691,7 +702,7 @@ trait QuantificationTemplates { self: Templates =>
         val sb = bs ++ (if (b == guard) Set.empty else Set(substituter(b)))
         val sm = m.substitute(substituter, msubst)
 
-        if (b != guard) {
+        if (interrupted || b != guard) {
           val gen = currentGeneration + 1
           ignoredMatchers += ((gen, sb, sm))
         } else {
diff --git a/src/main/scala/inox/solvers/unrolling/Templates.scala b/src/main/scala/inox/solvers/unrolling/Templates.scala
index 34406c5a6..d2579e9a0 100644
--- a/src/main/scala/inox/solvers/unrolling/Templates.scala
+++ b/src/main/scala/inox/solvers/unrolling/Templates.scala
@@ -24,6 +24,8 @@ trait Templates extends TemplateGenerator
 
   def asString(e: Encoded): String
 
+  def interrupted: Boolean
+
   def encodeSymbol(v: Variable): Encoded
   def mkEncoder(bindings: Map[Variable, Encoded])(e: Expr): Encoded
   def mkSubstituter(map: Map[Encoded, Encoded]): Encoded => Encoded
diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala
index e08ece074..b21b0ce92 100644
--- a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala
+++ b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala
@@ -601,7 +601,7 @@ trait AbstractUnrollingSolver extends Solver { self =>
   }
 }
 
-trait UnrollingSolver extends AbstractUnrollingSolver {
+trait UnrollingSolver extends AbstractUnrollingSolver { self =>
   import program._
   import program.trees._
   import program.symbols._
@@ -629,6 +629,7 @@ trait UnrollingSolver extends AbstractUnrollingSolver {
     type Encoded = Expr
 
     def asString(expr: Expr): String = expr.asString
+    def interrupted: Boolean = self.interrupted
 
     def encodeSymbol(v: Variable): Expr = v.freshen
     def mkEncoder(bindings: Map[Variable, Expr])(e: Expr): Expr =
diff --git a/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala b/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala
index 87f3f2183..3043b4c89 100644
--- a/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala
+++ b/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala
@@ -38,6 +38,7 @@ trait NativeZ3Solver extends AbstractUnrollingSolver { self =>
     type Encoded = self.Encoded
 
     def asString(ast: Z3AST): String = ast.toString
+    def interrupted: Boolean = self.interrupted
 
     def encodeSymbol(v: Variable): Z3AST = underlying.symbolToFreshZ3Symbol(v)
 
-- 
GitLab