From 217067ea4acc5f82abaa94c03e86f59d0bf86e51 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Thu, 20 Aug 2015 18:02:20 +0200
Subject: [PATCH] Merge example banks correctly

---
 .../scala/leon/datagen/SolverDataGen.scala     |  2 --
 .../scala/leon/synthesis/ExamplesBank.scala    | 18 ++++++++++++++++--
 .../scala/leon/utils/ModelEnumerator.scala     |  2 +-
 3 files changed, 17 insertions(+), 5 deletions(-)

diff --git a/src/main/scala/leon/datagen/SolverDataGen.scala b/src/main/scala/leon/datagen/SolverDataGen.scala
index dfbf50aa6..78b3753cf 100644
--- a/src/main/scala/leon/datagen/SolverDataGen.scala
+++ b/src/main/scala/leon/datagen/SolverDataGen.scala
@@ -84,8 +84,6 @@ class SolverDataGen(ctx: LeonContext, pgm: Program, sff: ((LeonContext, Program)
       val sf = sff(ctx, pgm1)
       val modelEnum = new ModelEnumerator(ctx, pgm1, sf)
 
-      println("Generating for "+ins.map(_.getType.asString)+", satisfying "+satisfying.asString)
-
       val enum = modelEnum.enumVarying(ins, satisfying, sizeOf, 5)
 
       try {
diff --git a/src/main/scala/leon/synthesis/ExamplesBank.scala b/src/main/scala/leon/synthesis/ExamplesBank.scala
index 1f34fe87d..a0ed2936a 100644
--- a/src/main/scala/leon/synthesis/ExamplesBank.scala
+++ b/src/main/scala/leon/synthesis/ExamplesBank.scala
@@ -50,11 +50,25 @@ case class ExamplesBank(valids: Seq[Example], invalids: Seq[Example]) {
 
   def union(that: ExamplesBank) = {
     ExamplesBank(
-      (this.valids union that.valids).distinct,
-      (this.invalids union that.invalids).distinct
+      distinctIns((this.valids union that.valids)),
+      distinctIns((this.invalids union that.invalids))
     )
   }
 
+  private def distinctIns(s: Seq[Example]): Seq[Example] = {
+    val insOuts = (s.collect {
+      case InOutExample(ins, outs) => ins -> outs
+    }).toMap
+
+    s.map(_.ins).distinct.map {
+      case ins =>
+        insOuts.get(ins) match {
+          case Some(outs) => InOutExample(ins, outs)
+          case _ => InExample(ins)
+        }
+    }
+  }
+
   def map(f: Example => List[Example]) = {
     ExamplesBank(valids.flatMap(f), invalids.flatMap(f))
   }
diff --git a/src/main/scala/leon/utils/ModelEnumerator.scala b/src/main/scala/leon/utils/ModelEnumerator.scala
index 81bacf162..7429c247f 100644
--- a/src/main/scala/leon/utils/ModelEnumerator.scala
+++ b/src/main/scala/leon/utils/ModelEnumerator.scala
@@ -95,7 +95,7 @@ class ModelEnumerator(ctx: LeonContext, pgm: Program, sf: SolverFactory[Solver])
   case object Up   extends SearchDirection
   case object Down extends SearchDirection
 
-  private[this] def enumOptimizing(ids: Seq[Identifier], satisfying: Expr, measure: Expr, dir: SearchDirection): FreeableIterator[Map[Identifier, Expr]] = {
+  private[this] def enumOptimizing(ids: Seq[Identifier], satisfying: Expr, measure: Expr, dir: SearchDirection): Iterator[Map[Identifier, Expr]] = {
     assert(measure.getType == IntegerType)
 
     val s = sf.getNewSolver
-- 
GitLab