From b53f54890e8fd422676c26aec239ad91e82a0bca Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <colder@php.net> Date: Thu, 4 Apr 2013 15:21:02 +0200 Subject: [PATCH] Improved, cached, datagen --- src/main/scala/leon/purescala/DataGen.scala | 53 +++++++++++++++++---- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/src/main/scala/leon/purescala/DataGen.scala b/src/main/scala/leon/purescala/DataGen.scala index 53a2aa7a8..0ace125ba 100644 --- a/src/main/scala/leon/purescala/DataGen.scala +++ b/src/main/scala/leon/purescala/DataGen.scala @@ -16,21 +16,55 @@ import scala.collection.mutable.{Map=>MutableMap} * e.g. by passing trees representing variables for the "bounds". */ object DataGen { private val defaultBounds : Map[TypeTree,Seq[Expr]] = Map( - Int32Type -> Seq(IntLiteral(0), IntLiteral(1), IntLiteral(-1), IntLiteral(2)) + Int32Type -> Seq(IntLiteral(0), IntLiteral(1), IntLiteral(-1)) ) private val boolStream : Stream[Expr] = Stream.cons(BooleanLiteral(true), Stream.cons(BooleanLiteral(false), Stream.empty)) + class VectorizedStream[T](initial : Stream[T]) { + private def mkException(i : Int) = new IndexOutOfBoundsException("Can't access VectorizedStream at : " + i) + private def streamHeadIndex : Int = indexed.size + private var stream : Stream[T] = initial + private var indexed : Vector[T] = Vector.empty + + def apply(index : Int) : T = { + if(index < streamHeadIndex) { + indexed(index) + } else { + val diff = index - streamHeadIndex // diff >= 0 + var i = 0 + while(i < diff) { + if(stream.isEmpty) throw mkException(index) + indexed = indexed :+ stream.head + stream = stream.tail + i += 1 + } + // The trick is *not* to read past the desired element. Leave it in the + // stream, or it will force the *following* one... + stream.headOption.getOrElse { throw mkException(index) } + } + } + } + private def intStream : Stream[Expr] = Stream.cons(IntLiteral(0), intStream0(1)) private def intStream0(n : Int) : Stream[Expr] = Stream.cons(IntLiteral(n), intStream0(if(n > 0) -n else -(n-1))) - private def natStream : Stream[Expr] = natStream0(0) + def natStream : Stream[Expr] = natStream0(0) private def natStream0(n : Int) : Stream[Expr] = Stream.cons(IntLiteral(n), natStream0(n+1)) - // TODO can we cache something, maybe? It seems like every type should correspond to a unique stream? - // We should make sure the cache depends on the bounds (i.e. is not reused for different bounds.) - def generate(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Expr] = bounds.get(tpe).map(_.toStream).getOrElse { + private val streamCache : MutableMap[TypeTree,Stream[Expr]] = MutableMap.empty + + def generate(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Expr] = { + streamCache.getOrElse(tpe, { + val s = generate0(tpe, bounds) + streamCache(tpe) = s + s + }) + } + + // TODO We should make sure the cache depends on the bounds (i.e. is not reused for different bounds.) + private def generate0(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]]) : Stream[Expr] = bounds.get(tpe).map(_.toStream).getOrElse { tpe match { case BooleanType => boolStream @@ -75,15 +109,15 @@ object DataGen { } } - def findModels(expr : Expr, evaluator : Evaluator, maxModels : Int, maxTries : Int, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Map[Identifier,Expr]] = { - val freeVars : Seq[Identifier] = variablesOf(expr).toSeq + def findModels(expr : Expr, evaluator : Evaluator, maxModels : Int, maxTries : Int, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds, forcedFreeVars: Option[Seq[Identifier]] = None) : Stream[Map[Identifier,Expr]] = { + val freeVars : Seq[Identifier] = forcedFreeVars.getOrElse(variablesOf(expr).toSeq) evaluator.compile(expr, freeVars).map { evalFun => val sat = EvaluationResults.Successful(BooleanLiteral(true)) naryProduct(freeVars.map(id => generate(id.getType, bounds))) .take(maxTries) - .filter(s => evalFun(s) == sat) + .filter{s => evalFun(s) == sat } .take(maxModels) .map(s => freeVars.zip(s).toMap) @@ -111,6 +145,7 @@ object DataGen { // Takes a series of streams and enumerates their cartesian product. private def naryProduct[T](streams : Seq[Stream[T]]) : Stream[List[T]] = { val dimensions = streams.size + val vectorizedStreams = streams.map(new VectorizedStream(_)) if(dimensions == 0) return Stream.cons(Nil, Stream.empty) @@ -132,7 +167,7 @@ object DataGen { var d = 0 var continue = true var is = indexList - var ss = streams.toList + var ss = vectorizedStreams.toList if(indexList.sum >= bounds.max) { allReached = true -- GitLab