From ac4bd7cfd8984ee09433aea29d7b0e03ba3d4a15 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <a-mikmay@microsoft.com>
Date: Fri, 4 Dec 2015 18:14:15 +0100
Subject: [PATCH] Examples can now be given in opt guards of case x if x == ...
 => BUG FIX: Reconstructing of passes was reversed.

---
 .../scala/leon/purescala/Constructors.scala   |   2 +-
 .../scala/leon/purescala/Extractors.scala     |   2 +-
 .../scala/leon/synthesis/ExamplesFinder.scala | 107 ++++++++++--------
 3 files changed, 59 insertions(+), 52 deletions(-)

diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala
index 7625eddf8..4404276cb 100644
--- a/src/main/scala/leon/purescala/Constructors.scala
+++ b/src/main/scala/leon/purescala/Constructors.scala
@@ -144,7 +144,7 @@ object Constructors {
 
     resType match {
       case Some(tpe) =>
-        casesFiltered.filter(c => isSubtypeOf(c.rhs.getType, tpe) || isSubtypeOf(tpe, c.rhs.getType))
+        casesFiltered.filter(c => isSubtypeOf(c.rhs.getType, tpe) || isSubtypeOf(tpe, c.rhs.getType) || c.optGuard.nonEmpty)
       case None =>
         casesFiltered
     }
diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index 4673ea6da..3a3bd7c7b 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -229,7 +229,7 @@ object Extractors {
             var i = 0
             val newcases = for (caze <- cases) yield caze match {
               case SimpleCase(b, _) => i += 1; SimpleCase(b, es(i - 1))
-              case GuardedCase(b, _, _) => i += 2; GuardedCase(b, es(i - 1), es(i - 2))
+              case GuardedCase(b, _, _) => i += 2; GuardedCase(b, es(i - 2), es(i - 1))
             }
 
             passes(in, out, newcases)
diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala
index 45f399b1b..744029ee3 100644
--- a/src/main/scala/leon/synthesis/ExamplesFinder.scala
+++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala
@@ -171,48 +171,56 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) {
         // The input contains no free vars. Trivially return input-output pair
         Seq((pattExpr, doSubstitute(ieMap,cs.rhs)))
       } else {
-        // If the input contains free variables, it does not provide concrete examples. 
-        // We will instantiate them according to a simple grammar to get them.
-        val enum = new MemoizedEnumerator[TypeTree, Expr](ValueGrammar.getProductions)
-        val values = enum.iterator(tupleTypeWrap(freeVars.map{ _.getType }))
-        val instantiations = values.map {
-          v => freeVars.zip(unwrapTuple(v, freeVars.size)).toMap
-        }
-        
-        def filterGuard(e: Expr, mapping: Map[Identifier, Expr]): Boolean = cs.optGuard match {
-          case Some(guard) => 
-            // in -> e should be enough. We shouldn't find any subexpressions of in.
-            evaluator.eval(replace(Map(in -> e), guard), mapping) match {
-              case EvaluationResults.Successful(BooleanLiteral(true)) => true
-              case _ => false
-            }
+        // Extract test cases such as    case x if x == s =>
+        ((pattExpr, ieMap, cs.optGuard) match {
+          case (Variable(id), Seq(), Some(Equals(Variable(id2), s))) if id == id2 =>
+            Some((Seq((s, doSubstitute(ieMap, cs.rhs)))))
+          case (Variable(id), Seq(), Some(Equals(s, Variable(id2)))) if id == id2 =>
+            Some((Seq((s, doSubstitute(ieMap, cs.rhs)))))
+          case (a, b, c) =>
+            None
+        }) getOrElse {
+          // If the input contains free variables, it does not provide concrete examples. 
+          // We will instantiate them according to a simple grammar to get them.
+          val enum = new MemoizedEnumerator[TypeTree, Expr](ValueGrammar.getProductions)
+          val values = enum.iterator(tupleTypeWrap(freeVars.map { _.getType }))
+          val instantiations = values.map {
+            v => freeVars.zip(unwrapTuple(v, freeVars.size)).toMap
+          }
 
-          case None =>
-            true
-        }
+          def filterGuard(e: Expr, mapping: Map[Identifier, Expr]): Boolean = cs.optGuard match {
+            case Some(guard) =>
+              // in -> e should be enough. We shouldn't find any subexpressions of in.
+              evaluator.eval(replace(Map(in -> e), guard), mapping) match {
+                case EvaluationResults.Successful(BooleanLiteral(true)) => true
+                case _ => false
+              }
+
+            case None =>
+              true
+          }
 
-        (for {
-          inst <- instantiations.toSeq
-          inR = replaceFromIDs(inst, pattExpr)
-          outR = replaceFromIDs(inst, doSubstitute(ieMap, cs.rhs))
-          if filterGuard(inR, inst)
-        } yield (inR,outR) ).take(examplesPerCase)
+          (for {
+            inst <- instantiations.toSeq
+            inR = replaceFromIDs(inst, pattExpr)
+            outR = replaceFromIDs(inst, doSubstitute(ieMap, cs.rhs))
+            if filterGuard(inR, inst)
+          } yield (inR, outR)).take(examplesPerCase)
+        }
       }
     }
   }
 
-  /** 
-   * Check if two tests are compatible.
-   * Compatible should evaluate to the same value for the same identifier
-   */
+  /** Check if two tests are compatible.
+    * Compatible should evaluate to the same value for the same identifier
+    */
   private def isCompatible(m1: Map[Identifier, Expr], m2: Map[Identifier, Expr]) = {
     val ks = m1.keySet & m2.keySet
     ks.nonEmpty && ks.map(m1) == ks.map(m2)
   }
 
-  /** 
-   * Merge tests t1 and t2 if they are compatible. Return m1 if not.
-   */
+  /** Merge tests t1 and t2 if they are compatible. Return m1 if not.
+    */
   private def mergeTest(m1: Map[Identifier, Expr], m2: Map[Identifier, Expr]) = {
     if (!isCompatible(m1, m2)) {
       m1
@@ -221,12 +229,12 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) {
     }
   }
 
-  /**
-   * we now need to consolidate different clusters of compatible tests together
-   * t1: a->1, c->3
-   * t2: a->1, b->4
-   *   => a->1, b->4, c->3
-   */
+  /** we now need to consolidate different clusters of compatible tests together
+    * t1: a->1, c->3
+    * t1: a->1, c->3
+    * t2: a->1, b->4
+    *   => a->1, b->4, c->3
+    */
   private def consolidateTests(ts: Set[Map[Identifier, Expr]]): Set[Map[Identifier, Expr]] = {
 
     var consolidated = Set[Map[Identifier, Expr]]()
@@ -240,19 +248,18 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) {
     consolidated
   }
 
-  /**
-   * Extract ids in ins/outs args, and compute corresponding extractors for values map
-   *
-   * Examples:
-   * (a,b) =>
-   *     a -> _.1
-   *     b -> _.2
-   *
-   * Cons(a, Cons(b, c)) =>
-   *     a -> _.head
-   *     b -> _.tail.head
-   *     c -> _.tail.tail
-   */
+  /** Extract ids in ins/outs args, and compute corresponding extractors for values map
+    *
+    * Examples:
+    * (a,b) =>
+    *     a -> _.1
+    *     b -> _.2
+    *
+    * Cons(a, Cons(b, c)) =>
+    *     a -> _.head
+    *     b -> _.tail.head
+    *     c -> _.tail.tail
+    */
   private def extractIds(e: Expr): Seq[(Identifier, PartialFunction[Expr, Expr])] = e match {
     case Variable(id) =>
       List((id, { case e => e }))
-- 
GitLab