From 61c730295afa1fe0b275b323b09bb72bfdbd5ef1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch>
Date: Wed, 4 May 2016 16:17:45 +0200
Subject: [PATCH] adding bigSubstring and bigLength for strings.

---
 library/lang/StrOps.scala                     |  8 ++++
 library/lang/package.scala                    |  7 ++++
 library/theories/String.scala                 | 16 ++++++++
 .../java/leon/codegen/runtime/StrOps.java     |  8 ++++
 .../scala/leon/codegen/CodeGeneration.scala   | 14 ++++++-
 .../leon/evaluators/RecursiveEvaluator.scala  |  9 +++++
 .../frontends/scalac/CodeExtraction.scala     |  6 +++
 src/main/scala/leon/purescala/ExprOps.scala   |  6 ++-
 .../scala/leon/purescala/Expressions.scala    | 17 +++++++++
 .../scala/leon/purescala/Extractors.scala     |  3 ++
 .../scala/leon/purescala/PrettyPrinter.scala  |  2 +
 .../solvers/smtlib/SMTLIBCVC4Target.scala     | 22 +++++++++--
 .../leon/solvers/theories/StringEncoder.scala | 37 +++++++++++++++----
 13 files changed, 140 insertions(+), 15 deletions(-)

diff --git a/library/lang/StrOps.scala b/library/lang/StrOps.scala
index 631c2e3ad..1276a06f3 100644
--- a/library/lang/StrOps.scala
+++ b/library/lang/StrOps.scala
@@ -12,6 +12,14 @@ object StrOps {
   def concat(a: String, b: String): String = {
     a + b
   }
+  @ignore
+  def bigLength(s: String): BigInt = {
+    BigInt(s.length)
+  }
+  @ignore
+  def bigSubstring(s: String, start: BigInt, end: BigInt): String = {
+    s.substring(start.toInt, end.toInt)
+  }
   @internal @library
   def escape(s: String): String = s // Wrong definition, but it will eventually use StringEscapeUtils.escapeJava(s) at parsing and compile time.
 }
\ No newline at end of file
diff --git a/library/lang/package.scala b/library/lang/package.scala
index 823788a88..c10c02582 100644
--- a/library/lang/package.scala
+++ b/library/lang/package.scala
@@ -68,6 +68,13 @@ package object lang {
       (res: A) => byExample(input, res)
     }
   }
+  
+  implicit class StringDecorations(val underlying: String) {
+    @ignore
+    def bigLength = BigInt(underlying.length)
+    @ignore
+    def bigsubstring(start: BigInt, end: BigInt): String = underlying.substring(start.toInt, end.toInt)
+  }
 
   @ignore
   object BigInt {
diff --git a/library/theories/String.scala b/library/theories/String.scala
index ca2b9027d..745501578 100644
--- a/library/theories/String.scala
+++ b/library/theories/String.scala
@@ -9,6 +9,11 @@ sealed abstract class String {
     case StringCons(_, tail) => 1 + tail.size
     case StringNil() => BigInt(0)
   }) ensuring (_ >= 0)
+  
+  def sizeI: Int = this match {
+    case StringCons(_, tail) => 1 + tail.length
+    case StringNil() => 0
+  }
 
   def concat(that: String): String = this match {
     case StringCons(head, tail) => StringCons(head, tail concat that)
@@ -24,8 +29,19 @@ sealed abstract class String {
     case StringCons(head, tail) if i > 0 => tail drop (i - 1)
     case _ => this
   }
+  
+  def takeI(i: Int): String = this match {
+    case StringCons(head, tail) if i > 0 => StringCons(head, tail takeI (i - 1))
+    case _ => StringNil()
+  }
+
+  def dropI(i: Int): String = this match {
+    case StringCons(head, tail) if i > 0 => tail dropI (i - 1)
+    case _ => this
+  }
 
   def slice(from: BigInt, to: BigInt): String = drop(from).take(to - from)
+  def sliceI(from: Int, to: Int): String = dropI(from).takeI(to - from)
 }
 
 case class StringCons(head: Char, tail: String) extends String
diff --git a/src/main/java/leon/codegen/runtime/StrOps.java b/src/main/java/leon/codegen/runtime/StrOps.java
index 26a95aa7c..9f97e37a5 100644
--- a/src/main/java/leon/codegen/runtime/StrOps.java
+++ b/src/main/java/leon/codegen/runtime/StrOps.java
@@ -8,6 +8,14 @@ public class StrOps {
 	public static String concat(String a, String b) {
 		return a + b;
 	}
+	
+	public static BigInt bigLength(String a) {
+		return new BigInt(a.length() + "");
+	}
+	
+	public static String bigSubstring(String s, BigInt start, BigInt end) {
+		return s.substring(Integer.parseInt(start.toString()), Integer.parseInt(end.toString()));
+	}
 
 	public static String bigIntToString(BigInt a) {
 		return a.toString();
diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala
index c352b45e3..199fd8669 100644
--- a/src/main/scala/leon/codegen/CodeGeneration.scala
+++ b/src/main/scala/leon/codegen/CodeGeneration.scala
@@ -872,7 +872,11 @@ trait CodeGeneration {
         
       case StringLength(a) =>
         mkExpr(a, ch)
-        ch << InvokeSpecial(JavaStringClass, "length", s"()I")
+        ch << InvokeVirtual(JavaStringClass, "length", s"()I")
+        
+      case StringBigLength(a) =>
+        mkExpr(a, ch)
+        ch << InvokeStatic(StrOpsClass, "bigLength", s"(L$JavaStringClass;)L$BigIntClass;")
         
       case Int32ToString(a) =>
         mkExpr(a, ch)
@@ -894,7 +898,13 @@ trait CodeGeneration {
         mkExpr(a, ch)
         mkExpr(start, ch)
         mkExpr(end, ch)
-        ch << InvokeSpecial(JavaStringClass, "substring", s"(L$JavaStringClass;II)L$JavaStringClass;")
+        ch << InvokeVirtual(JavaStringClass, "substring", s"(II)L$JavaStringClass;")
+      
+      case BigSubString(a, start, end) =>
+        mkExpr(a, ch)
+        mkExpr(start, ch)
+        mkExpr(end, ch)
+        ch << InvokeStatic(StrOpsClass, "bigSubstring", s"(L$JavaStringClass;II)L$JavaStringClass;")
         
       // Arithmetic
       case Plus(l, r) =>
diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
index 348f06fb4..a709de91f 100644
--- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
@@ -285,11 +285,20 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, val bank: Eva
       case StringLiteral(a) => IntLiteral(a.length)
       case res => throw EvalError(typeErrorMsg(res, Int32Type))
     }
+    case StringBigLength(a) => e(a) match {
+      case StringLiteral(a) => InfiniteIntegerLiteral(a.length)
+      case res => throw EvalError(typeErrorMsg(res, IntegerType))
+    }
     case SubString(a, start, end) => (e(a), e(start), e(end)) match {
       case (StringLiteral(a), IntLiteral(b), IntLiteral(c))  =>
         StringLiteral(a.substring(b, c))
       case res => throw EvalError(typeErrorMsg(res._1, StringType))
     }
+    case BigSubString(a, start, end) => (e(a), e(start), e(end)) match {
+      case (StringLiteral(a), InfiniteIntegerLiteral(b), InfiniteIntegerLiteral(c))  =>
+        StringLiteral(a.substring(b.toInt, c.toInt))
+      case res => throw EvalError(typeErrorMsg(res._1, StringType))
+    }
     case Int32ToString(a) => e(a) match {
       case IntLiteral(i) => StringLiteral(i.toString)
       case res =>  throw EvalError(typeErrorMsg(res, Int32Type))
diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index c58644473..9c3e4b393 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -1728,10 +1728,16 @@ trait CodeExtraction extends ASTExtractors {
               StringConcat(converter(a1), a2)
             case (IsTyped(a1, StringType), "length", List()) =>
               StringLength(a1)
+            case (IsTyped(a1, StringType), "bigLength", List()) =>
+              StringBigLength(a1)
             case (IsTyped(a1, StringType), "substring", List(IsTyped(start, Int32Type))) =>
               SubString(a1, start, StringLength(a1))
             case (IsTyped(a1, StringType), "substring", List(IsTyped(start, Int32Type), IsTyped(end, Int32Type))) =>
               SubString(a1, start, end)
+            case (IsTyped(a1, StringType), "bigSubstring", List(IsTyped(start, IntegerType))) =>
+              BigSubString(a1, start, StringBigLength(a1))
+            case (IsTyped(a1, StringType), "bigSubstring", List(IsTyped(start, IntegerType), IsTyped(end, IntegerType))) =>
+              BigSubString(a1, start, end)
 
             //BigInt methods
             case (IsTyped(a1, IntegerType), "+", List(IsTyped(a2, IntegerType))) =>
diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala
index 2f63ac23b..c0a219c4d 100644
--- a/src/main/scala/leon/purescala/ExprOps.scala
+++ b/src/main/scala/leon/purescala/ExprOps.scala
@@ -1102,8 +1102,10 @@ object ExprOps extends GenTreeOps[Expr] {
       case StringConcat(StringLiteral(""), b) => b
       case StringConcat(b, StringLiteral("")) => b
       case StringConcat(StringLiteral(a), StringLiteral(b)) => StringLiteral(a + b)
-      case StringLength(StringLiteral(a)) => InfiniteIntegerLiteral(a.length)
-      case SubString(StringLiteral(a), InfiniteIntegerLiteral(start), InfiniteIntegerLiteral(end)) => StringLiteral(a.substring(start.toInt, end.toInt))
+      case StringLength(StringLiteral(a)) => IntLiteral(a.length)
+      case StringBigLength(StringLiteral(a)) => InfiniteIntegerLiteral(a.length)
+      case SubString(StringLiteral(a), IntLiteral(start), IntLiteral(end)) => StringLiteral(a.substring(start.toInt, end.toInt))
+      case BigSubString(StringLiteral(a), InfiniteIntegerLiteral(start), InfiniteIntegerLiteral(end)) => StringLiteral(a.substring(start.toInt, end.toInt))
       case _ => expr
     }).copiedFrom(expr)
     simplify0(expr)
diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala
index adeee57f3..6620ebb6a 100644
--- a/src/main/scala/leon/purescala/Expressions.scala
+++ b/src/main/scala/leon/purescala/Expressions.scala
@@ -599,6 +599,16 @@ object Expressions {
       else Untyped
     }
   }
+  /** $encodingof `lhs.subString(start, end)` for strings */
+  case class BigSubString(expr: Expr, start: Expr, end: Expr) extends Expr {
+    val getType = {
+      val ext = expr.getType
+      val st = start.getType
+      val et = end.getType
+      if (ext == StringType && st == IntegerType && et == IntegerType) StringType
+      else Untyped
+    }
+  }
   /** $encodingof `lhs.length` for strings */
   case class StringLength(expr: Expr) extends Expr {
     val getType = {
@@ -606,6 +616,13 @@ object Expressions {
       else Untyped
     }
   }
+  /** $encodingof `lhs.length` for strings */
+  case class StringBigLength(expr: Expr) extends Expr {
+    val getType = {
+      if (expr.getType == StringType) IntegerType
+      else Untyped
+    }
+  }
 
   /* Integer arithmetic */
 
diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index 42cc07526..36b0887dc 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -40,6 +40,8 @@ object Extractors {
         Some((Seq(t), (es: Seq[Expr]) => BVNot(es.head)))
       case StringLength(t) =>
         Some((Seq(t), (es: Seq[Expr]) => StringLength(es.head)))
+      case StringBigLength(t) =>
+        Some((Seq(t), (es: Seq[Expr]) => StringBigLength(es.head)))
       case Int32ToString(t) =>
         Some((Seq(t), (es: Seq[Expr]) => Int32ToString(es.head)))
       case BooleanToString(t) =>
@@ -196,6 +198,7 @@ object Extractors {
       case And(args) => Some((args, es => And(es)))
       case Or(args) => Some((args, es => Or(es)))
       case SubString(t1, a, b) => Some((t1::a::b::Nil, es => SubString(es(0), es(1), es(2))))
+      case BigSubString(t1, a, b) => Some((t1::a::b::Nil, es => BigSubString(es(0), es(1), es(2))))
       case FiniteSet(els, base) =>
         Some((els.toSeq, els => FiniteSet(els.toSet, base)))
       case FiniteBag(els, base) =>
diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index e822cae94..bda0e0612 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -190,7 +190,9 @@ class PrettyPrinter(opts: PrinterOptions,
       case StringConcat(lhs, rhs) => optP { p"$lhs + $rhs" }
     
       case SubString(expr, start, end) => p"$expr.substring($start, $end)"
+      case BigSubString(expr, start, end) => p"$expr.bigSubstring($start, $end)"
       case StringLength(expr)          => p"$expr.length"
+      case StringBigLength(expr)       => p"$expr.bigLength"
 
       case IntLiteral(v)        => p"$v"
       case InfiniteIntegerLiteral(v) => p"$v"
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
index 04763b1b5..bd1ef5448 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
@@ -109,9 +109,13 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
       case (SString(v), Some(StringType)) =>
         StringLiteral(v)
         
-      case (Strings.Length(a), _) =>
+      case (Strings.Length(a), Some(Int32Type)) =>
         val aa = fromSMT(a)
         StringLength(aa)
+        
+      case (Strings.Length(a), Some(IntegerType)) =>
+        val aa = fromSMT(a)
+        StringBigLength(aa)
 
       case (Strings.Concat(a, b, c @ _*), _) =>
         val aa = fromSMT(a)
@@ -125,8 +129,14 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
         val tt = fromSMT(start)
         val oo = fromSMT(offset)
         oo match {
-          case Minus(otherEnd, `tt`) => SubString(ss, tt, otherEnd)
-          case _ => SubString(ss, tt, Plus(tt, oo))
+          case BVMinus(otherEnd, `tt`) => SubString(ss, tt, otherEnd)
+          case Minus(otherEnd, `tt`) => BigSubString(ss, tt, otherEnd)
+          case _ => 
+            if(tt.getType == IntegerType) {
+              BigSubString(ss, tt, Plus(tt, oo))
+            } else {
+              SubString(ss, tt, BVPlus(tt, oo))
+            }
         }
         
       case (Strings.At(a, b), _) => fromSMT(Strings.Substring(a, b, SNumeral(1)))
@@ -164,10 +174,14 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
         declareSort(StringType)
         Strings.StringLit(v)
     case StringLength(a)           => Strings.Length(toSMT(a))
+    case StringBigLength(a)        => Strings.Length(toSMT(a))
     case StringConcat(a, b)        => Strings.Concat(toSMT(a), toSMT(b))
-    case SubString(a, start, Plus(start2, length)) if start == start2  =>
+    case SubString(a, start, BVPlus(start2, length)) if start == start2  =>
                                       Strings.Substring(toSMT(a),toSMT(start),toSMT(length))
     case SubString(a, start, end)  => Strings.Substring(toSMT(a),toSMT(start),toSMT(Minus(end, start)))
+    case BigSubString(a, start, Plus(start2, length)) if start == start2  =>
+                                      Strings.Substring(toSMT(a),toSMT(start),toSMT(length))
+    case BigSubString(a, start, end)  => Strings.Substring(toSMT(a),toSMT(start),toSMT(Minus(end, start)))
     case _ =>
       super.toSMT(e)
   }
diff --git a/src/main/scala/leon/solvers/theories/StringEncoder.scala b/src/main/scala/leon/solvers/theories/StringEncoder.scala
index 02293fab3..1b7538cd8 100644
--- a/src/main/scala/leon/solvers/theories/StringEncoder.scala
+++ b/src/main/scala/leon/solvers/theories/StringEncoder.scala
@@ -22,6 +22,11 @@ class StringEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder {
   val Drop   = p.library.lookupUnique[FunDef]("leon.theories.String.drop").typed
   val Slice  = p.library.lookupUnique[FunDef]("leon.theories.String.slice").typed
   val Concat = p.library.lookupUnique[FunDef]("leon.theories.String.concat").typed
+
+  val SizeI   = p.library.lookupUnique[FunDef]("leon.theories.String.sizeI").typed
+  val TakeI   = p.library.lookupUnique[FunDef]("leon.theories.String.takeI").typed
+  val DropI   = p.library.lookupUnique[FunDef]("leon.theories.String.dropI").typed
+  val SliceI  = p.library.lookupUnique[FunDef]("leon.theories.String.sliceI").typed
   
   val FromInt      = p.library.lookupUnique[FunDef]("leon.theories.String.fromInt").typed
   val FromChar     = p.library.lookupUnique[FunDef]("leon.theories.String.fromChar").typed
@@ -46,13 +51,19 @@ class StringEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder {
     override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match {
       case StringLiteral(v)          =>
         Some(convertFromString(v))
-      case StringLength(a)           =>
+      case StringBigLength(a)           =>
         Some(FunctionInvocation(Size, Seq(transform(a))).copiedFrom(e))
+      case StringLength(a)           =>
+        Some(FunctionInvocation(SizeI, Seq(transform(a))).copiedFrom(e))
       case StringConcat(a, b)        =>
         Some(FunctionInvocation(Concat, Seq(transform(a), transform(b))).copiedFrom(e))
       case SubString(a, start, Plus(start2, length)) if start == start2  =>
-        Some(FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(transform(a), transform(start))), transform(length))).copiedFrom(e))
+        Some(FunctionInvocation(TakeI, Seq(FunctionInvocation(DropI, Seq(transform(a), transform(start))), transform(length))).copiedFrom(e))
       case SubString(a, start, end)  => 
+        Some(FunctionInvocation(SliceI, Seq(transform(a), transform(start), transform(end))).copiedFrom(e))
+      case BigSubString(a, start, Plus(start2, length)) if start == start2  =>
+        Some(FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(transform(a), transform(start))), transform(length))).copiedFrom(e))
+      case BigSubString(a, start, end)  => 
         Some(FunctionInvocation(Slice, Seq(transform(a), transform(start), transform(end))).copiedFrom(e))
       case Int32ToString(a) => 
         Some(FunctionInvocation(FromInt, Seq(transform(a))).copiedFrom(e))
@@ -85,20 +96,32 @@ class StringEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder {
     override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match {
       case cc @ CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, String)=>
         Some(StringLiteral(convertToString(cc)).copiedFrom(cc))
-      case FunctionInvocation(Size, Seq(a)) =>
+      case FunctionInvocation(SizeI, Seq(a)) =>
         Some(StringLength(transform(a)).copiedFrom(e))
+      case FunctionInvocation(Size, Seq(a)) =>
+        Some(StringBigLength(transform(a)).copiedFrom(e))
       case FunctionInvocation(Concat, Seq(a, b)) =>
         Some(StringConcat(transform(a), transform(b)).copiedFrom(e))
-      case FunctionInvocation(Slice, Seq(a, from, to)) =>
+      case FunctionInvocation(SliceI, Seq(a, from, to)) =>
         Some(SubString(transform(a), transform(from), transform(to)).copiedFrom(e))
-      case FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(a, start)), length)) =>
+      case FunctionInvocation(Slice, Seq(a, from, to)) =>
+        Some(BigSubString(transform(a), transform(from), transform(to)).copiedFrom(e))
+      case FunctionInvocation(TakeI, Seq(FunctionInvocation(DropI, Seq(a, start)), length)) =>
         val rstart = transform(start)
         Some(SubString(transform(a), rstart, plus(rstart, transform(length))).copiedFrom(e))
+      case FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(a, start)), length)) =>
+        val rstart = transform(start)
+        Some(BigSubString(transform(a), rstart, plus(rstart, transform(length))).copiedFrom(e))
+      case FunctionInvocation(TakeI, Seq(a, length)) =>
+        Some(SubString(transform(a), IntLiteral(0), transform(length)).copiedFrom(e))
       case FunctionInvocation(Take, Seq(a, length)) =>
-        Some(SubString(transform(a), InfiniteIntegerLiteral(0), transform(length)).copiedFrom(e))
-      case FunctionInvocation(Drop, Seq(a, count)) =>
+        Some(BigSubString(transform(a), InfiniteIntegerLiteral(0), transform(length)).copiedFrom(e))
+      case FunctionInvocation(DropI, Seq(a, count)) =>
         val ra = transform(a)
         Some(SubString(ra, transform(count), StringLength(ra)).copiedFrom(e))
+      case FunctionInvocation(Drop, Seq(a, count)) =>
+        val ra = transform(a)
+        Some(BigSubString(ra, transform(count), StringBigLength(ra)).copiedFrom(e))
       case FunctionInvocation(FromInt, Seq(a)) =>
         Some(Int32ToString(transform(a)).copiedFrom(e))
       case FunctionInvocation(FromBigInt, Seq(a)) =>
-- 
GitLab