From f34d906ffc40a3688de866d01309e047c45d50b9 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Fri, 28 Oct 2016 12:08:43 +0200 Subject: [PATCH] FiniteMap with to type + fixes in ProgramEncoder --- src/main/scala/inox/Main.scala | 2 +- src/main/scala/inox/ast/Expressions.scala | 4 ++-- src/main/scala/inox/ast/Extractors.scala | 8 ++++---- src/main/scala/inox/ast/Printers.scala | 2 +- src/main/scala/inox/ast/ProgramEncoder.scala | 14 ++++++++++++++ src/main/scala/inox/ast/SymbolOps.scala | 9 +++++---- .../inox/evaluators/RecursiveEvaluator.scala | 12 ++++++------ .../scala/inox/solvers/smtlib/CVC4Target.scala | 8 ++++---- .../scala/inox/solvers/smtlib/SMTLIBTarget.scala | 2 +- .../scala/inox/solvers/smtlib/Z3Target.scala | 16 +++++++--------- .../scala/inox/solvers/z3/AbstractZ3Solver.scala | 8 ++++---- src/main/scala/inox/tip/Parser.scala | 3 ++- 12 files changed, 51 insertions(+), 37 deletions(-) diff --git a/src/main/scala/inox/Main.scala b/src/main/scala/inox/Main.scala index 625152626..c6be43a62 100644 --- a/src/main/scala/inox/Main.scala +++ b/src/main/scala/inox/Main.scala @@ -110,7 +110,7 @@ trait MainHelpers { Context( reporter = reporter, - options = Options(inoxOptions), + options = Options(inoxOptions :+ optFiles(files)), interruptManager = new utils.InterruptManager(reporter) ) } diff --git a/src/main/scala/inox/ast/Expressions.scala b/src/main/scala/inox/ast/Expressions.scala index 66f2f55c5..bf4ebabd6 100644 --- a/src/main/scala/inox/ast/Expressions.scala +++ b/src/main/scala/inox/ast/Expressions.scala @@ -670,10 +670,10 @@ trait Expressions { self: Trees => /* Total map operations */ /** $encodingof `Map[keyType, valueType](key1 -> value1, key2 -> value2 ...)` */ - case class FiniteMap(pairs: Seq[(Expr, Expr)], default: Expr, keyType: Type) extends Expr with CachingTyped { + case class FiniteMap(pairs: Seq[(Expr, Expr)], default: Expr, keyType: Type, valueType: Type) extends Expr with CachingTyped { protected def computeType(implicit s: Symbols): Type = MapType( checkParamTypes(pairs.map(_._1.getType), List.fill(pairs.size)(keyType), keyType), - s.leastUpperBound(pairs.map(_._2.getType) :+ default.getType).getOrElse(Untyped) + checkParamTypes(pairs.map(_._2.getType) :+ default.getType, List.fill(pairs.size + 1)(valueType), valueType) ).unveilUntyped } diff --git a/src/main/scala/inox/ast/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala index f7e261e88..f09ba2b8b 100644 --- a/src/main/scala/inox/ast/Extractors.scala +++ b/src/main/scala/inox/ast/Extractors.scala @@ -150,9 +150,9 @@ trait TreeDeconstructor { t.FiniteBag(rec(as), tps.head) } (Seq(), subArgs, Seq(base), builder) - case s.FiniteMap(elems, default, kT) => + case s.FiniteMap(elems, default, kT, vT) => val subArgs = elems.flatMap { case (k, v) => Seq(k, v) } :+ default - val builder = (vs: Seq[t.Variable], as: Seq[t.Expr], kT: Seq[t.Type]) => { + val builder = (vs: Seq[t.Variable], as: Seq[t.Expr], tps: Seq[t.Type]) => { def rec(kvs: Seq[t.Expr]): (Seq[(t.Expr, t.Expr)], t.Expr) = kvs match { case Seq(k, v, t @ _*) => val (kvs, default) = rec(t) @@ -160,9 +160,9 @@ trait TreeDeconstructor { case Seq(default) => (Seq(), default) } val (pairs, default) = rec(as) - t.FiniteMap(pairs, default, kT.head) + t.FiniteMap(pairs, default, tps(0), tps(1)) } - (Seq(), subArgs, Seq(kT), builder) + (Seq(), subArgs, Seq(kT, vT), builder) case s.Tuple(args) => (Seq(), args, Seq(), (_, es, _) => t.Tuple(es)) case s.IfExpr(cond, thenn, elze) => ( diff --git a/src/main/scala/inox/ast/Printers.scala b/src/main/scala/inox/ast/Printers.scala index e180bc0f3..27ac8db61 100644 --- a/src/main/scala/inox/ast/Printers.scala +++ b/src/main/scala/inox/ast/Printers.scala @@ -239,7 +239,7 @@ trait Printers { } case fs @ FiniteSet(rs, _) => p"{${rs.distinct}}" case fs @ FiniteBag(rs, _) => p"{${rs.toMap.toSeq}}" - case fm @ FiniteMap(rs, _, _) => p"{${rs.toMap.toSeq}}" + case fm @ FiniteMap(rs, _, _, _) => p"{${rs.toMap.toSeq}}" case Not(ElementOfSet(e, s)) => p"$e \u2209 $s" case ElementOfSet(e, s) => p"$e \u2208 $s" case SubsetOf(l, r) => p"$l \u2286 $r" diff --git a/src/main/scala/inox/ast/ProgramEncoder.scala b/src/main/scala/inox/ast/ProgramEncoder.scala index 0f8edaf9d..0d60a6ed6 100644 --- a/src/main/scala/inox/ast/ProgramEncoder.scala +++ b/src/main/scala/inox/ast/ProgramEncoder.scala @@ -42,6 +42,13 @@ trait ProgramEncoder { self => val sourceProgram: that.sourceProgram.type = that.sourceProgram val t: self.t.type = self.t + // make sure we don't ignore potential `encodedProgram` overrides + // note that we don't actually need to look at `that.encodedProgram` since the type + // of the compose method ensures the override is not ignored + override protected def encodedProgram: Program { val trees: self.t.type } = self.encodedProgram + override protected val extraFunctions: Seq[t.FunDef] = self.extraFunctions + override protected val extraADTs: Seq[t.ADTDefinition] = self.extraADTs + val encoder = self.encoder compose that.encoder val decoder = that.decoder compose self.decoder } @@ -53,6 +60,13 @@ trait ProgramEncoder { self => val sourceProgram: self.sourceProgram.type = self.sourceProgram val t: that.t.type = that.t + // make sure we don't ignore potential `encodedProgram` overrides + // note that we don't actually need to look at `that.encodedProgram` since the type + // of the andThen method ensures the override is not ignored + override protected def encodedProgram: Program { val trees: that.t.type } = that.encodedProgram + override protected val extraFunctions: Seq[t.FunDef] = that.extraFunctions + override protected val extraADTs: Seq[t.ADTDefinition] = that.extraADTs + val encoder = self.encoder andThen that.encoder val decoder = that.decoder andThen self.decoder } diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index b9f5f56b7..7bd557dae 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -327,7 +327,7 @@ trait SymbolOps { self: TypeOps => case UnitType => UnitLiteral() case SetType(baseType) => FiniteSet(Seq(), baseType) case BagType(baseType) => FiniteBag(Seq(), baseType) - case MapType(fromType, toType) => FiniteMap(Seq(), simplestValue(toType), fromType) + case MapType(fromType, toType) => FiniteMap(Seq(), simplestValue(toType), fromType, toType) case TupleType(tpes) => Tuple(tpes.map(simplestValue)) case adt @ ADTType(id, tps) => @@ -389,7 +389,7 @@ trait SymbolOps { self: TypeOps => val seqs = elems.scanLeft(Stream(Seq[(Expr, Expr)]())) { (prev, curr) => prev flatMap { case seq => Stream(seq, seq :+ curr) } }.flatten - cartesianProduct(seqs, valuesOf(to)) map { case (values, default) => FiniteMap(values, default, from) } + cartesianProduct(seqs, valuesOf(to)) map { case (values, default) => FiniteMap(values, default, from, to) } case adt: ADTType => adt.getADT match { case tcons: TypedADTConstructor => cartesianProduct(tcons.fieldsTypes map valuesOf) map (ADT(adt, _)) @@ -620,8 +620,9 @@ trait SymbolOps { self: TypeOps => case (FiniteBag(elements, fbtpe), BagType(tpe)) => fbtpe == tpe && elements.forall{ case (key, value) => isValueOfType(key, tpe) && isValueOfType(value, IntegerType) } - case (FiniteMap(elems, default, kt), MapType(from, to)) => - (kt == from) < s"$kt not equal to $from" && (default.getType == to) < s"${default.getType} not equal to $to" && + case (FiniteMap(elems, default, kt, vt), MapType(from, to)) => + (kt == from) < s"$kt not equal to $from" && (vt == to) < s"${default.getType} not equal to $to" && + isValueOfType(default, to) < s"${default} not a value of type $to" && (elems forall (kv => isValueOfType(kv._1, from) < s"${kv._1} not a value of type $from" && isValueOfType(unWrapSome(kv._2), to) < s"${unWrapSome(kv._2)} not a value of type ${to}" )) case (ADT(adt, args), adt2: ADTType) => isSubtypeOf(adt, adt2) < s"$adt not a subtype of $adt2" && diff --git a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala index 56ea64bf6..6675d5915 100644 --- a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala @@ -118,7 +118,7 @@ trait RecursiveEvaluator BooleanLiteral(el1.toSet == el2.toSet) case (FiniteBag(el1, _),FiniteBag(el2, _)) => BooleanLiteral(el1.toMap == el2.toMap) - case (FiniteMap(el1, dflt1, _),FiniteMap(el2, dflt2, _)) => + case (FiniteMap(el1, dflt1, _, _),FiniteMap(el2, dflt2, _, _)) => BooleanLiteral(el1.toMap == el2.toMap && dflt1 == dflt2) case (l1: Lambda, l2: Lambda) => val (nl1, subst1) = normalizeStructure(l1) @@ -481,20 +481,20 @@ trait RecursiveEvaluator replaceFromSymbols(variablesOf(c).map(v => v -> e(v)).toMap, c).asInstanceOf[Choose] } - case f @ FiniteMap(ss, dflt, vT) => + case f @ FiniteMap(ss, dflt, kT, vT) => // we use toMap.toSeq to reduce dupplicate keys - FiniteMap(ss.map{ case (k, v) => (e(k), e(v)) }.toMap.toSeq, e(dflt), vT) + FiniteMap(ss.map{ case (k, v) => (e(k), e(v)) }.toMap.toSeq, e(dflt), kT, vT) case g @ MapApply(m,k) => (e(m), e(k)) match { - case (FiniteMap(ss, dflt, _), e) => + case (FiniteMap(ss, dflt, _, _), e) => ss.toMap.getOrElse(e, dflt) case (l,r) => throw EvalError(typeErrorMsg(l, MapType(r.getType, g.getType))) } case g @ MapUpdated(m, k, v) => (e(m), e(k), e(v)) match { - case (FiniteMap(ss, dflt, tpe), ek, ev) => - FiniteMap((ss.toMap + (ek -> ev)).toSeq, dflt, tpe) + case (FiniteMap(ss, dflt, kT, vT), ek, ev) => + FiniteMap((ss.toMap + (ek -> ev)).toSeq, dflt, kT, vT) case (m,l,r) => throw EvalError("Unexpected operation: " + m.asString + ".updated(" + l.asString + ", " + r.asString + ")") diff --git a/src/main/scala/inox/solvers/smtlib/CVC4Target.scala b/src/main/scala/inox/solvers/smtlib/CVC4Target.scala index b28755ba2..5e07a5d49 100644 --- a/src/main/scala/inox/solvers/smtlib/CVC4Target.scala +++ b/src/main/scala/inox/solvers/smtlib/CVC4Target.scala @@ -54,17 +54,17 @@ trait CVC4Target extends SMTLIBTarget with SMTLIBDebugger { FiniteSet(Seq(), base) case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), Some(MapType(k, v))) => - FiniteMap(Seq(), fromSMT(elem, v), k) + FiniteMap(Seq(), fromSMT(elem, v), k, v) case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), Some(MapType(k, v))) => - FiniteMap(Seq(), fromSMT(elem, v), k) + FiniteMap(Seq(), fromSMT(elem, v), k, v) case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), Some(MapType(kT, vT))) => - val FiniteMap(elems, default, _) = fromSMT(arr, otpe) + val FiniteMap(elems, default, _, _) = fromSMT(arr, otpe) val newKey = fromSMT(key, kT) val newV = fromSMT(elem, vT) val newElems = elems.filterNot(_._1 == newKey) :+ (newKey -> newV) - FiniteMap(newElems, default, kT) + FiniteMap(newElems, default, kT, vT) case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), Some(SetType(base))) => FiniteSet(elems.map(fromSMT(_, base)), base) diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala index 890761d96..2b71f2ccf 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala @@ -324,7 +324,7 @@ trait SMTLIBTarget extends Interruptible with ADTManagers { ArraysEx.Select(toSMT(a), toSMT(i)) case al @ MapUpdated(map, k, v) => ArraysEx.Store(toSMT(map), toSMT(k), toSMT(v)) - case ra @ FiniteMap(elems, default, keyTpe) => + case ra @ FiniteMap(elems, default, keyTpe, valueType) => val s = declareSort(ra.getType) var res: Term = FunctionApplication( diff --git a/src/main/scala/inox/solvers/smtlib/Z3Target.scala b/src/main/scala/inox/solvers/smtlib/Z3Target.scala index e803c54a3..c9b7fc76e 100644 --- a/src/main/scala/inox/solvers/smtlib/Z3Target.scala +++ b/src/main/scala/inox/solvers/smtlib/Z3Target.scala @@ -75,28 +75,26 @@ trait Z3Target extends SMTLIBTarget with SMTLIBDebugger { } // Need to recover value form function model val (cases, default) = extractCases(body) - FiniteMap(cases.toSeq, default, keyType) + FiniteMap(cases.toSeq, default, keyType, valueType) } else { throw FatalError("Array on non-function or unknown symbol "+k) } case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe @ SetType(base))) => - val fm @ FiniteMap(cases, dflt, _) = fromSMT(t, Some(MapType(base, BooleanType))) + val fm @ FiniteMap(cases, dflt, _, _) = fromSMT(t, Some(MapType(base, BooleanType))) if (dflt != BooleanLiteral(false)) unsupported(fm, "Solver returned a co-finite set which is not supported") FiniteSet(cases.collect { case (k, BooleanLiteral(true)) => k }, base) case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe @ BagType(base))) => - val fm @ FiniteMap(cases, dflt, _) = fromSMT(t, Some(MapType(base, IntegerType))) + val fm @ FiniteMap(cases, dflt, _, _) = fromSMT(t, Some(MapType(base, IntegerType))) if (dflt != IntegerLiteral(0)) unsupported(fm, "Solver returned a co-finite bag which is not supported") FiniteBag(cases.filter(_._2 != IntegerLiteral(BigInt(0))), base) case (FunctionApplication( QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), Some(ArraysEx.ArraySort(k, v))), Seq(defV) - ), Some(tpe: MapType)) => - val ktpe = sorts.fromB(k) - val vtpe = sorts.fromB(v) - FiniteMap(Seq(), fromSMT(defV, Some(vtpe)), ktpe) + ), Some(MapType(from, to))) => + FiniteMap(Seq(), fromSMT(defV, Some(to)), from, to) case (FunctionApplication( QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), Some(ArraysEx.ArraySort(k, v))), @@ -126,7 +124,7 @@ trait Z3Target extends SMTLIBTarget with SMTLIBDebugger { */ case fs @ FiniteSet(elems, base) => declareSort(fs.getType) - toSMT(FiniteMap(elems map ((_, BooleanLiteral(true))), BooleanLiteral(false), base)) + toSMT(FiniteMap(elems map ((_, BooleanLiteral(true))), BooleanLiteral(false), base, BooleanType)) case SubsetOf(ss, s) => // a isSubset b ==> (a zip b).map(implies) == (* => true) @@ -155,7 +153,7 @@ trait Z3Target extends SMTLIBTarget with SMTLIBDebugger { case fb @ FiniteBag(elems, base) => val BagType(t) = fb.getType declareSort(BagType(t)) - toSMT(FiniteMap(elems, IntegerLiteral(0), t)) + toSMT(FiniteMap(elems, IntegerLiteral(0), t, IntegerType)) case BagAdd(b, e) => val bid = FreshIdentifier("b", true) diff --git a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala index a2fc56cf3..fd394dfb8 100644 --- a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala @@ -441,7 +441,7 @@ trait AbstractZ3Solver */ case fb @ FiniteBag(elems, base) => typeToSort(fb.getType) - rec(FiniteMap(elems, IntegerLiteral(0), base)) + rec(FiniteMap(elems, IntegerLiteral(0), base, IntegerType)) case BagAdd(b, e) => val (bag, elem) = (rec(b), rec(e)) @@ -476,7 +476,7 @@ trait AbstractZ3Solver case al @ MapUpdated(a, i, e) => z3.mkStore(rec(a), rec(i), rec(e)) - case FiniteMap(elems, default, keyTpe) => + case FiniteMap(elems, default, keyTpe, valueType) => val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) elems.foldLeft(ar) { @@ -612,12 +612,12 @@ trait AbstractZ3Solver case (k,v) => (rec(k, from), rec(v, to)) } - FiniteMap(entries.toSeq, default, from) + FiniteMap(entries.toSeq, default, from, to) case None => unsound(t, "invalid array AST") } case BagType(base) => - val fm @ FiniteMap(entries, default, from) = rec(t, MapType(base, IntegerType)) + val fm @ FiniteMap(entries, default, from, IntegerType) = rec(t, MapType(base, IntegerType)) if (default != IntegerLiteral(0)) { unsound(t, "co-finite bag AST") } diff --git a/src/main/scala/inox/tip/Parser.scala b/src/main/scala/inox/tip/Parser.scala index 72245aeb8..2b6122036 100644 --- a/src/main/scala/inox/tip/Parser.scala +++ b/src/main/scala/inox/tip/Parser.scala @@ -513,7 +513,8 @@ class Parser(file: File) { case ArraysEx.Select(e1, e2) => MapApply(extractTerm(e1), extractTerm(e2)) case ArraysEx.Store(e1, e2, e3) => MapUpdated(extractTerm(e1), extractTerm(e2), extractTerm(e3)) case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("const")), Some(sort)), Seq(dflt)) => - FiniteMap(Seq.empty, extractTerm(dflt), extractSort(sort)) + val d = extractTerm(dflt) + FiniteMap(Seq.empty, d, extractSort(sort), locals.symbols.bestRealType(d.getType(locals.symbols))) case Sets.Union(e1, e2) => SetUnion(extractTerm(e1), extractTerm(e2)) case Sets.Intersection(e1, e2) => SetIntersection(extractTerm(e1), extractTerm(e2)) -- GitLab