diff --git a/demo/Maps.scala b/demo/Maps.scala new file mode 100644 index 0000000000000000000000000000000000000000..77b4dac0e07c3a7737d909401194a82ecd61bbec --- /dev/null +++ b/demo/Maps.scala @@ -0,0 +1,39 @@ +import scala.collection.immutable.Set +import scala.collection.immutable.Map +import funcheck.Utils._ +import funcheck.Annotations._ + +object Maps { + // To implement: + // - updated -> MapUnion (simply mkStore) + // - isDefinedAt -> MapIsDefinedAt (look it up by using mkSelect, check if it is MapSome(bla) + // - apply -> MapGet (look it up by using mkSelect, and return the value if it is MapSome(v) + // + // - constant maps -> FiniteMap , empty map -> EmptyMap (use store on mkArrayConst with default value MapNone) + + // deal with it in: + // - trees OK (we assume current structure is suitable) + // - evaluator + // - codeextraction OK + // - extractors OK + // - solver + // - printer OK + + def applyTest(m : Map[Int,Int], i : Int) : Int = m(i) + + def isDefinedAtTest(m : Map[Int,Int], i : Int) : Boolean = m.isDefinedAt(i) + + def emptyTest() : Map[Int,Int] = Map.empty[Int,Int] + + def updatedTest(m : Map[Int,Int]) : Map[Int,Int] = m.updated(1, 2) + + def useCase(map : Map[Int,Int], k : Int, v : Int) : Boolean = { + val map2 = map.updated(k, v) + map2.isDefinedAt(k) + } holds + + def useCase2(map : Map[Int,Int], k : Int, v : Int) : Map[Int,Int] = { + val map2 = map.updated(k, v) + map2 + } ensuring (res => res.isDefinedAt(k)) +} diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index e86f1938e428d725d4771fa992859cf40bfe5d6c..ad8de4eca2716133a15df539fdb1f06fb44e504a 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -17,6 +17,7 @@ trait CodeExtraction extends Extractors { import ExpressionExtractors._ private lazy val setTraitSym = definitions.getClass("scala.collection.immutable.Set") + private lazy val mapTraitSym = definitions.getClass("scala.collection.immutable.Map") private lazy val multisetTraitSym = try { definitions.getClass("scala.collection.immutable.Multiset") } catch { @@ -29,6 +30,10 @@ trait CodeExtraction extends Extractors { sym == setTraitSym || sym.tpe.toString.startsWith("scala.Predef.Set") } + def isMapTraitSym(sym : Symbol) : Boolean = { + sym == mapTraitSym || sym.tpe.toString.startsWith("scala.Predef.Map") + } + def isMultisetTraitSym(sym : Symbol) : Boolean = { sym == multisetTraitSym } @@ -421,6 +426,11 @@ trait CodeExtraction extends Extractors { val underlying = scalaType2PureScala(unit, silent)(tt.tpe) EmptyMultiset(underlying).setType(MultisetType(underlying)) } + case ExEmptyMap(ft, tt) => { + val fromUnderlying = scalaType2PureScala(unit, silent)(ft.tpe) + val toUnderlying = scalaType2PureScala(unit, silent)(tt.tpe) + EmptyMap(fromUnderlying, toUnderlying).setType(MapType(fromUnderlying, toUnderlying)) + } case ExSetMin(t) => { val set = rec(t) if(!set.getType.isInstanceOf[SetType]) { @@ -516,6 +526,38 @@ trait CodeExtraction extends Extractors { } } } + case ExMapUpdated(m,f,t) => { + val rm = rec(m) + val rf = rec(f) + val rt = rec(t) + val newSingleton = SingletonMap(rf, rt).setType(rm.getType) + println("singleton: " + newSingleton) + rm.getType match { + case MapType(ft, tt) => + println("extracted maptype: " + MapType(ft, tt)) + MapUnion(rm, newSingleton).setType(rm.getType) + case _ => { + if (!silent) unit.error(tree.pos, "updated can only be applied to maps.") + throw ImpureCodeEncounteredException(tree) + } + } + } + case ExMapApply(m,f) => { + val rm = rec(m) + val rf = rec(f) + MapGet(rm, rf).setType(rm.getType match { + case MapType(_,toType) => toType + case _ => { + if (!silent) unit.error(tree.pos, "apply on non-map expression") + throw ImpureCodeEncounteredException(tree) + } + }) + } + case ExMapIsDefinedAt(m,k) => { + val rm = rec(m) + val rk = rec(k) + MapIsDefinedAt(rm, rk) + } case ExPlusPlusPlus(t1,t2) => { val rl = rec(t1) @@ -590,6 +632,7 @@ trait CodeExtraction extends Extractors { case TypeRef(_, sym, btt :: Nil) if isSetTraitSym(sym) => SetType(rec(btt)) case TypeRef(_, sym, btt :: Nil) if isMultisetTraitSym(sym) => MultisetType(rec(btt)) case TypeRef(_, sym, btt :: Nil) if isOptionClassSym(sym) => OptionType(rec(btt)) + case TypeRef(_, sym, List(ftt,ttt)) if isMapTraitSym(sym) => MapType(rec(ftt),rec(ttt)) case TypeRef(_, sym, Nil) if classesToClasses.keySet.contains(sym) => classDefToClassType(classesToClasses(sym)) case _ => { diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala index c1266088d59fb714812461049e21f0fab9feb0c2..48d6cc92a8498c5b34cc2b3d2e04a9eac74ebd19 100644 --- a/src/funcheck/Extractors.scala +++ b/src/funcheck/Extractors.scala @@ -384,6 +384,23 @@ trait Extractors { } } + object ExEmptyMap { + def unapply(tree: TypeApply): Option[(Tree,Tree)] = tree match { + case TypeApply( + Select( + Select( + Select( + Select(Ident(s), collectionName), + immutableName), + mapName), + emptyName), fromTypeTree :: toTypeTree :: Nil) if ( + collectionName.toString == "collection" && immutableName.toString == "immutable" && mapName.toString == "Map" && emptyName.toString == "empty" + ) => Some((fromTypeTree, toTypeTree)) + case TypeApply(Select(Select(Select(This(scalaName), predefName), mapName), emptyName), fromTypeTree :: toTypeTree :: Nil) if (scalaName.toString == "scala" && predefName.toString == "Predef" && emptyName.toString == "empty") => Some((fromTypeTree, toTypeTree)) + case _ => None + } + } + object ExFiniteSet { def unapply(tree: Apply): Option[(Tree,List[Tree])] = tree match { case Apply(TypeApply(Select(Select(Select(Select(Ident(s), collectionName), immutableName), setName), applyName), theTypeTree :: Nil), args) if (collectionName.toString == "collection" && immutableName.toString == "immutable" && setName.toString == "Set" && applyName.toString == "apply") => Some((theTypeTree, args)) @@ -409,6 +426,7 @@ trait Extractors { } } + object ExUnion { def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { case Apply(Select(lhs, n), List(rhs)) if (n == nme.PLUSPLUS) => Some((lhs,rhs)) @@ -464,5 +482,27 @@ trait Extractors { case _ => None } } + + object ExMapUpdated { + def unapply(tree: Apply): Option[(Tree,Tree,Tree)] = tree match { + case Apply(TypeApply(Select(lhs, n), typeTreeList), List(from, to)) if (n.toString == "updated") => + Some((lhs, from, to)) + case _ => None + } + } + + object ExMapApply { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n.toString == "apply") => Some((lhs, rhs)) + case _ => None + } + } + + object ExMapIsDefinedAt { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n.toString == "isDefinedAt") => Some((lhs, rhs)) + case _ => None + } + } } } diff --git a/src/purescala/FairZ3Solver.scala b/src/purescala/FairZ3Solver.scala index 5e5c17f93c9e492ed57da45190b92de94f49f9b5..3cd19b8343b0ac1d254ba8c283a485a076e57cdf 100644 --- a/src/purescala/FairZ3Solver.scala +++ b/src/purescala/FairZ3Solver.scala @@ -38,6 +38,54 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac private var toCheckAgainstModels: Expr = BooleanLiteral(true) private var varsInVC: Set[Identifier] = Set.empty + private val mapRangeSorts: MutableMap[TypeTree, Z3Sort] = MutableMap.empty + private val mapRangeSomeConstructors: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty + private val mapRangeNoneConstructors: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty + private val mapRangeSomeTesters: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty + private val mapRangeNoneTesters: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty + private val mapRangeValueSelectors: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty + + private def mapRangeSort(toType : TypeTree) : Z3Sort = mapRangeSorts.get(toType) match { + case Some(z3sort) => z3sort + case None => { + import Z3Context.{ADTSortReference, RecursiveType, RegularSort} + intSort = z3.mkIntSort + boolSort = z3.mkBoolSort + + def typeToSortRef(tt: TypeTree): ADTSortReference = tt match { + case BooleanType => RegularSort(boolSort) + case Int32Type => RegularSort(intSort) + case AbstractClassType(d) => RegularSort(adtSorts(d)) + case CaseClassType(d) => RegularSort(adtSorts(d)) + case _ => throw UntranslatableTypeException("Can't handle type " + tt) + } + + val z3info = z3.mkADTSorts( + Seq( + ( + toType.toString + "Option", + Seq(toType.toString + "Some", toType.toString + "None"), + Seq( + Seq(("value", typeToSortRef(toType))), + Seq() + ) + ) + ) + ) + + z3info match { + case Seq((optionSort, Seq(someCons, noneCons), Seq(someTester, noneTester), Seq(Seq(valueSelector), Seq()))) => + mapRangeSorts += ((toType, optionSort)) + mapRangeSomeConstructors += ((toType, someCons)) + mapRangeNoneConstructors += ((toType, noneCons)) + mapRangeSomeTesters += ((toType, someTester)) + mapRangeNoneTesters += ((toType, noneTester)) + mapRangeValueSelectors += ((toType, valueSelector)) + optionSort + } + } + } + override def setProgram(prog: Program): Unit = { program = prog } @@ -53,6 +101,16 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac exprToZ3Id = Map.empty z3IdToExpr = Map.empty + + mapSorts = Map.empty + + mapRangeSorts.clear + mapRangeSomeConstructors.clear + mapRangeNoneConstructors.clear + mapRangeSomeTesters.clear + mapRangeNoneTesters.clear + mapRangeValueSelectors.clear + counter = 0 prepareSorts prepareFunctions @@ -75,6 +133,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac private var intSort: Z3Sort = null private var boolSort: Z3Sort = null private var setSorts: Map[TypeTree, Z3Sort] = Map.empty + private var mapSorts: Map[TypeTree, Z3Sort] = Map.empty private var intSetMinFun: Z3FuncDecl = null private var intSetMaxFun: Z3FuncDecl = null private var setCardFuns: Map[TypeTree, Z3FuncDecl] = Map.empty @@ -285,6 +344,16 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac newSetSort } } + case mt @ MapType(fromType, toType) => mapSorts.get(mt) match { + case Some(s) => s + case None => { + val fromSort = typeToSort(fromType) + val toSort = mapRangeSort(toType) + val ms = z3.mkArraySort(fromSort, toSort) + mapSorts += ((mt, ms)) + ms + } + } case other => fallbackSorts.get(other) match { case Some(s) => s case None => { @@ -322,7 +391,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac val validatingStopwatch = new Stopwatch("validating", false) val decideTopLevelSw = new Stopwatch("top-level", false).start - // println("Deciding : " + vc) + println("Deciding : " + vc) initializationStopwatch.start @@ -793,6 +862,42 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac } case SetMin(s) => intSetMinFun(rec(s)) case SetMax(s) => intSetMaxFun(rec(s)) + case s @ SingletonMap(from,to) => s.getType match { + case MapType(fromType, toType) => + val fromSort = typeToSort(fromType) + val toSort = typeToSort(toType) + val constArray = z3.mkConstArray(toSort, mapRangeNoneConstructors(toType)()) + z3.mkStore(constArray, rec(from), mapRangeSomeConstructors(toType)(rec(to))) + case errorType => scala.Predef.error("Unexpected type for singleton map: " + errorType) + } + case e @ EmptyMap(fromType, toType) => { + val fromSort = typeToSort(fromType) + val toSort = typeToSort(toType) + z3.mkConstArray(toSort, mapRangeNoneConstructors(toType)()) + } + case f @ FiniteMap(elems) => f.getType match { + case MapType(fromType, toType) => + val fromSort = typeToSort(fromType) + val toSort = typeToSort(toType) + elems.foldLeft(z3.mkConstArray(toSort, mapRangeNoneConstructors(toType)())){ case (ast, SingletonMap(k,v)) => z3.mkStore(ast, rec(k), mapRangeSomeConstructors(toType)(rec(v))) } + case errorType => scala.Predef.error("Unexpected type for finite map: " + errorType) + } + case MapGet(m,k) => z3.mkSelect(rec(m), rec(k)) + case MapUnion(m1,m2) => m1.getType match { + case MapType(ft, tt) => m2 match { + case FiniteMap(ss) => + ss.foldLeft(rec(m1)){ + case (ast, SingletonMap(k, v)) => z3.mkStore(ast, rec(k), mapRangeSomeConstructors(tt)(rec(v))) + } + case SingletonMap(k, v) => z3.mkStore(rec(m1), rec(k), mapRangeSomeConstructors(tt)(rec(v))) + case _ => scala.Predef.error("map updates can only be applied with concrete map instances") + } + case errorType => scala.Predef.error("Unexpected type for map: " + errorType) + } + case MapIsDefinedAt(m,k) => m.getType match { + case MapType(ft, tt) => z3.mkDistinct(z3.mkSelect(rec(m), rec(k)), mapRangeNoneConstructors(tt)()) + case errorType => scala.Predef.error("Unexpected type for map: " + errorType) + } case Distinct(exs) => z3.mkDistinct(exs.map(rec(_)): _*) diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala index 5f93ef629904907539a4fed770e7374f4b02c664..a654580a7c780af282e057b1b0626f6fbb55ca6e 100644 --- a/src/purescala/PrettyPrinter.scala +++ b/src/purescala/PrettyPrinter.scala @@ -131,6 +131,7 @@ object PrettyPrinter { case SetMax(s) => pp(s, sb, lvl).append(".max") case SetUnion(l,r) => ppBinary(sb, l, r, " \u222A ", lvl) // \cup case MultisetUnion(l,r) => ppBinary(sb, l, r, " \u222A ", lvl) // \cup + case MapUnion(l,r) => ppBinary(sb, l, r, " \u222A ", lvl) // \cup case SetDifference(l,r) => ppBinary(sb, l, r, " \\ ", lvl) case MultisetDifference(l,r) => ppBinary(sb, l, r, " \\ ", lvl) case SetIntersection(l,r) => ppBinary(sb, l, r, " \u2229 ", lvl) // \cap @@ -139,6 +140,22 @@ object PrettyPrinter { case MultisetCardinality(t) => ppUnary(sb, t, "|", "|", lvl) case MultisetPlus(l,r) => ppBinary(sb, l, r, " \u228E ", lvl) // U+ case MultisetToSet(e) => pp(e, sb, lvl).append(".toSet") + case EmptyMap(_,_) => sb.append("{}") + case SingletonMap(f,t) => ppBinary(sb, f, t, " -> ", lvl) + case FiniteMap(rs) => ppNary(sb, rs, "{", ", ", "}", lvl) + case MapGet(m,k) => { + var nsb = sb + pp(m, nsb, lvl) + nsb = ppNary(nsb, Seq(k), "(", ",", ")", lvl) + nsb + } + case MapIsDefinedAt(m,k) => { + var nsb = sb + pp(m, nsb, lvl) + nsb.append(".isDefinedAt") + nsb = ppNary(nsb, Seq(k), "(", ",", ")", lvl) + nsb + } case Distinct(exprs) => { var nsb = sb @@ -233,6 +250,7 @@ object PrettyPrinter { case Int32Type => sb.append("Int") case BooleanType => sb.append("Boolean") case SetType(bt) => pp(bt, sb.append("Set["), lvl).append("]") + case MapType(ft,tt) => pp(tt, pp(ft, sb.append("Map["), lvl).append(","), lvl).append("]") case MultisetType(bt) => pp(bt, sb.append("Multiset["), lvl).append("]") case OptionType(bt) => pp(bt, sb.append("Option["), lvl).append("]") case c: ClassType => sb.append(c.classDef.id) diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index 902c3d54b0016296b23d7f052fedcbbb1966c2b0..e5259abd796b466e584f2a6a561785e4839475bf 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -319,6 +319,9 @@ object Trees { @serializable case class MapGet(map: Expr, key: Expr) extends Expr @serializable case class MapUnion(map1: Expr, map2: Expr) extends Expr @serializable case class MapDifference(map: Expr, keys: Expr) extends Expr + @serializable case class MapIsDefinedAt(map: Expr, key: Expr) extends Expr with FixedType { + val fixedType = BooleanType + } /* List operations */ @serializable case class NilList(baseType: TypeTree) extends Expr with Terminal @@ -378,6 +381,7 @@ object Trees { case MapGet(t1,t2) => Some((t1,t2,MapGet)) case MapUnion(t1,t2) => Some((t1,t2,MapUnion)) case MapDifference(t1,t2) => Some((t1,t2,MapDifference)) + case MapIsDefinedAt(t1,t2) => Some((t1,t2, MapIsDefinedAt)) case Concat(t1,t2) => Some((t1,t2,Concat)) case ListAt(t1,t2) => Some((t1,t2,ListAt)) case _ => None