From dcc22b096cc2656ed49088572d3a0d23bd336564 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <a-mikmay@microsoft.com>
Date: Wed, 6 Jan 2016 17:28:55 +0100
Subject: [PATCH] Added string escape

---
 library/lang/StrOps.scala                          |  2 ++
 src/main/java/leon/codegen/runtime/StrOps.java     |  5 +++++
 src/main/scala/leon/codegen/CodeGeneration.scala   |  4 ++++
 src/main/scala/leon/datagen/VanuatooDataGen.scala  |  2 +-
 .../scala/leon/evaluators/RecursiveEvaluator.scala |  4 ++++
 .../leon/evaluators/StringTracingEvaluator.scala   | 14 +++++++++++++-
 .../leon/frontends/scalac/ASTExtractors.scala      |  8 ++++++++
 .../leon/frontends/scalac/CodeExtraction.scala     |  3 +++
 src/main/scala/leon/grammars/ValueGrammar.scala    |  7 +++++++
 src/main/scala/leon/purescala/Expressions.scala    |  7 +++++++
 src/main/scala/leon/purescala/Extractors.scala     |  2 ++
 src/main/scala/leon/purescala/PrettyPrinter.scala  |  8 +++++---
 src/main/scala/leon/purescala/ScalaPrinter.scala   |  3 ++-
 .../scala/leon/synthesis/rules/StringRender.scala  |  2 +-
 14 files changed, 64 insertions(+), 7 deletions(-)

diff --git a/library/lang/StrOps.scala b/library/lang/StrOps.scala
index 0c4d480ee..3c59590d0 100644
--- a/library/lang/StrOps.scala
+++ b/library/lang/StrOps.scala
@@ -18,4 +18,6 @@ object StrOps {
   def substring(a: String, start: BigInt, end: BigInt): String = {
     if(start > end || start >= length(a) || end <= 0) "" else a.substring(start.toInt, end.toInt)
   }
+  @ignore
+  def escape(a: String): String = a // Wrong definition, but it will eventually use StringEscapeUtils.escapeJava(s) at parsing and compile time.
 }
\ No newline at end of file
diff --git a/src/main/java/leon/codegen/runtime/StrOps.java b/src/main/java/leon/codegen/runtime/StrOps.java
index 1e5d7f60d..d8a889101 100644
--- a/src/main/java/leon/codegen/runtime/StrOps.java
+++ b/src/main/java/leon/codegen/runtime/StrOps.java
@@ -1,5 +1,7 @@
 package leon.codegen.runtime;
 
+import org.apache.commons.lang3.StringEscapeUtils;
+
 public class StrOps {
 	public static String concat(String a, String b) {
       return a + b;
@@ -34,4 +36,7 @@ public class StrOps {
 	public static String realToString (Real a) {
 	  return ""; // TODO: Not supported at this moment.
 	}
+	public static String escape(String s) {
+	  return StringEscapeUtils.escapeJava(s);
+	}
 }
diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala
index dc674dfbf..f1fced490 100644
--- a/src/main/scala/leon/codegen/CodeGeneration.scala
+++ b/src/main/scala/leon/codegen/CodeGeneration.scala
@@ -1199,6 +1199,10 @@ trait CodeGeneration {
         mkExpr(end, ch)
         ch << InvokeStatic(StrOpsClass, "substring", s"(L$JavaStringClass;L$BigIntClass;L$BigIntClass;)L$JavaStringClass;")
         
+      case StringEscape(a) =>
+        mkExpr(a, ch)
+        ch << InvokeStatic(StrOpsClass, "escape", s"(L$JavaStringClass;)L$JavaStringClass;")
+        
       // Arithmetic
       case Plus(l, r) =>
         mkExpr(l, ch)
diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala
index 0f3df8aed..7b259f2ea 100644
--- a/src/main/scala/leon/datagen/VanuatooDataGen.scala
+++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala
@@ -41,7 +41,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
     (n, d) -> Constructor[Expr, TypeTree](List(), RealType, s => FractionalLiteral(n, d), "" + n + "/" + d)
   }).toMap
 
-  val strings = (for (b <- Set("", "a", "b", "Abcd")) yield {
+  val strings = (for (b <- Set("", "a", "\"\t\n", "Abcd")) yield {
     b -> Constructor[Expr, TypeTree](List(), StringType, s => StringLiteral(b), b)
   }).toMap
 
diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
index 7ef36c23c..ceb23eb38 100644
--- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
@@ -268,6 +268,10 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
         case FractionalLiteral(n, d) => StringLiteral(n.toString + "/" + d.toString)
         case res =>  throw EvalError(typeErrorMsg(res, RealType))
       }
+    case StringEscape(a) => e(a) match {
+      case StringLiteral(i) => StringLiteral(codegen.runtime.StrOps.escape(i))
+      case res => throw EvalError(typeErrorMsg(res, StringType))
+    }
     
     case BVPlus(l,r) =>
       (e(l), e(r)) match {
diff --git a/src/main/scala/leon/evaluators/StringTracingEvaluator.scala b/src/main/scala/leon/evaluators/StringTracingEvaluator.scala
index 43ff4bc24..22d4cd8c1 100644
--- a/src/main/scala/leon/evaluators/StringTracingEvaluator.scala
+++ b/src/main/scala/leon/evaluators/StringTracingEvaluator.scala
@@ -56,7 +56,12 @@ class StringTracingEvaluator(ctx: LeonContext, prog: Program) extends Contextual
           case _ =>
             StringConcat(es1, es2)
         }
-        
+      case StringEscape(a) =>
+        val ea = e(a)
+        ea match {
+          case StringLiteral(_) => super.e(StringEscape(a))
+          case _ => StringEscape(ea)
+        }
       case expr =>
         super.e(expr)
     }
@@ -94,6 +99,13 @@ class StringTracingEvaluator(ctx: LeonContext, prog: Program) extends Contextual
         case _ =>
           (StringLength(es1), StringLength(t1))
       }
+      
+    case StringEscape(a) =>
+      val (ea, ta) = e(a)
+      ea match {
+        case StringLiteral(_) => (underlying.e(StringEscape(ea)), StringEscape(ta))
+        case _ => (StringEscape(ea), StringEscape(ta))
+      }
 
     case expr@StringLiteral(s) => 
       (expr, expr)
diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
index dc4b7481c..561d267df 100644
--- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
+++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
@@ -200,6 +200,14 @@ trait ASTExtractors {
         case _ => None
       }
     }
+    
+    /** Matches a call to StrOps.escape */
+    object ExStringEscape {
+      def unapply(tree: Apply): Option[Tree] = tree match {
+        case Apply(ExSelected("leon", "lang", "StrOps", "escape"), List(arg)) => Some(arg)
+        case _ => None
+      }
+    }
 
     /** Extracts the 'require' contract from an expression (only if it's the
      * first call in the block). */
diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index 8c01b7731..4d6a1b113 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -1509,6 +1509,9 @@ trait CodeExtraction extends ASTExtractors {
         case ExImplies(lhs, rhs) =>
           Implies(extractTree(lhs), extractTree(rhs)).setPos(current.pos)
 
+        case ExStringEscape(s) =>
+          StringEscape(extractTree(s))
+          
         case c @ ExCall(rec, sym, tps, args) =>
           // The object on which it is called is null if the symbol sym is a valid function in the scope and not a method.
           val rrec = rec match {
diff --git a/src/main/scala/leon/grammars/ValueGrammar.scala b/src/main/scala/leon/grammars/ValueGrammar.scala
index 5442b4ee7..7fa70729c 100644
--- a/src/main/scala/leon/grammars/ValueGrammar.scala
+++ b/src/main/scala/leon/grammars/ValueGrammar.scala
@@ -25,6 +25,13 @@ case object ValueGrammar extends ExpressionGrammar[TypeTree] {
         terminal(InfiniteIntegerLiteral(1)),
         terminal(InfiniteIntegerLiteral(5))
       )
+    case StringType =>
+      List(
+        terminal(StringLiteral("")),
+        terminal(StringLiteral("a")),
+        terminal(StringLiteral("\"'\n\r\t")),
+        terminal(StringLiteral("Lara 2007"))
+      )
 
     case tp: TypeParameter =>
       for (ind <- (1 to 3).toList) yield {
diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala
index ae93221bf..aa216b852 100644
--- a/src/main/scala/leon/purescala/Expressions.scala
+++ b/src/main/scala/leon/purescala/Expressions.scala
@@ -588,6 +588,13 @@ object Expressions {
       else Untyped
     }
   }
+  /** $encodingof `StrOps.escape(expr)` for strings */
+  case class StringEscape(expr: Expr) extends Expr {
+    val getType = {
+      if (expr.getType == StringType) StringType
+      else Untyped
+    }
+  }
 
   /* Integer arithmetic */
 
diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index e2581dd8c..7e8277a4f 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -39,6 +39,8 @@ object Extractors {
         Some((Seq(t), (es: Seq[Expr]) => CharToString(es.head)))
       case RealToString(t) =>
         Some((Seq(t), (es: Seq[Expr]) => RealToString(es.head)))
+      case StringEscape(t) =>
+        Some((Seq(t), (es: Seq[Expr]) => StringEscape(es.head)))
       case SetCardinality(t) =>
         Some((Seq(t), (es: Seq[Expr]) => SetCardinality(es.head)))
       case CaseClassSelector(cd, e, sel) =>
diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index e77825dd2..554e70f43 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -12,6 +12,7 @@ import PrinterHelpers._
 import ExprOps.{isListLiteral, simplestValue}
 import Expressions._
 import Types._
+import org.apache.commons.lang3.StringEscapeUtils
 
 case class PrinterContext(
   current: Tree,
@@ -174,10 +175,11 @@ class PrettyPrinter(opts: PrinterOptions,
       case IntegerToString(expr)  => p"$expr.toString"
       case CharToString(expr)     => p"$expr.toString"
       case RealToString(expr)     => p"$expr.toString"
+      case StringEscape(expr)     => p"leon.lang.StrOps.escape($expr)"
       case StringConcat(lhs, rhs) => optP { p"$lhs + $rhs" }
     
-      case SubString(expr, start, end) => p"StrOps.substring($expr, $start, $end)"
-      case StringLength(expr)          => p"StrOps.length($expr)"
+      case SubString(expr, start, end) => p"leon.lang.StrOps.substring($expr, $start, $end)"
+      case StringLength(expr)          => p"leon.lang.StrOps.length($expr)"
 
       case IntLiteral(v)        => p"$v"
       case InfiniteIntegerLiteral(v) => p"$v"
@@ -191,7 +193,7 @@ class PrettyPrinter(opts: PrinterOptions,
         if(v.count(c => c == '\n') >= 1 && v.length >= 80 && v.indexOf("\"\"\"") == -1) {
           p"$dbquote$dbquote$dbquote$v$dbquote$dbquote$dbquote"
         } else {
-          val escaped = v.replaceAll(dbquote, "\\\\\"").replaceAll("\n","\\\\n").replaceAll("\r","\\\\r")
+          val escaped = StringEscapeUtils.escapeJava(v)
           p"$dbquote$escaped$dbquote"
         }
       case GenericValue(tp, id) => p"$tp#$id"
diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala
index a72c1a37f..390e1af2c 100644
--- a/src/main/scala/leon/purescala/ScalaPrinter.scala
+++ b/src/main/scala/leon/purescala/ScalaPrinter.scala
@@ -9,6 +9,7 @@ import Common._
 import Expressions._
 import Types._
 import Definitions._
+import org.apache.commons.lang3.StringEscapeUtils
 
 /** This pretty-printer only print valid scala syntax */
 class ScalaPrinter(opts: PrinterOptions,
@@ -44,7 +45,7 @@ class ScalaPrinter(opts: PrinterOptions,
         if(v.count(c => c == '\n') >= 1 && v.length >= 80 && v.indexOf("\"\"\"") == -1) {
           p"$dbquote$dbquote$dbquote$v$dbquote$dbquote$dbquote"
         } else {
-          val escaped = v.replaceAll(dbquote, "\\\\\"").replaceAll("\n","\\\\n").replaceAll("\r","\\\\r")
+          val escaped = StringEscapeUtils.escapeJava(v)
           p"$dbquote$escaped$dbquote"
         }
 
diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala
index 5c393ad4c..d2491d52f 100644
--- a/src/main/scala/leon/synthesis/rules/StringRender.scala
+++ b/src/main/scala/leon/synthesis/rules/StringRender.scala
@@ -377,7 +377,7 @@ case object StringRender extends Rule("StringRender") {
         case None => // No function can render the current type.
           input.getType match {
             case StringType =>
-              gatherInputs(ctx, q, result += Stream((input, Nil)))
+              gatherInputs(ctx, q, result += Stream((input, Nil))  #::: Stream((StringEscape(input): Expr, Nil)))
             case BooleanType =>
               val (bTemplate, vs) = booleanTemplate(input).instantiateWithVars
               gatherInputs(ctx, q, result += Stream((BooleanToString(input), Nil)) #::: Stream((bTemplate, vs)))
-- 
GitLab