From 8ee9c66e29ff455de151dd6ba5506f53d1a9b243 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch> Date: Wed, 20 Jan 2016 14:20:55 +0100 Subject: [PATCH] Translation from String to List[Char] for Z3 on demand. --- .../solvers/smtlib/SMTLIBCVC4Target.scala | 37 +++++++- .../leon/solvers/smtlib/SMTLIBTarget.scala | 35 -------- .../leon/solvers/smtlib/SMTLIBZ3Target.scala | 86 +++++++++++++++++++ 3 files changed, 121 insertions(+), 37 deletions(-) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala index f1cf73142..87cc849b4 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala @@ -14,6 +14,7 @@ import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Forall => SMTFor import _root_.smtlib.parser.Commands._ import _root_.smtlib.interpreters.CVC4Interpreter import _root_.smtlib.theories.experimental.Sets +import _root_.smtlib.theories.experimental.Strings trait SMTLIBCVC4Target extends SMTLIBTarget { @@ -30,7 +31,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { tpe match { case SetType(base) => Sets.SetSort(declareSort(base)) - + case StringType => Strings.StringSort() case _ => super.declareSort(t) } @@ -109,6 +110,31 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { case FiniteSet(elems, _) => elems }).toSet, base) + case (SString(v), Some(StringType)) => + StringLiteral(v) + + case (Strings.Length(a), _) => + val aa = fromSMT(a) + StringLength(aa) + + case (Strings.Concat(a, b, c @ _*), _) => + val aa = fromSMT(a) + val bb = fromSMT(b) + (StringConcat(aa, bb) /: c.map(fromSMT(_))) { + case (s, cc) => StringConcat(s, cc) + } + + case (Strings.Substring(s, start, offset), _) => + val ss = fromSMT(s) + val tt = fromSMT(start) + val oo = fromSMT(offset) + oo match { + case Minus(otherEnd, `tt`) => SubString(ss, tt, otherEnd) + case _ => SubString(ss, tt, Plus(tt, oo)) + } + + case (Strings.At(a, b), _) => fromSMT(Strings.Substring(a, b, SNumeral(1))) + case _ => super.fromSMT(t, otpe) } @@ -138,7 +164,14 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { case SetDifference(a, b) => Sets.Setminus(toSMT(a), toSMT(b)) case SetUnion(a, b) => Sets.Union(toSMT(a), toSMT(b)) case SetIntersection(a, b) => Sets.Intersection(toSMT(a), toSMT(b)) - + case StringLiteral(v) => + declareSort(StringType) + Strings.StringLit(v) + case StringLength(a) => Strings.Length(toSMT(a)) + case StringConcat(a, b) => Strings.Concat(toSMT(a), toSMT(b)) + case SubString(a, start, Plus(start2, length)) if start == start2 => + Strings.Substring(toSMT(a),toSMT(start),toSMT(length)) + case SubString(a, start, end) => Strings.Substring(toSMT(a),toSMT(start),toSMT(Minus(end, start))) case _ => super.toSMT(e) } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index b97297fb0..47017bcf1 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -241,7 +241,6 @@ trait SMTLIBTarget extends Interruptible { case RealType => Reals.RealSort() case Int32Type => FixedSizeBitVectors.BitVectorSort(32) case CharType => FixedSizeBitVectors.BitVectorSort(32) - case StringType => Strings.StringSort() case RawArrayType(from, to) => Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(from), declareSort(to))) @@ -379,9 +378,6 @@ trait SMTLIBTarget extends Interruptible { case FractionalLiteral(n, d) => Reals.Div(Reals.NumeralLit(n), Reals.NumeralLit(d)) case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt)) case BooleanLiteral(v) => Core.BoolConst(v) - case StringLiteral(v) => - declareSort(StringType) - Strings.StringLit(v) case Let(b, d, e) => val id = id2sym(b) val value = toSMT(d) @@ -613,12 +609,6 @@ trait SMTLIBTarget extends Interruptible { case RealMinus(a, b) => Reals.Sub(toSMT(a), toSMT(b)) case RealTimes(a, b) => Reals.Mul(toSMT(a), toSMT(b)) case RealDivision(a, b) => Reals.Div(toSMT(a), toSMT(b)) - - case StringLength(a) => Strings.Length(toSMT(a)) - case StringConcat(a, b) => Strings.Concat(toSMT(a), toSMT(b)) - case SubString(a, start, Plus(start2, length)) if start == start2 => - Strings.Substring(toSMT(a),toSMT(start),toSMT(length)) - case SubString(a, start, end) => Strings.Substring(toSMT(a),toSMT(start),toSMT(Minus(end, start))) case And(sub) => Core.And(sub.map(toSMT): _*) case Or(sub) => Core.Or(sub.map(toSMT): _*) @@ -764,31 +754,6 @@ trait SMTLIBTarget extends Interruptible { case (SNumeral(n), Some(RealType)) => FractionalLiteral(n, 1) - case (SString(v), Some(StringType)) => - StringLiteral(v) - - case (Strings.Length(a), _) => - val aa = fromSMT(a) - StringLength(aa) - - case (Strings.Concat(a, b, c @ _*), _) => - val aa = fromSMT(a) - val bb = fromSMT(b) - (StringConcat(aa, bb) /: c.map(fromSMT(_))) { - case (s, cc) => StringConcat(s, cc) - } - - case (Strings.Substring(s, start, offset), _) => - val ss = fromSMT(s) - val tt = fromSMT(start) - val oo = fromSMT(offset) - oo match { - case Minus(otherEnd, `tt`) => SubString(ss, tt, otherEnd) - case _ => SubString(ss, tt, Plus(tt, oo)) - } - - case (Strings.At(a, b), _) => fromSMT(Strings.Substring(a, b, SNumeral(1))) - case (FunctionApplication(SimpleSymbol(SSymbol("ite")), Seq(cond, thenn, elze)), t) => IfExpr( fromSMT(cond, Some(BooleanType)), diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index 3d4a06a83..506d7519d 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -8,6 +8,7 @@ import purescala.Common._ import purescala.Expressions._ import purescala.Constructors._ import purescala.Types._ +import purescala.Definitions._ import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} @@ -15,6 +16,8 @@ import _root_.smtlib.interpreters.Z3Interpreter import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} import _root_.smtlib.theories.ArraysEx +import utils.Bijection + trait SMTLIBZ3Target extends SMTLIBTarget { def targetName = "z3" @@ -69,6 +72,31 @@ trait SMTLIBZ3Target extends SMTLIBTarget { Sort(SMTIdentifier(setSort.get), Seq(declareSort(of))) } + val stringBijection = new Bijection[String, CaseClass]() + + lazy val cons = program.lookup("leon.collection.Cons") match { + case Some(cc@CaseClassDef(id, tparams, parent, _)) => cc.typed + case _ => throw new Exception("Could not find Cons in Z3 solver") + } + lazy val nil = program.lookup("leon.collection.Nil") match { + case Some(cc@CaseClassDef(id, tparams, parent, _)) => cc.typed + case _ => throw new Exception("Could not find Nil in Z3 solver") + } + lazy val list = program.lookup("leon.collection.List") match { + case Some(cc@AbstractClassDef(id, tparams, parent)) => cc.typed + case _ => throw new Exception("Could not find List in Z3 solver") + } + def extractFunDef(s: String): FunDef = program.lookup(s) match { + case Some(fd: FunDef) => fd + case _ => throw new Exception("Could not find "+s+" in Z3 solver") + } + lazy val list_size = extractFunDef("leon.collection.List.size") + lazy val list_++ = extractFunDef("leon.collection.List.++") + lazy val list_take = extractFunDef("leon.collection.List.take") + lazy val list_drop = extractFunDef("leon.collection.List.drop") + lazy val list_slice = extractFunDef("leon.collection.List.slice") + + override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { (t, otpe) match { @@ -93,6 +121,48 @@ trait SMTLIBZ3Target extends SMTLIBTarget { fromRawArray(RawArrayValue(ktpe, Map(), fromSMT(defV, vtpe)), tpe) + case (SimpleSymbol(s), Some(StringType)) if constructors.containsB(s) => + constructors.toA(s) match { + case cct: CaseClassType if cct == nil => + StringLiteral("") + case t => + unsupported(t, "woot? for a single constructor for non-case-object") + } + case (FunctionApplication(SimpleSymbol(s), args), Some(StringType)) if constructors.containsB(s) => + constructors.toA(s) match { + case cct: CaseClassType if cct == cons => + val rargs = args.zip(cct.fields.map(_.getType)).map(fromSMT) + val s = ("" /: rargs) { + case (acc, c@CharLiteral(s)) => acc + s + case _ => unsupported(cct, "Cannot extract string out of list of any") + } + StringLiteral(s) + case t => unsupported(t, "Cannot extract string") + } + + /*case (Strings.Length(a), _) => + val aa = fromSMT(a) + StringLength(aa) + + case (Strings.Concat(a, b, c @ _*), _) => + val aa = fromSMT(a) + val bb = fromSMT(b) + (StringConcat(aa, bb) /: c.map(fromSMT(_))) { + case (s, cc) => StringConcat(s, cc) + } + + case (Strings.Substring(s, start, offset), _) => + val ss = fromSMT(s) + val tt = fromSMT(start) + val oo = fromSMT(offset) + oo match { + case Minus(otherEnd, `tt`) => SubString(ss, tt, otherEnd) + case _ => SubString(ss, tt, Plus(tt, oo)) + } + + case (Strings.At(a, b), _) => fromSMT(Strings.Substring(a, b, SNumeral(1))) +*/ + case _ => super.fromSMT(t, otpe) } @@ -132,6 +202,22 @@ trait SMTLIBZ3Target extends SMTLIBTarget { case SetIntersection(l, r) => ArrayMap(SSymbol("and"), toSMT(l), toSMT(r)) + case StringLiteral(v) => + // No string support for z3 at this moment. + val stringEncoding = stringBijection.cachedB(v) { + v.toList.foldRight(CaseClass(nil, Seq())){ + case (char, l) => CaseClass(cons, Seq(CharLiteral(char), l)) + } + } + toSMT(stringEncoding) + case StringLength(a) => + toSMT(functionInvocation(list_size, Seq(a))) + case StringConcat(a, b) => + toSMT(functionInvocation(list_++, Seq(a, b))) + case SubString(a, start, Plus(start2, length)) if start == start2 => + toSMT(functionInvocation(list_take, Seq(functionInvocation(list_drop, Seq(a, start)), length))) + case SubString(a, start, end) => + toSMT(functionInvocation(list_slice, Seq(a, start, end))) case _ => super.toSMT(e) } -- GitLab