diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBCVC4Target.scala index 74bb6a83b7cbb15cdba9a3cf05740f436133b575..4e87fa6b914bdd44dedbd7bc2ba9c071e8e49893 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBCVC4Target.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBCVC4Target.scala @@ -1,16 +1,10 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package solvers package smtlib -import purescala.Common._ -import purescala.Expressions._ -import purescala.Constructors._ -import purescala.Extractors._ -import purescala.Types._ - -import org.apache.commons.lang3.StringEscapeUtils; +import org.apache.commons.lang3.StringEscapeUtils import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Forall => SMTForall, _} import _root_.smtlib.parser.Commands._ @@ -19,15 +13,18 @@ import _root_.smtlib.theories.experimental.Sets import _root_.smtlib.theories.experimental.Strings trait SMTLIBCVC4Target extends SMTLIBTarget { + import program._ + import trees._ + import symbols._ - override def getNewInterpreter(ctx: LeonContext) = { + override def getNewInterpreter(ctx: InoxContext) = { val opts = interpreterOps(ctx) - context.reporter.debug("Invoking solver with "+opts.mkString(" ")) + ctx.reporter.debug("Invoking solver with "+opts.mkString(" ")) new CVC4Interpreter("cvc4", opts.toArray) } - override protected def declareSort(t: TypeTree): Sort = { + override protected def declareSort(t: Type): Sort = { val tpe = normalizeType(t) sorts.cachedB(tpe) { tpe match { @@ -40,63 +37,62 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { } } - override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) + override protected def fromSMT(t: Term, otpe: Option[Type] = None) (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { (t, otpe) match { // EK: This hack is necessary for sygus which does not strictly follow smt-lib for negative literals case (SimpleSymbol(SSymbol(v)), Some(IntegerType)) if v.startsWith("-") => try { - InfiniteIntegerLiteral(v.toInt) + IntegerLiteral(v.toInt) } catch { case _: Throwable => super.fromSMT(t, otpe) } case (QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset"), Seq()), _), Some(SetType(base))) => - FiniteSet(Set(), base) + FiniteSet(Seq(), base) case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), Some(tpe)) => tpe match { - case RawArrayType(k, v) => - RawArrayValue(k, Map(), fromSMT(elem, v)) - case ft @ FunctionType(from, to) => - FiniteLambda(Seq.empty, fromSMT(elem, to), ft) + finiteLambda(Seq.empty, fromSMT(elem, to), ft) case MapType(k, v) => - FiniteMap(Map(), k, v) + FiniteMap(Seq(), fromSMT(elem, v), k) } case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), Some(tpe)) => tpe match { - case RawArrayType(k, v) => - RawArrayValue(k, Map(), fromSMT(elem, v)) - case ft @ FunctionType(from, to) => - FiniteLambda(Seq.empty, fromSMT(elem, to), ft) + finiteLambda(Seq.empty, fromSMT(elem, to), ft) case MapType(k, v) => - FiniteMap(Map(), k, v) + FiniteMap(Seq(), fromSMT(elem, v), k) } case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), Some(tpe)) => tpe match { - case RawArrayType(_, v) => - val RawArrayValue(k, elems, base) = fromSMT(arr, otpe) - RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) - case FunctionType(from, v) => - val FiniteLambda(mapping, dflt, ft) = fromSMT(arr, otpe) - val args = unwrapTuple(fromSMT(key, tupleTypeWrap(from)), from.size) - FiniteLambda(mapping :+ (args -> fromSMT(elem, v)), dflt, ft) - - case MapType(k, v) => - val FiniteMap(elems, k, v) = fromSMT(arr, otpe) - FiniteMap(elems + (fromSMT(key, k) -> fromSMT(elem, v)), k, v) + val Lambda(args, bd) = fromSMT(arr, otpe) + Lambda(args, IfExpr( + Equals( + tupleWrap(args.map(_.toVariable)), + fromSMT(key, tupleTypeWrap(from)) + ), + fromSMT(elem, v), + bd + )) + + case MapType(kT, vT) => + val FiniteMap(elems, default, _) = fromSMT(arr, otpe) + val newKey = fromSMT(key, kT) + val newV = fromSMT(elem, vT) + val newElems = elems.filterNot(_._1 == newKey) :+ (newKey -> newV) + FiniteMap(newElems, default, kT) } case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), Some(SetType(base))) => - FiniteSet(elems.map(fromSMT(_, base)).toSet, base) + FiniteSet(elems.map(fromSMT(_, base)), base) case (FunctionApplication(SimpleSymbol(SSymbol("insert")), elems), Some(SetType(base))) => val selems = elems.init.map(fromSMT(_, base)) @@ -106,18 +102,14 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { case (FunctionApplication(SimpleSymbol(SSymbol("union")), elems), Some(SetType(base))) => FiniteSet(elems.flatMap(fromSMT(_, otpe) match { case FiniteSet(elems, _) => elems - }).toSet, base) + }), base) case (SString(v), Some(StringType)) => StringLiteral(StringEscapeUtils.unescapeJava(v)) - - case (Strings.Length(a), Some(Int32Type)) => - val aa = fromSMT(a) - StringLength(aa) - + case (Strings.Length(a), Some(IntegerType)) => val aa = fromSMT(a) - StringBigLength(aa) + StringLength(aa) case (Strings.Concat(a, b, c @ _*), _) => val aa = fromSMT(a) @@ -131,14 +123,9 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { val tt = fromSMT(start) val oo = fromSMT(offset) oo match { - case BVMinus(otherEnd, `tt`) => SubString(ss, tt, otherEnd) - case Minus(otherEnd, `tt`) => BigSubString(ss, tt, otherEnd) - case _ => - if(tt.getType == IntegerType) { - BigSubString(ss, tt, Plus(tt, oo)) - } else { - SubString(ss, tt, BVPlus(tt, oo)) - } + case Minus(otherEnd, `tt`) => SubString(ss, tt, otherEnd) + case _ => + SubString(ss, tt, Plus(tt, oo)) } case (Strings.At(a, b), _) => fromSMT(Strings.Substring(a, b, SNumeral(1))) @@ -156,7 +143,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { if (elems.isEmpty) { Sets.EmptySet(declareSort(fs.getType)) } else { - val selems = elems.toSeq.map(toSMT) + val selems = elems.map(toSMT) val sgt = Sets.Singleton(selems.head) @@ -176,14 +163,10 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { declareSort(StringType) Strings.StringLit(StringEscapeUtils.escapeJava(v)) case StringLength(a) => Strings.Length(toSMT(a)) - case StringBigLength(a) => Strings.Length(toSMT(a)) case StringConcat(a, b) => Strings.Concat(toSMT(a), toSMT(b)) - case SubString(a, start, BVPlus(start2, length)) if start == start2 => - Strings.Substring(toSMT(a),toSMT(start),toSMT(length)) - case SubString(a, start, end) => Strings.Substring(toSMT(a),toSMT(start),toSMT(BVMinus(end, start))) - case BigSubString(a, start, Plus(start2, length)) if start == start2 => + case SubString(a, start, Plus(start2, length)) if start == start2 => Strings.Substring(toSMT(a),toSMT(start),toSMT(length)) - case BigSubString(a, start, end) => Strings.Substring(toSMT(a),toSMT(start),toSMT(Minus(end, start))) + case SubString(a, start, end) => Strings.Substring(toSMT(a),toSMT(start),toSMT(Minus(end, start))) case _ => super.toSMT(e) }