From 5f5986a77478c85e3c087c3fbb8c8eafcece2112 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Fri, 4 Mar 2016 17:01:27 +0100
Subject: [PATCH] Introduce Hints in synthesis

Rules that decompose variables (e.g. ADTSplit) introduce the previous
recomposed expression as a Hint. These Hints are used by the default grammar
when they are of a specific size upwards (currently 4).
---
 src/main/scala/leon/grammars/Grammars.scala   |  9 +++--
 src/main/scala/leon/synthesis/Problem.scala   |  5 +++
 src/main/scala/leon/synthesis/Witnesses.scala |  8 +++++
 .../scala/leon/synthesis/rules/ADTSplit.scala | 35 ++++++++++---------
 .../leon/synthesis/rules/DetupleInput.scala   |  5 ++-
 5 files changed, 43 insertions(+), 19 deletions(-)

diff --git a/src/main/scala/leon/grammars/Grammars.scala b/src/main/scala/leon/grammars/Grammars.scala
index 867c5dcf2..1a9fb92dd 100644
--- a/src/main/scala/leon/grammars/Grammars.scala
+++ b/src/main/scala/leon/grammars/Grammars.scala
@@ -3,16 +3,19 @@
 package leon
 package grammars
 
+import synthesis.Witnesses.Hint
 import purescala.Expressions._
 import purescala.Definitions._
 import purescala.Types._
 import purescala.TypeOps._
+import purescala.Extractors.TopLevelAnds
+import purescala.ExprOps.formulaSize
 
 import synthesis.{SynthesisContext, Problem}
 
 object Grammars {
 
-  def default(prog: Program, inputs: Seq[Expr], currentFunction: FunDef, exclude: Set[FunDef], ws: Expr, pc: Expr): ExpressionGrammar[TypeTree] = {
+  def default(prog: Program, inputs: Seq[Expr], currentFunction: FunDef, exclude: Set[FunDef]): ExpressionGrammar[TypeTree] = {
     BaseGrammar ||
     EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ inputs.map { _.getType }) ||
     OneOf(inputs) ||
@@ -22,7 +25,9 @@ object Grammars {
   }
 
   def default(sctx: SynthesisContext, p: Problem): ExpressionGrammar[TypeTree] = {
-    default(sctx.program, p.as.map(_.toVariable), sctx.functionContext, sctx.settings.functionsToIgnore,  p.ws, p.pc)
+    val TopLevelAnds(ws) = p.ws
+    val hints = ws.collect{ case Hint(e) if formulaSize(e) >= 4 => e }
+    default(sctx.program, p.as.map(_.toVariable) ++ hints, sctx.functionContext, sctx.settings.functionsToIgnore)
   }
 
   def typeDepthBound[T <: Typed](g: ExpressionGrammar[T], b: Int) = {
diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala
index a842e0c0e..f51fe3dd5 100644
--- a/src/main/scala/leon/synthesis/Problem.scala
+++ b/src/main/scala/leon/synthesis/Problem.scala
@@ -38,6 +38,11 @@ case class Problem(as: List[Identifier], ws: Expr, pc: Expr, phi: Expr, xs: List
         |⟧  """.stripMargin + ebInfo
   }
 
+  def withWs(es: Seq[Expr]) = {
+    val TopLevelAnds(prev) = ws
+    copy(ws = andJoin(prev ++ es))
+  }
+
   // Qualified example bank, allows us to perform operations (e.g. filter) with expressions
   def qeb(implicit sctx: SearchContext) = QualifiedExamplesBank(this.as, this.xs, eb)
 
diff --git a/src/main/scala/leon/synthesis/Witnesses.scala b/src/main/scala/leon/synthesis/Witnesses.scala
index ed33df173..60b1f5ea3 100644
--- a/src/main/scala/leon/synthesis/Witnesses.scala
+++ b/src/main/scala/leon/synthesis/Witnesses.scala
@@ -30,5 +30,13 @@ object Witnesses {
       p"↓$fi"
     }
   }
+
+  case class Hint(e: Expr) extends Witness {
+    def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] = Some(( Seq(e), { case Seq(e) => Hint(e) }))
+
+    override def printWith(implicit pctx: PrinterContext): Unit = {
+      p"谶$e"
+    }
+  }
   
 }
diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala
index d3dc63472..956340e86 100644
--- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala
+++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala
@@ -4,6 +4,7 @@ package leon
 package synthesis
 package rules
 
+import Witnesses.Hint
 import purescala.Expressions._
 import purescala.Common._
 import purescala.Types._
@@ -66,27 +67,29 @@ case object ADTSplit extends Rule("ADT Split.") {
         val oas = p.as.filter(_ != id)
 
         val subInfo0 = for(ccd <- cases) yield {
-           val cct    = CaseClassType(ccd, act.tps)
+          val cct    = CaseClassType(ccd, act.tps)
 
-           val args   = cct.fields.map { vd => FreshIdentifier(vd.id.name, vd.getType, true) }.toList
+          val args   = cct.fields.map { vd => FreshIdentifier(vd.id.name, vd.getType, true) }.toList
 
-           val subPhi = subst(id -> CaseClass(cct, args.map(Variable)), p.phi)
-           val subPC  = subst(id -> CaseClass(cct, args.map(Variable)), p.pc)
-           val subWS  = subst(id -> CaseClass(cct, args.map(Variable)), p.ws)
+          val whole =  CaseClass(cct, args.map(Variable))
 
-           val eb2 = p.qeb.mapIns { inInfo =>
-              inInfo.toMap.apply(id) match {
-                case CaseClass(`cct`, vs) =>
-                  List(vs ++ inInfo.filter(_._1 != id).map(_._2))
-                case _ =>
-                  Nil
-              }
-           }
+          val subPhi = subst(id -> whole, p.phi)
+          val subPC  = subst(id -> whole, p.pc)
+          val subWS  = subst(id -> whole, p.ws)
+
+          val eb2 = p.qeb.mapIns { inInfo =>
+             inInfo.toMap.apply(id) match {
+               case CaseClass(`cct`, vs) =>
+                 List(vs ++ inInfo.filter(_._1 != id).map(_._2))
+               case _ =>
+                 Nil
+             }
+          }
 
-           val subProblem = Problem(args ::: oas, subWS, subPC, subPhi, p.xs, eb2)
-           val subPattern = CaseClassPattern(None, cct, args.map(id => WildcardPattern(Some(id))))
+          val subProblem = Problem(args ::: oas, subWS, subPC, subPhi, p.xs, eb2).withWs(Seq(Hint(whole)))
+          val subPattern = CaseClassPattern(None, cct, args.map(id => WildcardPattern(Some(id))))
 
-           (cct, subProblem, subPattern)
+          (cct, subProblem, subPattern)
         }
 
         val subInfo = subInfo0.sortBy{ case (cct, _, _) =>
diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala
index d3b4c823d..3c38a0170 100644
--- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala
+++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala
@@ -4,6 +4,7 @@ package leon
 package synthesis
 package rules
 
+import Witnesses.Hint
 import purescala.Expressions._
 import purescala.Common._
 import purescala.Types._
@@ -60,6 +61,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") {
       var subProblem = p.phi
       var subPc      = p.pc
       var subWs      = p.ws
+      var hints: Seq[Expr] = Nil
 
       var reverseMap = Map[Identifier, Expr]()
 
@@ -72,6 +74,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") {
           subProblem = subst(a -> expr, subProblem)
           subPc      = subst(a -> expr, subPc)
           subWs      = subst(a -> expr, subWs)
+          hints      +:= Hint(expr)
 
           reverseMap ++= map
 
@@ -125,7 +128,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") {
         case other => other
       }
       
-      val sub = Problem(newAs, subWs, subPc, subProblem, p.xs, eb)
+      val sub = Problem(newAs, subWs, subPc, subProblem, p.xs, eb).withWs(hints)
 
       val s = (substAll(reverseMap, _:Expr)) andThen simplePostTransform(recompose)
      
-- 
GitLab