From cc2f0edb73b07cd4c58c9f63d0a07fd1db304365 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <etienne.kneuss@epfl.ch>
Date: Fri, 12 Dec 2014 13:41:23 +0100
Subject: [PATCH] labels should be bestRealType. Filter cases by incompatible
 return type

---
 .../scala/leon/purescala/Constructors.scala   |  18 ++-
 src/main/scala/leon/repair/Repairman.scala    | 115 +++++++++---------
 .../synthesis/rules/EquivalentInputs.scala    |  10 +-
 .../synthesis/utils/ExpressionGrammar.scala   |  13 +-
 4 files changed, 89 insertions(+), 67 deletions(-)

diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala
index b68f8903a..7f56b2c32 100644
--- a/src/main/scala/leon/purescala/Constructors.scala
+++ b/src/main/scala/leon/purescala/Constructors.scala
@@ -7,6 +7,7 @@ import utils._
 
 object Constructors {
   import Trees._
+  import TypeTreeOps._
   import Common._
   import TypeTrees._
 
@@ -52,8 +53,8 @@ object Constructors {
     case more => TupleType(more)
   }
 
-  private def filterCases(scrutType : TypeTree, cases: Seq[MatchCase]): Seq[MatchCase] = {
-    scrutType match {
+  private def filterCases(scrutType : TypeTree, resType: Option[TypeTree], cases: Seq[MatchCase]): Seq[MatchCase] = {
+    val casesFiltered = scrutType match {
       case c: CaseClassType =>
         cases.filter(_.pattern match {
           case CaseClassPattern(_, cct, _) if cct.classDef != c.classDef => false
@@ -66,13 +67,20 @@ object Constructors {
       case t =>
         scala.sys.error("Constructing match expression on non-supported type: "+t)
     }
+
+    resType match {
+      case Some(tpe) =>
+        casesFiltered.filter(c => isSubtypeOf(c.rhs.getType, tpe) || isSubtypeOf(tpe, c.rhs.getType))
+      case None =>
+        casesFiltered
+    }
   }
 
   def gives(scrutinee : Expr, cases : Seq[MatchCase]) : Gives =
-    Gives(scrutinee, filterCases(scrutinee.getType, cases))
+    Gives(scrutinee, filterCases(scrutinee.getType, None, cases))
   
   def passes(in : Expr, out : Expr, cases : Seq[MatchCase]): Expr = {
-    val resultingCases = filterCases(in.getType, cases)
+    val resultingCases = filterCases(in.getType, Some(out.getType), cases)
     if (resultingCases.nonEmpty) {
       Passes(in, out, resultingCases)
     } else {
@@ -81,7 +89,7 @@ object Constructors {
   }
 
   def matchExpr(scrutinee : Expr, cases : Seq[MatchCase]) : Expr ={
-    val filtered = filterCases(scrutinee.getType, cases)
+    val filtered = filterCases(scrutinee.getType, None, cases)
     if (filtered.nonEmpty)
       MatchExpr(scrutinee, filtered)
     else 
diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala
index 561359715..0d87c91ce 100644
--- a/src/main/scala/leon/repair/Repairman.scala
+++ b/src/main/scala/leon/repair/Repairman.scala
@@ -30,6 +30,64 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout
 
   implicit val debugSection = DebugSectionRepair
 
+  def repair() = {
+    reporter.info(ASCIIHelpers.title("1. Discovering tests for "+fd.id))
+    val (tests, isVerified) = discoverTests
+
+    if (isVerified) {
+      reporter.info("Program verifies!")
+    }
+
+    reporter.info(ASCIIHelpers.title("2. Locating/Focusing synthesis problem"))
+    val synth = getSynthesizer(tests)
+    val p     = synth.problem
+
+    var solutions = List[Solution]()
+
+    reporter.info(ASCIIHelpers.title("3. Synthesizing"))
+    reporter.info(p)
+
+    synth.synthesize() match {
+      case (search, sols) =>
+        for (sol <- sols) {
+
+          // Validate solution if not trusted
+          if (!sol.isTrusted) {
+            reporter.info("Found untrusted solution! Verifying...")
+            val (npr, fds) = synth.solutionToProgram(sol)
+
+            getVerificationCounterExamples(fds.head, npr) match {
+              case Some(ces) =>
+                reporter.error("I ended up finding this counter example:\n"+ces.mkString("  |  "))
+
+              case None =>
+                solutions ::= sol
+                reporter.info("Solution was not trusted but verification passed!")
+            }
+          } else {
+            reporter.info("Found trusted solution!")
+            solutions ::= sol
+          }
+        }
+
+        if (synth.options.generateDerivationTrees) {
+          val dot = new DotGenerator(search.g)
+          dot.writeFile("derivation"+DotGenerator.nextId()+".dot")
+        }
+
+        if (solutions.isEmpty) {
+          reporter.error(ASCIIHelpers.title("Failed to repair!"))
+        } else {
+          reporter.info(ASCIIHelpers.title("Repair successful:"))
+          for ((sol, i) <- solutions.zipWithIndex) {
+            reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":"))
+            val expr = sol.toSimplifiedExpr(ctx, program)
+            reporter.info(ScalaPrinter(expr));
+          }
+        }
+      }
+  }
+
   def getSynthesizer(tests: List[Example]): Synthesizer = {
     // Create a fresh function
     val nid = FreshIdentifier(fd.id.name+"_repair").copiedFrom(fd.id)
@@ -224,7 +282,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout
   }
 
 
-  def discoverTests: List[Example] = {
+  def discoverTests: (List[Example], Boolean) = {
     
     import bonsai._
     import bonsai.enumerators._
@@ -288,62 +346,9 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout
     // Try to verify, if it fails, we have at least one CE
     val ces = getVerificationCounterExamples(fd, program) getOrElse Nil 
 
-    tests ++ ces
+    (tests ++ ces, ces.isEmpty)
   }
 
-  def repair() = {
-    reporter.info(ASCIIHelpers.title("1. Discovering tests for "+fd.id))
-    val tests = discoverTests
-
-    reporter.info(ASCIIHelpers.title("2. Locating/Focusing synthesis problem"))
-    val synth = getSynthesizer(tests)
-    val p     = synth.problem
-
-    var solutions = List[Solution]()
-
-    reporter.info(ASCIIHelpers.title("3. Synthesizing"))
-    reporter.info(p)
-
-    synth.synthesize() match {
-      case (search, sols) =>
-        for (sol <- sols) {
-
-          // Validate solution if not trusted
-          if (!sol.isTrusted) {
-            reporter.info("Found untrusted solution! Verifying...")
-            val (npr, fds) = synth.solutionToProgram(sol)
-
-            getVerificationCounterExamples(fds.head, npr) match {
-              case Some(ces) =>
-                reporter.error("I ended up finding this counter example:\n"+ces.mkString("  |  "))
-
-              case None =>
-                solutions ::= sol
-                reporter.info("Solution was not trusted but verification passed!")
-            }
-          } else {
-            reporter.info("Found trusted solution!")
-            solutions ::= sol
-          }
-        }
-
-        if (synth.options.generateDerivationTrees) {
-          val dot = new DotGenerator(search.g)
-          dot.writeFile("derivation"+DotGenerator.nextId()+".dot")
-        }
-
-        if (solutions.isEmpty) {
-          reporter.error(ASCIIHelpers.title("Failed to repair!"))
-        } else {
-          reporter.info(ASCIIHelpers.title("Repair successful:"))
-          for ((sol, i) <- solutions.zipWithIndex) {
-            reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":"))
-            val expr = sol.toSimplifiedExpr(ctx, program)
-            reporter.info(ScalaPrinter(expr));
-          }
-        }
-      }
-  }
 
   // ununsed for now, but implementation could be useful later
   private def disambiguate(p: Problem, sol1: Solution, sol2: Solution): Option[(InOutExample, InOutExample)] = {
diff --git a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala
index 99057db0c..3a0847389 100644
--- a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala
+++ b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala
@@ -8,6 +8,7 @@ import leon.utils._
 import purescala.Trees._
 import purescala.TreeOps._
 import purescala.Extractors._
+import purescala.Constructors._
 
 case object EquivalentInputs extends NormalizingRule("EquivalentInputs") {
   def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
@@ -51,11 +52,18 @@ case object EquivalentInputs extends NormalizingRule("EquivalentInputs") {
 
     val substs = discoverEquivalences(clauses)
 
+    val postsToInject = substs.collect {
+      case (FunctionInvocation(tfd, args), e) if tfd.hasPostcondition =>
+        val Some((id, post)) = tfd.postcondition
+
+        replaceFromIDs((tfd.params.map(_.id) zip args).toMap + (id -> e), post)
+    }
+
     if (substs.nonEmpty) {
       val simplifier = Simplifiers.bestEffort(sctx.context, sctx.program) _
 
       val sub = p.copy(ws = replaceSeq(substs, p.ws), 
-                       pc = simplifier(replaceSeq(substs, p.pc)),
+                       pc = simplifier(andJoin(replaceSeq(substs, p.pc) +: postsToInject)),
                        phi = simplifier(replaceSeq(substs, p.phi)))
 
       List(RuleInstantiation.immediateDecomp(p, this, List(sub), forward, this.name))
diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
index 565a6c2bc..8ea105ce8 100644
--- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
+++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
@@ -186,8 +186,9 @@ object ExpressionGrammars {
     def computeSimilar(e : Expr) : Seq[(L, Gen)] = {
 
       def getLabelPair(t: TypeTree) = {
+        val tpe = bestRealType(t)
         val c = getNext
-        (Label(t, "E"+c), Label(t, "G"+c))
+        (Label(tpe, "E"+c), Label(tpe, "G"+c))
       }
 
       def isCommutative(e: Expr) = e match {
@@ -282,12 +283,12 @@ object ExpressionGrammars {
 
       val res = rec(e, el, gl)
 
-      //for ((t, g) <- res) {
-      //  val subs = g.subTrees.map { t => FreshIdentifier(t.toString).setType(t.getType).toVariable}
-      //  val gen = g.builder(subs)
+      for ((t, g) <- res) {
+        val subs = g.subTrees.map { t => FreshIdentifier(t.toString).setType(t.getType).toVariable}
+        val gen = g.builder(subs)
 
-      //  println(f"$t%30s ::= "+gen)
-      //}
+        println(f"$t%30s ::= "+gen)
+      }
       res
     }
   }
-- 
GitLab