Skip to content
Snippets Groups Projects
Commit b53f5489 authored by Etienne Kneuss's avatar Etienne Kneuss Committed by Philippe Suter
Browse files

Improved, cached, datagen

parent 8b8addf3
No related branches found
No related tags found
No related merge requests found
......@@ -16,21 +16,55 @@ import scala.collection.mutable.{Map=>MutableMap}
* e.g. by passing trees representing variables for the "bounds". */
object DataGen {
private val defaultBounds : Map[TypeTree,Seq[Expr]] = Map(
Int32Type -> Seq(IntLiteral(0), IntLiteral(1), IntLiteral(-1), IntLiteral(2))
Int32Type -> Seq(IntLiteral(0), IntLiteral(1), IntLiteral(-1))
)
private val boolStream : Stream[Expr] =
Stream.cons(BooleanLiteral(true), Stream.cons(BooleanLiteral(false), Stream.empty))
class VectorizedStream[T](initial : Stream[T]) {
private def mkException(i : Int) = new IndexOutOfBoundsException("Can't access VectorizedStream at : " + i)
private def streamHeadIndex : Int = indexed.size
private var stream : Stream[T] = initial
private var indexed : Vector[T] = Vector.empty
def apply(index : Int) : T = {
if(index < streamHeadIndex) {
indexed(index)
} else {
val diff = index - streamHeadIndex // diff >= 0
var i = 0
while(i < diff) {
if(stream.isEmpty) throw mkException(index)
indexed = indexed :+ stream.head
stream = stream.tail
i += 1
}
// The trick is *not* to read past the desired element. Leave it in the
// stream, or it will force the *following* one...
stream.headOption.getOrElse { throw mkException(index) }
}
}
}
private def intStream : Stream[Expr] = Stream.cons(IntLiteral(0), intStream0(1))
private def intStream0(n : Int) : Stream[Expr] = Stream.cons(IntLiteral(n), intStream0(if(n > 0) -n else -(n-1)))
private def natStream : Stream[Expr] = natStream0(0)
def natStream : Stream[Expr] = natStream0(0)
private def natStream0(n : Int) : Stream[Expr] = Stream.cons(IntLiteral(n), natStream0(n+1))
// TODO can we cache something, maybe? It seems like every type should correspond to a unique stream?
// We should make sure the cache depends on the bounds (i.e. is not reused for different bounds.)
def generate(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Expr] = bounds.get(tpe).map(_.toStream).getOrElse {
private val streamCache : MutableMap[TypeTree,Stream[Expr]] = MutableMap.empty
def generate(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Expr] = {
streamCache.getOrElse(tpe, {
val s = generate0(tpe, bounds)
streamCache(tpe) = s
s
})
}
// TODO We should make sure the cache depends on the bounds (i.e. is not reused for different bounds.)
private def generate0(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]]) : Stream[Expr] = bounds.get(tpe).map(_.toStream).getOrElse {
tpe match {
case BooleanType =>
boolStream
......@@ -75,15 +109,15 @@ object DataGen {
}
}
def findModels(expr : Expr, evaluator : Evaluator, maxModels : Int, maxTries : Int, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Map[Identifier,Expr]] = {
val freeVars : Seq[Identifier] = variablesOf(expr).toSeq
def findModels(expr : Expr, evaluator : Evaluator, maxModels : Int, maxTries : Int, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds, forcedFreeVars: Option[Seq[Identifier]] = None) : Stream[Map[Identifier,Expr]] = {
val freeVars : Seq[Identifier] = forcedFreeVars.getOrElse(variablesOf(expr).toSeq)
evaluator.compile(expr, freeVars).map { evalFun =>
val sat = EvaluationResults.Successful(BooleanLiteral(true))
naryProduct(freeVars.map(id => generate(id.getType, bounds)))
.take(maxTries)
.filter(s => evalFun(s) == sat)
.filter{s => evalFun(s) == sat }
.take(maxModels)
.map(s => freeVars.zip(s).toMap)
......@@ -111,6 +145,7 @@ object DataGen {
// Takes a series of streams and enumerates their cartesian product.
private def naryProduct[T](streams : Seq[Stream[T]]) : Stream[List[T]] = {
val dimensions = streams.size
val vectorizedStreams = streams.map(new VectorizedStream(_))
if(dimensions == 0)
return Stream.cons(Nil, Stream.empty)
......@@ -132,7 +167,7 @@ object DataGen {
var d = 0
var continue = true
var is = indexList
var ss = streams.toList
var ss = vectorizedStreams.toList
if(indexList.sum >= bounds.max) {
allReached = true
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment