From 86345e4808fa93b1e7d90b836bdea4e7ea372b2e Mon Sep 17 00:00:00 2001
From: "Emmanouil (Manos) Koukoutos" <emmanouil.koukoutos@epfl.ch>
Date: Fri, 28 Nov 2014 16:23:00 +0100
Subject: [PATCH] Passes: examples in specs

---
 library/lang/package.scala                    |  7 +++
 .../scala/leon/codegen/CodeGeneration.scala   |  3 ++
 .../leon/evaluators/RecursiveEvaluator.scala  |  3 ++
 .../leon/evaluators/TracingEvaluator.scala    |  6 ++-
 .../leon/frontends/scalac/ASTExtractors.scala | 22 +++++++++
 .../frontends/scalac/CodeExtraction.scala     | 16 ++++++-
 .../scala/leon/purescala/Constructors.scala   | 20 +++++----
 .../scala/leon/purescala/Extractors.scala     | 12 +++++
 .../scala/leon/purescala/PrettyPrinter.scala  | 14 ++++--
 .../leon/purescala/ScopeSimplifier.scala      |  4 +-
 .../leon/purescala/TransformerWithPC.scala    |  7 ++-
 src/main/scala/leon/purescala/TreeOps.scala   | 45 +++++++------------
 src/main/scala/leon/purescala/Trees.scala     | 33 +++++++++-----
 src/main/scala/leon/refactor/Repairman.scala  | 12 +++--
 .../leon/solvers/z3/AbstractZ3Solver.scala    |  3 ++
 15 files changed, 143 insertions(+), 64 deletions(-)

diff --git a/library/lang/package.scala b/library/lang/package.scala
index 128f13df4..94578101e 100644
--- a/library/lang/package.scala
+++ b/library/lang/package.scala
@@ -32,5 +32,12 @@ package object lang {
   implicit class Gives[A](v : A) {
     def gives[B](tests : A => B) : B = tests(v)
   }
+ 
+  @ignore
+  implicit class Passes[A,B](io : (A,B)) {
+    val (in, out) = io
+    def passes(tests : A => B ) : Boolean = 
+      try { tests(in) == out } catch { case _ : MatchError => true }
+  }
 
 }
diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala
index 04237df5e..0ac7ae01a 100644
--- a/src/main/scala/leon/codegen/CodeGeneration.scala
+++ b/src/main/scala/leon/codegen/CodeGeneration.scala
@@ -699,6 +699,9 @@ trait CodeGeneration {
       case This(ct) =>
         ch << ALoad(0) // FIXME what if doInstrument etc
         
+      case p : Passes => 
+        mkExpr(matchToIfThenElse(p.asConstraint), ch)
+
       case m : MatchExpr => 
         mkExpr(matchToIfThenElse(m), ch)
       
diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
index 50643a440..dcc8c493f 100644
--- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
@@ -397,6 +397,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
 
     case g : Gives =>
       e(convertHoles(g, ctx, true)) 
+  
+    case p : Passes => 
+      e(p.asConstraint)
 
     case choose: Choose =>
       import purescala.TreeOps.simplestValue
diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala
index 2dd541997..c95b5f716 100644
--- a/src/main/scala/leon/evaluators/TracingEvaluator.scala
+++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala
@@ -33,7 +33,11 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex
           val res = e(b)(rctx.withNewVar(i, first), gctx)
           (res, first)
 
-        case MatchLike(scrut, cases, _) =>
+        case p: Passes =>
+           val r = e(p.asConstraint)
+           (r, r)
+
+        case MatchExpr(scrut, cases) =>
           val rscrut = e(scrut)
 
           val r = cases.toStream.map(c => matchesCase(rscrut, c)).find(_.nonEmpty) match {
diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
index 7a13e372a..a7b5f1e87 100644
--- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
+++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
@@ -178,6 +178,28 @@ trait ASTExtractors {
         case _ => None
       }
     }
+ 
+    object ExPasses { 
+      def unapply(tree : Apply) : Option[(Tree, Tree, List[CaseDef])] = tree match {
+        case  Apply(
+                Select(
+                  Apply(
+                    TypeApply(
+                      ExSelected("leon", "lang", "package", "Passes"), 
+                      _ :: _ :: Nil
+                    ), 
+                    ExpressionExtractors.ExTuple(_, Seq(in,out)) :: Nil
+                  ), 
+                  ExNamed("passes")
+                ),
+                (Function(
+                  (_ @ ValDef(_, _, _, EmptyTree)) :: Nil, 
+                  ExpressionExtractors.ExPatternMatching(_,tests))) :: Nil
+              )
+          => Some((in, out, tests))
+        case _ => None
+      }
+    }
 
 
 
diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index 0cbd20953..0b5bf20a7 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -20,7 +20,7 @@ import purescala.Definitions.{
 import purescala.Trees.{Expr => LeonExpr, This => LeonThis, _}
 import purescala.TypeTrees.{TypeTree => LeonType, _}
 import purescala.Common._
-import purescala.Extractors.IsTyped
+import purescala.Extractors.{IsTyped,UnwrapTuple}
 import purescala.Constructors._
 import purescala.TreeOps._
 import purescala.TypeTreeOps._
@@ -1002,7 +1002,19 @@ trait CodeExtraction extends ASTExtractors {
           rest = None
 
           Require(pre, b)
- 
+
+        case ExPasses(in, out, cases) =>
+          val ine = extractTree(in)
+          val oute = extractTree(out)
+          val rc = cases.map(extractMatchCase(_))
+
+          val UnwrapTuple(ines) = ine
+          (oute +: ines) foreach {
+            case Variable(_) => { }
+            case other => ctx.reporter.fatalError(other.getPos, "Only i/o variables are allowed in i/o examples")
+          }
+          passes(ine, oute, rc)
+
         case ExGives(sel, cses) =>
           val rs = extractTree(sel)
           val rc = cses.map(extractMatchCase(_))
diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala
index ffcd1b072..bb3b8156f 100644
--- a/src/main/scala/leon/purescala/Constructors.scala
+++ b/src/main/scala/leon/purescala/Constructors.scala
@@ -34,14 +34,14 @@ object Constructors {
     }
   }
 
-  def tupleWrap(es: Seq[Expr]): Expr = if (es.size > 1) {
-    Tuple(es)
-  } else {
-    es.head
+  def tupleWrap(es: Seq[Expr]): Expr = es match {
+    case Seq() => UnitLiteral()
+    case Seq(elem) => elem 
+    case more => Tuple(more)
   }
 
-  private def filterCases(scrutinee: Expr, cases: Seq[MatchCase]): Seq[MatchCase] = {
-    scrutinee.getType match {
+  private def filterCases(scrutType : TypeTree, cases: Seq[MatchCase]): Seq[MatchCase] = {
+    scrutType match {
       case c: CaseClassType =>
         cases.filter(_.pattern match {
           case CaseClassPattern(_, cct, _) if cct.classDef != c.classDef => false
@@ -57,10 +57,14 @@ object Constructors {
   }
 
   def gives(scrutinee : Expr, cases : Seq[MatchCase]) : Gives =
-    Gives(scrutinee, filterCases(scrutinee, cases))
+    Gives(scrutinee, filterCases(scrutinee.getType, cases))
   
+  def passes(in : Expr, out : Expr, cases : Seq[MatchCase]) : Passes = {
+    Passes(in, out, filterCases(in.getType, cases))
+  }
+
   def matchExpr(scrutinee : Expr, cases : Seq[MatchCase]) : MatchExpr = 
-    MatchExpr(scrutinee, filterCases(scrutinee, cases))
+    MatchExpr(scrutinee, filterCases(scrutinee.getType, cases))
 
   def and(exprs: Expr*): Expr = {
     val flat = exprs.flatMap(_ match {
diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index 5146387b9..439d0888c 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -231,6 +231,11 @@ object Extractors {
         (m.scrutinee, m.cases, m match {
           case _ : MatchExpr  => matchExpr
           case _ : Gives      => gives
+          case _ : Passes     => 
+            (s, cases) => {
+              val Tuple(Seq(in, out)) = s
+              passes(in,out,cases)
+            }
         })
       }
     }
@@ -250,4 +255,11 @@ object Extractors {
     }
   }
 
+  object UnwrapTuple {
+    def unapply(e : Expr) : Option[Seq[Expr]] = Option(e) map {
+      case Tuple(subs) => subs
+      case other => Seq(other)
+    }
+  }
+
 }
diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index b203c27a1..b8ebf7c95 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -223,10 +223,16 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe
         optP {
           p"""|$s gives {
               |  ${nary(tests, "\n")}
-              |}
-              |"""
+              |}"""
         }
       
+      case p@Passes(in, out, tests) =>
+        optP {
+          p"""|${p.scrutinee} passes {
+              |  ${nary(tests, "\n")}
+              |}"""
+        }
+
       case c @ WithOracle(vars, pred) =>
         p"""|withOracle { (${typed(vars)}) =>
             |  $pred
@@ -641,7 +647,7 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe
     case (_: Require, _) => true
     case (_: Assert, Some(_: Definition)) => true
     case (_, Some(_: Definition)) => false
-    case (_, Some(_: MatchExpr | _: MatchCase | _: Let | _: LetTuple | _: LetDef)) => false
+    case (_, Some(_: MatchExpr | _: MatchCase | _: Let | _: LetTuple | _: LetDef )) => false
     case (_, _) => true
   }
 
@@ -669,7 +675,7 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe
     case (BinaryMethodCall(_, _, _), Some(_: FunctionInvocation)) => true
     case (_, Some(_: FunctionInvocation)) => false
     case (ie: IfExpr, _) => true
-    case (me: MatchExpr, _ ) => true
+    case (me: MatchLike, _ ) => true
     case (e1: Expr, Some(e2: Expr)) if precedence(e1) > precedence(e2) => false
     case (_, _) => true
   }
diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala
index 7d3024bde..010c70812 100644
--- a/src/main/scala/leon/purescala/ScopeSimplifier.scala
+++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala
@@ -75,7 +75,7 @@ class ScopeSimplifier extends Transformer {
       val sb = rec(b, newScope)
       LetTuple(sis, se, sb)
 
-    case MatchLike(scrut, cases, builder) =>
+    case MatchExpr(scrut, cases) =>
       val rs = rec(scrut, scope)
 
       def trPattern(p: Pattern, scope: Scope): (Pattern, Scope) = {
@@ -112,7 +112,7 @@ class ScopeSimplifier extends Transformer {
         (newPattern, curScope)
       }
 
-      builder(rs, cases.map { c =>
+      MatchExpr(rs, cases.map { c =>
         val (newP, newScope) = trPattern(c.pattern, scope)
 
         c match {
diff --git a/src/main/scala/leon/purescala/TransformerWithPC.scala b/src/main/scala/leon/purescala/TransformerWithPC.scala
index 493b4639c..858f02a5f 100644
--- a/src/main/scala/leon/purescala/TransformerWithPC.scala
+++ b/src/main/scala/leon/purescala/TransformerWithPC.scala
@@ -21,12 +21,15 @@ abstract class TransformerWithPC extends Transformer {
       val sb = rec(b, register(Equals(Variable(i), se), path))
       Let(i, se, sb).copiedFrom(e)
 
-    case MatchLike(scrut, cases, builder) =>
+    case p: Passes =>
+      rec(p.asConstraint, path)
+
+    case MatchExpr(scrut, cases) =>
       val rs = rec(scrut, path)
 
       var soFar = path
 
-      builder(rs, cases.map { c =>
+      MatchExpr(rs, cases.map { c =>
         val patternExprPos = conditionForPattern(rs, c.pattern, includeBinders = true)
         val patternExprNeg = conditionForPattern(rs, c.pattern, includeBinders = false)
 
diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index bea32ebe2..33cc9699d 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -852,22 +852,22 @@ object TreeOps {
     postMap(rewritePM)(expr)
   }
 
-  def matchCasePathConditions(m : MatchLike, pathCond: List[Expr]) : Seq[List[Expr]] = m match {
-    case MatchLike(scrut, cases, _) => 
-      var pcSoFar = pathCond
-      for (c <- cases) yield {
-
-        val g = c.optGuard getOrElse BooleanLiteral(true)
-        val cond = conditionForPattern(scrut, c.pattern, includeBinders = true)
-        val localCond = pcSoFar :+ cond :+ g
-        
-        // These contain no binders defined in this MatchCase
-        val condSafe = conditionForPattern(scrut, c.pattern)
-        val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern),g)
-        pcSoFar ::= not(and(condSafe, gSafe))
+  def matchCasePathConditions(m: MatchExpr, pathCond: List[Expr]) : Seq[List[Expr]] = {
+    val MatchExpr(scrut, cases) = m
+    var pcSoFar = pathCond
+    for (c <- cases) yield {
+
+      val g = c.optGuard getOrElse BooleanLiteral(true)
+      val cond = conditionForPattern(scrut, c.pattern, includeBinders = true)
+      val localCond = pcSoFar :+ cond :+ g
+      
+      // These contain no binders defined in this MatchCase
+      val condSafe = conditionForPattern(scrut, c.pattern)
+      val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern),g)
+      pcSoFar ::= not(and(condSafe, gSafe))
 
-        localCond
-      }
+      localCond
+    }
   }
 
 
@@ -1550,19 +1550,8 @@ object TreeOps {
           fdHomo(fd1, fd2) &&
           isHomo(e1, e2)(map + (fd1.id -> fd2.id))
 
-        case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) =>
-          if (cs1.size == cs2.size) {
-            isHomo(s1, s2) && casesMatch(cs1,cs2)
-          } else {
-            false
-          }
-        
-        case (Gives(s1, cs1), Gives(s2, cs2)) =>
-          if (cs1.size == cs2.size) {
-            isHomo(s1, s2) && casesMatch(cs1,cs2)
-          } else {
-            false
-          }
+        case Same(MatchLike(s1, cs1, _), MatchLike(s2, cs2, _)) =>
+          cs1.size == cs2.size && isHomo(s1, s2) && casesMatch(cs1,cs2)
 
         case (FunctionInvocation(tfd1, args1), FunctionInvocation(tfd2, args2)) =>
           // TODO: Check type params
diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index bf90a8aeb..70b9c2116 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -46,7 +46,7 @@ object Trees {
   }
 
   case class Choose(vars: List[Identifier], pred: Expr) extends Expr with UnaryExtractable {
-    assert(!vars.isEmpty)
+    require(!vars.isEmpty)
 
     def getType = if (vars.size > 1) TupleType(vars.map(_.getType)) else vars.head.getType
 
@@ -60,7 +60,7 @@ object Trees {
   }
 
   case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr {
-    assert(value.getType.isInstanceOf[TupleType],
+    require(value.getType.isInstanceOf[TupleType],
            "The definition value in LetTuple must be of some tuple type; yet we got [%s]. In expr: \n%s".format(value.getType, this))
 
     def getType = body.getType
@@ -169,11 +169,11 @@ object Trees {
 
   // Index is 1-based, first element of tuple is 1.
   case class TupleSelect(tuple: Expr, index: Int) extends Expr {
-    assert(index >= 1)
+    require(index >= 1)
 
     def getType = tuple.getType match {
       case TupleType(ts) =>
-        assert(index <= ts.size)
+        require(index <= ts.size)
         ts(index - 1)
 
       case _ =>
@@ -188,17 +188,30 @@ object Trees {
   }
 
   case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends MatchLike {
-    assert(cases.nonEmpty)
+    require(cases.nonEmpty)
   }
   
   case class Gives(scrutinee: Expr, cases : Seq[MatchCase]) extends MatchLike {
-    assert(cases.nonEmpty)
+    require(cases.nonEmpty)
     def asIncompleteMatch = {
       val theHole = SimpleCase(WildcardPattern(None), Hole(this.getType, Seq()))
       MatchExpr(scrutinee, cases :+ theHole)
     }
+  } 
+  
+  case class Passes(in: Expr, out : Expr, cases : Seq[MatchCase]) extends MatchLike {
+    require(cases.nonEmpty)
+
+    override def getType = BooleanType
+    val scrutinee = Tuple(Seq(in, out))
+    
+    def asConstraint = {
+      val defaultCase = SimpleCase(WildcardPattern(None), out)
+      Equals(out, MatchExpr(in, cases :+ defaultCase))
+    }
   }
 
+
   sealed abstract class MatchCase extends Tree {
     val pattern: Pattern
     val rhs: Expr
@@ -246,7 +259,7 @@ object Trees {
   case class And(exprs: Seq[Expr]) extends Expr {
     def getType = BooleanType
 
-    assert(exprs.size >= 2)
+    require(exprs.size >= 2)
   }
 
   object And {
@@ -256,7 +269,7 @@ object Trees {
   case class Or(exprs: Seq[Expr]) extends Expr {
     def getType = BooleanType
 
-    assert(exprs.size >= 2)
+    require(exprs.size >= 2)
   }
 
   object Or {
@@ -456,7 +469,7 @@ object Trees {
 
   // Provide an oracle (synthesizable, all-seeing choose)
   case class WithOracle(oracles: List[Identifier], body: Expr) extends Expr with UnaryExtractable {
-    assert(!oracles.isEmpty)
+    require(!oracles.isEmpty)
 
     def getType = body.getType
 
@@ -501,7 +514,7 @@ object Trees {
     val getType = MultisetType(baseType)
   }
   case class FiniteMultiset(elements: Seq[Expr]) extends Expr {
-    assert(elements.size > 0)
+    require(elements.nonEmpty)
     def getType = MultisetType(leastUpperBound(elements.map(_.getType)).getOrElse(Untyped))
   }
   case class Multiplicity(element: Expr, multiset: Expr) extends Expr {
diff --git a/src/main/scala/leon/refactor/Repairman.scala b/src/main/scala/leon/refactor/Repairman.scala
index 1e8a6299a..c9bd7e9fc 100644
--- a/src/main/scala/leon/refactor/Repairman.scala
+++ b/src/main/scala/leon/refactor/Repairman.scala
@@ -41,17 +41,15 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) {
     // Compute tests
     val out = fd.postcondition.map(_._1).getOrElse(FreshIdentifier("res", true).setType(fd.returnType))
 
-    val tfd = program.library.passes.get.typed(Seq(argsWrapped.getType, out.getType))
-
     val inouts = testBank;
 
-    val testsExpr = FiniteMap(inouts.collect {
+    val testsCases = inouts.collect {
       case InOutExample(ins, outs) =>
-        tupleWrap(ins) -> tupleWrap(outs)
-    }.toList).setType(MapType(argsWrapped.getType, out.getType))
+        GuardedCase(WildcardPattern(None), Equals(argsWrapped, tupleWrap(ins)), tupleWrap(outs))
+    }.toList
 
-    val passes = if (testsExpr.singletons.nonEmpty) {
-      FunctionInvocation(tfd, Seq(argsWrapped, out.toVariable, testsExpr))
+    val passes = if (testsCases.nonEmpty) {
+      Passes(argsWrapped, out.toVariable, testsCases)
     } else {
       BooleanLiteral(true)
     }
diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
index 87a1e8e06..ba1945ae4 100644
--- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
+++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
@@ -495,6 +495,9 @@ trait AbstractZ3Solver
     }
 
     def rec(ex: Expr): Z3AST = ex match {
+      case p @ Passes(_, _, _) =>
+        rec(p.asConstraint)
+
       case me @ MatchExpr(s, cs) =>
         rec(matchToIfThenElse(me))
       
-- 
GitLab