From 345c72bac9dfa869399dd4137d603bd88702da44 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch>
Date: Thu, 14 Apr 2016 13:00:25 +0200
Subject: [PATCH] Added MatchExpr coverage test. Replaced Seq[Identifier] with
 Option[Seq[Identifier]] to take into account transformed expressions without
 new identifiers.

---
 .../disambiguation/InputCoverage.scala        | 122 +++++++++-----
 .../solvers/InputCoverageSuite.scala          | 159 +++++++++++++-----
 2 files changed, 191 insertions(+), 90 deletions(-)

diff --git a/src/main/scala/leon/synthesis/disambiguation/InputCoverage.scala b/src/main/scala/leon/synthesis/disambiguation/InputCoverage.scala
index f2b0946c6..fcbc0aa22 100644
--- a/src/main/scala/leon/synthesis/disambiguation/InputCoverage.scala
+++ b/src/main/scala/leon/synthesis/disambiguation/InputCoverage.scala
@@ -32,77 +32,103 @@ class InputCoverage(fd: FunDef, fds: Set[FunDef])(implicit c: LeonContext, p: Pr
   
   /** If the sub-branches contain identifiers, it returns them unchanged.
       Else it creates a new boolean indicating this branch. */
-  def wrapBranch(e: (Expr, Seq[Identifier])): (Expr, Seq[Identifier]) = {
-    if(e._2.isEmpty) {
-      val b = FreshIdentifier("l" + e._1.getPos.line + "c" + e._1.getPos.col, BooleanType)
-      (tupleWrap(Seq(e._1, Variable(b))), Seq(b))
-    } else e // No need to introduce a new boolean since if one of the child booleans is true, then this IfExpr has been called.
+  def wrapBranch(e: (Expr, Option[Seq[Identifier]])): (Expr, Option[Seq[Identifier]]) = e._2 match {
+    case None =>
+      val b = FreshIdentifier("l" + Math.abs(e._1.getPos.line) + "c" + Math.abs(e._1.getPos.col), BooleanType)
+      (tupleWrap(Seq(e._1, Variable(b))), Some(Seq(b)))
+    case Some(Seq()) =>
+      val b = FreshIdentifier("l" + Math.abs(e._1.getPos.line) + "c" + Math.abs(e._1.getPos.col), BooleanType)
+      
+      def putInLastBody(e: Expr): Expr = e match {
+        case Tuple(Seq(v, prev_b)) => Tuple(Seq(v, or(prev_b, b.toVariable))).copiedFrom(e)
+        case LetTuple(binders, value, body) => letTuple(binders, value, putInLastBody(body)).copiedFrom(e)
+        case MatchExpr(scrut, Seq(MatchCase(TuplePattern(optId, binders), None, rhs))) => 
+          MatchExpr(scrut, Seq(MatchCase(TuplePattern(optId, binders), None, putInLastBody(rhs)))).copiedFrom(e)
+        case _ => throw new Exception(s"Unexpected branching case: $e")
+        
+      }
+      (putInLastBody(e._1), Some(Seq(b)))
+    case _ =>
+      // No need to introduce a new boolean since if one of the child booleans is true, then this IfExpr has been called.
+      e
   }
   
   def hasConditionals(e: Expr) = {
     ExprOps.exists{ case i:IfExpr => true case m: MatchExpr => true case f: FunctionInvocation => true case _ => false}(e)
   }
   
+  def merge(a: Option[Seq[Identifier]], b: Option[Seq[Identifier]]) = {
+    (a, b) match {
+      case (None, None) => None
+      case (a, None) => a
+      case (None, b) => b
+      case (Some(a), Some(b)) => Some(a ++ b)
+    }
+  }
+  
   /** For each branch in the expression, adds a boolean variable such that the new type of the expression is (previousType, Boolean)
    *  If no variable is output, then the type of the expression is not changed.
-    * Returns the list of boolean variables which appear in the expression */
+    * If the expression is augmented with a boolean, returns the list of boolean variables which appear in the expression */
   // All functions now return a boolean along with their original return type.
-  def markBranches(e: Expr): (Expr, Seq[Identifier]) =
-    if(!hasConditionals(e)) (e, Seq()) else e match {
+  def markBranches(e: Expr): (Expr, Option[Seq[Identifier]]) =
+    if(!hasConditionals(e)) (e, None) else e match {
     case IfExpr(cond, thenn, elze) =>
       val (c1, cv1) = markBranches(cond)
       val (t1, tv1) = wrapBranch(markBranches(thenn))
       val (e1, ev1) = wrapBranch(markBranches(elze))
-      if(cv1.isEmpty) {
-        (IfExpr(c1, t1, e1).copiedFrom(e), tv1 ++ ev1)
-      } else {
-        val arg_id = FreshIdentifier("arg", BooleanType)
-        val arg_b = FreshIdentifier("b", BooleanType)
-        (letTuple(Seq(arg_id, arg_b), c1, IfExpr(Variable(arg_id), t1, e1).copiedFrom(e)), cv1 ++ tv1 ++ ev1)
+      cv1 match {
+        case None =>
+          (IfExpr(c1, t1, e1).copiedFrom(e), merge(tv1, ev1))
+        case cv1 =>
+          val arg_id = FreshIdentifier("arg", BooleanType)
+          val arg_b = FreshIdentifier("bc", BooleanType)
+          (letTuple(Seq(arg_id, arg_b), c1, IfExpr(Variable(arg_id), t1, e1).copiedFrom(e)).copiedFrom(e), merge(merge(cv1, tv1), ev1))
       }
     case MatchExpr(scrut, cases) =>
       val (c1, cv1) = markBranches(scrut)
-      val (new_cases, variables) = (cases map { case MatchCase(pattern, opt, rhs) =>
+      val (new_cases, variables) = (cases map { case m@MatchCase(pattern, opt, rhs) =>
         val (rhs_new, ids) = wrapBranch(markBranches(rhs))
-        (MatchCase(pattern, opt, rhs_new), ids)
+        (MatchCase(pattern, opt, rhs_new).copiedFrom(m), ids)
       }).unzip // TODO: Check for unapply with function pattern ?
-      (MatchExpr(c1, new_cases).copiedFrom(e), variables.flatten)
+      (MatchExpr(c1, new_cases).copiedFrom(e), variables.fold(None)(merge))
     case Operator(lhsrhs, builder) =>
       // The exprBuilder adds variable definitions needed to compute the arguments.
-      val (exprBuilder, children, ids) = (((e: Expr) => e, List[Expr](), ListBuffer[Identifier]()) /: lhsrhs) {
-        case ((exprBuilder, children, ids), arg) =>
+      val (exprBuilder, children, tmpIds, ids) = (((e: Expr) => e, ListBuffer[Expr](), ListBuffer[Identifier](), None: Option[Seq[Identifier]]) /: lhsrhs) {
+        case ((exprBuilder, children, tmpIds, ids), arg) =>
           val (arg1, argv1) = markBranches(arg)
-          if(argv1.nonEmpty) {
+          if(argv1.nonEmpty || isNewFunCall(arg1)) {
             val arg_id = FreshIdentifier("arg", arg.getType)
-            val arg_b = FreshIdentifier("b", BooleanType)
-            val f = (body: Expr) => letTuple(Seq(arg_id, arg_b), arg1, body)
-            (exprBuilder andThen f, Variable(arg_id)::children, ids ++= argv1)
+            val arg_b = FreshIdentifier("ba", BooleanType)
+            val f = (body: Expr) => letTuple(Seq(arg_id, arg_b), arg1, body).copiedFrom(body)
+            (exprBuilder andThen f, children += Variable(arg_id), tmpIds += arg_b, merge(ids, argv1))
           } else {
-            (exprBuilder, arg::children, ids)
+            (exprBuilder, children += arg, tmpIds, ids)
           }
       }
       e match {
-        case FunctionInvocation(TypedFunDef(fd, targs), args) if fds(fd) =>
+        case FunctionInvocation(tfd@TypedFunDef(fd, targs), args) if fds(fd) =>
           val new_fd = wrapFunDef(fd)
           // Is different since functions will return a boolean as well.
-          val res_id = FreshIdentifier("res", fd.returnType)
-          val res_b = FreshIdentifier("b", BooleanType)
-          if(ids.isEmpty) {
-            val funCall = FunctionInvocation(TypedFunDef(new_fd, targs), children).copiedFrom(e)
-            (exprBuilder(funCall), Seq(res_b))
-          } else {
-            val finalIds = (ids :+ res_b)
-            val finalExpr = 
-              tupleWrap(Seq(Variable(res_id), or(finalIds.map(Variable(_)): _*)))
-            val funCall = letTuple(Seq(res_id, res_b), FunctionInvocation(TypedFunDef(new_fd, targs), children).copiedFrom(e), finalExpr)
-            (exprBuilder(funCall), finalIds)
+          tmpIds match {
+            case Seq() =>
+              val funCall = FunctionInvocation(TypedFunDef(new_fd, targs).copiedFrom(tfd), children).copiedFrom(e)
+              (exprBuilder(funCall), if(new_fd != fd) merge(Some(Seq()), ids) else ids)
+            case idvars =>
+              val res_id = FreshIdentifier("res", fd.returnType)
+              val res_b = FreshIdentifier("bb", BooleanType)
+              val finalIds = idvars :+ res_b
+              val finalExpr = 
+                tupleWrap(Seq(Variable(res_id), or(finalIds.map(Variable(_)): _*))).copiedFrom(e)
+              val funCall = letTuple(Seq(res_id, res_b), FunctionInvocation(TypedFunDef(new_fd, targs), children).copiedFrom(e), finalExpr).copiedFrom(e)
+              (exprBuilder(funCall), Some(finalIds))
           }
         case _ =>
-          if(ids.isEmpty) {
-            (e, Seq.empty)
-          } else {
-            val finalExpr = tupleWrap(Seq(builder(children).copiedFrom(e), or(ids.map(Variable): _*)))
-            (exprBuilder(finalExpr), ids)
+          tmpIds match {
+            case Seq() =>
+              (e, ids)
+            case idvars =>
+              val finalExpr = tupleWrap(Seq(builder(children).copiedFrom(e), or(idvars.map(Variable): _*))).copiedFrom(e)
+              (exprBuilder(finalExpr), ids)
           }
       }
   }
@@ -111,13 +137,21 @@ class InputCoverage(fd: FunDef, fds: Set[FunDef])(implicit c: LeonContext, p: Pr
   
   def wrapFunDef(fd: FunDef): FunDef = {
     if(!(cache contains fd)) {
-      val new_fd = fd.duplicate(returnType = TupleType(Seq(fd.returnType, BooleanType)))
-      new_fd.body = None
-      cache += fd -> new_fd
+      cache += fd -> (if(fds(fd)) {
+        val new_fd = fd.duplicate(returnType = TupleType(Seq(fd.returnType, BooleanType)))
+        new_fd.body = None
+        new_fd
+      } else fd)
     }
     cache(fd)
   }
   
+  def isNewFunCall(e: Expr): Boolean = e match {
+    case FunctionInvocation(TypedFunDef(fd, targs), args) =>
+      cache.values.exists { f => f == fd }
+    case _ => false
+  }
+  
   /** The number of expressions is the same as the number of arguments. */
   def result(): Stream[Seq[Expr]] = {
     /* Algorithm:
diff --git a/src/test/scala/leon/integration/solvers/InputCoverageSuite.scala b/src/test/scala/leon/integration/solvers/InputCoverageSuite.scala
index 7bc13e757..11f57f929 100644
--- a/src/test/scala/leon/integration/solvers/InputCoverageSuite.scala
+++ b/src/test/scala/leon/integration/solvers/InputCoverageSuite.scala
@@ -46,6 +46,7 @@ import leon.test.helpers.ExpressionsDSL
 import leon.synthesis.disambiguation.InputCoverage
 import leon.test.helpers.ExpressionsDSLProgram
 import leon.test.helpers.ExpressionsDSLVariables
+import leon.purescala.Extractors._
 
 class InputCoverageSuite extends LeonTestSuiteWithProgram with Matchers with ScalaFutures with ExpressionsDSLProgram with ExpressionsDSLVariables {
   val sources = List("""
@@ -63,6 +64,7 @@ class InputCoverageSuite extends LeonTestSuiteWithProgram with Matchers with Sca
     |      2
     |    }
     |  }
+    |  
     |  def withIfInIf(cond: Boolean) = {
     |    if(if(cond) false else true) {
     |      1
@@ -71,6 +73,20 @@ class InputCoverageSuite extends LeonTestSuiteWithProgram with Matchers with Sca
     |    }
     |  }
     |  
+    |  sealed abstract class A
+    |  case class B() extends A
+    |  case class C(a: Int, tail: A) extends A
+    |  case class D(a: String, tail: A, b: String) extends A
+    |  
+    |  def withMatch(a: A): String = {
+    |    a match {
+    |      case B() => "B"
+    |      case C(a, C(b, tail)) => b.toString + withMatch(tail) + a.toString
+    |      case C(a, tail) => withMatch(tail) + a.toString
+    |      case D(a, tail, b) => a + withMatch(tail) + b
+    |    }
+    |  }
+    |  
     |  def withCoveredFun1(input: Int) = {
     |    withCoveredFun2(input - 5) + withCoveredFun2(input + 5)
     |  }
@@ -93,10 +109,14 @@ class InputCoverageSuite extends LeonTestSuiteWithProgram with Matchers with Sca
     val dummy = funDef("InputCoverageSuite.dummy")
     val coverage = new InputCoverage(dummy, Set(dummy))
     val simpleExpr = Plus(IntLiteral(1), b)
-    coverage.wrapBranch((simpleExpr, Seq(p.id, q.id))) should equal ((simpleExpr, Seq(p.id, q.id)))
-    val (covered, ids) = coverage.wrapBranch((simpleExpr, Seq()))
-    ids should have size 1
-    covered should equal (Tuple(Seq(simpleExpr, Variable(ids.head))))
+    coverage.wrapBranch((simpleExpr, Some(Seq(p.id, q.id)))) should equal ((simpleExpr, Some(Seq(p.id, q.id))))
+    coverage.wrapBranch((simpleExpr, None)) match {
+      case (covered, Some(ids)) =>
+        ids should have size 1
+        covered should equal (Tuple(Seq(simpleExpr, Variable(ids.head))))
+      case _ =>
+        fail("No ids added")
+    }
   }
   
   test("If-coverage should work"){ ctxprogram =>
@@ -105,12 +125,16 @@ class InputCoverageSuite extends LeonTestSuiteWithProgram with Matchers with Sca
     val coverage = new InputCoverage(withIf, Set(withIf))
     val expr = withIf.body.get
     
-    val (res, ids) = coverage.markBranches(expr)
-    ids should have size 2
-    expr match {
-      case IfExpr(cond, thenn, elze) =>
-        res should equal (IfExpr(cond, Tuple(Seq(thenn, Variable(ids(0)))), Tuple(Seq(elze, Variable(ids(1))))))
-      case _ => fail(s"$expr is not an IfExpr")
+    coverage.markBranches(expr) match {
+      case (res, Some(ids)) =>
+        ids should have size 2
+        expr match {
+          case IfExpr(cond, thenn, elze) =>
+            res should equal (IfExpr(cond, Tuple(Seq(thenn, Variable(ids(0)))), Tuple(Seq(elze, Variable(ids(1))))))
+          case _ => fail(s"$expr is not an IfExpr")
+        }
+      case _ =>
+        fail("No ids added")
     }
   }
   
@@ -120,15 +144,52 @@ class InputCoverageSuite extends LeonTestSuiteWithProgram with Matchers with Sca
     val coverage = new InputCoverage(withIfInIf, Set(withIfInIf))
     val expr = withIfInIf.body.get
     
-    val (res, ids) = coverage.markBranches(expr)
-    ids should have size 4
-    expr match {
-      case IfExpr(IfExpr(cond, t1, e1), t2, e2) =>
-        res match {
-          case MatchExpr(IfExpr(c, t1, e1), Seq(MatchCase(TuplePattern(None, s), None, IfExpr(c2, t2, e2)))) if s.size == 2 =>
-          case _ => fail("should have a shape like (if() else ) match { case (a, b) => if(...) else }")
+    coverage.markBranches(expr) match {
+      case (res, None) => fail("No ids added")
+      case (res, Some(ids)) =>
+        ids should have size 4
+        expr match {
+          case IfExpr(IfExpr(cond, t1, e1), t2, e2) =>
+            res match {
+              case MatchExpr(IfExpr(c, t1, e1), Seq(MatchCase(TuplePattern(None, s), None, IfExpr(c2, t2, e2)))) if s.size == 2 =>
+              case _ => fail("should have a shape like (if() else ) match { case (a, b) => if(...) else }")
+            }
+          case _ => fail(s"$expr is not an IfExpr")
+        }
+    }
+  }
+  
+  test("Match coverage should work with recursive functions") { ctxprogram =>
+    implicit val (c, p) = ctxprogram
+    val withMatch = funDef("InputCoverageSuite.withMatch")
+    val coverage = new InputCoverage(withMatch, Set(withMatch))
+    val expr = withMatch.body.get
+    
+    coverage.markBranches(expr) match {
+      case (res, None) => fail("No ids added")
+      case (res, Some(ids)) =>
+        withClue(res.toString) {
+          ids should have size 4
+          res match {
+            case MatchExpr(scrut,
+                   Seq(
+                       MatchCase(CaseClassPattern(_, _, Seq()), None, rhs1),
+                       MatchCase(CaseClassPattern(_, _, _), None, rhs2),
+                       MatchCase(CaseClassPattern(_, _, _), None, rhs3),
+                       MatchCase(CaseClassPattern(_, _, _), None, rhs4))
+            ) =>
+              rhs1 match {
+                case Tuple(Seq(_, Variable(b))) => b shouldEqual ids(0)
+                case _ => fail(s"$rhs1 should be a Tuple")
+              }
+              rhs2 match {
+                case LetTuple(_, _, Tuple(Seq(_, Or(Seq(_, Variable(b)))))) => b shouldEqual ids(1)
+                case _ => fail(s"$rhs2 should be a val + tuple like val ... = ... ; (..., ... || ${ids(1)})")
+              }
+              
+            case _ => fail(s"$res does not have the format a match { case B() => .... x 4 }")
+          }
         }
-      case _ => fail(s"$expr is not an IfExpr")
     }
   }
   
@@ -139,24 +200,28 @@ class InputCoverageSuite extends LeonTestSuiteWithProgram with Matchers with Sca
     val coverage = new InputCoverage(withCoveredFun1, Set(withCoveredFun1, withCoveredFun2))
     val expr = withCoveredFun1.body.get
     
-    val (res, ids) = coverage.markBranches(expr)
-    ids should have size 2
-    
-    res match {
-      case MatchExpr(funCall, Seq(
-            MatchCase(TuplePattern(None, Seq(WildcardPattern(Some(a1)), WildcardPattern(Some(b1)))), None,
-              MatchExpr(funCall2, Seq(
-                MatchCase(TuplePattern(None, Seq(WildcardPattern(Some(a2)), WildcardPattern(Some(b2)))), None,
-                  Tuple(Seq(BVPlus(Variable(ida1), Variable(ida2)), Or(Seq(Variable(idb1), Variable(idb2)))))
-            )
-          ))))) =>
+    coverage.markBranches(expr) match {
+      case (res, Some(ids)) if ids.size > 0 =>
         withClue(res.toString) {
-          ida1 shouldEqual a1
-          ida2 shouldEqual a2
-          Set(idb1.uniqueName, idb2.uniqueName) shouldEqual Set(b1.uniqueName, b2.uniqueName)
+          fail(s"Should have not added any ids, but got $ids")
+        }
+      case (res, _) =>
+        res match {
+          case MatchExpr(funCall, Seq(
+                MatchCase(TuplePattern(None, Seq(WildcardPattern(Some(a1)), WildcardPattern(Some(b1)))), None,
+                  MatchExpr(funCall2, Seq(
+                    MatchCase(TuplePattern(None, Seq(WildcardPattern(Some(a2)), WildcardPattern(Some(b2)))), None,
+                      Tuple(Seq(BVPlus(Variable(ida1), Variable(ida2)), Or(Seq(Variable(idb1), Variable(idb2)))))
+                )
+              ))))) =>
+            withClue(res.toString) {
+              ida1.uniqueName shouldEqual a2.uniqueName
+              ida2.uniqueName shouldEqual a1.uniqueName
+              Set(idb1.uniqueName, idb2.uniqueName) shouldEqual Set(b1.uniqueName, b2.uniqueName)
+            }
+          case _ =>
+            fail(s"$res is not of type funCall() match { case (a1, b1) => funCall() match { case (a2, b2) => (a1 + a2, b1 || b2) } }")
         }
-      case _ =>
-        fail(s"$res is not of type funCall() match { case (a1, b1) => funCall() match { case (a2, b2) => (a1 + a2, b1 || b2) } }")
     }
   }
   
@@ -167,18 +232,20 @@ class InputCoverageSuite extends LeonTestSuiteWithProgram with Matchers with Sca
     val coverage = new InputCoverage(withCoveredFun3, Set(withCoveredFun3, withCoveredFun2))
     val expr = withCoveredFun3.body.get
     
-    val (res, ids) = coverage.markBranches(expr)
-    ids should have size 2
-    res match {
-      case MatchExpr(funCall, Seq(
-            MatchCase(TuplePattern(None, Seq(WildcardPattern(Some(a)), WildcardPattern(Some(b1)))), None,
-              MatchExpr(FunctionInvocation(_, Seq(Variable(ida))), Seq(
-                MatchCase(TuplePattern(None, Seq(WildcardPattern(_), WildcardPattern(Some(b2)))), None,
-                  Tuple(Seq(p, Or(Seq(Variable(id1), Variable(id2)))))
-            )
-          ))))) if ida == a && id1 == b1 && id2 == b2 =>
-      case _ =>
-        fail(s"$res is not of type funCall() match { case (a, b1) => funCall(a) match { case (c, b2) => (c, b1 || b2) } }")
+    coverage.markBranches(expr) match {
+      case (res, None) => fail("No ids added")
+      case (res, Some(ids)) =>
+        res match {
+          case MatchExpr(funCall, Seq(
+                MatchCase(TuplePattern(None, Seq(WildcardPattern(Some(a)), WildcardPattern(Some(b1)))), None,
+                  MatchExpr(FunctionInvocation(_, Seq(Variable(ida))), Seq(
+                    MatchCase(TuplePattern(None, Seq(WildcardPattern(_), WildcardPattern(Some(b2)))), None,
+                      Tuple(Seq(p, Or(Seq(Variable(id1), Variable(id2)))))
+                )
+              ))))) if ida == a && id1 == b1 && id2 == b2 =>
+          case _ =>
+            fail(s"$res is not of type funCall() match { case (a, b1) => funCall(a) match { case (c, b2) => (c, b1 || b2) } }")
+        }
     }
   }
 }
\ No newline at end of file
-- 
GitLab