From 8ee9c66e29ff455de151dd6ba5506f53d1a9b243 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch>
Date: Wed, 20 Jan 2016 14:20:55 +0100
Subject: [PATCH] Translation from String to List[Char] for Z3 on demand.

---
 .../solvers/smtlib/SMTLIBCVC4Target.scala     | 37 +++++++-
 .../leon/solvers/smtlib/SMTLIBTarget.scala    | 35 --------
 .../leon/solvers/smtlib/SMTLIBZ3Target.scala  | 86 +++++++++++++++++++
 3 files changed, 121 insertions(+), 37 deletions(-)

diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
index f1cf73142..87cc849b4 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
@@ -14,6 +14,7 @@ import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Forall => SMTFor
 import _root_.smtlib.parser.Commands._
 import _root_.smtlib.interpreters.CVC4Interpreter
 import _root_.smtlib.theories.experimental.Sets
+import _root_.smtlib.theories.experimental.Strings
 
 trait SMTLIBCVC4Target extends SMTLIBTarget {
 
@@ -30,7 +31,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
       tpe match {
         case SetType(base) =>
           Sets.SetSort(declareSort(base))
-
+        case StringType  => Strings.StringSort()
         case _ =>
           super.declareSort(t)
       }
@@ -109,6 +110,31 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
           case FiniteSet(elems, _) => elems
         }).toSet, base)
 
+      case (SString(v), Some(StringType)) =>
+        StringLiteral(v)
+        
+      case (Strings.Length(a), _) =>
+        val aa = fromSMT(a)
+        StringLength(aa)
+
+      case (Strings.Concat(a, b, c @ _*), _) =>
+        val aa = fromSMT(a)
+        val bb = fromSMT(b)
+        (StringConcat(aa, bb) /: c.map(fromSMT(_))) {
+          case (s, cc) => StringConcat(s, cc)
+        }
+      
+      case (Strings.Substring(s, start, offset), _) =>
+        val ss = fromSMT(s)
+        val tt = fromSMT(start)
+        val oo = fromSMT(offset)
+        oo match {
+          case Minus(otherEnd, `tt`) => SubString(ss, tt, otherEnd)
+          case _ => SubString(ss, tt, Plus(tt, oo))
+        }
+        
+      case (Strings.At(a, b), _) => fromSMT(Strings.Substring(a, b, SNumeral(1)))
+
       case _ =>
         super.fromSMT(t, otpe)
     }
@@ -138,7 +164,14 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
     case SetDifference(a, b) => Sets.Setminus(toSMT(a), toSMT(b))
     case SetUnion(a, b) => Sets.Union(toSMT(a), toSMT(b))
     case SetIntersection(a, b) => Sets.Intersection(toSMT(a), toSMT(b))
-
+    case StringLiteral(v)          =>
+        declareSort(StringType)
+        Strings.StringLit(v)
+    case StringLength(a)           => Strings.Length(toSMT(a))
+    case StringConcat(a, b)        => Strings.Concat(toSMT(a), toSMT(b))
+    case SubString(a, start, Plus(start2, length)) if start == start2  =>
+                                      Strings.Substring(toSMT(a),toSMT(start),toSMT(length))
+    case SubString(a, start, end)  => Strings.Substring(toSMT(a),toSMT(start),toSMT(Minus(end, start)))
     case _ =>
       super.toSMT(e)
   }
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
index b97297fb0..47017bcf1 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
@@ -241,7 +241,6 @@ trait SMTLIBTarget extends Interruptible {
         case RealType    => Reals.RealSort()
         case Int32Type   => FixedSizeBitVectors.BitVectorSort(32)
         case CharType    => FixedSizeBitVectors.BitVectorSort(32)
-        case StringType  => Strings.StringSort()
 
         case RawArrayType(from, to) =>
           Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(from), declareSort(to)))
@@ -379,9 +378,6 @@ trait SMTLIBTarget extends Interruptible {
       case FractionalLiteral(n, d)   => Reals.Div(Reals.NumeralLit(n), Reals.NumeralLit(d))
       case CharLiteral(c)            => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt))
       case BooleanLiteral(v)         => Core.BoolConst(v)
-      case StringLiteral(v)          =>
-        declareSort(StringType)
-        Strings.StringLit(v)
       case Let(b, d, e) =>
         val id = id2sym(b)
         val value = toSMT(d)
@@ -613,12 +609,6 @@ trait SMTLIBTarget extends Interruptible {
       case RealMinus(a, b)           => Reals.Sub(toSMT(a), toSMT(b))
       case RealTimes(a, b)           => Reals.Mul(toSMT(a), toSMT(b))
       case RealDivision(a, b)        => Reals.Div(toSMT(a), toSMT(b))
-      
-      case StringLength(a)           => Strings.Length(toSMT(a))
-      case StringConcat(a, b)        => Strings.Concat(toSMT(a), toSMT(b))
-      case SubString(a, start, Plus(start2, length)) if start == start2  =>
-                                        Strings.Substring(toSMT(a),toSMT(start),toSMT(length))
-      case SubString(a, start, end)  => Strings.Substring(toSMT(a),toSMT(start),toSMT(Minus(end, start)))
 
       case And(sub)                  => Core.And(sub.map(toSMT): _*)
       case Or(sub)                   => Core.Or(sub.map(toSMT): _*)
@@ -764,31 +754,6 @@ trait SMTLIBTarget extends Interruptible {
       case (SNumeral(n), Some(RealType)) =>
         FractionalLiteral(n, 1)
       
-      case (SString(v), Some(StringType)) =>
-        StringLiteral(v)
-        
-      case (Strings.Length(a), _) =>
-        val aa = fromSMT(a)
-        StringLength(aa)
-
-      case (Strings.Concat(a, b, c @ _*), _) =>
-        val aa = fromSMT(a)
-        val bb = fromSMT(b)
-        (StringConcat(aa, bb) /: c.map(fromSMT(_))) {
-          case (s, cc) => StringConcat(s, cc)
-        }
-      
-      case (Strings.Substring(s, start, offset), _) =>
-        val ss = fromSMT(s)
-        val tt = fromSMT(start)
-        val oo = fromSMT(offset)
-        oo match {
-          case Minus(otherEnd, `tt`) => SubString(ss, tt, otherEnd)
-          case _ => SubString(ss, tt, Plus(tt, oo))
-        }
-        
-      case (Strings.At(a, b), _) => fromSMT(Strings.Substring(a, b, SNumeral(1)))
-
       case (FunctionApplication(SimpleSymbol(SSymbol("ite")), Seq(cond, thenn, elze)), t) =>
         IfExpr(
           fromSMT(cond, Some(BooleanType)),
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
index 3d4a06a83..506d7519d 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
@@ -8,6 +8,7 @@ import purescala.Common._
 import purescala.Expressions._
 import purescala.Constructors._
 import purescala.Types._
+import purescala.Definitions._
 
 import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _}
 import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _}
@@ -15,6 +16,8 @@ import _root_.smtlib.interpreters.Z3Interpreter
 import _root_.smtlib.theories.Core.{Equals => SMTEquals, _}
 import _root_.smtlib.theories.ArraysEx
 
+import utils.Bijection
+
 trait SMTLIBZ3Target extends SMTLIBTarget {
 
   def targetName = "z3"
@@ -69,6 +72,31 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
     Sort(SMTIdentifier(setSort.get), Seq(declareSort(of)))
   }
 
+  val stringBijection = new Bijection[String, CaseClass]()
+  
+  lazy val cons = program.lookup("leon.collection.Cons") match {
+    case Some(cc@CaseClassDef(id, tparams, parent, _)) => cc.typed
+    case _ => throw new Exception("Could not find Cons in Z3 solver")
+  }
+  lazy val nil = program.lookup("leon.collection.Nil") match {
+    case Some(cc@CaseClassDef(id, tparams, parent, _)) => cc.typed
+    case _ => throw new Exception("Could not find Nil in Z3 solver")
+  }
+  lazy val list = program.lookup("leon.collection.List") match {
+    case Some(cc@AbstractClassDef(id, tparams, parent)) => cc.typed
+    case _ => throw new Exception("Could not find List in Z3 solver")
+  }
+  def extractFunDef(s: String): FunDef = program.lookup(s) match {
+    case Some(fd: FunDef) => fd
+    case _ => throw new Exception("Could not find "+s+" in Z3 solver")
+  }
+  lazy val list_size = extractFunDef("leon.collection.List.size")
+  lazy val list_++ = extractFunDef("leon.collection.List.++")
+  lazy val list_take = extractFunDef("leon.collection.List.take")
+  lazy val list_drop = extractFunDef("leon.collection.List.drop")
+  lazy val list_slice = extractFunDef("leon.collection.List.slice")
+  
+  
   override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None)
                                 (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = {
     (t, otpe) match {
@@ -93,6 +121,48 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
 
         fromRawArray(RawArrayValue(ktpe, Map(), fromSMT(defV, vtpe)), tpe)
 
+      case (SimpleSymbol(s), Some(StringType)) if constructors.containsB(s) =>
+        constructors.toA(s) match {
+          case cct: CaseClassType if cct == nil =>
+            StringLiteral("")
+          case t =>
+            unsupported(t, "woot? for a single constructor for non-case-object")
+        }
+      case (FunctionApplication(SimpleSymbol(s), args), Some(StringType)) if constructors.containsB(s) =>
+        constructors.toA(s) match {
+          case cct: CaseClassType if cct == cons =>
+            val rargs = args.zip(cct.fields.map(_.getType)).map(fromSMT)
+            val s = ("" /: rargs)  {
+              case (acc, c@CharLiteral(s)) => acc + s
+              case _ => unsupported(cct, "Cannot extract string out of list of any")
+            }
+            StringLiteral(s)
+          case t => unsupported(t, "Cannot extract string")
+        }
+
+      /*case (Strings.Length(a), _) =>
+        val aa = fromSMT(a)
+        StringLength(aa)
+
+      case (Strings.Concat(a, b, c @ _*), _) =>
+        val aa = fromSMT(a)
+        val bb = fromSMT(b)
+        (StringConcat(aa, bb) /: c.map(fromSMT(_))) {
+          case (s, cc) => StringConcat(s, cc)
+        }
+      
+      case (Strings.Substring(s, start, offset), _) =>
+        val ss = fromSMT(s)
+        val tt = fromSMT(start)
+        val oo = fromSMT(offset)
+        oo match {
+          case Minus(otherEnd, `tt`) => SubString(ss, tt, otherEnd)
+          case _ => SubString(ss, tt, Plus(tt, oo))
+        }
+        
+      case (Strings.At(a, b), _) => fromSMT(Strings.Substring(a, b, SNumeral(1)))
+*/
+
       case _ =>
         super.fromSMT(t, otpe)
     }
@@ -132,6 +202,22 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
     case SetIntersection(l, r) =>
       ArrayMap(SSymbol("and"), toSMT(l), toSMT(r))
 
+    case StringLiteral(v)          =>
+      // No string support for z3 at this moment.
+      val stringEncoding = stringBijection.cachedB(v) {
+        v.toList.foldRight(CaseClass(nil, Seq())){
+          case (char, l) => CaseClass(cons, Seq(CharLiteral(char), l))
+        }
+      }
+      toSMT(stringEncoding)
+    case StringLength(a)           =>
+      toSMT(functionInvocation(list_size, Seq(a)))
+    case StringConcat(a, b)        =>
+      toSMT(functionInvocation(list_++, Seq(a, b)))
+    case SubString(a, start, Plus(start2, length)) if start == start2  =>
+      toSMT(functionInvocation(list_take, Seq(functionInvocation(list_drop, Seq(a, start)), length)))
+    case SubString(a, start, end)  => 
+      toSMT(functionInvocation(list_slice, Seq(a, start, end)))
     case _ =>
       super.toSMT(e)
   }
-- 
GitLab