From 184b4790032e7e3273ae71bef109cef0ff711db5 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Fri, 14 Aug 2015 00:34:45 +0200
Subject: [PATCH] Introduce model minimization/maximization

- Measure is provided by an user-defined expression of BigInt type
  (e.g. List.size)

- Search for min/max lazily enumerates intermediate models. Users of the
  enumerator may wish to skip them using `.last`, and/or bound the
  search using `.take(N)`.

- Discovery of upper&lower bounds is done with pseudo-exponential
  progression from initial model. Bisection method is then used to
  zero-in on the min/max.
---
 .../test/solvers/ModelEnumerationSuite.scala  |  52 +++++++-
 .../scala/leon/solvers/ModelEnumerator.scala  | 118 +++++++++++++++++-
 .../scala/leon/solvers/SolverFactory.scala    |   2 +-
 3 files changed, 167 insertions(+), 5 deletions(-)

diff --git a/src/integration/scala/leon/test/solvers/ModelEnumerationSuite.scala b/src/integration/scala/leon/test/solvers/ModelEnumerationSuite.scala
index 9d1b16019..2025437ed 100644
--- a/src/integration/scala/leon/test/solvers/ModelEnumerationSuite.scala
+++ b/src/integration/scala/leon/test/solvers/ModelEnumerationSuite.scala
@@ -38,7 +38,8 @@ class ModelEnumeratorSuite extends LeonTestSuiteWithProgram with helpers.Express
   )
 
   def getModelEnum(implicit ctx: LeonContext, pgm: Program) = {
-    new ModelEnumerator(ctx, pgm, SolverFactory.getFromSettings)
+    val sf = SolverFactory.default.asInstanceOf[SolverFactory[IncrementalSolver]]
+    new ModelEnumerator(ctx, pgm, sf)
   }
 
   test("Simple model enumeration 1") { implicit fix =>
@@ -149,4 +150,53 @@ class ModelEnumeratorSuite extends LeonTestSuiteWithProgram with helpers.Express
     }
   }
 
+  test("Maximizing size") { implicit fix =>
+    val tpe = classDef("List1.List").typed
+    val l   = FreshIdentifier("l", tpe)
+
+    val cnstr = LessThan(fcall("List1.size")(l.toVariable), bi(5))
+
+    val car   = fcall("List1.size")(l.toVariable)
+
+    val evaluator = new DefaultEvaluator(fix._1, fix._2)
+    val me = getModelEnum
+
+    try {
+      val models1 = me.enumMaximizing(Seq(l), cnstr, car).take(5).toList
+
+      assert(models1.size < 5, "It took less than 5 models to reach max")
+      assert(evaluator.eval(car, models1.last).result === Some(bi(4)), "Max should be 4")
+
+      val models2 = me.enumMaximizing(Seq(l), BooleanLiteral(true), car).take(4).toList
+
+      assert(models2.size == 4, "Unbounded search yields models")
+      // in 4 steps, it should reach lists of size > 10
+      assert(evaluator.eval(GreaterThan(car, bi(10)), models2.last).result === Some(T), "Progression should be efficient")
+    } finally {
+      me.shutdown()
+    }
+  }
+
+  test("Minimizing size") { implicit fix =>
+    val tpe = classDef("List1.List").typed
+    val l   = FreshIdentifier("l", tpe)
+
+    val cnstr = LessThan(fcall("List1.size")(l.toVariable), bi(5))
+
+    val car   = fcall("List1.size")(l.toVariable)
+
+    val evaluator = new DefaultEvaluator(fix._1, fix._2)
+    val me = getModelEnum
+
+    try {
+      val models1 = me.enumMinimizing(Seq(l), cnstr, car).take(5).toList
+
+      assert(models1.size < 5, "It took less than 5 models to reach min")
+      assert(evaluator.eval(car, models1.last).result === Some(bi(0)), "Min should be 0")
+
+    } finally {
+      me.shutdown()
+    }
+  }
+
 }
diff --git a/src/main/scala/leon/solvers/ModelEnumerator.scala b/src/main/scala/leon/solvers/ModelEnumerator.scala
index b22dc9210..3a5fdc5df 100644
--- a/src/main/scala/leon/solvers/ModelEnumerator.scala
+++ b/src/main/scala/leon/solvers/ModelEnumerator.scala
@@ -10,7 +10,7 @@ import purescala.Types._
 import evaluators._
 
 
-class ModelEnumerator(ctx: LeonContext, pgm: Program, sf: SolverFactory[Solver]) {
+class ModelEnumerator(ctx: LeonContext, pgm: Program, sf: SolverFactory[IncrementalSolver]) {
   private[this] var reclaimPool = List[Solver]()
   private[this] val evaluator = new DefaultEvaluator(ctx, pgm)
 
@@ -29,7 +29,7 @@ class ModelEnumerator(ctx: LeonContext, pgm: Program, sf: SolverFactory[Solver])
     enumVarying0(ids, cnstr, Some(caracteristic), nPerCaracteristic)
   }
 
-  def enumVarying0(ids: Seq[Identifier], cnstr: Expr, caracteristic: Option[Expr], nPerCaracteristic: Int = 1): Iterator[Map[Identifier, Expr]] = {
+  private[this] def enumVarying0(ids: Seq[Identifier], cnstr: Expr, caracteristic: Option[Expr], nPerCaracteristic: Int = 1): Iterator[Map[Identifier, Expr]] = {
     val s = sf.getNewSolver
     reclaimPool ::= s
 
@@ -78,8 +78,120 @@ class ModelEnumerator(ctx: LeonContext, pgm: Program, sf: SolverFactory[Solver])
     }
   }
 
+  def enumMinimizing(ids: Seq[Identifier], cnstr: Expr, measure: Expr) = {
+    enumOptimizing(ids, cnstr, measure, Down)
+  }
+
+  def enumMaximizing(ids: Seq[Identifier], cnstr: Expr, measure: Expr) = {
+    enumOptimizing(ids, cnstr, measure, Up)
+  }
+
+  abstract class SearchDirection
+  case object Up   extends SearchDirection
+  case object Down extends SearchDirection
+
+  private[this] def enumOptimizing(ids: Seq[Identifier], cnstr: Expr, measure: Expr, dir: SearchDirection): Iterator[Map[Identifier, Expr]] = {
+    assert(measure.getType == IntegerType)
+
+    val s = sf.getNewSolver
+    reclaimPool ::= s
+
+    s.assertCnstr(cnstr)
+
+    val mId = FreshIdentifier("measure", measure.getType)
+    s.assertCnstr(Equals(mId.toVariable, measure))
+
+    // Search Range
+    var ub: Option[BigInt] = None
+    var lb: Option[BigInt] = None
+
+    def rangeEmpty() = (lb, ub) match {
+      case (Some(l), Some(u)) => u-l <= 1
+      case _ => false
+    }
+
+    def getPivot(): Option[BigInt] = (lb, ub, dir) match {
+      // Bisection Method
+      case (Some(l), Some(u), _) => Some(l + (u-l)/2)
+      // No bound yet, let the solver find at least one bound
+      case (None, None, _)       => None
+
+      // Increase lower bound
+      case (Some(l), None, Up)   => Some(l + l.abs + 1)
+      // Decrease upper bound
+      case (None, Some(u), Down) => Some(u - u.abs - 1)
+
+      // This shouldn't happen
+      case _ => None
+    }
+
+    def getNext(): Stream[Map[Identifier, Expr]] = {
+      if (rangeEmpty()) {
+        Stream.empty
+      } else {
+        // Assert a new pivot point
+        val thisTry = getPivot().map { t =>
+          s.push()
+          dir match {
+            case Up =>
+              s.assertCnstr(GreaterThan(mId.toVariable, InfiniteIntegerLiteral(t)))
+            case Down =>
+              s.assertCnstr(LessThan(mId.toVariable, InfiniteIntegerLiteral(t)))
+          }
+          t
+        }
+
+        s.check match {
+          case Some(true) =>
+            val sm = s.getModel
+            val m = (ids.map { id =>
+              id -> sm.getOrElse(id, simplestValue(id.getType))
+            }).toMap
+
+            evaluator.eval(measure, m).result match {
+              case Some(InfiniteIntegerLiteral(measureVal)) =>
+                // Positive result
+                dir match {
+                  case Up   => lb = Some(measureVal)
+                  case Down => ub = Some(measureVal)
+                }
+
+                Stream.cons(m, getNext())
+
+              case _ =>
+                ctx.reporter.warning("Evaluator failed to evaluate measure!")
+                Stream.empty
+            }
+
+
+          case Some(false) =>
+            // Negative result
+            thisTry match {
+              case Some(t) =>
+                s.pop()
+
+                dir match {
+                  case Up   => ub = Some(t)
+                  case Down => lb = Some(t)
+                }
+                getNext()
+
+              case None =>
+                Stream.empty
+            }
+
+          case None =>
+            Stream.empty
+        }
+      }
+    }
+
+    getNext().iterator
+  }
+
+
   def shutdown() = {
-    reclaimPool.foreach{sf.reclaim(_)}
+    reclaimPool.foreach(sf.reclaim)
   }
 
 }
diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala
index 90fc54110..56d0ae733 100644
--- a/src/main/scala/leon/solvers/SolverFactory.scala
+++ b/src/main/scala/leon/solvers/SolverFactory.scala
@@ -142,7 +142,7 @@ object SolverFactory {
   }
 
   // Full featured solver used by default
-  def default(ctx: LeonContext, program: Program): SolverFactory[TimeoutSolver] = {
+  def default(implicit ctx: LeonContext, program: Program): SolverFactory[TimeoutSolver] = {
     getFromName(ctx, program)("fairz3")
   }
 
-- 
GitLab