diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBZ3Target.scala index eb55f9b60ea47bc89f883ce9d4c5a18d0456eaa9..bf2cf05e2ad68fddec9f068fffbed3b6b6242baf 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBZ3Target.scala @@ -1,14 +1,9 @@ /* Copyright 2009-2016 EPFL, Lausanne */ -package leon +package inox package solvers package smtlib -import purescala.Common._ -import purescala.Expressions._ -import purescala.Constructors._ -import purescala.Types._ - import _root_.smtlib.common._ import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Let => SMTLet, _} import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} @@ -18,18 +13,22 @@ import _root_.smtlib.theories._ trait SMTLIBZ3Target extends SMTLIBTarget { + import program._ + import trees._ + import symbols._ + def targetName = "z3" - def interpreterOps(ctx: LeonContext) = { + def interpreterOps(ctx: InoxContext) = { Seq( "-in", "-smt2" ) } - def getNewInterpreter(ctx: LeonContext) = { + def getNewInterpreter(ctx: InoxContext) = { val opts = interpreterOps(ctx) - context.reporter.debug("Invoking solver "+targetName+" with "+opts.mkString(" ")) + ctx.reporter.debug("Invoking solver "+targetName+" with "+opts.mkString(" ")) new Z3Interpreter("z3", opts.toArray) } @@ -38,7 +37,7 @@ trait SMTLIBZ3Target extends SMTLIBTarget { protected var setSort: Option[SSymbol] = None - override protected def declareSort(t: TypeTree): Sort = { + override protected def declareSort(t: Type): Sort = { val tpe = normalizeType(t) sorts.cachedB(tpe) { tpe match { @@ -46,14 +45,14 @@ trait SMTLIBZ3Target extends SMTLIBTarget { super.declareSort(BooleanType) declareSetSort(base) case BagType(base) => - declareSort(RawArrayType(base, IntegerType)) + declareSort(MapType(base, IntegerType)) case _ => super.declareSort(t) } } } - protected def declareSetSort(of: TypeTree): Sort = { + protected def declareSetSort(of: Type): Sort = { setSort match { case None => val s = SSymbol("Set") @@ -72,15 +71,28 @@ trait SMTLIBZ3Target extends SMTLIBTarget { Sort(SMTIdentifier(setSort.get), Seq(declareSort(of))) } - override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) + override protected def fromSMT(t: Term, otpe: Option[Type] = None) (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + + (t, otpe) match { case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe)) => if (letDefs contains k) { + val MapType(keyType, valueType) = tpe + val DefineFun(SMTFunDef(a, Seq(SortedVar(arg, akind)), rkind, body)) = letDefs(k) + + def extractCases(e: Term): (Map[Expr, Expr], Expr) = e match { + case ITE(SMTEquals(SimpleSymbol(`arg`), k), v, e) => + val (cs, d) = extractCases(e) + (Map(fromSMT(k, keyType) -> fromSMT(v, valueType)) ++ cs, d) + case e => + (Map(),fromSMT(e, valueType)) + } // Need to recover value form function model - fromRawArray(extractRawArray(letDefs(k), tpe), tpe) + val (cases, default) = extractCases(body) + FiniteMap(cases.toSeq, default, keyType) } else { - throw LeonFatalError("Array on non-function or unknown symbol "+k) + throw FatalError("Array on non-function or unknown symbol "+k) } case (FunctionApplication( @@ -89,8 +101,7 @@ trait SMTLIBZ3Target extends SMTLIBTarget { ), Some(tpe)) => val ktpe = sorts.fromB(k) val vtpe = sorts.fromB(v) - - fromRawArray(RawArrayValue(ktpe, Map(), fromSMT(defV, vtpe)), tpe) + FiniteMap(Seq(), fromSMT(defV), ktpe) case _ => super.fromSMT(t, otpe) @@ -103,11 +114,9 @@ trait SMTLIBZ3Target extends SMTLIBTarget { * ===== Set operations ===== */ case fs @ FiniteSet(elems, base) => - declareSort(fs.getType) - - toSMT(RawArrayValue(base, elems.map { - case k => k -> BooleanLiteral(true) - }.toMap, BooleanLiteral(false))) + val st @ SetType(base) = fs.getType + declareSort(st) + toSMT(FiniteMap(elems map ((_, BooleanLiteral(true))), BooleanLiteral(false), BooleanType)) case SubsetOf(ss, s) => // a isSubset b ==> (a zip b).map(implies) == (* => true) @@ -135,13 +144,13 @@ trait SMTLIBZ3Target extends SMTLIBTarget { ArrayMap(SSymbol("and"), toSMT(l), toSMT(r)) case fb @ FiniteBag(elems, base) => - declareSort(fb.getType) - - toSMT(RawArrayValue(base, elems, InfiniteIntegerLiteral(0))) + val BagType(t) = fb.getType + declareSort(BagType(t)) + toSMT(FiniteMap(elems, IntegerLiteral(0), t)) case BagAdd(b, e) => - val bid = FreshIdentifier("b", b.getType, true) - val eid = FreshIdentifier("e", e.getType, true) + val bid = FreshIdentifier("b", true) + val eid = FreshIdentifier("e", true) val (bSym, eSym) = (id2sym(bid), id2sym(eid)) SMTLet(VarBinding(bSym, toSMT(b)), Seq(VarBinding(eSym, toSMT(e))), ArraysEx.Store(bSym, eSym, Ints.Add(ArraysEx.Select(bSym, eSym), Ints.NumeralLit(1)))) @@ -162,7 +171,7 @@ trait SMTLIBZ3Target extends SMTLIBTarget { val minus = SortedSymbol("-", List(IntegerType, IntegerType).map(declareSort), declareSort(IntegerType)) val div = SortedSymbol("/", List(IntegerType, IntegerType).map(declareSort), declareSort(IntegerType)) - val did = FreshIdentifier("d", b1.getType, true) + val did = FreshIdentifier("d", true) val dSym = id2sym(did) val all2 = ArrayConst(declareSort(IntegerType), Ints.NumeralLit(2)) @@ -174,33 +183,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget { super.toSMT(e) } - protected def extractRawArray(s: DefineFun, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): RawArrayValue = s match { - case DefineFun(SMTFunDef(a, Seq(SortedVar(arg, akind)), rkind, body)) => - val (argTpe, retTpe) = tpe match { - case SetType(base) => (base, BooleanType) - case MapType(from, to) => (from, library.optionType(to)) - case ArrayType(base) => (Int32Type, base) - case FunctionType(args, ret) => (tupleTypeWrap(args), ret) - case RawArrayType(from, to) => (from, to) - case _ => unsupported(tpe, "Unsupported type for (un)packing into raw arrays (got kinds "+akind+" -> "+rkind+")") - } - - def extractCases(e: Term): (Map[Expr, Expr], Expr) = e match { - case ITE(SMTEquals(SimpleSymbol(`arg`), k), v, e) => - val (cs, d) = extractCases(e) - (Map(fromSMT(k, argTpe) -> fromSMT(v, retTpe)) ++ cs, d) - case e => - (Map(),fromSMT(e, retTpe)) - } - - val (cases, default) = extractCases(body) - - RawArrayValue(argTpe, cases, default) - - case _ => - throw LeonFatalError("Unable to extract "+s) - } - protected object SortedSymbol { def apply(op: String, from: List[Sort], to: Sort) = { def simpleSort(s: Sort): Boolean = s.subSorts.isEmpty && !s.id.isIndexed @@ -212,7 +194,7 @@ trait SMTLIBZ3Target extends SMTLIBTarget { protected object ArrayMap { def apply(op: SExpr, arrs: Term*) = { FunctionApplication( - QualifiedIdentifier(SMTIdentifier(SSymbol("map"), List(op))), + QualifiedIdentifier(SMTIdentifier(SSymbol("map"), List(op)), None), arrs ) }