diff --git a/src/main/scala/leon/synthesis/rules/CegisLike.scala b/src/main/scala/leon/synthesis/rules/CegisLike.scala index be01424025cc154e461b44d03f0629c1b3621ab3..372845ce470b8cb1e563672d5d3d9522b3f88017 100644 --- a/src/main/scala/leon/synthesis/rules/CegisLike.scala +++ b/src/main/scala/leon/synthesis/rules/CegisLike.scala @@ -132,14 +132,20 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { def allProgramsFor(cs: Set[Identifier]): Stream[Set[Identifier]] = { val streams = for (c <- cs.toSeq) yield { - val subs = for ((b, _, subcs) <- cTree(c) if isBActive(b); - p <- allProgramsFor(subcs)) yield { + val subs = for ((b, _, subcs) <- cTree(c) if isBActive(b)) yield { - p + b + if (subcs.isEmpty) { + Seq(Set(b)) + } else { + for (p <- allProgramsFor(subcs)) yield { + p + b + } + } } - subs.toStream + subs.flatten.toStream } + StreamUtils.cartesianProduct(streams).map { ls => ls.foldLeft(Set[Identifier]())(_ ++ _) } @@ -601,6 +607,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { bs = bs ++ newBs bsOrdered = bs.toSeq.sortBy(_.id) + //debugCExpr(cTree) updateCTree() unfoldedSomething @@ -819,10 +826,13 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val nInitial = prunedPrograms.size sctx.reporter.debug("#Programs: "+nInitial) //sctx.reporter.ifDebug{ printer => - // for (p <- prunedPrograms.take(10)) { - // printer(" - "+ndProgram.getExpr(p)) + // val limit = 100 + + // for (p <- prunedPrograms.take(limit)) { + // val ps = p.toSeq.sortBy(_.id).mkString(", ") + // printer(f" - $ps%-40s - "+ndProgram.getExpr(p)) // } - // if(nPassing > 10) { + // if(nInitial > limit) { // printer(" - ...") // } //} diff --git a/src/main/scala/leon/utils/StreamUtils.scala b/src/main/scala/leon/utils/StreamUtils.scala index cfc8d574546b76cdae92660be56063ffcaf58636..3af285528e0243cb6858336e3048a48e6c5cf98f 100644 --- a/src/main/scala/leon/utils/StreamUtils.scala +++ b/src/main/scala/leon/utils/StreamUtils.scala @@ -26,15 +26,16 @@ object StreamUtils { if(streams.exists(_.isEmpty)) return Stream.empty - val indices = if(streams.forall(_.hasDefiniteSize)) { - val max = streams.map(_.size).max - diagCount(dimensions).take(max) - } else { - diagCount(dimensions) - } + val indices = diagCount(dimensions) var allReached : Boolean = false - val bounds : Array[Int] = Array.fill(dimensions)(Int.MaxValue) + val bounds : Array[Option[Int]] = for (s <- streams.toArray) yield { + if (s.hasDefiniteSize) { + Some(s.size) + } else { + None + } + } indices.takeWhile(_ => !allReached).flatMap { indexList => var d = 0 @@ -42,7 +43,10 @@ object StreamUtils { var is = indexList var ss = vectorizedStreams.toList - if(indexList.sum >= bounds.max) { + if ((indexList zip bounds).forall { + case (i, Some(b)) => i >= b + case _ => false + }) { allReached = true } @@ -50,7 +54,7 @@ object StreamUtils { while(continue && d < dimensions) { var i = is.head - if(i > bounds(d)) { + if(bounds(d).map(i > _).getOrElse(false)) { continue = false } else try { // TODO can we speed up by caching the random access into @@ -62,7 +66,7 @@ object StreamUtils { d += 1 } catch { case e : IndexOutOfBoundsException => - bounds(d) = i - 1 + bounds(d) = Some(i - 1) continue = false } } diff --git a/src/test/scala/leon/test/utils/Streams.scala b/src/test/scala/leon/test/utils/Streams.scala new file mode 100644 index 0000000000000000000000000000000000000000..1b6a094a758fe1a631704c1483524ce9998271a1 --- /dev/null +++ b/src/test/scala/leon/test/utils/Streams.scala @@ -0,0 +1,78 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon.test.purescala + +import leon._ +import leon.test._ +import leon.utils.{TemporaryInputPhase, PreprocessingPhase} +import leon.frontends.scalac.ExtractionPhase + +import leon.purescala.Common._ +import leon.purescala.Trees._ +import leon.purescala.Definitions._ +import leon.purescala.TypeTrees._ +import leon.datagen._ +import leon.utils.StreamUtils._ + +import leon.evaluators._ + +import org.scalatest.FunSuite + +class Streams extends LeonTestSuite { + test("Cartesian Product 1") { + val s1 = FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: Stream.empty; + + val s2 = FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: Stream.empty; + + val ss = cartesianProduct(List(s1, s2)) + + assert(ss.size === s1.size * s2.size) + + + } + + test("Cartesian Product 2") { + val s1 = FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: Stream.empty; + + val s2 = FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: Stream.empty; + + val tmp1 = s1.mkString + val tmp2 = s2.mkString + + val ss = cartesianProduct(List(s1, s2)) + + assert(ss.size === s1.size * s2.size) + } + + + test("Cartesian Product 3") { + val s1 = 1 #:: + 2 #:: + 3 #:: + 4 #:: Stream.empty; + + val s2 = 5 #:: + 6 #:: + 7 #:: + 8 #:: Stream.empty; + + val tmp1 = s1.mkString + val tmp2 = s2.mkString + + val ss = cartesianProduct(List(s1, s2)) + + assert(ss.size === s1.size * s2.size) + } +}