From c5af979bc19761403961cce0f29cd196d7dd2cf9 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <etienne.kneuss@epfl.ch> Date: Mon, 19 Jan 2015 14:33:57 +0100 Subject: [PATCH] Fix cartesian product --- .../leon/synthesis/rules/CegisLike.scala | 24 ++++-- src/main/scala/leon/utils/StreamUtils.scala | 24 +++--- src/test/scala/leon/test/utils/Streams.scala | 78 +++++++++++++++++++ 3 files changed, 109 insertions(+), 17 deletions(-) create mode 100644 src/test/scala/leon/test/utils/Streams.scala diff --git a/src/main/scala/leon/synthesis/rules/CegisLike.scala b/src/main/scala/leon/synthesis/rules/CegisLike.scala index be0142402..372845ce4 100644 --- a/src/main/scala/leon/synthesis/rules/CegisLike.scala +++ b/src/main/scala/leon/synthesis/rules/CegisLike.scala @@ -132,14 +132,20 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { def allProgramsFor(cs: Set[Identifier]): Stream[Set[Identifier]] = { val streams = for (c <- cs.toSeq) yield { - val subs = for ((b, _, subcs) <- cTree(c) if isBActive(b); - p <- allProgramsFor(subcs)) yield { + val subs = for ((b, _, subcs) <- cTree(c) if isBActive(b)) yield { - p + b + if (subcs.isEmpty) { + Seq(Set(b)) + } else { + for (p <- allProgramsFor(subcs)) yield { + p + b + } + } } - subs.toStream + subs.flatten.toStream } + StreamUtils.cartesianProduct(streams).map { ls => ls.foldLeft(Set[Identifier]())(_ ++ _) } @@ -601,6 +607,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { bs = bs ++ newBs bsOrdered = bs.toSeq.sortBy(_.id) + //debugCExpr(cTree) updateCTree() unfoldedSomething @@ -819,10 +826,13 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val nInitial = prunedPrograms.size sctx.reporter.debug("#Programs: "+nInitial) //sctx.reporter.ifDebug{ printer => - // for (p <- prunedPrograms.take(10)) { - // printer(" - "+ndProgram.getExpr(p)) + // val limit = 100 + + // for (p <- prunedPrograms.take(limit)) { + // val ps = p.toSeq.sortBy(_.id).mkString(", ") + // printer(f" - $ps%-40s - "+ndProgram.getExpr(p)) // } - // if(nPassing > 10) { + // if(nInitial > limit) { // printer(" - ...") // } //} diff --git a/src/main/scala/leon/utils/StreamUtils.scala b/src/main/scala/leon/utils/StreamUtils.scala index cfc8d5745..3af285528 100644 --- a/src/main/scala/leon/utils/StreamUtils.scala +++ b/src/main/scala/leon/utils/StreamUtils.scala @@ -26,15 +26,16 @@ object StreamUtils { if(streams.exists(_.isEmpty)) return Stream.empty - val indices = if(streams.forall(_.hasDefiniteSize)) { - val max = streams.map(_.size).max - diagCount(dimensions).take(max) - } else { - diagCount(dimensions) - } + val indices = diagCount(dimensions) var allReached : Boolean = false - val bounds : Array[Int] = Array.fill(dimensions)(Int.MaxValue) + val bounds : Array[Option[Int]] = for (s <- streams.toArray) yield { + if (s.hasDefiniteSize) { + Some(s.size) + } else { + None + } + } indices.takeWhile(_ => !allReached).flatMap { indexList => var d = 0 @@ -42,7 +43,10 @@ object StreamUtils { var is = indexList var ss = vectorizedStreams.toList - if(indexList.sum >= bounds.max) { + if ((indexList zip bounds).forall { + case (i, Some(b)) => i >= b + case _ => false + }) { allReached = true } @@ -50,7 +54,7 @@ object StreamUtils { while(continue && d < dimensions) { var i = is.head - if(i > bounds(d)) { + if(bounds(d).map(i > _).getOrElse(false)) { continue = false } else try { // TODO can we speed up by caching the random access into @@ -62,7 +66,7 @@ object StreamUtils { d += 1 } catch { case e : IndexOutOfBoundsException => - bounds(d) = i - 1 + bounds(d) = Some(i - 1) continue = false } } diff --git a/src/test/scala/leon/test/utils/Streams.scala b/src/test/scala/leon/test/utils/Streams.scala new file mode 100644 index 000000000..1b6a094a7 --- /dev/null +++ b/src/test/scala/leon/test/utils/Streams.scala @@ -0,0 +1,78 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon.test.purescala + +import leon._ +import leon.test._ +import leon.utils.{TemporaryInputPhase, PreprocessingPhase} +import leon.frontends.scalac.ExtractionPhase + +import leon.purescala.Common._ +import leon.purescala.Trees._ +import leon.purescala.Definitions._ +import leon.purescala.TypeTrees._ +import leon.datagen._ +import leon.utils.StreamUtils._ + +import leon.evaluators._ + +import org.scalatest.FunSuite + +class Streams extends LeonTestSuite { + test("Cartesian Product 1") { + val s1 = FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: Stream.empty; + + val s2 = FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: Stream.empty; + + val ss = cartesianProduct(List(s1, s2)) + + assert(ss.size === s1.size * s2.size) + + + } + + test("Cartesian Product 2") { + val s1 = FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: Stream.empty; + + val s2 = FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: + FreshIdentifier("B", true) #:: Stream.empty; + + val tmp1 = s1.mkString + val tmp2 = s2.mkString + + val ss = cartesianProduct(List(s1, s2)) + + assert(ss.size === s1.size * s2.size) + } + + + test("Cartesian Product 3") { + val s1 = 1 #:: + 2 #:: + 3 #:: + 4 #:: Stream.empty; + + val s2 = 5 #:: + 6 #:: + 7 #:: + 8 #:: Stream.empty; + + val tmp1 = s1.mkString + val tmp2 = s2.mkString + + val ss = cartesianProduct(List(s1, s2)) + + assert(ss.size === s1.size * s2.size) + } +} -- GitLab