diff --git a/library/lang/StrOps.scala b/library/lang/StrOps.scala index 0c4d480eeb5ac0ccc1021c7bac1cf37b07520424..3c59590d073de60b109c8d8dc3d48b330749a72c 100644 --- a/library/lang/StrOps.scala +++ b/library/lang/StrOps.scala @@ -18,4 +18,6 @@ object StrOps { def substring(a: String, start: BigInt, end: BigInt): String = { if(start > end || start >= length(a) || end <= 0) "" else a.substring(start.toInt, end.toInt) } + @ignore + def escape(a: String): String = a // Wrong definition, but it will eventually use StringEscapeUtils.escapeJava(s) at parsing and compile time. } \ No newline at end of file diff --git a/src/main/java/leon/codegen/runtime/StrOps.java b/src/main/java/leon/codegen/runtime/StrOps.java index 1e5d7f60d5086dd1a45b4e3e0a52fda11db842a1..d8a889101e9fd90e9208ad30ec9231d3264a36cf 100644 --- a/src/main/java/leon/codegen/runtime/StrOps.java +++ b/src/main/java/leon/codegen/runtime/StrOps.java @@ -1,5 +1,7 @@ package leon.codegen.runtime; +import org.apache.commons.lang3.StringEscapeUtils; + public class StrOps { public static String concat(String a, String b) { return a + b; @@ -34,4 +36,7 @@ public class StrOps { public static String realToString (Real a) { return ""; // TODO: Not supported at this moment. } + public static String escape(String s) { + return StringEscapeUtils.escapeJava(s); + } } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index dc674dfbf2e7f35ba89fe2fbe6bc402d63e1bbde..f1fced4905e524d9185602756a8971479faca9aa 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -1199,6 +1199,10 @@ trait CodeGeneration { mkExpr(end, ch) ch << InvokeStatic(StrOpsClass, "substring", s"(L$JavaStringClass;L$BigIntClass;L$BigIntClass;)L$JavaStringClass;") + case StringEscape(a) => + mkExpr(a, ch) + ch << InvokeStatic(StrOpsClass, "escape", s"(L$JavaStringClass;)L$JavaStringClass;") + // Arithmetic case Plus(l, r) => mkExpr(l, ch) diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index 0f3df8aede362bb3e3c7fb2c6431c3daf059556c..7b259f2ea8fe20ab55f75c8fbce7dc34fbc7ee29 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -41,7 +41,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { (n, d) -> Constructor[Expr, TypeTree](List(), RealType, s => FractionalLiteral(n, d), "" + n + "/" + d) }).toMap - val strings = (for (b <- Set("", "a", "b", "Abcd")) yield { + val strings = (for (b <- Set("", "a", "\"\t\n", "Abcd")) yield { b -> Constructor[Expr, TypeTree](List(), StringType, s => StringLiteral(b), b) }).toMap diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 7ef36c23cfdbc5a540f8eb4b11f9dddc1012ba59..ceb23eb38c62649e2cb1a1d8f1bbcea10f109100 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -268,6 +268,10 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case FractionalLiteral(n, d) => StringLiteral(n.toString + "/" + d.toString) case res => throw EvalError(typeErrorMsg(res, RealType)) } + case StringEscape(a) => e(a) match { + case StringLiteral(i) => StringLiteral(codegen.runtime.StrOps.escape(i)) + case res => throw EvalError(typeErrorMsg(res, StringType)) + } case BVPlus(l,r) => (e(l), e(r)) match { diff --git a/src/main/scala/leon/evaluators/StringTracingEvaluator.scala b/src/main/scala/leon/evaluators/StringTracingEvaluator.scala index 43ff4bc2459b9a125d60d2ce16992e90b491505b..22d4cd8c186235f391c418e0c2b8fe2b649dc302 100644 --- a/src/main/scala/leon/evaluators/StringTracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/StringTracingEvaluator.scala @@ -56,7 +56,12 @@ class StringTracingEvaluator(ctx: LeonContext, prog: Program) extends Contextual case _ => StringConcat(es1, es2) } - + case StringEscape(a) => + val ea = e(a) + ea match { + case StringLiteral(_) => super.e(StringEscape(a)) + case _ => StringEscape(ea) + } case expr => super.e(expr) } @@ -94,6 +99,13 @@ class StringTracingEvaluator(ctx: LeonContext, prog: Program) extends Contextual case _ => (StringLength(es1), StringLength(t1)) } + + case StringEscape(a) => + val (ea, ta) = e(a) + ea match { + case StringLiteral(_) => (underlying.e(StringEscape(ea)), StringEscape(ta)) + case _ => (StringEscape(ea), StringEscape(ta)) + } case expr@StringLiteral(s) => (expr, expr) diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index dc4b7481cc545219455e6b4cbb8068f1ce04617d..561d267dfe7af1a3531d7a6b9c59c8a391628c5c 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -200,6 +200,14 @@ trait ASTExtractors { case _ => None } } + + /** Matches a call to StrOps.escape */ + object ExStringEscape { + def unapply(tree: Apply): Option[Tree] = tree match { + case Apply(ExSelected("leon", "lang", "StrOps", "escape"), List(arg)) => Some(arg) + case _ => None + } + } /** Extracts the 'require' contract from an expression (only if it's the * first call in the block). */ diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 8c01b7731c2dd412d22da5816b6f6873ca0e0219..4d6a1b113e10292b32ab91d0bd757fa6a0befc8e 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1509,6 +1509,9 @@ trait CodeExtraction extends ASTExtractors { case ExImplies(lhs, rhs) => Implies(extractTree(lhs), extractTree(rhs)).setPos(current.pos) + case ExStringEscape(s) => + StringEscape(extractTree(s)) + case c @ ExCall(rec, sym, tps, args) => // The object on which it is called is null if the symbol sym is a valid function in the scope and not a method. val rrec = rec match { diff --git a/src/main/scala/leon/grammars/ValueGrammar.scala b/src/main/scala/leon/grammars/ValueGrammar.scala index 5442b4ee7723fb124aab8aea65647ed762a9f0ec..7fa70729c68bfe1a3463e9f3b6ccc6e7de92517e 100644 --- a/src/main/scala/leon/grammars/ValueGrammar.scala +++ b/src/main/scala/leon/grammars/ValueGrammar.scala @@ -25,6 +25,13 @@ case object ValueGrammar extends ExpressionGrammar[TypeTree] { terminal(InfiniteIntegerLiteral(1)), terminal(InfiniteIntegerLiteral(5)) ) + case StringType => + List( + terminal(StringLiteral("")), + terminal(StringLiteral("a")), + terminal(StringLiteral("\"'\n\r\t")), + terminal(StringLiteral("Lara 2007")) + ) case tp: TypeParameter => for (ind <- (1 to 3).toList) yield { diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index ae93221bfe8313cf8cef20ed084d559a2a472133..aa216b852507ef66780aa47885c84c87e8e79f7e 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -588,6 +588,13 @@ object Expressions { else Untyped } } + /** $encodingof `StrOps.escape(expr)` for strings */ + case class StringEscape(expr: Expr) extends Expr { + val getType = { + if (expr.getType == StringType) StringType + else Untyped + } + } /* Integer arithmetic */ diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index e2581dd8cdb33e3d025e8206d72dd32fc4ca59f7..7e8277a4fab21f835eafc85e94d567bc3b7050bc 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -39,6 +39,8 @@ object Extractors { Some((Seq(t), (es: Seq[Expr]) => CharToString(es.head))) case RealToString(t) => Some((Seq(t), (es: Seq[Expr]) => RealToString(es.head))) + case StringEscape(t) => + Some((Seq(t), (es: Seq[Expr]) => StringEscape(es.head))) case SetCardinality(t) => Some((Seq(t), (es: Seq[Expr]) => SetCardinality(es.head))) case CaseClassSelector(cd, e, sel) => diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index e77825dd212725a46f346fe6582463b4b7f90f2a..554e70f4343beb38d4ed2773a292cb31476ffe0d 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -12,6 +12,7 @@ import PrinterHelpers._ import ExprOps.{isListLiteral, simplestValue} import Expressions._ import Types._ +import org.apache.commons.lang3.StringEscapeUtils case class PrinterContext( current: Tree, @@ -174,10 +175,11 @@ class PrettyPrinter(opts: PrinterOptions, case IntegerToString(expr) => p"$expr.toString" case CharToString(expr) => p"$expr.toString" case RealToString(expr) => p"$expr.toString" + case StringEscape(expr) => p"leon.lang.StrOps.escape($expr)" case StringConcat(lhs, rhs) => optP { p"$lhs + $rhs" } - case SubString(expr, start, end) => p"StrOps.substring($expr, $start, $end)" - case StringLength(expr) => p"StrOps.length($expr)" + case SubString(expr, start, end) => p"leon.lang.StrOps.substring($expr, $start, $end)" + case StringLength(expr) => p"leon.lang.StrOps.length($expr)" case IntLiteral(v) => p"$v" case InfiniteIntegerLiteral(v) => p"$v" @@ -191,7 +193,7 @@ class PrettyPrinter(opts: PrinterOptions, if(v.count(c => c == '\n') >= 1 && v.length >= 80 && v.indexOf("\"\"\"") == -1) { p"$dbquote$dbquote$dbquote$v$dbquote$dbquote$dbquote" } else { - val escaped = v.replaceAll(dbquote, "\\\\\"").replaceAll("\n","\\\\n").replaceAll("\r","\\\\r") + val escaped = StringEscapeUtils.escapeJava(v) p"$dbquote$escaped$dbquote" } case GenericValue(tp, id) => p"$tp#$id" diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index a72c1a37ff9b3cb9752aac57bc97d6ac2fc30e52..390e1af2c62ede0bc0bf934f582d248832650226 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -9,6 +9,7 @@ import Common._ import Expressions._ import Types._ import Definitions._ +import org.apache.commons.lang3.StringEscapeUtils /** This pretty-printer only print valid scala syntax */ class ScalaPrinter(opts: PrinterOptions, @@ -44,7 +45,7 @@ class ScalaPrinter(opts: PrinterOptions, if(v.count(c => c == '\n') >= 1 && v.length >= 80 && v.indexOf("\"\"\"") == -1) { p"$dbquote$dbquote$dbquote$v$dbquote$dbquote$dbquote" } else { - val escaped = v.replaceAll(dbquote, "\\\\\"").replaceAll("\n","\\\\n").replaceAll("\r","\\\\r") + val escaped = StringEscapeUtils.escapeJava(v) p"$dbquote$escaped$dbquote" } diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala index 5c393ad4ced2026ca5cea966fdbcc175a53153b1..d2491d52f10a932b1c956c2b74f4f7690f68eed7 100644 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ b/src/main/scala/leon/synthesis/rules/StringRender.scala @@ -377,7 +377,7 @@ case object StringRender extends Rule("StringRender") { case None => // No function can render the current type. input.getType match { case StringType => - gatherInputs(ctx, q, result += Stream((input, Nil))) + gatherInputs(ctx, q, result += Stream((input, Nil)) #::: Stream((StringEscape(input): Expr, Nil))) case BooleanType => val (bTemplate, vs) = booleanTemplate(input).instantiateWithVars gatherInputs(ctx, q, result += Stream((BooleanToString(input), Nil)) #::: Stream((bTemplate, vs)))