From cf6bb9a221befd5f69122529721b621c71fbc021 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <a-mikmay@microsoft.com>
Date: Tue, 1 Dec 2015 18:00:12 +0100
Subject: [PATCH] Added tracking Ids in templates to fill missing ones with
 "_edit_me_" Corrected pretty printing of "\n"

---
 .../scala/leon/purescala/PrettyPrinter.scala  |   2 +-
 .../scala/leon/purescala/ScalaPrinter.scala   |   2 +-
 .../leon/synthesis/rules/StringRender.scala   | 100 ++++++++++--------
 3 files changed, 56 insertions(+), 48 deletions(-)

diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index 9acbe6c5a..33c200a5a 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -186,7 +186,7 @@ class PrettyPrinter(opts: PrinterOptions,
         if(v.count(c => c == '\n') >= 1 && v.length >= 80 && v.indexOf("\"\"\"") == -1) {
           p"$dbquote$dbquote$dbquote$v$dbquote$dbquote$dbquote"
         } else {
-          val escaped = v.replaceAll(dbquote, "\\\\\"").replaceAll("\n","\\n").replaceAll("\r","\\r")
+          val escaped = v.replaceAll(dbquote, "\\\\\"").replaceAll("\n","\\\\n").replaceAll("\r","\\\\r")
           p"$dbquote$escaped$dbquote"
         }
       case GenericValue(tp, id) => p"$tp#$id"
diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala
index 20b8942e9..a72c1a37f 100644
--- a/src/main/scala/leon/purescala/ScalaPrinter.scala
+++ b/src/main/scala/leon/purescala/ScalaPrinter.scala
@@ -44,7 +44,7 @@ class ScalaPrinter(opts: PrinterOptions,
         if(v.count(c => c == '\n') >= 1 && v.length >= 80 && v.indexOf("\"\"\"") == -1) {
           p"$dbquote$dbquote$dbquote$v$dbquote$dbquote$dbquote"
         } else {
-          val escaped = v.replaceAll(dbquote, "\\\\\"").replaceAll("\n","\\n").replaceAll("\r","\\r")
+          val escaped = v.replaceAll(dbquote, "\\\\\"").replaceAll("\n","\\\\n").replaceAll("\r","\\\\r")
           p"$dbquote$escaped$dbquote"
         }
 
diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala
index 1aebc53d7..133b00d4e 100644
--- a/src/main/scala/leon/synthesis/rules/StringRender.scala
+++ b/src/main/scala/leon/synthesis/rules/StringRender.scala
@@ -36,24 +36,31 @@ import leon.programsets.{UnionProgramSet, DirectProgramSet, JoinProgramSet}
   * Each call to the ``.instantiate` method of the subsequent Template will provide different instances at each position of the hole.
   */
 abstract class TypedTemplateGenerator(t: TypeTree) {
+  import StringRender.WithIds
   /** Provides a hole which can be */
   def apply(f: Expr => Expr): TemplateGenerator = {
     val id = FreshIdentifier("ConstToInstantiate", t, true)
     new TemplateGenerator(f(Variable(id)), id, t)
   }
-  class TemplateGenerator(template: Expr, varId: Identifier, t: TypeTree) {
-    private val optimizationVars = ListBuffer[Identifier]()
+  def nested(f: Expr => WithIds[Expr]): TemplateGenerator = {
+    val id = FreshIdentifier("ConstToInstantiate", t, true)
+    val res = f(Variable(id))
+    new TemplateGenerator(res._1, id, t, res._2)
+  }
+  class TemplateGenerator(template: Expr, varId: Identifier, t: TypeTree, initialHoles: List[Identifier] = Nil) {
+    private val optimizationVars = ListBuffer[Identifier]() ++= initialHoles
     private def Const: Variable = {
       val res = FreshIdentifier("const", t, true)
       optimizationVars += res
       Variable(res)
     }
-    def instantiate = {
+    private def instantiate: Expr = {
       ExprOps.postMap({
         case Variable(id) if id == varId => Some(Const)
         case _ => None
       })(template)
     }
+    def instantiateWithVars: WithIds[Expr] = (instantiate, optimizationVars.toList)
   }
 }
 
@@ -61,6 +68,7 @@ abstract class TypedTemplateGenerator(t: TypeTree) {
  * @author Mikael
  */
 case object StringRender extends Rule("StringRender") {
+  type WithIds[T] = (T, List[Identifier])
   
   var _defaultTypeToString: Option[Map[TypeTree, FunDef]] = None
   
@@ -155,16 +163,16 @@ case object StringRender extends Rule("StringRender") {
     }
   }
   
-  def findSolutions(examples: ExamplesBank, template: Stream[Expr], funDefs: Seq[(FunDef, Stream[Expr])])(implicit hctx: SearchContext, p: Problem): RuleApplication = {
+  def findSolutions(examples: ExamplesBank, template: Stream[WithIds[Expr]], funDefs: Seq[(FunDef, Stream[WithIds[Expr]])])(implicit hctx: SearchContext, p: Problem): RuleApplication = {
     // Fun is a stream of many function applications.
     val funs= JoinProgramSet.direct(funDefs.map(fbody => fbody._2.map((fbody._1, _))).map(d => DirectProgramSet(d)))
     
     val wholeTemplates = JoinProgramSet.direct(funs, DirectProgramSet(template))
     
-    def computeSolutions(funDefsBodies: Seq[(FunDef, Expr)], template: Expr): Stream[Assignment] = {
-      val funDefs = for((funDef, body) <- funDefsBodies) yield  { funDef.body = Some(body); funDef }
+    def computeSolutions(funDefsBodies: Seq[(FunDef, WithIds[Expr])], template: WithIds[Expr]): Stream[Assignment] = {
+      val funDefs = for((funDef, body) <- funDefsBodies) yield  { funDef.body = Some(body._1); funDef }
       val newProgram = DefOps.addFunDefs(hctx.program, funDefs, hctx.sctx.functionContext)
-      findAssignments(newProgram, p.as, examples, template)
+      findAssignments(newProgram, p.as, examples, template._1)
     }
     
     val tagged_solutions =
@@ -173,17 +181,18 @@ case object StringRender extends Rule("StringRender") {
     solutionStreamToRuleApplication(p, leon.utils.StreamUtils.interleave(tagged_solutions))
   }
   
-  def solutionStreamToRuleApplication(p: Problem, solutions: Stream[(Seq[(FunDef, Expr)], Expr, Assignment)]): RuleApplication = {
+  def solutionStreamToRuleApplication(p: Problem, solutions: Stream[(Seq[(FunDef, WithIds[Expr])], WithIds[Expr], Assignment)]): RuleApplication = {
     if(solutions.isEmpty) RuleFailed() else {
       RuleClosed(
-          for((funDefsBodies, singleTemplate, assignment) <- solutions) yield {
-            val template = (singleTemplate /: funDefsBodies) {
-              case (e, (fd, body)) =>
-                fd.body = Some(body)
-                LetDef(fd, e)
+          for((funDefsBodies, (singleTemplate, ids), assignment) <- solutions) yield {
+            val fds = for((fd, (body, ids)) <- funDefsBodies) yield {
+              val initMap = ids.map(_ -> StringLiteral("_edit_me_")).toMap
+              fd.body = Some(ExprOps.simplifyString(ExprOps.replaceFromIDs(initMap ++ assignment.mapValues(StringLiteral), body)))
+              fd
             }
-            val term = ExprOps.simplifyString(ExprOps.replaceFromIDs(assignment.mapValues(StringLiteral), template))
-            Solution(pre=p.pc, defs=Set(), term=term)
+            val initMap = ids.map(_ -> StringLiteral("_edit_me_")).toMap
+            val term = ExprOps.simplifyString(ExprOps.replaceFromIDs(initMap ++ assignment.mapValues(StringLiteral), singleTemplate))
+            Solution(pre=p.pc, defs=fds.toSet, term=term)
           })
     }
   }
@@ -214,7 +223,7 @@ case object StringRender extends Rule("StringRender") {
   }
   
   /** Pre-updates of the function definition */
-  def preUpdateFunDefBody(tpe: DependentType, fd: FunDef, assignments: Map[DependentType, (FunDef, Stream[Expr])]): Map[DependentType, (FunDef, Stream[Expr])] = {
+  def preUpdateFunDefBody(tpe: DependentType, fd: FunDef, assignments: Map[DependentType, (FunDef, Stream[WithIds[Expr]])]): Map[DependentType, (FunDef, Stream[WithIds[Expr]])] = {
     assignments.get(tpe) match {
       case None => assignments + (tpe -> ((fd, Stream.Empty)))
       case Some(_) => assignments
@@ -222,7 +231,7 @@ case object StringRender extends Rule("StringRender") {
   }
 
   /** Assembles multiple MatchCase to a singleMatchExpr using the function definition fd */
-  private val mergeMatchCases = (fd: FunDef) => (cases: Seq[MatchCase]) => MatchExpr(Variable(fd.params(0).id), cases)
+  private val mergeMatchCases = (fd: FunDef) => (cases: Seq[WithIds[MatchCase]]) => (MatchExpr(Variable(fd.params(0).id), cases.map(_._1)), cases.map(_._2).flatten.toList)
   
   /** Returns a (possibly recursive) template which can render the inputs in their order.
     * Returns an expression and path-dependent pretty printers which can be used.
@@ -231,16 +240,17 @@ case object StringRender extends Rule("StringRender") {
     **/
   def createFunDefsTemplates(
       currentCaseClassParent: Option[TypeTree],
-      adtToString: Map[DependentType, (FunDef, Stream[Expr])],
-      inputs: Seq[Identifier])(implicit hctx: SearchContext): (Stream[Expr], Map[DependentType, (FunDef, Stream[Expr])]) = {
+      adtToString: Map[DependentType, (FunDef, Stream[WithIds[Expr]])],
+      inputs: Seq[Identifier])(implicit hctx: SearchContext): (Stream[WithIds[Expr]], Map[DependentType, (FunDef, Stream[WithIds[Expr]])]) = {
     
-    def extractCaseVariants(cct: CaseClassType, assignments2: Map[DependentType, (FunDef, Stream[Expr])]) = cct match {
+    def extractCaseVariants(cct: CaseClassType, assignments2: Map[DependentType, (FunDef, Stream[WithIds[Expr]])])
+      : (Map[DependentType, (FunDef, Stream[WithIds[Expr]])], Stream[WithIds[MatchCase]]) = cct match {
       case CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) =>
         val typeMap = tparams.zip(tparams2).toMap
         val fields = ccd.fields.map(vd => TypeOps.instantiateType(vd, typeMap).id )
         val pattern = CaseClassPattern(None, ccd.typed(tparams2), fields.map(k => WildcardPattern(Some(k))))
         val (rhs, assignments2tmp2) = createFunDefsTemplates(Some(cct), assignments2, fields) // Invoke functions for each of the fields.
-        val newCases = rhs.map(MatchCase(pattern, None, _))
+        val newCases = rhs.map(e => (MatchCase(pattern, None, e._1), e._2))
         (assignments2tmp2, newCases)
     }
     
@@ -259,46 +269,43 @@ case object StringRender extends Rule("StringRender") {
       * }}}
       * 
       * */
-    def constantPatternMatching(fd: FunDef, act: AbstractClassType): MatchExpr = {
-      val cases = (ListBuffer[MatchCase]() /: act.knownCCDescendants) {
+    def constantPatternMatching(fd: FunDef, act: AbstractClassType): WithIds[MatchExpr] = {
+      val cases = (ListBuffer[WithIds[MatchCase]]() /: act.knownCCDescendants) {
         case (acc, cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2)) =>
           val typeMap = tparams.zip(tparams2).toMap
           val fields = ccd.fields.map(vd => TypeOps.instantiateType(vd, typeMap).id )
           val pattern = CaseClassPattern(None, ccd.typed(tparams2), fields.map(k => WildcardPattern(Some(k))))
           val rhs = StringLiteral(id.asString)
           MatchCase(pattern, None, rhs)
-          acc += MatchCase(pattern, None, rhs)
+          acc += ((MatchCase(pattern, None, rhs), Nil))
         case (acc, e) => hctx.reporter.fatalError("Could not handle this class definition for string rendering " + e)
       }
       mergeMatchCases(fd)(cases)
     }
     
     /* Returns a list of expressions converting the list of inputs to string.
-     * Holes should be inserted before, after and in-between for solving concatenation.
+     * Each expression is tagged with a list of identifiers, which is the list of variables which need to be found.
      * @return Along with the list, an updated function definitions to transform (parent-dependent) types to strings */
     @tailrec def gatherInputs(
         currentCaseClassParent: Option[TypeTree],
-        assignments1: Map[DependentType, (FunDef, Stream[Expr])],
+        assignments1: Map[DependentType, (FunDef, Stream[WithIds[Expr]])],
         inputs: List[Identifier],
-        result: ListBuffer[Stream[Expr]] = ListBuffer()): (List[Stream[Expr]], Map[DependentType, (FunDef, Stream[Expr])]) = inputs match {
+        result: ListBuffer[Stream[WithIds[Expr]]] = ListBuffer()): (List[Stream[WithIds[Expr]]], Map[DependentType, (FunDef, Stream[WithIds[Expr]])]) = inputs match {
       case Nil => (result.toList, assignments1)
       case input::q => 
         val dependentType = DependentType(currentCaseClassParent, input.asString(hctx.program)(hctx.context), input.getType)
         assignments1.get(dependentType) match {
         case Some(fd) =>
-          gatherInputs(currentCaseClassParent, assignments1, q, result += Stream(functionInvocation(fd._1, Seq(Variable(input)))))
+          gatherInputs(currentCaseClassParent, assignments1, q, result += Stream((functionInvocation(fd._1, Seq(Variable(input))), Nil)))
         case None => // No function can render the current type.
           input.getType match {
             case StringType =>
-              gatherInputs(currentCaseClassParent, assignments1, q, result += Stream(Variable(input)))
+              gatherInputs(currentCaseClassParent, assignments1, q, result += Stream((Variable(input), Nil)))
             case BooleanType =>
-              // Special case. But might be overkill.
-              // It should be possible to have generic conversion instead, else it needs two examples, which might be cu
-              // gatherInputs(currentCaseClassParent, assignments1, q, result += booleanTemplate(input).instantiate)
-              // OR
-              gatherInputs(currentCaseClassParent, assignments1, q, result += Stream(BooleanToString(Variable(input)), booleanTemplate(input).instantiate))
+              val (bTemplate, vs) = booleanTemplate(input).instantiateWithVars
+              gatherInputs(currentCaseClassParent, assignments1, q, result += Stream((BooleanToString(Variable(input)), Nil)) #::: Stream((bTemplate, vs)))
             case WithStringconverter(converter) => // Base case
-              gatherInputs(currentCaseClassParent, assignments1, q, result += Stream(converter(Variable(input))))
+              gatherInputs(currentCaseClassParent, assignments1, q, result += Stream((converter(Variable(input)), Nil)))
             case t: ClassType =>
               // Create the empty function body and updates the assignments parts.
               val fd = createEmptyFunDef(dependentType)
@@ -313,7 +320,7 @@ case object StringRender extends Rule("StringRender") {
                   }}
                   
                   //TODO: Test other templates not only with Wilcard patterns, but more cases options for non-recursive classes (e.g. Option, Boolean, Finite parameterless case classes.)
-                  val (assignments3, cases) = ((assignments2, ListBuffer[Stream[MatchCase]]()) /: act.knownCCDescendants) {
+                  val (assignments3, cases) = ((assignments2, ListBuffer[Stream[WithIds[MatchCase]]]()) /: act.knownCCDescendants) {
                     case ((assignments2, acc), cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2)) =>
                       val (assignments2tmp2, newCases) = extractCaseVariants(cct, assignments2)
                       (assignments2tmp2, acc += newCases)
@@ -325,12 +332,12 @@ case object StringRender extends Rule("StringRender") {
                     Stream(constantPatternMatching(fd, act)) ++ allMatchExprsEnd
                   } else allMatchExprsEnd
                   val assignments4 = assignments3 + (dependentType -> (fd, allMatchExprs))
-                  gatherInputs(currentCaseClassParent, assignments4, q, result += Stream(functionInvocation(fd, Seq(Variable(input)))))
+                  gatherInputs(currentCaseClassParent, assignments4, q, result += Stream((functionInvocation(fd, Seq(Variable(input))), Nil)))
                 case cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) =>
                   val (assignments3, newCases) = extractCaseVariants(cct, assignments2)
                   val allMatchExprs = newCases.map(acase => mergeMatchCases(fd)(Seq(acase)))
                   val assignments4 = assignments3 + (dependentType -> (fd, allMatchExprs))
-                  gatherInputs(currentCaseClassParent, assignments4, q, result += Stream(functionInvocation(fd, Seq(Variable(input)))))
+                  gatherInputs(currentCaseClassParent, assignments4, q, result += Stream((functionInvocation(fd, Seq(Variable(input))), Nil)))
               }
             case TypeParameter(t) =>
               hctx.reporter.fatalError("Could not handle type parameter for string rendering " + t)
@@ -342,16 +349,17 @@ case object StringRender extends Rule("StringRender") {
     val (exprs, assignments) = gatherInputs(currentCaseClassParent, adtToString, inputs.toList)
     /** Add post, pre and in-between holes, and returns a single expr along with the new assignments. */
     
-    val template: Stream[Expr] = exprs match {
+    val template: Stream[WithIds[Expr]] = exprs match {
       case Nil =>
-        Stream(StringTemplateGenerator(Hole => Hole).instantiate)
+        Stream(StringTemplateGenerator(Hole => Hole).instantiateWithVars)
       case exprList =>
-        JoinProgramSet(exprList.map(DirectProgramSet(_)), (exprs: Seq[Expr]) =>
-            StringTemplateGenerator(Hole => {
-              StringConcat((StringConcat(Hole, exprs.head) /: exprs.tail) {
-                case (finalExpr, expr) => StringConcat(StringConcat(finalExpr, Hole), expr)
-              }, Hole)
-            }).instantiate
+        JoinProgramSet(exprList.map(DirectProgramSet(_)), (exprs: Seq[WithIds[Expr]]) =>
+            StringTemplateGenerator.nested(Hole => {
+              val res = ((StringConcat(Hole, exprs.head._1), exprs.head._2) /: exprs.tail) {
+                case ((finalExpr, finalIds), (expr, ids)) => (StringConcat(StringConcat(finalExpr, Hole), expr), finalIds ++ ids)
+              }
+              (StringConcat(res._1, Hole), res._2)
+            }).instantiateWithVars
         ).programs
     }
     (template, assignments)
-- 
GitLab