Skip to content
Snippets Groups Projects
Commit b53f5489 authored by Etienne Kneuss's avatar Etienne Kneuss Committed by Philippe Suter
Browse files

Improved, cached, datagen

parent 8b8addf3
No related branches found
No related tags found
No related merge requests found
...@@ -16,21 +16,55 @@ import scala.collection.mutable.{Map=>MutableMap} ...@@ -16,21 +16,55 @@ import scala.collection.mutable.{Map=>MutableMap}
* e.g. by passing trees representing variables for the "bounds". */ * e.g. by passing trees representing variables for the "bounds". */
object DataGen { object DataGen {
private val defaultBounds : Map[TypeTree,Seq[Expr]] = Map( private val defaultBounds : Map[TypeTree,Seq[Expr]] = Map(
Int32Type -> Seq(IntLiteral(0), IntLiteral(1), IntLiteral(-1), IntLiteral(2)) Int32Type -> Seq(IntLiteral(0), IntLiteral(1), IntLiteral(-1))
) )
private val boolStream : Stream[Expr] = private val boolStream : Stream[Expr] =
Stream.cons(BooleanLiteral(true), Stream.cons(BooleanLiteral(false), Stream.empty)) Stream.cons(BooleanLiteral(true), Stream.cons(BooleanLiteral(false), Stream.empty))
class VectorizedStream[T](initial : Stream[T]) {
private def mkException(i : Int) = new IndexOutOfBoundsException("Can't access VectorizedStream at : " + i)
private def streamHeadIndex : Int = indexed.size
private var stream : Stream[T] = initial
private var indexed : Vector[T] = Vector.empty
def apply(index : Int) : T = {
if(index < streamHeadIndex) {
indexed(index)
} else {
val diff = index - streamHeadIndex // diff >= 0
var i = 0
while(i < diff) {
if(stream.isEmpty) throw mkException(index)
indexed = indexed :+ stream.head
stream = stream.tail
i += 1
}
// The trick is *not* to read past the desired element. Leave it in the
// stream, or it will force the *following* one...
stream.headOption.getOrElse { throw mkException(index) }
}
}
}
private def intStream : Stream[Expr] = Stream.cons(IntLiteral(0), intStream0(1)) private def intStream : Stream[Expr] = Stream.cons(IntLiteral(0), intStream0(1))
private def intStream0(n : Int) : Stream[Expr] = Stream.cons(IntLiteral(n), intStream0(if(n > 0) -n else -(n-1))) private def intStream0(n : Int) : Stream[Expr] = Stream.cons(IntLiteral(n), intStream0(if(n > 0) -n else -(n-1)))
private def natStream : Stream[Expr] = natStream0(0) def natStream : Stream[Expr] = natStream0(0)
private def natStream0(n : Int) : Stream[Expr] = Stream.cons(IntLiteral(n), natStream0(n+1)) private def natStream0(n : Int) : Stream[Expr] = Stream.cons(IntLiteral(n), natStream0(n+1))
// TODO can we cache something, maybe? It seems like every type should correspond to a unique stream? private val streamCache : MutableMap[TypeTree,Stream[Expr]] = MutableMap.empty
// We should make sure the cache depends on the bounds (i.e. is not reused for different bounds.)
def generate(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Expr] = bounds.get(tpe).map(_.toStream).getOrElse { def generate(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Expr] = {
streamCache.getOrElse(tpe, {
val s = generate0(tpe, bounds)
streamCache(tpe) = s
s
})
}
// TODO We should make sure the cache depends on the bounds (i.e. is not reused for different bounds.)
private def generate0(tpe : TypeTree, bounds : Map[TypeTree,Seq[Expr]]) : Stream[Expr] = bounds.get(tpe).map(_.toStream).getOrElse {
tpe match { tpe match {
case BooleanType => case BooleanType =>
boolStream boolStream
...@@ -75,15 +109,15 @@ object DataGen { ...@@ -75,15 +109,15 @@ object DataGen {
} }
} }
def findModels(expr : Expr, evaluator : Evaluator, maxModels : Int, maxTries : Int, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds) : Stream[Map[Identifier,Expr]] = { def findModels(expr : Expr, evaluator : Evaluator, maxModels : Int, maxTries : Int, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds, forcedFreeVars: Option[Seq[Identifier]] = None) : Stream[Map[Identifier,Expr]] = {
val freeVars : Seq[Identifier] = variablesOf(expr).toSeq val freeVars : Seq[Identifier] = forcedFreeVars.getOrElse(variablesOf(expr).toSeq)
evaluator.compile(expr, freeVars).map { evalFun => evaluator.compile(expr, freeVars).map { evalFun =>
val sat = EvaluationResults.Successful(BooleanLiteral(true)) val sat = EvaluationResults.Successful(BooleanLiteral(true))
naryProduct(freeVars.map(id => generate(id.getType, bounds))) naryProduct(freeVars.map(id => generate(id.getType, bounds)))
.take(maxTries) .take(maxTries)
.filter(s => evalFun(s) == sat) .filter{s => evalFun(s) == sat }
.take(maxModels) .take(maxModels)
.map(s => freeVars.zip(s).toMap) .map(s => freeVars.zip(s).toMap)
...@@ -111,6 +145,7 @@ object DataGen { ...@@ -111,6 +145,7 @@ object DataGen {
// Takes a series of streams and enumerates their cartesian product. // Takes a series of streams and enumerates their cartesian product.
private def naryProduct[T](streams : Seq[Stream[T]]) : Stream[List[T]] = { private def naryProduct[T](streams : Seq[Stream[T]]) : Stream[List[T]] = {
val dimensions = streams.size val dimensions = streams.size
val vectorizedStreams = streams.map(new VectorizedStream(_))
if(dimensions == 0) if(dimensions == 0)
return Stream.cons(Nil, Stream.empty) return Stream.cons(Nil, Stream.empty)
...@@ -132,7 +167,7 @@ object DataGen { ...@@ -132,7 +167,7 @@ object DataGen {
var d = 0 var d = 0
var continue = true var continue = true
var is = indexList var is = indexList
var ss = streams.toList var ss = vectorizedStreams.toList
if(indexList.sum >= bounds.max) { if(indexList.sum >= bounds.max) {
allReached = true allReached = true
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment