From 3f39702f0b906a46c2140580579e5bef765154fe Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Fri, 29 Apr 2016 16:21:14 +0200
Subject: [PATCH] Datagens should respect interrupts

---
 src/main/scala/leon/datagen/GrammarDataGen.scala  |  3 ++-
 src/main/scala/leon/datagen/SolverDataGen.scala   |  4 ++--
 src/main/scala/leon/utils/FreeableIterator.scala  | 15 ++++++++++-----
 .../verification/XLangVerificationSuite.scala     |  1 -
 4 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/src/main/scala/leon/datagen/GrammarDataGen.scala b/src/main/scala/leon/datagen/GrammarDataGen.scala
index 619b76606..29e102281 100644
--- a/src/main/scala/leon/datagen/GrammarDataGen.scala
+++ b/src/main/scala/leon/datagen/GrammarDataGen.scala
@@ -55,7 +55,7 @@ class GrammarDataGen(evaluator: Evaluator, grammar: ExpressionGrammar = ValueGra
 
   def generate(tpe: TypeTree): Iterator[Expr] = {
     val enum = new MemoizedEnumerator[Label, Expr, ProductionRule[Label, Expr]](grammar.getProductions)
-    enum.iterator(Label(tpe)).flatMap(expandGenerics)
+    enum.iterator(Label(tpe)).flatMap(expandGenerics).takeWhile(_ => !interrupted.get)
   }
 
   def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]] = {
@@ -82,6 +82,7 @@ class GrammarDataGen(evaluator: Evaluator, grammar: ExpressionGrammar = ValueGra
       detupled.take(maxEnumerated)
               .filter(filterCond)
               .take(maxValid)
+              .takeWhile(_ => !interrupted.get)
     }
   }
 
diff --git a/src/main/scala/leon/datagen/SolverDataGen.scala b/src/main/scala/leon/datagen/SolverDataGen.scala
index 34d94c1a9..37dcfba06 100644
--- a/src/main/scala/leon/datagen/SolverDataGen.scala
+++ b/src/main/scala/leon/datagen/SolverDataGen.scala
@@ -15,7 +15,7 @@ class SolverDataGen(ctx: LeonContext, pgm: Program, sf: SolverFactory[Solver]) e
   implicit val ctx0 = ctx
 
   def generate(tpe: TypeTree): FreeableIterator[Expr] = {
-    generateFor(Seq(FreshIdentifier("tmp", tpe)), BooleanLiteral(true), 20, 20).map(_.head)
+    generateFor(Seq(FreshIdentifier("tmp", tpe)), BooleanLiteral(true), 20, 20).map(_.head).takeWhile(_ => !interrupted.get)
   }
 
   def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): FreeableIterator[Seq[Expr]] = {
@@ -76,7 +76,7 @@ class SolverDataGen(ctx: LeonContext, pgm: Program, sf: SolverFactory[Solver]) e
 
       val enum = modelEnum.enumVarying(ins, satisfying, sizeOf, 5)
 
-      enum.take(maxValid).map(model => ins.map(model))
+      enum.take(maxValid).map(model => ins.map(model)).takeWhile(_ => !interrupted.get)
     }
   }
 
diff --git a/src/main/scala/leon/utils/FreeableIterator.scala b/src/main/scala/leon/utils/FreeableIterator.scala
index ad2942c93..9b3734270 100644
--- a/src/main/scala/leon/utils/FreeableIterator.scala
+++ b/src/main/scala/leon/utils/FreeableIterator.scala
@@ -4,6 +4,8 @@ package leon
 package utils
 
 abstract class FreeableIterator[T] extends Iterator[T] {
+  orig =>
+
   private[this] var nextElem: Option[T] = None
 
   def hasNext = {
@@ -20,8 +22,6 @@ abstract class FreeableIterator[T] extends Iterator[T] {
   def free()
 
   override def map[B](f: T => B): FreeableIterator[B] = {
-    val orig = this
-
     new FreeableIterator[B] {
       def computeNext() = orig.computeNext.map(f)
       def free() = orig.free()
@@ -29,10 +29,8 @@ abstract class FreeableIterator[T] extends Iterator[T] {
   }
 
   override def take(n: Int): FreeableIterator[T] = {
-    val orig = this
-
     new FreeableIterator[T] {
-      private var c = 0;
+      private var c = 0
 
       def computeNext() = {
         if (c < n) {
@@ -47,6 +45,13 @@ abstract class FreeableIterator[T] extends Iterator[T] {
     }
   }
 
+  override def takeWhile(p: T => Boolean) = {
+    new FreeableIterator[T] {
+      def computeNext(): Option[T] = orig.computeNext.filter(p)
+      def free(): Unit = orig.free()
+    }
+  }
+
   override def toList: List[T] = {
     val res = super.toList
     free()
diff --git a/src/test/scala/leon/regression/verification/XLangVerificationSuite.scala b/src/test/scala/leon/regression/verification/XLangVerificationSuite.scala
index 104389eb7..839e45dba 100644
--- a/src/test/scala/leon/regression/verification/XLangVerificationSuite.scala
+++ b/src/test/scala/leon/regression/verification/XLangVerificationSuite.scala
@@ -2,7 +2,6 @@
 
 package leon.regression.verification
 
-import smtlib.interpreters.Z3Interpreter
 import leon.solvers.SolverFactory
 
 // If you add another regression test, make sure it contains exactly one object, whose name matches the file name.
-- 
GitLab