From 028f3691e2aa36c5e641280158da440fa972b574 Mon Sep 17 00:00:00 2001
From: "Emmanouil (Manos) Koukoutos" <emmanouil.koukoutos@epfl.ch>
Date: Thu, 18 Dec 2014 19:13:09 +0100
Subject: [PATCH] Correct test minimization to handle preconditions correctly

---
 .../leon/evaluators/CollectingEvaluator.scala |  39 ------
 .../leon/repair/RepairTrackingEvaluator.scala | 114 ++++++++++++++++++
 src/main/scala/leon/repair/Repairman.scala    |  59 ++++-----
 3 files changed, 139 insertions(+), 73 deletions(-)
 delete mode 100644 src/main/scala/leon/evaluators/CollectingEvaluator.scala
 create mode 100644 src/main/scala/leon/repair/RepairTrackingEvaluator.scala

diff --git a/src/main/scala/leon/evaluators/CollectingEvaluator.scala b/src/main/scala/leon/evaluators/CollectingEvaluator.scala
deleted file mode 100644
index 0804d1ab2..000000000
--- a/src/main/scala/leon/evaluators/CollectingEvaluator.scala
+++ /dev/null
@@ -1,39 +0,0 @@
-package leon.evaluators
-
-import scala.collection.immutable.Map
-import leon.purescala.Common._
-import leon.purescala.Trees._
-import leon.purescala.Definitions._
-import leon.LeonContext
-
-abstract class CollectingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog, 50000) {
-  type RC = DefaultRecContext
-  type GC = CollectingGlobalContext
-  type ES = Seq[Expr]
-  
-  def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings)
-  def initGC = new CollectingGlobalContext()
-  
-  class CollectingGlobalContext extends GlobalContext {
-    var collected : Set[Seq[Expr]] = Set()
-    def collect(es : ES) = collected += es
-  }
-  case class DefaultRecContext(mappings: Map[Identifier, Expr]) extends RecContext {
-    def withVars(news: Map[Identifier, Expr]) = copy(news)
-  }
-  
-  // A function that returns a Seq[Expr]
-  // This expressions will be evaluated in the current context and then collected in the global environment
-  def collecting(e : Expr) : Option[ES]
-  
-  override def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = {
-    for {
-      es <- collecting(expr) 
-      evaled = es map e
-    } gctx.collect(evaled)
-    super.e(expr)
-  }
-  
-  def collected : Set[ES] = lastGC map { _.collected } getOrElse Set()
-  
-}
diff --git a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala
new file mode 100644
index 000000000..3fcb7f331
--- /dev/null
+++ b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala
@@ -0,0 +1,114 @@
+package leon.repair
+
+import scala.collection.immutable.Map
+import scala.collection.mutable.{Map => MMap}
+import leon.purescala.Common._
+import leon.purescala.Trees._
+import leon.purescala.TypeTrees._
+import leon.purescala.Definitions._
+import leon.LeonContext
+import leon.evaluators.RecursiveEvaluator
+
+abstract class RepairTrackingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog, 50000) {
+  type RC = CollectingRecContext
+  type GC = GlobalContext
+  
+  def initRC(mappings: Map[Identifier, Expr]) = CollectingRecContext(mappings, None)
+  def initGC = new GlobalContext()
+  
+  type FI = (FunDef, Seq[Expr])
+  
+  // This is a call graph to track dependencies of function invocations.
+  // If fi1 calls fi2 but fails fi2's precondition, we consider it 
+  // fi1's fault and we don't register the dependency.
+  private val callGraph : MMap[FI, Set[FI]] = MMap().withDefaultValue(Set())
+  private def registerCall(fi : FI, lastFI : Option[FI]) = {
+    lastFI foreach { lfi => 
+      callGraph update (lfi, callGraph(lfi) + fi) 
+    }
+  }
+  def fullCallGraph = leon.utils.GraphOps.transitiveClosure(callGraph.toMap)
+  
+  // Tracks if every function invocation succeeded or failed
+  private val fiStatus_ : MMap[FI, Boolean] = MMap().withDefaultValue(false)
+  private def registerSuccessful(fi : FI) = fiStatus_ update (fi, true )
+  private def registerFailed    (fi : FI) = fiStatus_ update (fi, false)
+  def fiStatus = fiStatus_.toMap.withDefaultValue(false)
+  
+  case class CollectingRecContext(mappings: Map[Identifier, Expr], lastFI : Option[FI]) extends RecContext {
+    def withVars(news: Map[Identifier, Expr]) = copy(news, lastFI)
+    def withLastFI(fi : FI) = copy(lastFI = Some(fi))
+  }
+  
+  override def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = expr match {
+    case FunctionInvocation(tfd, args) =>
+      if (gctx.stepsLeft < 0) {
+        throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")")
+      }
+      gctx.stepsLeft -= 1
+      
+      val evArgs = args.map(a => e(a))
+
+      // We consider this function invocation successful, unless the opposite is proven.
+      registerSuccessful(tfd.fd, evArgs)
+      
+      // build a mapping for the function...
+      val frameBlamingCaller = rctx.withVars((tfd.params.map(_.id) zip evArgs).toMap)
+      
+      if(tfd.hasPrecondition) {
+        e(tfd.precondition.get)(frameBlamingCaller, gctx) match {
+          case BooleanLiteral(true) => 
+            // Only register a call dependency if the call we depend on does not fail precondition
+            registerCall((tfd.fd, evArgs), rctx.lastFI)
+          case BooleanLiteral(false) =>
+            // Caller's fault!
+            rctx.lastFI foreach registerFailed
+            throw RuntimeError("Precondition violation for " + tfd.id.name + " reached in evaluation.: " + tfd.precondition.get)
+          case other =>
+            // Caller's fault!
+            rctx.lastFI foreach registerFailed
+            throw RuntimeError(typeErrorMsg(other, BooleanType))
+        }
+      } else {
+        registerCall((tfd.fd, evArgs), rctx.lastFI)
+      }
+
+      if(!tfd.hasBody && !rctx.mappings.isDefinedAt(tfd.id)) {
+        throw EvalError("Evaluation of function with unknown implementation.")
+      }
+
+      val body = tfd.body.getOrElse(rctx.mappings(tfd.id))
+
+      val frameBlamingCallee = frameBlamingCaller.withLastFI(tfd.fd, evArgs)
+      
+      val callResult = e(body)(frameBlamingCallee, gctx)
+
+      if(tfd.hasPostcondition) {
+        val (id, post) = tfd.postcondition.get
+
+        e(post)(frameBlamingCallee.withNewVar(id, callResult), gctx) match {
+          case BooleanLiteral(true) =>
+          case BooleanLiteral(false) =>
+            // Callee's fault
+            registerFailed(tfd.fd, evArgs)
+            throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.")
+          case other =>
+            // Callee's fault
+            registerFailed(tfd.fd, evArgs)
+            throw EvalError(typeErrorMsg(other, BooleanType))
+        }
+      }
+
+      callResult
+
+    case other =>
+      try {
+        super.e(other)
+      } catch {
+        case t : Throwable =>
+          rctx.lastFI foreach registerFailed
+          throw t
+      }
+  }
+    
+}
diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala
index 1c34786ad..5c01a389c 100644
--- a/src/main/scala/leon/repair/Repairman.scala
+++ b/src/main/scala/leon/repair/Repairman.scala
@@ -150,52 +150,43 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout
 
     // We exclude redundant failing tests, and only select the minimal tests
     val minimalFailingTests = {
+            
+      type FI = (FunDef, Seq[Expr])
+      
       // We don't want tests whose invocation will call other failing tests.
       // This is because they will appear erroneous, 
       // even though the error comes from the called test
-      val testEval : CollectingEvaluator = new CollectingEvaluator(ctx, program){
-        def collecting(e : Expr) : Option[Seq[Expr]] = e match {
-          case fi@FunctionInvocation(TypedFunDef(`fd`, _), args) => 
-            Some(args)
-          case _ => None
-        }
+      val testEval : RepairTrackingEvaluator = new RepairTrackingEvaluator(ctx, program) {
+        def withFilter(fi : FI) = fi._1 == fd
       }
 
       val passingTs = for (test <- passingTests) yield InExample(test.ins)
       val failingTs = for (test <- failingTests) yield InExample(test.ins)
 
-      val test2Tests : Map[InExample, Set[InExample]] = (failingTs ++ passingTs).map{ ts => 
-        testEval.eval(body, args.zip(ts.ins).toMap)
-        (ts, testEval.collected map (InExample(_)))
-      }.toMap
-
-      val recursiveTests : Set[InExample] = test2Tests.values.toSet.flatten -- (failingTs ++ passingTs)
-
-      val testsTransitive : Map[InExample,Set[InExample]] = 
-        leon.utils.GraphOps.transitiveClosure[InExample](
-          test2Tests ++ recursiveTests.map ((_,Set[InExample]()))
-        )
-
-      val knownWithResults : Map[InExample, Boolean] = (failingTs.map((_, false)).toMap) ++ (passingTs.map((_,true)))
-
-      val recWithResults : Map[InExample, Boolean] = recursiveTests.map { ex =>
-        (ex, evaluator.eval(spec, (args zip ex.ins).toMap + (out -> body)) match {
-          case EvaluationResults.Successful(BooleanLiteral(true))  => true
-          case _ => false
-        })
-      }.toMap
-
-      val allWithResults = knownWithResults ++ recWithResults
+      (failingTs ++ passingTs) foreach { ts => 
+        testEval.eval(functionInvocation(fd, ts.ins))
+      }
+      
+      val test2Tests : Map[FI, Set[FI]] = testEval.fullCallGraph
+      
+      println("CALL GRAPH")
+      for {
+        ((fi, args), tos) <- test2Tests
+        (tofi, toArgs) <- tos
+      }{
+        println(s"${fi.id}(${args mkString ", "}) ----> ${tofi.id}(${toArgs mkString ", "})")
+      }
 
+      def isFailing(fi : FI) = !testEval.fiStatus(fi) && (fi._1 == fd)
+      val failing = test2Tests filter { case (from, to) => 
+         isFailing(from) && (to forall (!isFailing(_)) )
+      }
 
-      testsTransitive.collect {
-        case (rest, called) if !allWithResults(rest) && (called forall allWithResults) => 
-          rest
-      }.toSeq
+      failing.keySet map { case (_, args) => InExample(args) }
     }
 
     reporter.ifDebug { printer =>
-      printer(new ExamplesTable("Minimal failing:", minimalFailingTests).toString)
+      printer(new ExamplesTable("Minimal failing:", minimalFailingTests.toSeq).toString)
     }
 
     // Check how an expression behaves on tests
@@ -210,7 +201,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout
           case EvaluationResults.Successful(BooleanLiteral(false)) => Some(false)
           case e => None
         }
-      }.distinct
+      }
 
       if (results.size == 1) {
         results.head
-- 
GitLab