From 895a196227835a46702170124a01a6b9771b9aae Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Mon, 16 Jun 2014 12:08:44 +0200
Subject: [PATCH] Extract tests from specification: passes(in,out)(Map(1 -> 2))

---
 library/lang/package.scala                    |   8 ++
 .../frontends/scalac/CodeExtraction.scala     |  22 +++-
 .../scala/leon/purescala/Constructors.scala   |  27 +++++
 .../scala/leon/synthesis/InOutExample.scala   |  13 ++
 src/main/scala/leon/synthesis/Problem.scala   | 113 ++++++++++++++++++
 src/main/scala/leon/synthesis/Rules.scala     |  24 ++--
 .../scala/leon/synthesis/rules/Cegis.scala    |  29 ++++-
 7 files changed, 218 insertions(+), 18 deletions(-)
 create mode 100644 src/main/scala/leon/purescala/Constructors.scala
 create mode 100644 src/main/scala/leon/synthesis/InOutExample.scala

diff --git a/library/lang/package.scala b/library/lang/package.scala
index 41cd25b65..a3600b868 100644
--- a/library/lang/package.scala
+++ b/library/lang/package.scala
@@ -27,4 +27,12 @@ package object lang {
 
   @ignore
   def error[T](reason: String): T = sys.error(reason)
+
+  def passes[A, B](in: A, out: B)(tests: Map[A,B]): Boolean = {
+    if (tests contains in) {
+      tests(in) == out
+    } else {
+      true
+    }
+  }
 }
diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index 3a98f859a..9826e01a0 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -394,8 +394,11 @@ trait CodeExtraction extends ASTExtractors {
 
       // We collect the methods
       for (d <- tmpl.body) d match {
+        case EmptyTree =>
+          // ignore
+
         case t if isIgnored(t.symbol) =>
-          //ignore
+          // ignore
 
         case t @ ExFunctionDef(fsym, _, _, _, _) if !fsym.isSynthetic && !fsym.isAccessor =>
           if (parent.isDefined) {
@@ -1045,10 +1048,6 @@ trait CodeExtraction extends ASTExtractors {
         case ExOr(l, r)            => Or(extractTree(l), extractTree(r))
         case ExNot(e)              => Not(extractTree(e))
         case ExUMinus(e)           => UMinus(extractTree(e))
-        case ExPlus(l, r)          => Plus(extractTree(l), extractTree(r))
-        case ExMinus(l, r)         => Minus(extractTree(l), extractTree(r))
-        case ExTimes(l, r)         => Times(extractTree(l), extractTree(r))
-        case ExDiv(l, r)           => Division(extractTree(l), extractTree(r))
         case ExMod(l, r)           => Modulo(extractTree(l), extractTree(r))
         case ExNotEquals(l, r)     => Not(Equals(extractTree(l), extractTree(r)))
         case ExGreaterThan(l, r)   => GreaterThan(extractTree(l), extractTree(r))
@@ -1214,6 +1213,19 @@ trait CodeExtraction extends ASTExtractors {
 
               CaseClassSelector(cct, rec, fieldID)
 
+            // Int methods
+            case (IsTyped(a1, Int32Type), "+", List(IsTyped(a2, Int32Type))) =>
+              Plus(a1, a2)
+
+            case (IsTyped(a1, Int32Type), "-", List(IsTyped(a2, Int32Type))) =>
+              Minus(a1, a2)
+
+            case (IsTyped(a1, Int32Type), "*", List(IsTyped(a2, Int32Type))) =>
+              Times(a1, a2)
+
+            case (IsTyped(a1, Int32Type), "/", List(IsTyped(a2, Int32Type))) =>
+              Division(a1, a2)
+
             // Set methods
             case (IsTyped(a1, SetType(b1)), "min", Nil) =>
               SetMin(a1).setType(b1)
diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala
new file mode 100644
index 000000000..d02804e45
--- /dev/null
+++ b/src/main/scala/leon/purescala/Constructors.scala
@@ -0,0 +1,27 @@
+/* Copyright 2009-2014 EPFL, Lausanne */
+
+package leon
+package purescala
+
+import utils._
+
+object Constructors {
+  import Trees._
+  import Common._
+
+  def tupleSelect(t: Expr, index: Int) = t match {
+    case Tuple(es) =>
+      es(index-1)
+    case _ =>
+      TupleSelect(t, index)
+  }
+
+  def letTuple(binders: Seq[Identifier], value: Expr, body: Expr) = binders match {
+    case Nil =>
+      body
+    case x :: Nil =>
+      Let(x, tupleSelect(value, 1), body)
+    case xs =>
+      LetTuple(xs, value, body)
+  }
+}
diff --git a/src/main/scala/leon/synthesis/InOutExample.scala b/src/main/scala/leon/synthesis/InOutExample.scala
new file mode 100644
index 000000000..6c64f3977
--- /dev/null
+++ b/src/main/scala/leon/synthesis/InOutExample.scala
@@ -0,0 +1,13 @@
+/* Copyright 2009-2014 EPFL, Lausanne */
+
+package leon
+package synthesis
+
+import purescala.Trees.Expr
+
+case class InOutExample(ins: Seq[Expr], outs: Seq[Expr]) {
+  def inExample = InExample(ins)
+}
+
+
+case class InExample(ins: Seq[Expr])
diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala
index a227eb7c1..2ed3e5d17 100644
--- a/src/main/scala/leon/synthesis/Problem.scala
+++ b/src/main/scala/leon/synthesis/Problem.scala
@@ -11,6 +11,119 @@ import leon.purescala.Common._
 // ⟦ as ⟨ C | phi ⟩ xs ⟧
 case class Problem(as: List[Identifier], pc: Expr, phi: Expr, xs: List[Identifier]) {
   override def toString = "⟦ "+as.mkString(";")+", "+(if (pc != BooleanLiteral(true)) pc+" ≺ " else "")+" ⟨ "+phi+" ⟩ "+xs.mkString(";")+" ⟧ "
+
+  def getTests(sctx: SynthesisContext): Seq[InOutExample] = {
+    import purescala.Extractors._
+    import evaluators._
+
+    val TopLevelAnds(predicates) = And(pc, phi)
+
+    val ev = new DefaultEvaluator(sctx.context, sctx.program)
+
+    def isValidExample(ex: InOutExample): Boolean = {
+      val mapping = Map((as zip ex.ins) ++ (xs zip ex.outs): _*)
+
+      ev.eval(And(pc, phi), mapping) match {
+        case EvaluationResults.Successful(BooleanLiteral(true)) => true
+        case _ => false
+      }
+    }
+
+    // Returns a list of identifiers, and extractors
+    def andThen(pf1: PartialFunction[Expr, Expr], pf2: PartialFunction[Expr, Expr]): PartialFunction[Expr, Expr] = {
+      Function.unlift(pf1.lift(_) flatMap pf2.lift)
+    }
+
+    /**
+     * Extract ids in ins/outs args, and compute corresponding extractors for values map
+     *
+     * Examples:
+     * (a,b) =>
+     *     a -> _.1
+     *     b -> _.2
+     *
+     * Cons(a, Cons(b, c)) =>
+     *     a -> _.head
+     *     b -> _.tail.head
+     *     c -> _.tail.tail
+     */
+    def extractIds(e: Expr): Seq[(Identifier, PartialFunction[Expr, Expr])] = e match {
+      case Variable(id) =>
+        List((id, { case e => e }))
+      case Tuple(vs) =>
+        vs.map(extractIds).zipWithIndex.flatMap{ case (ids, i) =>
+          ids.map{ case (id, e) =>
+            (id, andThen({ case Tuple(vs) => vs(i) }, e))
+          }
+        }
+      case CaseClass(cct, args) =>
+        args.map(extractIds).zipWithIndex.flatMap { case (ids, i) =>
+          ids.map{ case (id, e) =>
+            (id, andThen({ case CaseClass(cct2, vs) if cct2 == cct => vs(i) } ,e))
+          }
+        }
+
+      case _ =>
+        sctx.reporter.warning("Unnexpected pattern in test-ids extraction: "+e)
+        Nil
+    }
+
+    def exprToIds(e: Expr): List[Identifier] = e match {
+      case Variable(i) => List(i)
+      case Tuple(is) => is.collect { case Variable(i) => i }.toList
+      case _ => Nil
+    }
+
+    val testClusters = predicates.collect {
+      case FunctionInvocation(tfd, List(in, out, FiniteMap(inouts))) if tfd.id.name == "passes" =>
+        val infos = extractIds(Tuple(Seq(in, out)))
+        val exs   = inouts.map{ case (i, o) => Tuple(Seq(i, o)) }
+
+        // Check whether we can extract all ids from example
+        val results = exs.collect { case e if infos.forall(_._2.isDefinedAt(e)) =>
+          infos.map{ case (id, f) => id -> f(e) }.toMap
+        }
+
+        results
+    }
+
+    /**
+     * we now need to consolidate different clusters of compatible tests together
+     * t1: a->1, c->3
+     * t2: a->1, b->4
+     *   => a->1, b->4, c->3
+     */
+
+    def isCompatible(m1: Map[Identifier, Expr], m2: Map[Identifier, Expr]) = {
+      val ks = m1.keySet & m2.keySet
+      ks.nonEmpty && ks.map(m1) == ks.map(m2)
+    }
+
+    def mergeTest(m1: Map[Identifier, Expr], m2: Map[Identifier, Expr]) = {
+      if (!isCompatible(m1, m2)) {
+        m1
+      } else {
+        m1 ++ m2
+      }
+    }
+
+    var consolidated = Set[Map[Identifier, Expr]]()
+    for (ts <- testClusters; t <- ts) {
+      consolidated += t
+
+      consolidated = consolidated.map { c =>
+        mergeTest(c, t)
+      }
+    }
+
+    // Finally, we keep complete tests covering all as++xs
+    val requiredIds = (as ++ xs).toSet
+    val complete = consolidated.filter{ t => (t.keySet & requiredIds) == requiredIds }
+
+    complete.toSeq.map { m =>
+      InOutExample(as.map(m), xs.map(m))
+    }.filter(isValidExample)
+  }
 }
 
 object Problem {
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index 5f75db650..2a1b90973 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -141,16 +141,24 @@ object RuleInstantiation {
   }
 }
 
-abstract class Rule(val name: String) {
+abstract class Rule(val name: String) extends RuleHelpers {
   def instantiateOn(sctx: SynthesisContext, problem: Problem): Traversable[RuleInstantiation]
 
-  val priority: RulePriority = RulePriorityDefault 
-
-  def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replace(Map(Variable(what._1) -> what._2), in)
-  def substAll(what: Map[Identifier, Expr], in: Expr): Expr = replace(what.map(w => Variable(w._1) -> w._2), in)
+  val priority: RulePriority = RulePriorityDefault
 
   implicit val debugSection = leon.utils.DebugSectionSynthesis
 
+  override def toString = "R: "+name
+}
+
+abstract class NormalizingRule(name: String) extends Rule(name) {
+  override val priority = RulePriorityNormalizing
+}
+
+trait RuleHelpers {
+  def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replaceFromIDs(Map(what), in)
+  def substAll(what: Map[Identifier, Expr], in: Expr): Expr = replaceFromIDs(what, in)
+
   val forward: List[Solution] => Option[Solution] = {
     case List(s) =>
       Some(Solution(s.pre, s.defs, s.term))
@@ -169,10 +177,4 @@ abstract class Rule(val name: String) {
     case _ =>
       None
   }
-
-  override def toString = "R: "+name
-}
-
-abstract class NormalizingRule(name: String) extends Rule(name) {
-  override val priority = RulePriorityNormalizing
 }
diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala
index 99cdcda04..37d19bb36 100644
--- a/src/main/scala/leon/synthesis/rules/Cegis.scala
+++ b/src/main/scala/leon/synthesis/rules/Cegis.scala
@@ -55,7 +55,19 @@ case object CEGIS extends Rule("CEGIS") {
             { () => List((BooleanLiteral(true), Set()), (BooleanLiteral(false), Set())) }
 
           case Int32Type =>
-            { () => List((IntLiteral(0), Set()), (IntLiteral(1), Set())) }
+            { () =>
+              val ground = List((IntLiteral(0), Set[Identifier]()), (IntLiteral(1), Set[Identifier]()))
+              val ops    = List[Function2[Expr, Expr, Expr]](
+                (a,b) => Plus(a,b),
+                (a,b) => Minus(a,b),
+                (a,b) => Times(a,b)
+              )
+
+              ops.map{f =>
+                val ids = List(FreshIdentifier("a", true).setType(Int32Type), FreshIdentifier("b", true).setType(Int32Type))
+                (f(ids(0).toVariable, ids(1).toVariable), ids.toSet)
+              } ++ ground
+            }
 
           case TupleType(tps) =>
             { () =>
@@ -460,6 +472,8 @@ case object CEGIS extends Rule("CEGIS") {
         // We populate the list of examples with a predefined one
         sctx.reporter.debug("Acquiring list of examples")
 
+        baseExampleInputs ++= p.getTests(sctx).map(_.ins).toSet
+
         if (p.pc == BooleanLiteral(true)) {
           baseExampleInputs = p.as.map(a => simplestValue(a.getType)) +: baseExampleInputs
         } else {
@@ -485,6 +499,12 @@ case object CEGIS extends Rule("CEGIS") {
           }
         }
 
+        sctx.reporter.ifDebug { debug =>
+          baseExampleInputs.foreach { in =>
+            debug("  - "+in.mkString(", "))
+          }
+        }
+
         val inputIterator: Iterator[Seq[Expr]] = if (useVanuatoo) {
           new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, 20, 3000)
         } else {
@@ -509,6 +529,7 @@ case object CEGIS extends Rule("CEGIS") {
           for (prog <- programs) {
             val expr = ndProgram.determinize(prog)
             val res = Equals(Tuple(p.xs.map(Variable(_))), expr)
+
             val solver3 = sctx.newSolver.setTimeout(cexSolverTo)
             solver3.assertCnstr(And(p.pc :: res :: Not(p.phi) :: Nil))
 
@@ -609,7 +630,11 @@ case object CEGIS extends Rule("CEGIS") {
               needMoreUnrolling = true;
             } else if (nPassing <= testUpTo) {
               // Immediate Test
-              result = Some(checkForPrograms(prunedPrograms))
+              checkForPrograms(prunedPrograms) match {
+                case rs: RuleSuccess =>
+                  result = Some(rs)
+                case _ =>
+              }
             } else if (((nPassing < allPrograms*filterThreshold) || didFilterAlready) && useBssFiltering) {
               // We filter the Bss so that the formula we give to z3 is much smalled
               val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _)
-- 
GitLab