From 217067ea4acc5f82abaa94c03e86f59d0bf86e51 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Thu, 20 Aug 2015 18:02:20 +0200 Subject: [PATCH] Merge example banks correctly --- .../scala/leon/datagen/SolverDataGen.scala | 2 -- .../scala/leon/synthesis/ExamplesBank.scala | 18 ++++++++++++++++-- .../scala/leon/utils/ModelEnumerator.scala | 2 +- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/main/scala/leon/datagen/SolverDataGen.scala b/src/main/scala/leon/datagen/SolverDataGen.scala index dfbf50aa6..78b3753cf 100644 --- a/src/main/scala/leon/datagen/SolverDataGen.scala +++ b/src/main/scala/leon/datagen/SolverDataGen.scala @@ -84,8 +84,6 @@ class SolverDataGen(ctx: LeonContext, pgm: Program, sff: ((LeonContext, Program) val sf = sff(ctx, pgm1) val modelEnum = new ModelEnumerator(ctx, pgm1, sf) - println("Generating for "+ins.map(_.getType.asString)+", satisfying "+satisfying.asString) - val enum = modelEnum.enumVarying(ins, satisfying, sizeOf, 5) try { diff --git a/src/main/scala/leon/synthesis/ExamplesBank.scala b/src/main/scala/leon/synthesis/ExamplesBank.scala index 1f34fe87d..a0ed2936a 100644 --- a/src/main/scala/leon/synthesis/ExamplesBank.scala +++ b/src/main/scala/leon/synthesis/ExamplesBank.scala @@ -50,11 +50,25 @@ case class ExamplesBank(valids: Seq[Example], invalids: Seq[Example]) { def union(that: ExamplesBank) = { ExamplesBank( - (this.valids union that.valids).distinct, - (this.invalids union that.invalids).distinct + distinctIns((this.valids union that.valids)), + distinctIns((this.invalids union that.invalids)) ) } + private def distinctIns(s: Seq[Example]): Seq[Example] = { + val insOuts = (s.collect { + case InOutExample(ins, outs) => ins -> outs + }).toMap + + s.map(_.ins).distinct.map { + case ins => + insOuts.get(ins) match { + case Some(outs) => InOutExample(ins, outs) + case _ => InExample(ins) + } + } + } + def map(f: Example => List[Example]) = { ExamplesBank(valids.flatMap(f), invalids.flatMap(f)) } diff --git a/src/main/scala/leon/utils/ModelEnumerator.scala b/src/main/scala/leon/utils/ModelEnumerator.scala index 81bacf162..7429c247f 100644 --- a/src/main/scala/leon/utils/ModelEnumerator.scala +++ b/src/main/scala/leon/utils/ModelEnumerator.scala @@ -95,7 +95,7 @@ class ModelEnumerator(ctx: LeonContext, pgm: Program, sf: SolverFactory[Solver]) case object Up extends SearchDirection case object Down extends SearchDirection - private[this] def enumOptimizing(ids: Seq[Identifier], satisfying: Expr, measure: Expr, dir: SearchDirection): FreeableIterator[Map[Identifier, Expr]] = { + private[this] def enumOptimizing(ids: Seq[Identifier], satisfying: Expr, measure: Expr, dir: SearchDirection): Iterator[Map[Identifier, Expr]] = { assert(measure.getType == IntegerType) val s = sf.getNewSolver -- GitLab