From b7493d67222f2227105ee9d4990ba301cd65109d Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Mon, 8 Aug 2016 16:49:48 +0200
Subject: [PATCH] Added Choose expression

---
 src/main/scala/inox/ast/DSL.scala                   |  3 +++
 src/main/scala/inox/ast/ExprOps.scala               |  3 ++-
 src/main/scala/inox/ast/Expressions.scala           |  8 ++++++--
 src/main/scala/inox/ast/Extractors.scala            |  2 ++
 src/main/scala/inox/ast/Printers.scala              |  5 ++++-
 src/main/scala/inox/ast/TreeOps.scala               | 13 +++++++++++++
 .../scala/inox/evaluators/RecursiveEvaluator.scala  |  2 ++
 .../scala/inox/evaluators/SolvingEvaluator.scala    | 11 +++++------
 .../inox/solvers/unrolling/TemplateGenerator.scala  |  8 ++++++++
 9 files changed, 45 insertions(+), 10 deletions(-)

diff --git a/src/main/scala/inox/ast/DSL.scala b/src/main/scala/inox/ast/DSL.scala
index 052abc20e..edb08f59b 100644
--- a/src/main/scala/inox/ast/DSL.scala
+++ b/src/main/scala/inox/ast/DSL.scala
@@ -180,6 +180,9 @@ trait DSL {
       body(vd1.toVariable, vd2.toVariable, vd3.toVariable, vd4.toVariable))
   }
 
+  // Choose
+  def choose(res: ValDef)(pred: Variable => Expr) = Choose(res, pred(res.toVariable))
+
   // Block-like
   class BlockSuspension(susp: Expr => Expr) {
     def in(e: Expr) = susp(e)
diff --git a/src/main/scala/inox/ast/ExprOps.scala b/src/main/scala/inox/ast/ExprOps.scala
index 1f85f14f7..a17968b41 100644
--- a/src/main/scala/inox/ast/ExprOps.scala
+++ b/src/main/scala/inox/ast/ExprOps.scala
@@ -53,6 +53,7 @@ trait ExprOps extends GenTreeOps {
           case Let(vd, _, _) => subvs - vd.toVariable
           case Lambda(args, _) => subvs -- args.map(_.toVariable)
           case Forall(args, _) => subvs -- args.map(_.toVariable)
+          case Choose(res, _) => subvs - res.toVariable
           case _ => subvs
         }
     }(expr)
@@ -85,7 +86,7 @@ trait ExprOps extends GenTreeOps {
     * unrolling solver. See implementation for what this means exactly.
     */
   def isSimple(e: Expr): Boolean = !exists {
-    case (_: Assume) | (_: Forall) | (_: Lambda) |
+    case (_: Assume) | (_: Forall) | (_: Lambda) | (_: Choose) |
          (_: FunctionInvocation) | (_: Application) => true
     case _ => false
   } (e)
diff --git a/src/main/scala/inox/ast/Expressions.scala b/src/main/scala/inox/ast/Expressions.scala
index 4aa8a174f..2828e2cfe 100644
--- a/src/main/scala/inox/ast/Expressions.scala
+++ b/src/main/scala/inox/ast/Expressions.scala
@@ -115,12 +115,16 @@ trait Expressions { self: Trees =>
     }
   }
 
-  /* Universal Quantification */
-
+  /** $encodingof `forall(...)` (universal quantification) */
   case class Forall(args: Seq[ValDef], body: Expr) extends Expr with CachingTyped {
     protected def computeType(implicit s: Symbols): Type = body.getType
   }
 
+  /** $encodingof `choose(...)` (returns a value satisfying the provided predicate) */
+  case class Choose(res: ValDef, pred: Expr) extends Expr {
+    def getType(implicit s: Symbols): Type = res.tpe
+  }
+
   /* Control flow */
 
   /** $encodingof  `function(...)` (function invocation) */
diff --git a/src/main/scala/inox/ast/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala
index e59547df2..173e2d5d1 100644
--- a/src/main/scala/inox/ast/Extractors.scala
+++ b/src/main/scala/inox/ast/Extractors.scala
@@ -46,6 +46,8 @@ trait Extractors { self: Trees =>
         Some((Seq(body), (es: Seq[Expr]) => Lambda(args, es.head)))
       case Forall(args, body) =>
         Some((Seq(body), (es: Seq[Expr]) => Forall(args, es.head)))
+      case Choose(res, pred) =>
+        Some((Seq(pred), (es: Seq[Expr]) => Choose(res, es.head)))
 
       /* Binary operators */
       case Equals(t1, t2) =>
diff --git a/src/main/scala/inox/ast/Printers.scala b/src/main/scala/inox/ast/Printers.scala
index de9b43264..9e751eec0 100644
--- a/src/main/scala/inox/ast/Printers.scala
+++ b/src/main/scala/inox/ast/Printers.scala
@@ -96,6 +96,9 @@ trait Printers { self: Trees =>
         case Forall(args, e) =>
           p"\u2200${nary(args)}. $e"
 
+        case Choose(res, pred) =>
+          p"choose(($res) => $pred)"
+
         case e @ CaseClass(cct, args) =>
           p"$cct($args)"
 
@@ -324,7 +327,7 @@ trait Printers { self: Trees =>
       case (pa: PrettyPrintable, _) => pa.printRequiresParentheses(within)
       case (_, None) => false
       case (_, Some(
-        _: Definition | _: Let | _: IfExpr | _ : CaseClass | _ : Lambda | _ : Tuple
+        _: Definition | _: Let | _: IfExpr | _: CaseClass | _: Lambda | _: Choose | _: Tuple
       )) => false
       case (ex: StringConcat, Some(_: StringConcat)) => false
       case (_, Some(_: FunctionInvocation)) => false
diff --git a/src/main/scala/inox/ast/TreeOps.scala b/src/main/scala/inox/ast/TreeOps.scala
index d2da99b30..2e5a44c4d 100644
--- a/src/main/scala/inox/ast/TreeOps.scala
+++ b/src/main/scala/inox/ast/TreeOps.scala
@@ -82,6 +82,15 @@ trait TreeOps { self: Trees =>
           e
         }
 
+      case Choose(res, pred) =>
+        val newRes = transform(res)
+        val newPred = transform(pred)
+        if ((res ne newRes) || (pred ne newPred)) {
+          Choose(newRes, newPred).copiedFrom(e)
+        } else {
+          e
+        }
+
       case Let(vd, expr, body) =>
         val newVd = transform(vd)
         val newExpr = transform(expr)
@@ -210,6 +219,10 @@ trait TreeOps { self: Trees =>
         args foreach (vd => traverse(vd.tpe))
         traverse(body)
 
+      case Choose(res, pred) =>
+        traverse(res.tpe)
+        traverse(pred)
+
       case Let(a, expr, body) =>
         traverse(expr)
         traverse(body)
diff --git a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala
index 66d4d3e1e..983bc3a49 100644
--- a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala
@@ -500,6 +500,8 @@ trait RecursiveEvaluator
 
     case f: Forall => onForallInvocation(f)
 
+    case c: Choose => onChooseInvocation(c)
+
     case f @ FiniteMap(ss, dflt, vT) =>
       // we use toMap.toSeq to reduce dupplicate keys
       FiniteMap(ss.map{ case (k, v) => (e(k), e(v)) }.toMap.toSeq, e(dflt), vT)
diff --git a/src/main/scala/inox/evaluators/SolvingEvaluator.scala b/src/main/scala/inox/evaluators/SolvingEvaluator.scala
index 6b84839cf..fd18f7060 100644
--- a/src/main/scala/inox/evaluators/SolvingEvaluator.scala
+++ b/src/main/scala/inox/evaluators/SolvingEvaluator.scala
@@ -22,11 +22,10 @@ trait SolvingEvaluator extends Evaluator {
 
   def getSolver(opts: InoxOption[Any]*): SolverFactory { val program: SolvingEvaluator.this.program.type }
 
-  private val specCache: MutableMap[Expr, Expr] = MutableMap.empty
+  private val chooseCache: MutableMap[Choose, Expr] = MutableMap.empty
   private val forallCache: MutableMap[Forall, Expr] = MutableMap.empty
 
-  def onSpecInvocation(specs: Lambda): Expr = specCache.getOrElseUpdate(specs, {
-    val Lambda(Seq(vd), body) = specs
+  def onChooseInvocation(choose: Choose): Expr = chooseCache.getOrElseUpdate(choose, {
     val timer = ctx.timers.evaluators.specs.start()
 
     val sf = getSolver(options.options.collect {
@@ -36,15 +35,15 @@ trait SolvingEvaluator extends Evaluator {
     import SolverResponses._
 
     val api = SimpleSolverAPI(sf)
-    val res = api.solveSAT(body)
+    val res = api.solveSAT(choose.pred)
     timer.stop()
 
     res match {
       case SatWithModel(model) =>
-        valuateWithModel(model)(vd)
+        valuateWithModel(model)(choose.res)
 
       case _ =>
-        throw new RuntimeException("Failed to evaluate specs " + specs.asString)
+        throw new RuntimeException("Failed to evaluate choose " + choose.asString)
     }
   })
 
diff --git a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala
index 680414018..68e2f8994 100644
--- a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala
+++ b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala
@@ -143,6 +143,14 @@ trait TemplateGenerator { self: Templates =>
         storeGuarded(pathVar, e)
         rec(pathVar, body, pol)
 
+      case c @ Choose(res, pred) =>
+        val newExpr = res.toVariable.freshen
+        storeExpr(newExpr)
+
+        val p = rec(pathVar, exprOps.replace(Map(res.toVariable -> newExpr), pred), Some(true))
+        storeGuarded(pathVar, p)
+        newExpr
+
       case l @ Let(i, e: Lambda, b) =>
         val re = rec(pathVar, e, None) // guaranteed variable!
         val rb = rec(pathVar, exprOps.replace(Map(i.toVariable -> re), b), pol)
-- 
GitLab