diff --git a/src/main/scala/leon/solvers/ADTManager.scala b/src/main/scala/leon/solvers/ADTManager.scala new file mode 100644 index 0000000000000000000000000000000000000000..7987ca439bc1553372b860ce727560e5a624df6a --- /dev/null +++ b/src/main/scala/leon/solvers/ADTManager.scala @@ -0,0 +1,106 @@ +package leon +package solvers + +import purescala.Types._ +import purescala.Common._ + +case class DataType(sym: Identifier, cases: Seq[Constructor]) +case class Constructor(sym: Identifier, tpe: TypeTree, fields: Seq[(Identifier, TypeTree)]) + +class ADTManager { + protected def freshId(id: Identifier): Identifier = freshId(id.name) + protected def freshId(name: String): Identifier = FreshIdentifier(name) + + protected def getHierarchy(ct: ClassType): (ClassType, Seq[CaseClassType]) = ct match { + case act: AbstractClassType => + (act, act.knownCCDescendents) + case cct: CaseClassType => + cct.parent match { + case Some(p) => + getHierarchy(p) + case None => + (cct, List(cct)) + } + } + + protected var defined = Set[TypeTree]() + + def defineADT(t: TypeTree): Map[TypeTree, DataType] = { + val adts = findDependencies(t) + for ((t, dt) <- adts) { + defined += t + } + adts + } + + protected def findDependencies(t: TypeTree, dts: Map[TypeTree, DataType] = Map()): Map[TypeTree, DataType] = t match { + case ct: ClassType => + val (root, sub) = getHierarchy(ct) + + if (!(dts contains root) && !(defined contains root)) { + val sym = freshId(ct.id) + + val conss = sub.map { case cct => + Constructor(freshId(cct.id), cct, cct.fields.map(vd => (freshId(vd.id), vd.getType))) + } + + var cdts = dts + (root -> DataType(sym, conss)) + + // look for dependencies + for (ct <- root +: sub; f <- ct.fields) { + cdts ++= findDependencies(f.getType, cdts) + } + + cdts + } else { + dts + } + + case tt @ TupleType(bases) => + if (!(dts contains t) && !(defined contains t)) { + val sym = freshId("tuple"+bases.size) + + val c = Constructor(freshId(sym.name), tt, bases.zipWithIndex.map { + case (tpe, i) => (freshId("_"+(i+1)), tpe) + }) + + var cdts = dts + (tt -> DataType(sym, Seq(c))) + + for (b <- bases) { + cdts ++= findDependencies(b, cdts) + } + cdts + } else { + dts + } + + case UnitType => + if (!(dts contains t) && !(defined contains t)) { + + val sym = freshId("Unit") + + dts + (t -> DataType(sym, Seq(Constructor(freshId(sym.name), t, Nil)))) + } else { + dts + } + + case at @ ArrayType(base) => + if (!(dts contains t) && !(defined contains t)) { + val sym = freshId("array") + + val c = Constructor(freshId(sym.name), at, List( + (freshId("size"), Int32Type), + (freshId("content"), RawArrayType(Int32Type, base)) + )) + + val cdts = dts + (at -> DataType(sym, Seq(c))) + + findDependencies(base, cdts) + } else { + dts + } + + case _ => + dts + } +} diff --git a/src/main/scala/leon/solvers/RawArray.scala b/src/main/scala/leon/solvers/RawArray.scala new file mode 100644 index 0000000000000000000000000000000000000000..85c104ae7a39e2ee9051a62861daed44b93a9cd4 --- /dev/null +++ b/src/main/scala/leon/solvers/RawArray.scala @@ -0,0 +1,13 @@ +package leon +package solvers + +import purescala.Types._ +import purescala.Expressions._ + +// Corresponds to a smt map, not a leon/scala array +private[solvers] case class RawArrayType(from: TypeTree, to: TypeTree) extends TypeTree + +// Corresponds to a raw array value, which is coerced to a Leon expr depending on target type (set/array) +private[solvers] case class RawArrayValue(keyTpe: TypeTree, elems: Map[Expr, Expr], default: Expr) extends Expr { + val getType = RawArrayType(keyTpe, default.getType) +} diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala index 1b91684774ba01cc5e794ed25dfd57fd60800704..f43e4bd831c16cb87a1eb671261e438ee5b5e33c 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala @@ -114,38 +114,6 @@ class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSol } override def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]) = e match { - case a @ FiniteArray(elems, default, size) => - val tpe @ ArrayType(base) = normalizeType(a.getType) - declareSort(tpe) - - var ar: Term = declareVariable(FreshIdentifier("arrayconst", RawArrayType(Int32Type, base))) - - for ((i, e) <- elems) { - ar = FunctionApplication(SSymbol("store"), Seq(ar, toSMT(IntLiteral(i)), toSMT(e))) - } - - FunctionApplication(constructors.toB(tpe), Seq(toSMT(size), ar)) - - case fm @ FiniteMap(elems, from, to) => - import OptionManager._ - declareSort(MapType(from, to)) - - var m: Term = declareVariable(FreshIdentifier("mapconst", RawArrayType(from, leonOptionType(to)))) - - sendCommand(Assert(SMTForall( - SortedVar(SSymbol("mapelem"), declareSort(from)), Seq(), - Core.Equals( - ArraysEx.Select(m, SSymbol("mapelem")), - toSMT(mkLeonNone(to)) - ) - ))) - - for ((k, v) <- elems) { - m = FunctionApplication(SSymbol("store"), Seq(m, toSMT(k), toSMT(mkLeonSome(v)))) - } - - m - /** * ===== Set operations ===== */ diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 0ac1f11e6b173fc6be0ab5b0cb297a3e2af4c7d4..85f09790298b1fa951f89186efa556b8c5406c4f 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -93,11 +93,16 @@ abstract class SMTLIBSolver(val context: LeonContext, case _ => None } } + import scala.language.implicitConversions protected implicit def symbolToQualifiedId(s: SSymbol): QualifiedIdentifier = { QualifiedIdentifier(SMTIdentifier(s)) } + val adtManager = new ADTManager + + val library = program.library + protected def id2sym(id: Identifier): SSymbol = SSymbol(id.name+"!"+id.globalId) protected def freshSym(id: Identifier): SSymbol = freshSym(id.name) @@ -116,58 +121,8 @@ abstract class SMTLIBSolver(val context: LeonContext, protected def hasError = errors.getB(()) contains true protected def addError() = errors += () -> true - /* A manager object for the Option type (since it is hard-coded for Maps) */ - protected object OptionManager { - lazy val leonOption = program.library.Option.get - lazy val leonSome = program.library.Some.get - lazy val leonNone = program.library.None.get - def leonOptionType(tp: TypeTree) = AbstractClassType(leonOption, Seq(tp)) - - def mkLeonSome(e: Expr) = CaseClass(CaseClassType(leonSome, Seq(e.getType)), Seq(e)) - def mkLeonNone(tp: TypeTree) = CaseClass(CaseClassType(leonNone, Seq(tp)), Seq()) - - def someTester(tp: TypeTree): SSymbol = { - val someTp = CaseClassType(leonSome, Seq(tp)) - testers.getB(someTp) match { - case Some(s) => s - case None => - declareOptionSort(tp) - someTester(tp) - } - } - def someConstructor(tp: TypeTree): SSymbol = { - val someTp = CaseClassType(leonSome, Seq(tp)) - constructors.getB(someTp) match { - case Some(s) => s - case None => - declareOptionSort(tp) - someConstructor(tp) - } - } - def someSelector(tp: TypeTree): SSymbol = { - val someTp = CaseClassType(leonSome, Seq(tp)) - selectors.getB(someTp,0) match { - case Some(s) => s - case None => - declareOptionSort(tp) - someSelector(tp) - } - } - - def inlinedOptionGet(t : Term, tp: TypeTree): Term = { - FunctionApplication(SSymbol("ite"), Seq( - FunctionApplication(someTester(tp), Seq(t)), - FunctionApplication(someSelector(tp), Seq(t)), - declareVariable(FreshIdentifier("error_value", tp)) - )) - } - - } - - /* Helper functions */ - protected def normalizeType(t: TypeTree): TypeTree = t match { case ct: ClassType if ct.parent.isDefined => ct.parent.get case tt: TupleType => tupleTypeWrap(tt.bases.map(normalizeType)) @@ -196,16 +151,6 @@ abstract class SMTLIBSolver(val context: LeonContext, protected def quantifiedTerm(quantifier: (SortedVar, Seq[SortedVar], Term) => Term, body: Expr): Term = quantifiedTerm(quantifier, variablesOf(body).toSeq, body) - // Corresponds to a smt map, not a leon/scala array - // Should NEVER escape past SMT-world - private[smtlib] case class RawArrayType(from: TypeTree, to: TypeTree) extends TypeTree - - // Corresponds to a raw array value, which is coerced to a Leon expr depending on target type (set/array) - // Should NEVER escape past SMT-world - private[smtlib] case class RawArrayValue(keyTpe: TypeTree, elems: Map[Expr, Expr], default: Expr) extends Expr { - val getType = RawArrayType(keyTpe, default.getType) - } - protected def fromRawArray(r: RawArrayValue, tpe: TypeTree): Expr = tpe match { case SetType(base) => if (r.default != BooleanLiteral(false)) { @@ -225,7 +170,7 @@ abstract class SMTLIBSolver(val context: LeonContext, case MapType(from, to) => // We expect a RawArrayValue with keys in from and values in Option[to], // with default value == None - if (r.default != OptionManager.mkLeonNone(to)) { + if (r.default.getType != library.noneType(to)) { reporter.warning("Co-finite maps are not supported. (Default was "+r.default+")") throw new IllegalArgumentException } @@ -256,7 +201,7 @@ abstract class SMTLIBSolver(val context: LeonContext, Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(from), declareSort(to))) case MapType(from, to) => - declareMapSort(from, to) + declareSort(RawArrayType(from, library.optionType(to))) case FunctionType(from, to) => Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(tupleTypeWrap(from)), declareSort(to))) @@ -276,133 +221,27 @@ abstract class SMTLIBSolver(val context: LeonContext, } } - protected def declareOptionSort(of: TypeTree): Sort = { - declareSort(OptionManager.leonOptionType(of)) - } - - protected def declareMapSort(from: TypeTree, to: TypeTree): Sort = { - sorts.cachedB(MapType(from, to)) { - val m = freshSym("Map") - - val toSort = declareOptionSort(to) - val fromSort = declareSort(from) - - val arraySort = Sort(SMTIdentifier(SSymbol("Array")), - Seq(fromSort, toSort)) - val cmd = DefineSort(m, Seq(), arraySort) - - sendCommand(cmd) - Sort(SMTIdentifier(m), Seq()) - } - } - - protected def getHierarchy(ct: ClassType): (ClassType, Seq[CaseClassType]) = ct match { - case act: AbstractClassType => - (act, act.knownCCDescendents) - case cct: CaseClassType => - cct.parent match { - case Some(p) => - getHierarchy(p) - case None => - (cct, List(cct)) - } - } - - protected case class DataType(sym: SSymbol, cases: Seq[Constructor]) - protected case class Constructor(sym: SSymbol, tpe: TypeTree, fields: Seq[(SSymbol, TypeTree)]) - - protected def findDependencies(t: TypeTree, dts: Map[TypeTree, DataType] = Map()): Map[TypeTree, DataType] = t match { - case ct: ClassType => - val (root, sub) = getHierarchy(ct) - - if (!(dts contains root) && !(sorts containsA root)) { - val sym = freshSym(ct.id) - - val conss = sub.map { case cct => - Constructor(freshSym(cct.id), cct, cct.fields.map(vd => (freshSym(vd.id), vd.getType))) - } - - var cdts = dts + (root -> DataType(sym, conss)) - - // look for dependencies - for (ct <- root +: sub; f <- ct.fields) { - cdts ++= findDependencies(f.getType, cdts) - } - - cdts - } else { - dts - } - - case tt @ TupleType(bases) => - if (!(dts contains t) && !(sorts containsA t)) { - val sym = freshSym("tuple"+bases.size) - - val c = Constructor(freshSym(sym.name), tt, bases.zipWithIndex.map { - case (tpe, i) => (freshSym("_"+(i+1)), tpe) - }) - - var cdts = dts + (tt -> DataType(sym, Seq(c))) - - for (b <- bases) { - cdts ++= findDependencies(b, cdts) - } - cdts - } else { - dts - } - - case UnitType => - if (!(dts contains t) && !(sorts containsA t)) { - - val sym = freshSym("Unit") - - dts + (t -> DataType(sym, Seq(Constructor(freshSym(sym.name), t, Nil)))) - } else { - dts - } - - case at @ ArrayType(base) => - if (!(dts contains t) && !(sorts containsA t)) { - val sym = freshSym("array") - - val c = Constructor(freshSym(sym.name), at, List( - (freshSym("size"), Int32Type), - (freshSym("content"), RawArrayType(Int32Type, base)) - )) - - val cdts = dts + (at -> DataType(sym, Seq(c))) - - findDependencies(base, cdts) - } else { - dts - } - - case _ => - dts - } - protected def declareDatatypes(datatypes: Map[TypeTree, DataType]): Unit = { // We pre-declare ADTs for ((tpe, DataType(sym, _)) <- datatypes) { - sorts += tpe -> Sort(SMTIdentifier(sym)) + sorts += tpe -> Sort(SMTIdentifier(id2sym(sym))) } def toDecl(c: Constructor): SMTConstructor = { - val s = c.sym + val s = id2sym(c.sym) testers += c.tpe -> SSymbol("is-"+s.name) constructors += c.tpe -> s SMTConstructor(s, c.fields.zipWithIndex.map { case ((cs, t), i) => - selectors += (c.tpe, i) -> cs - (cs, declareSort(t)) + selectors += (c.tpe, i) -> id2sym(cs) + (id2sym(cs), declareSort(t)) }) } val adts = for ((tpe, DataType(sym, cases)) <- datatypes.toList) yield { - (sym, cases.map(toDecl)) + (id2sym(sym), cases.map(toDecl)) } @@ -412,7 +251,7 @@ abstract class SMTLIBSolver(val context: LeonContext, protected def declareStructuralSort(t: TypeTree): Sort = { // Populates the dependencies of the structural type to define. - val datatypes = findDependencies(t) + val datatypes = adtManager.defineADT(t) declareDatatypes(datatypes) @@ -445,42 +284,6 @@ abstract class SMTLIBSolver(val context: LeonContext, } } - protected def declareMapUnion(from: TypeTree, to: TypeTree): SSymbol = { - // FIXME cache results - val a = declareSort(from) - val b = declareSort(OptionManager.leonOptionType(to)) - val arraySort = Sort(SMTIdentifier(SSymbol("Array")), Seq(a, b)) - - val f = freshSym("map_union") - - sendCommand(DeclareFun(f, Seq(arraySort, arraySort), arraySort)) - - val v = SSymbol("v") - val a1 = SSymbol("a1") - val a2 = SSymbol("a2") - - val axiom = SMTForall( - SortedVar(a1, arraySort), Seq(SortedVar(a2, arraySort), SortedVar(v,a)), - Core.Equals( - ArraysEx.Select( - FunctionApplication(f: QualifiedIdentifier, Seq(a1: Term, a2: Term)), - v: Term - ), - Core.ITE( - FunctionApplication( - OptionManager.someTester(to), - Seq(ArraysEx.Select(a2: Term, v: Term)) - ), - ArraysEx.Select(a2,v), - ArraysEx.Select(a1,v) - ) - ) - ) - - sendCommand(SMTAssert(axiom)) - f - } - /* Translate a Leon Expr to an SMTLIB term */ protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = { @@ -568,49 +371,69 @@ abstract class SMTLIBSolver(val context: LeonContext, val constructor = constructors.toB(tpe) FunctionApplication(constructor, Seq(ssize, newcontent)) - /** - * ===== Map operations ===== - */ - case m @ FiniteMap(elems, _, _) => - import OptionManager._ - val mt @ MapType(_, to) = m.getType - val ms = declareSort(mt) + case ra @ RawArrayValue(keyTpe, elems, default) => + val s = declareSort(ra.getType) var res: Term = FunctionApplication( - QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(ms)), - List(toSMT(mkLeonNone(to))) + QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(s)), + List(toSMT(default)) ) for ((k, v) <- elems) { - res = ArraysEx.Store(res, toSMT(k), toSMT(mkLeonSome(v))) + res = ArraysEx.Store(res, toSMT(k), toSMT(v)) } res + case a @ FiniteArray(elems, oDef, size) => + val tpe @ ArrayType(to) = normalizeType(a.getType) + declareSort(tpe) + + val default: Expr = oDef.getOrElse(simplestValue(to)) + + val arr = toSMT(RawArrayValue(Int32Type, elems.map { + case (k, v) => IntLiteral(k) -> v + }, default)) + + FunctionApplication(constructors.toB(tpe), List(toSMT(size), arr)) + + /** + * ===== Map operations ===== + */ + case m @ FiniteMap(elems, _, _) => + val mt @ MapType(from, to) = m.getType + val ms = declareSort(mt) + + toSMT(RawArrayValue(from, elems.map { + case (k, v) => k -> CaseClass(library.someType(to), Seq(v)) + }.toMap, CaseClass(library.noneType(to), Seq()))) + + case MapGet(m, k) => - import OptionManager._ - val mt@MapType(_, vt) = m.getType + val mt @ MapType(from, to) = m.getType declareSort(mt) // m(k) becomes - // (Option$get (select m k)) - inlinedOptionGet(ArraysEx.Select(toSMT(m), toSMT(k)), vt) + // (Some-value (select m k)) + FunctionApplication( + selectors.toB((library.someType(to), 0)), + Seq(ArraysEx.Select(toSMT(m), toSMT(k))) + ) case MapIsDefinedAt(m, k) => - import OptionManager._ - val mt@MapType(_, vt) = m.getType + val mt @ MapType(from, to) = m.getType declareSort(mt) // m.isDefinedAt(k) becomes - // (Option$isDefined (select m k)) + // (is-Some (select m k)) FunctionApplication( - someTester(vt), + testers.toB(library.someType(to)), Seq(ArraysEx.Select(toSMT(m), toSMT(k))) ) - case MapUnion(m1, m2) => - val MapType(vk, vt) = m1.getType - FunctionApplication( - declareMapUnion(vk, vt), - Seq(toSMT(m1), toSMT(m2)) - ) + case MapUnion(m1, FiniteMap(elems, _, _)) => + val mt @ MapType(f, t) = m1.getType + + elems.foldLeft(toSMT(m1)) { case (m, (k,v)) => + ArraysEx.Store(m, toSMT(k), toSMT(CaseClass(library.someType(t), Seq(v)))) + } case p : Passes => toSMT(matchToIfThenElse(p.asConstraint)) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala index cc6e2161c3ee3d7c1ebd0ee27546a13dfe56f403..3a7b345fb4d0c4d8664c90a2904ca487bb1199fd 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala @@ -96,35 +96,16 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve } override def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match { - case a @ FiniteArray(elems, oDef, size) => - val tpe @ ArrayType(base) = normalizeType(a.getType) - declareSort(tpe) - - val default: Expr = oDef.getOrElse(simplestValue(base)) - - var ar: Term = ArrayConst(declareSort(RawArrayType(Int32Type, base)), toSMT(default)) - - for ((i, e) <- elems) { - ar = ArraysEx.Store(ar, toSMT(IntLiteral(i)), toSMT(e)) - } - - FunctionApplication(constructors.toB(tpe), List(toSMT(size), ar)) /** * ===== Set operations ===== */ case fs @ FiniteSet(elems, base) => val ss = declareSort(fs.getType) - var res: Term = FunctionApplication( - QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(ss)), - Seq(toSMT(BooleanLiteral(false))) - ) - - for (e <- elems) { - res = ArraysEx.Store(res, toSMT(e), toSMT(BooleanLiteral(true))) - } - res + toSMT(RawArrayValue(base, elems.map { + case k => k -> BooleanLiteral(true) + }.toMap, BooleanLiteral(false))) case SubsetOf(ss, s) => // a isSubset b ==> (a zip b).map(implies) == (* => true) @@ -156,7 +137,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve case DefineFun(SMTFunDef(a, Seq(SortedVar(arg, akind)), rkind, body)) => val (argTpe, retTpe) = tpe match { case SetType(base) => (base, BooleanType) - case MapType(from, to) => (from, OptionManager.leonOptionType(to)) + case MapType(from, to) => (from, library.optionType(to)) case ArrayType(base) => (Int32Type, base) case FunctionType(args, ret) => (tupleTypeWrap(args), ret) case RawArrayType(from, to) => (from, to) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index e036c86eba71fdd38cfeface511816fc0837e808..1ded84a6fcd702cb5c448650e236c6d942de5611 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -30,6 +30,8 @@ trait AbstractZ3Solver val context : LeonContext val program : Program + val library = program.library + protected[z3] val reporter : Reporter = context.reporter context.interruptManager.registerForInterrupts(this) @@ -131,55 +133,24 @@ trait AbstractZ3Solver def containsZ3(b: B): Boolean = z3ToLeon contains b } + // ADT Manager + protected[leon] val adtManager = new ADTManager + // Bijections between Leon Types/Functions/Ids to Z3 Sorts/Decls/ASTs protected[leon] var functions = new Bijection[TypedFunDef, Z3FuncDecl] protected[leon] var generics = new Bijection[GenericValue, Z3FuncDecl] protected[leon] var sorts = new Bijection[TypeTree, Z3Sort] protected[leon] var variables = new Bijection[Expr, Z3AST] - // Meta decls and information used by several sorts - case class ArrayDecls(cons: Z3FuncDecl, select: Z3FuncDecl, length: Z3FuncDecl) - case class TupleDecls(cons: Z3FuncDecl, selects: Seq[Z3FuncDecl]) - - protected[leon] var unitValue: Z3AST = null - protected[leon] var intSetMinFun: Z3FuncDecl = null - protected[leon] var intSetMaxFun: Z3FuncDecl = null - - protected[leon] var arrayMetaDecls: Map[TypeTree, ArrayDecls] = Map.empty - protected[leon] var tupleMetaDecls: Map[TypeTree, TupleDecls] = Map.empty - protected[leon] var setCardDecls: Map[TypeTree, Z3FuncDecl] = Map.empty - - protected[leon] var adtTesters: Map[CaseClassType, Z3FuncDecl] = Map.empty - protected[leon] var adtConstructors: Map[CaseClassType, Z3FuncDecl] = Map.empty - protected[leon] var adtFieldSelectors: Map[(CaseClassType, Identifier), Z3FuncDecl] = Map.empty - - protected[leon] var reverseADTTesters: Map[Z3FuncDecl, CaseClassType] = Map.empty - protected[leon] var reverseADTConstructors: Map[Z3FuncDecl, CaseClassType] = Map.empty - protected[leon] var reverseADTFieldSelectors: Map[Z3FuncDecl, (CaseClassType,Identifier)] = Map.empty - - protected[leon] val mapRangeSorts: MutableMap[TypeTree, Z3Sort] = MutableMap.empty - protected[leon] val mapRangeSomeConstructors: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty - protected[leon] val mapRangeNoneConstructors: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty - protected[leon] val mapRangeSomeTesters: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty - protected[leon] val mapRangeNoneTesters: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty - protected[leon] val mapRangeValueSelectors: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty - - private var counter = 0 - private object nextIntForSymbol { - def apply(): Int = { - val res = counter - counter = counter + 1 - res - } - } + protected val constructors = new IncrementalBijection[TypeTree, Z3FuncDecl]() + protected val selectors = new IncrementalBijection[(TypeTree, Int), Z3FuncDecl]() + protected val testers = new IncrementalBijection[TypeTree, Z3FuncDecl]() var isInitialized = false protected[leon] def initZ3() { if (!isInitialized) { val timer = context.timers.solvers.z3.init.start() - counter = 0 - z3 = new Z3Context(z3cfg) functions.clear() @@ -187,10 +158,6 @@ trait AbstractZ3Solver sorts.clear() variables.clear() - arrayMetaDecls = Map() - tupleMetaDecls = Map() - setCardDecls = Map() - prepareSorts() isInitialized = true @@ -205,39 +172,6 @@ trait AbstractZ3Solver initZ3() } - protected[leon] def mapRangeSort(toType : TypeTree) : Z3Sort = mapRangeSorts.get(toType) match { - case Some(z3sort) => z3sort - case None => { - import Z3Context.RegularSort - - val z3info = z3.mkADTSorts( - Seq( - ( - toType.toString + "Option", - Seq(toType.toString + "Some", toType.toString + "None"), - Seq( - Seq(("value", RegularSort(typeToSort(toType)))), - Seq() - ) - ) - ) - ) - - z3info match { - case Seq((optionSort, Seq(someCons, noneCons), Seq(someTester, noneTester), Seq(Seq(valueSelector), Seq()))) => - mapRangeSorts += ((toType, optionSort)) - mapRangeSomeConstructors += ((toType, someCons)) - mapRangeNoneConstructors += ((toType, noneCons)) - mapRangeSomeTesters += ((toType, someTester)) - mapRangeNoneTesters += ((toType, noneTester)) - mapRangeValueSelectors += ((toType, valueSelector)) - optionSort - } - } - } - - case class UntranslatableTypeException(msg: String) extends Exception(msg) - def rootType(ct: TypeTree): TypeTree = ct match { case ct: ClassType => ct.parent match { @@ -247,111 +181,56 @@ trait AbstractZ3Solver case t => t } - def declareADTSort(ct: ClassType): Z3Sort = { + def declareStructuralSort(t: TypeTree): Z3Sort = { import Z3Context.{ADTSortReference, RecursiveType, RegularSort} //println("///"*40) //println("Declaring for: "+ct) - def getHierarchy(ct: ClassType): (ClassType, Seq[CaseClassType]) = ct match { - case act: AbstractClassType => - (act, act.knownCCDescendents) - case cct: CaseClassType => - cct.parent match { - case Some(p) => - getHierarchy(p) - case None => - (cct, List(cct)) - } - } + val adts = adtManager.defineADT(t).toSeq - def resolveTypes(ct: ClassType) = { - var newHierarchiesMap = Map[ClassType, Seq[CaseClassType]]() + val indexMap: Map[TypeTree, Int] = adts.map(_._1).zipWithIndex.toMap - def findDependencies(ct: ClassType): Unit = { - val (root, sub) = getHierarchy(ct) + def typeToSortRef(tt: TypeTree): ADTSortReference = { + val tpe = rootType(tt) - if (!(newHierarchiesMap contains root) && !(sorts containsLeon root)) { - newHierarchiesMap += root -> sub - - // look for dependencies - for (ct <- root +: sub; f <- ct.fields) f.getType match { - case fct: ClassType => - findDependencies(fct) - case _ => - } - } + if (indexMap contains tpe) { + RecursiveType(indexMap(tpe)) + } else { + RegularSort(typeToSort(tt)) } + } - // Populates the dependencies of the ADT to define. - findDependencies(ct) - - //println("Dependencies: ") - //for ((r, sub) <- newHierarchiesMap) { - // println(s" - $r >: $sub") - //} - - val newHierarchies = newHierarchiesMap.toSeq - - val indexMap: Map[ClassType, Int] = Map()++newHierarchies.map(_._1).zipWithIndex + // Define stuff + val defs = for ((_, DataType(sym, cases)) <- adts) yield {( + sym.uniqueName, + cases.map(c => c.sym.uniqueName), + cases.map(c => c.fields.map{ case(id, tpe) => (id.uniqueName, typeToSortRef(tpe))}) + )} - def typeToSortRef(tt: TypeTree): ADTSortReference = rootType(tt) match { - case ct: ClassType if sorts containsLeon ct => - RegularSort(sorts.toZ3(ct)) + val resultingZ3Info = z3.mkADTSorts(defs) - case act: ClassType => - // It has to be here - RecursiveType(indexMap(act)) + for ((z3Inf, (tpe, DataType(sym, cases))) <- resultingZ3Info zip adts) { + sorts += (tpe -> z3Inf._1) + assert(cases.size == z3Inf._2.size) - case _=> - RegularSort(typeToSort(tt)) + for ((c, (consFun, testFun)) <- cases zip (z3Inf._2 zip z3Inf._3)) { + testers += (c.tpe -> testFun) + constructors += (c.tpe -> consFun) } - // Define stuff - val defs = for ((root, childrenList) <- newHierarchies) yield {( - root.toString, - childrenList.map(ccd => ccd.id.uniqueName), - childrenList.map(ccd => ccd.fields.map(f => (f.id.uniqueName, typeToSortRef(f.getType)))) - )} - (defs, newHierarchies) - } - - // @EK: the first step is needed to introduce ADT sorts referenced inside Sets of this CT - // When defining Map(s: Set[Pos], p: Pos), it will need Pos, but Pos will be defined through Set[Pos] in the first pass - resolveTypes(ct) - val (defs, newHierarchies) = resolveTypes(ct) - - //for ((n, sub, cstrs) <- defs) { - // println(n+":") - // for ((s,css) <- sub zip cstrs) { - // println(" "+s) - // println(" -> "+css) - // } - //} - - val resultingZ3Info = z3.mkADTSorts(defs) + for ((c, fieldFuns) <- cases zip z3Inf._4) { + assert(c.fields.size == fieldFuns.size) - for ((z3Inf, (root, childrenList)) <- resultingZ3Info zip newHierarchies) { - sorts += (root -> z3Inf._1) - assert(childrenList.size == z3Inf._2.size) - for ((child, (consFun, testFun)) <- childrenList zip (z3Inf._2 zip z3Inf._3)) { - adtTesters += (child -> testFun) - reverseADTTesters += (testFun -> child) - adtConstructors += (child -> consFun) - reverseADTConstructors += (consFun -> child) - } - for ((child, fieldFuns) <- childrenList zip z3Inf._4) { - assert(child.fields.size == fieldFuns.size) - for ((fid, selFun) <- child.fields.map(_.id) zip fieldFuns) { - adtFieldSelectors += ((child, fid) -> selFun) - reverseADTFieldSelectors += (selFun -> (child, fid)) + for ((selFun, index) <- fieldFuns.zipWithIndex) { + selectors += (c.tpe, index) -> selFun } } } //println("\\\\\\"*40) - sorts.toZ3(ct) + sorts.toZ3(t) } // Prepares some of the Z3 sorts, but *not* the tuple sorts; these are created on-demand. @@ -372,23 +251,10 @@ trait AbstractZ3Solver sorts += CharType -> z3.mkBVSort(32) sorts += IntegerType -> z3.mkIntSort sorts += BooleanType -> z3.mkBoolSort - sorts += UnitType -> us - - unitValue = unitCons() - - val intSetSort = typeToSort(SetType(IntegerType)) - val intSort = typeToSort(IntegerType) - intSetMinFun = z3.mkFreshFuncDecl("setMin", Seq(intSetSort), intSort) - intSetMaxFun = z3.mkFreshFuncDecl("setMax", Seq(intSetSort), intSort) - - // Empty everything - adtTesters = Map.empty - adtConstructors = Map.empty - adtFieldSelectors = Map.empty - reverseADTTesters = Map.empty - reverseADTConstructors = Map.empty - reverseADTFieldSelectors = Map.empty + testers.clear + constructors.clear + selectors.clear } def normalizeType(t: TypeTree): TypeTree = { @@ -397,58 +263,29 @@ trait AbstractZ3Solver // assumes prepareSorts has been called.... protected[leon] def typeToSort(oldtt: TypeTree): Z3Sort = normalizeType(oldtt) match { - case Int32Type | BooleanType | UnitType | IntegerType | CharType => + case Int32Type | BooleanType | IntegerType | CharType => sorts.toZ3(oldtt) - case act: AbstractClassType => - sorts.toZ3OrCompute(act) { - declareADTSort(act) + case tpe @ (_: ClassType | _: ArrayType | _: TupleType | UnitType) => + sorts.toZ3OrCompute(tpe) { + declareStructuralSort(tpe) } - case cct: CaseClassType => - sorts.toZ3OrCompute(cct) { - declareADTSort(cct) - } - case tt @ SetType(base) => sorts.toZ3OrCompute(tt) { - val newSetSort = z3.mkSetSort(typeToSort(base)) - val card = z3.mkFreshFuncDecl("card", Seq(newSetSort), typeToSort(Int32Type)) - setCardDecls += tt -> card - - newSetSort + z3.mkSetSort(typeToSort(base)) } case tt @ MapType(fromType, toType) => - sorts.toZ3OrCompute(tt) { - val fromSort = typeToSort(fromType) - val toSort = mapRangeSort(toType) - - z3.mkArraySort(fromSort, toSort) - } + typeToSort(RawArrayType(fromType, library.optionType(toType))) - case tt @ ArrayType(base) => - sorts.toZ3OrCompute(tt) { - val intSort = typeToSort(Int32Type) - val toSort = typeToSort(base) - val as = z3.mkArraySort(intSort, toSort) - val tupleSortSymbol = z3.mkFreshStringSymbol("Array") - val (ats, atcons, Seq(atsel, atlength)) = z3.mkTupleSort(tupleSortSymbol, as, intSort) - - arrayMetaDecls += tt -> ArrayDecls(atcons, atsel, atlength) - - ats - } - case tt @ TupleType(tpes) => - sorts.toZ3OrCompute(tt) { - val tpesSorts = tpes.map(typeToSort) - val sortSymbol = z3.mkFreshStringSymbol("Tuple") - val (tupleSort, consTuple, projsTuple) = z3.mkTupleSort(sortSymbol, tpesSorts: _*) - - tupleMetaDecls += tt -> TupleDecls(consTuple, projsTuple) + case rat @ RawArrayType(from, to) => + sorts.toZ3OrCompute(rat) { + val fromSort = typeToSort(from) + val toSort = typeToSort(to) - tupleSort + z3.mkArraySort(fromSort, toSort) } case tt @ TypeParameter(id) => @@ -470,7 +307,7 @@ trait AbstractZ3Solver case other => sorts.toZ3OrCompute(other) { reporter.warning(other.getPos, "Resorting to uninterpreted type for : " + other) - val symbol = z3.mkIntSymbol(nextIntForSymbol()) + val symbol = z3.mkIntSymbol(FreshIdentifier("unint").globalId) z3.mkUninterpretedSort(symbol) } } @@ -508,17 +345,6 @@ trait AbstractZ3Solver case me @ MatchExpr(s, cs) => rec(matchToIfThenElse(me)) - case tu @ Tuple(args) => - typeToSort(tu.getType) // Make sure we generate sort & meta info - val meta = tupleMetaDecls(normalizeType(tu.getType)) - - meta.cons(args.map(rec): _*) - - case ts @ TupleSelect(tu, i) => - typeToSort(tu.getType) // Make sure we generate sort & meta info - val meta = tupleMetaDecls(normalizeType(tu.getType)) - - meta.selects(i-1)(rec(tu)) case Let(i, e, b) => { val re = rec(e) @@ -555,7 +381,6 @@ trait AbstractZ3Solver case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() - case UnitLiteral() => unitValue case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) case Minus(l, r) => z3.mkSub(rec(l), rec(r)) @@ -611,21 +436,77 @@ trait AbstractZ3Solver case CharType => z3.mkBVSge(rec(l), rec(r)) } + case u : UnitLiteral => + val tpe = normalizeType(u.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor() + + case t @ Tuple(es) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor(es.map(rec): _*) + + case ts @ TupleSelect(t, i) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val selector = selectors.toB((tpe, i-1)) + selector(rec(t)) + case c @ CaseClass(ct, args) => typeToSort(ct) // Making sure the sort is defined - val constructor = adtConstructors(ct) + val constructor = constructors.toB(ct) constructor(args.map(rec): _*) case c @ CaseClassSelector(cct, cc, sel) => typeToSort(cct) // Making sure the sort is defined - val selector = adtFieldSelectors(cct, sel) + val selector = selectors.toB(cct, c.selectorIndex) selector(rec(cc)) case c @ CaseClassInstanceOf(cct, e) => typeToSort(cct) // Making sure the sort is defined - val tester = adtTesters(cct) + val tester = testers.toB(cct) tester(rec(e)) + case al @ ArraySelect(a, i) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val content = selectors.toB((tpe, 1))(sa) + + z3.mkSelect(content, rec(i)) + + case al @ ArrayUpdated(a, i, e) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val ssize = selectors.toB((tpe, 0))(sa) + val scontent = selectors.toB((tpe, 1))(sa) + + val newcontent = z3.mkStore(scontent, rec(i), rec(e)) + + val constructor = constructors.toB(tpe) + + constructor(ssize, newcontent) + + case al @ ArrayLength(a) => + val tpe = normalizeType(a.getType) + val sa = rec(a) + selectors.toB((tpe, 0))(sa) + + case arr @ FiniteArray(elems, oDefault, length) => + val at @ ArrayType(base) = normalizeType(arr.getType) + typeToSort(at) + + val default = oDefault.getOrElse(simplestValue(base)) + + val ar = rec(RawArrayValue(Int32Type, elems.map { + case (i, e) => IntLiteral(i) -> e + }, default)) + + constructors.toB(at)(rec(length), ar) + case f @ FunctionInvocation(tfd, args) => z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) @@ -638,75 +519,49 @@ trait AbstractZ3Solver case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) case f @ FiniteSet(elems, base) => elems.foldLeft(z3.mkEmptySet(typeToSort(base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) - case SetCardinality(s) => - val rs = rec(s) - setCardDecls(s.getType)(rs) - - case SetMin(s) => intSetMinFun(rec(s)) - case SetMax(s) => intSetMaxFun(rec(s)) - case f @ FiniteMap(elems, fromType, toType) => - typeToSort(MapType(fromType, toType)) //had to add this here because the mapRangeNoneConstructors was not yet constructed... - val fromSort = typeToSort(fromType) - elems.foldLeft(z3.mkConstArray(fromSort, mapRangeNoneConstructors(toType)())){ - case (ast, (k,v)) => z3.mkStore(ast, rec(k), mapRangeSomeConstructors(toType)(rec(v))) - } - case mg @ MapGet(m,k) => m.getType match { - case MapType(fromType, toType) => - val selected = z3.mkSelect(rec(m), rec(k)) - mapRangeValueSelectors(toType)(selected) - case errorType => scala.sys.error("Unexpected type for map: " + (ex, errorType)) - } - case MapUnion(m1,m2) => m1.getType match { - case MapType(ft, tt) => m2 match { - case FiniteMap(ss, _, _) => - ss.foldLeft(rec(m1)){ - case (ast, (k, v)) => z3.mkStore(ast, rec(k), mapRangeSomeConstructors(tt)(rec(v))) - } - case _ => scala.sys.error("map updates can only be applied with concrete map instances") + case RawArrayValue(keyTpe, elems, default) => + val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) + + elems.foldLeft(ar) { + case (array, (k, v)) => z3.mkStore(array, rec(k), rec(v)) } - case errorType => scala.sys.error("Unexpected type for map: " + (ex, errorType)) - } - case MapIsDefinedAt(m,k) => m.getType match { - case MapType(ft, tt) => z3.mkDistinct(z3.mkSelect(rec(m), rec(k)), mapRangeNoneConstructors(tt)()) - case errorType => scala.sys.error("Unexpected type for map: " + (ex, errorType)) - } - case ArraySelect(a, index) => - typeToSort(a.getType) - val ar = rec(a) - val getArray = arrayMetaDecls(normalizeType(a.getType)).select - val res = z3.mkSelect(getArray(ar), rec(index)) - res + /** + * ===== Map operations ===== + */ + case m @ FiniteMap(elems, from, to) => + val mt @ MapType(f, t) = normalizeType(m.getType) - case ArrayUpdated(a, index, newVal) => - typeToSort(a.getType) - val ar = rec(a) - val meta = arrayMetaDecls(normalizeType(a.getType)) + rec(RawArrayValue(from, elems.map{ + case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) + }.toMap, CaseClass(library.noneType(t), Seq()))) - val store = z3.mkStore(meta.select(ar), rec(index), rec(newVal)) - val res = meta.cons(store, meta.length(ar)) - res + case MapGet(m, k) => + val mt @ MapType(f, t) = normalizeType(m.getType) + typeToSort(mt) - case ArrayLength(a) => - typeToSort(a.getType) - val ar = rec(a) - val meta = arrayMetaDecls(normalizeType(a.getType)) - val res = meta.length(ar) - res + val el = z3.mkSelect(rec(m), rec(k)) - case arr @ FiniteArray(elems, oDefault, length) => - val at @ ArrayType(base) = arr.getType - typeToSort(at) - val meta = arrayMetaDecls(normalizeType(at)) + // Really ?!? We don't check that it is actually != None? + selectors.toB(library.someType(t), 0)(el) - val default = oDefault.getOrElse(simplestValue(base)) + case MapIsDefinedAt(m, k) => + val mt @ MapType(f, t) = normalizeType(m.getType) + typeToSort(mt) + + val el = z3.mkSelect(rec(m), rec(k)) + + testers.toB(library.someType(t))(el) + + case MapUnion(m1, FiniteMap(elems, _, _)) => + val mt @ MapType(f, t) = normalizeType(m1.getType) + typeToSort(mt) + + elems.foldLeft(rec(m1)) { case (m, (k,v)) => + z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) + } - val ar = z3.mkConstArray(typeToSort(Int32Type), rec(default)) - val u = elems.foldLeft(ar)((array, el) => { - z3.mkStore(array, rec(IntLiteral(el._1)), rec(el._2)) - }) - meta.cons(u, rec(length)) case gv @ GenericValue(tp, id) => z3.mkApp(genericValueToDecl(gv)) @@ -725,15 +580,26 @@ trait AbstractZ3Solver } } - protected[leon] def fromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree) : Expr = { + protected def fromRawArray(r: Expr, tpe: TypeTree): Expr = r match { + case rav: RawArrayValue => + fromRawArray(rav, tpe) + case _ => + scala.sys.error("Unable to extract from raw array for "+r) + } - // This is unsafe and should be avoided because sorts<->types are not a bijection - // For instance, both Int32Type and CharType compile to BV32 + protected def fromRawArray(r: RawArrayValue, tpe: TypeTree): Expr = tpe match { + case RawArrayType(from, to) => + r + + case ft @ FunctionType(from, to) => + finiteLambda(r.default, r.elems.toSeq, from) - def recGuess(t: Z3AST): Expr = { - val s = z3.getSort(t) - rec(t, sorts.toLeon(s)) - } + + case _ => + scala.sys.error("Unable to extract from raw array for "+tpe) + } + + protected[leon] def fromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree): Expr = { def rec(t: Z3AST, tpe: TypeTree): Expr = { val kind = z3.getASTKind(t) @@ -784,61 +650,73 @@ trait AbstractZ3Solver val tfd = functions.toLeon(decl) assert(tfd.params.size == argsSize) FunctionInvocation(tfd, args.zip(tfd.params).map{ case (a, p) => rec(a, p.getType) }) - } else if(argsSize == 1 && (reverseADTTesters contains decl)) { - val cct = reverseADTTesters(decl) - CaseClassInstanceOf(cct, rec(args.head, cct.root)) - } else if(argsSize == 1 && (reverseADTFieldSelectors contains decl)) { - val (cct, fid) = reverseADTFieldSelectors(decl) - CaseClassSelector(cct, rec(args.head, cct), fid) - } else if(reverseADTConstructors contains decl) { - val cct = reverseADTConstructors(decl) - assert(argsSize == cct.fields.size) - CaseClass(cct, args.zip(cct.fieldsTypes).map{ case (a, t) => rec(a, t) }) } else if (generics containsZ3 decl) { generics.toLeon(decl) - } else { - tpe match { - case tp: TypeParameter => - val id = t.toString.split("!").last.toInt - GenericValue(tp, id) + } else if (constructors containsB decl) { + constructors.toA(decl) match { + case cct: CaseClassType => + CaseClass(cct, args.zip(cct.fieldsTypes).map { case (a, t) => rec(a, t) }) + + case UnitType => + UnitLiteral() case TupleType(ts) => - val rargs = args.zip(ts).map{ case (a, t) => rec(a, t) } - tupleWrap(rargs) - - case at @ ArrayType(dt) => - assert(args.size == 2) - val length = rec(args(1), Int32Type) match { - case IntLiteral(length) => length - case _ => throw new CantTranslateException(t) - } - model.getArrayValue(args(0)) match { - case None => throw new CantTranslateException(t) - case Some((map, elseZ3Value)) => - val elseValue = rec(elseZ3Value, dt) - val valuesMap = map.map { case (k,v) => - val index = rec(k, Int32Type) match { - case IntLiteral(index) => index - case _ => throw new CantTranslateException(t) - } - index -> rec(v, dt) + tupleWrap(args.zip(ts).map { case (a, t) => rec(a, t) }) + + case ArrayType(to) => + val size = rec(args(0), Int32Type) + val map = rec(args(1), RawArrayType(Int32Type, to)) + + (size, map) match { + case (s : IntLiteral, RawArrayValue(_, elems, default)) => + val entries = elems.map { + case (IntLiteral(i), v) => i -> v + case _ => throw new CantTranslateException(t) } - finiteArray(valuesMap, Some(elseValue, IntLiteral(length)), dt) + finiteArray(entries, Some(s, default), to) + case _ => + throw new CantTranslateException(t) } - - case tpe @ MapType(kt, vt) => + } + } else { + tpe match { + case RawArrayType(from, to) => model.getArrayValue(t) match { + case Some((z3map, z3default)) => + val default = rec(z3default, to) + val entries = z3map.map { + case (k,v) => (rec(k, from), rec(v, to)) + } + + RawArrayValue(from, entries, default) case None => throw new CantTranslateException(t) - case Some((map, elseZ3Value)) => - val values = map.toSeq.map { case (k, v) => (k, z3.getASTKind(v)) }.collect { - case (k, Z3AppAST(cons, arg :: Nil)) if cons == mapRangeSomeConstructors(vt) => - (rec(k, kt), rec(arg, vt)) + } + + case tp: TypeParameter => + val id = t.toString.split("!").last.toInt + GenericValue(tp, id) + + case MapType(from, to) => + rec(t, RawArrayType(from, library.optionType(to))) match { + case r: RawArrayValue => + // We expect a RawArrayValue with keys in from and values in Option[to], + // with default value == None + if (r.default.getType != library.noneType(to)) { + reporter.warning("Co-finite maps are not supported. (Default was "+r.default+")") + throw new IllegalArgumentException } + require(r.keyTpe == from, s"Type error in solver model, expected $from, found ${r.keyTpe}") + + val elems = r.elems.flatMap { + case (k, CaseClass(leonSome, Seq(x))) => Some(k -> x) + case (k, _) => None + }.toSeq - finiteMap(values, kt, vt) + finiteMap(elems, from, to) } + case FunctionType(fts, tt) => model.getArrayValue(t) match { case None => throw new CantTranslateException(t) @@ -856,37 +734,36 @@ trait AbstractZ3Solver finiteSet(elems, dt) } - case UnitType => - UnitLiteral() - case _ => import Z3DeclKind._ - val rargs = args.map(recGuess) z3.getDeclKind(decl) match { case OpTrue => BooleanLiteral(true) case OpFalse => BooleanLiteral(false) - case OpEq => Equals(rargs(0), rargs(1)) - case OpITE => IfExpr(rargs(0), rargs(1), rargs(2)) - case OpAnd => andJoin(rargs) - case OpOr => orJoin(rargs) - case OpIff => Equals(rargs(0), rargs(1)) - case OpXor => not(Equals(rargs(0), rargs(1))) - case OpNot => not(rargs(0)) - case OpImplies => implies(rargs(0), rargs(1)) - case OpLE => LessEquals(rargs(0), rargs(1)) - case OpGE => GreaterEquals(rargs(0), rargs(1)) - case OpLT => LessThan(rargs(0), rargs(1)) - case OpGT => GreaterThan(rargs(0), rargs(1)) - case OpAdd => Plus(rargs(0), rargs(1)) - case OpSub => Minus(rargs(0), rargs(1)) - case OpUMinus => UMinus(rargs(0)) - case OpMul => Times(rargs(0), rargs(1)) - case OpDiv => Division(rargs(0), rargs(1)) - case OpIDiv => Division(rargs(0), rargs(1)) - case OpMod => Modulo(rargs(0), rargs(1)) + // case OpEq => Equals(rargs(0), rargs(1)) + // case OpITE => IfExpr(rargs(0), rargs(1), rargs(2)) + // case OpAnd => andJoin(rargs) + // case OpOr => orJoin(rargs) + // case OpIff => Equals(rargs(0), rargs(1)) + // case OpXor => not(Equals(rargs(0), rargs(1))) + // case OpNot => not(rargs(0)) + // case OpImplies => implies(rargs(0), rargs(1)) + // case OpLE => LessEquals(rargs(0), rargs(1)) + // case OpGE => GreaterEquals(rargs(0), rargs(1)) + // case OpLT => LessThan(rargs(0), rargs(1)) + // case OpGT => GreaterThan(rargs(0), rargs(1)) + // case OpAdd => Plus(rargs(0), rargs(1)) + // case OpSub => Minus(rargs(0), rargs(1)) + case OpUMinus => UMinus(rec(args(0), tpe)) + // case OpMul => Times(rargs(0), rargs(1)) + // case OpDiv => Division(rargs(0), rargs(1)) + // case OpIDiv => Division(rargs(0), rargs(1)) + // case OpMod => Modulo(rargs(0), rargs(1)) case other => System.err.println("Don't know what to do with this declKind : " + other) + System.err.println("Expected type: " + tpe) + System.err.println("Tree: " + t) System.err.println("The arguments are : " + args) + new Exception().printStackTrace throw new CantTranslateException(t) } } diff --git a/src/main/scala/leon/utils/Library.scala b/src/main/scala/leon/utils/Library.scala index e7c76c83165969a7cf12b32b6f96af12a1657010..920a45b5060cde5508c8acaa4279c8cc14cbf02e 100644 --- a/src/main/scala/leon/utils/Library.scala +++ b/src/main/scala/leon/utils/Library.scala @@ -4,6 +4,7 @@ package leon package utils import purescala.Definitions._ +import purescala.Types._ import purescala.DefOps.searchByFullName case class Library(pgm: Program) { @@ -22,4 +23,8 @@ case class Library(pgm: Program) { def lookup(name: String): Option[Definition] = { searchByFullName(name, pgm) } + + def optionType(tp: TypeTree) = AbstractClassType(Option.get, Seq(tp)) + def someType(tp: TypeTree) = CaseClassType(Some.get, Seq(tp)) + def noneType(tp: TypeTree) = CaseClassType(None.get, Seq(tp)) }