From b53f54890e8fd422676c26aec239ad91e82a0bca Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Thu, 4 Apr 2013 15:21:02 +0200
Subject: [PATCH] Improved, cached, datagen

---
 src/main/scala/leon/purescala/DataGen.scala | 53 +++++++++++++++++----
 1 file changed, 44 insertions(+), 9 deletions(-)

diff --git a/src/main/scala/leon/purescala/DataGen.scala b/src/main/scala/leon/purescala/DataGen.scala
index 53a2aa7a8..0ace125ba 100644
--- a/src/main/scala/leon/purescala/DataGen.scala
+++ b/src/main/scala/leon/purescala/DataGen.scala
@@ -16,21 +16,55 @@ import scala.collection.mutable.{Map=>MutableMap}
   * e.g. by passing trees representing variables for the "bounds". */
 object DataGen {
   private val defaultBounds : Map[TypeTree,Seq[Expr]] = Map(
-    Int32Type -> Seq(IntLiteral(0), IntLiteral(1), IntLiteral(-1), IntLiteral(2))
+    Int32Type -> Seq(IntLiteral(0), IntLiteral(1), IntLiteral(-1))
   )
 
   private val boolStream : Stream[Expr] =
     Stream.cons(BooleanLiteral(true), Stream.cons(BooleanLiteral(false), Stream.empty))
 
+  class VectorizedStream[T](initial : Stream[T]) {
+    private def mkException(i : Int) = new IndexOutOfBoundsException("Can't access VectorizedStream at : " + i)
+    private def streamHeadIndex : Int = indexed.size
+    private var stream  : Stream[T] = initial
+    private var indexed : Vector[T] = Vector.empty
+
+    def apply(index : Int) : T = {
+      if(index < streamHeadIndex) {
+        indexed(index)
+      } else {
+        val diff = index - streamHeadIndex // diff >= 0
+        var i = 0
+        while(i < diff) {
+          if(stream.isEmpty) throw mkException(index)
+          indexed = indexed :+ stream.head
+          stream  = stream.tail
+          i += 1
+        }
+        // The trick is *not* to read past the desired element. Leave it in the
+        // stream, or it will force the *following* one...
+        stream.headOption.getOrElse { throw mkException(index) }
+      }
+    }
+  }
+
   private def intStream : Stream[Expr] = Stream.cons(IntLiteral(0), intStream0(1))
   private def intStream0(n : Int) : Stream[Expr] = Stream.cons(IntLiteral(n), intStream0(if(n > 0) -n else -(n-1)))
 
-  private def natStream : Stream[Expr] = natStream0(0)
+  def natStream : Stream[Expr] = natStream0(0)
   private def natStream0(n : Int) : Stream[Expr] = Stream.cons(IntLiteral(n), natStream0(n+1))
 
-  // TODO can we cache something, maybe? It seems like every type should correspond to a unique stream?
-  // We should make sure the cache depends on the bounds (i.e. is not reused for different bounds.)
-  def generate(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Expr] = bounds.get(tpe).map(_.toStream).getOrElse {
+  private val streamCache : MutableMap[TypeTree,Stream[Expr]] = MutableMap.empty
+
+  def generate(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Expr] = {
+    streamCache.getOrElse(tpe, {
+      val s = generate0(tpe, bounds)
+      streamCache(tpe) = s
+      s
+    })
+  }
+
+  // TODO We should make sure the cache depends on the bounds (i.e. is not reused for different bounds.)
+  private def generate0(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]]) : Stream[Expr] = bounds.get(tpe).map(_.toStream).getOrElse {
     tpe match {
       case BooleanType =>
         boolStream
@@ -75,15 +109,15 @@ object DataGen {
     }
   }
 
-  def findModels(expr : Expr, evaluator : Evaluator, maxModels : Int, maxTries : Int, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Map[Identifier,Expr]] = {
-    val freeVars : Seq[Identifier] = variablesOf(expr).toSeq
+  def findModels(expr : Expr, evaluator : Evaluator, maxModels : Int, maxTries : Int, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds, forcedFreeVars: Option[Seq[Identifier]] = None) : Stream[Map[Identifier,Expr]] = {
+    val freeVars : Seq[Identifier] = forcedFreeVars.getOrElse(variablesOf(expr).toSeq)
 
     evaluator.compile(expr, freeVars).map { evalFun =>
       val sat = EvaluationResults.Successful(BooleanLiteral(true))
 
       naryProduct(freeVars.map(id => generate(id.getType, bounds)))
         .take(maxTries)
-        .filter(s => evalFun(s) == sat)
+        .filter{s => evalFun(s) == sat }
         .take(maxModels)
         .map(s => freeVars.zip(s).toMap)
 
@@ -111,6 +145,7 @@ object DataGen {
   // Takes a series of streams and enumerates their cartesian product.
   private def naryProduct[T](streams : Seq[Stream[T]]) : Stream[List[T]] = {
     val dimensions = streams.size
+    val vectorizedStreams = streams.map(new VectorizedStream(_))
 
     if(dimensions == 0)
       return Stream.cons(Nil, Stream.empty)
@@ -132,7 +167,7 @@ object DataGen {
       var d = 0
       var continue = true
       var is = indexList
-      var ss = streams.toList
+      var ss = vectorizedStreams.toList
 
       if(indexList.sum >= bounds.max) {
         allReached = true
-- 
GitLab