diff --git a/src/main/scala/leon/codegen/runtime/Monitor.scala b/src/main/scala/leon/codegen/runtime/Monitor.scala index a1323b4e1ddbc39d4103a094d9fe843a397280fa..efe230e8df471d05702d410bec41b9631f4dcddc 100644 --- a/src/main/scala/leon/codegen/runtime/Monitor.scala +++ b/src/main/scala/leon/codegen/runtime/Monitor.scala @@ -102,7 +102,7 @@ class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Id val tparams = params.toSeq.map(unit.runtimeIdToTypeMap(_).asInstanceOf[TypeParameter]) val static = tps.toSeq.map(unit.runtimeIdToTypeMap(_)) val newTypes = newTps.toSeq.map(unit.runtimeIdToTypeMap(_)) - val tpMap = (tparams.map(TypeParameterDef(_)) zip newTypes).toMap + val tpMap = (tparams zip newTypes).toMap static.map(tpe => unit.registerType(instantiateType(tpe, tpMap))).toArray } @@ -143,7 +143,7 @@ class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Id val solver = solverf.getNewSolver() val newTypes = tps.toSeq.map(unit.runtimeIdToTypeMap(_)) - val tpMap = (tparams.map(TypeParameterDef(_)) zip newTypes).toMap + val tpMap = (tparams zip newTypes).toMap val newXs = p.xs.map { id => val newTpe = instantiateType(id.getType, tpMap) @@ -229,7 +229,7 @@ class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Id val solver = solverf.getNewSolver() val newTypes = tps.toSeq.map(unit.runtimeIdToTypeMap(_)) - val tpMap = (tparams.map(TypeParameterDef(_)) zip newTypes).toMap + val tpMap = (tparams zip newTypes).toMap val vars = variablesOf(f).toSeq.sortBy(_.uniqueName) val newVars = vars.map(id => FreshIdentifier(id.name, instantiateType(id.getType, tpMap), true)) diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index e37ce4c9816598b07f181b341c426984e5d5e23d..be000b03b3d31d4f2511a5a4f82d70b36c0cf68f 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -21,9 +21,9 @@ import Types.{TypeTree => LeonType, _} import Common._ import Extractors._ import Constructors._ -import ExprOps._ -import TypeOps.{leastUpperBound, typesCompatible, typeParamsOf, canBeSubtypeOf} -import xlang.Expressions.{Block => LeonBlock, _} +import ExprOps.exists +import TypeOps.{exists => _, _} +import xlang.Expressions.{Block => _, _} import xlang.ExprOps._ import xlang.Constructors.{block => leonBlock} @@ -1086,10 +1086,11 @@ trait CodeExtraction extends ASTExtractors { extractType(up.tpe), tupleTypeWrap(args map { tr => extractType(tr.tpe)}) )) - val newTps = canBeSubtypeOf(realTypes, typeParamsOf(formalTypes).toSeq, formalTypes) match { + val newTps = subtypingInstantiation(realTypes, formalTypes) match { case Some(tmap) => fd.tparams map { tpd => tmap.getOrElse(tpd.tp, tpd.tp) } case None => + //println(realTypes, formalTypes) reporter.fatalError("Could not instantiate type of unapply method") } diff --git a/src/main/scala/leon/grammars/FunctionCalls.scala b/src/main/scala/leon/grammars/FunctionCalls.scala index 00b0ef2be3f0612119e3dda4ae965444be5bbfaa..f3b143d03d92aade5b746d892cfc0391b0a6cd4e 100644 --- a/src/main/scala/leon/grammars/FunctionCalls.scala +++ b/src/main/scala/leon/grammars/FunctionCalls.scala @@ -28,10 +28,9 @@ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[Type val isDet = fd.body.exists(isDeterministic) if (!isRecursiveCall && isDet) { - val free = fd.tparams.map(_.tp) - - canBeSubtypeOf(fd.returnType, free, t, rhsFixed = true) match { + subtypingInstantiation(t, fd.returnType) match { case Some(tpsMap) => + val free = fd.tparams.map(_.tp) val tfd = fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp))) if (tpsMap.size < free.size) { @@ -48,7 +47,7 @@ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[Type var finalMap = tpsMap for (ptpe <- tfd.params.map(_.getType).distinct) { - canBeSubtypeOf(atpe, finalFree.toSeq, ptpe) match { + unify(atpe, ptpe, finalFree.toSeq) match { // FIXME!!!! this may allow weird things if lub!=ptpe case Some(ntpsMap) => finalFree --= ntpsMap.keySet finalMap ++= ntpsMap diff --git a/src/main/scala/leon/invariant/engine/RefinementEngine.scala b/src/main/scala/leon/invariant/engine/RefinementEngine.scala index 0c5189072dba45957c9c8a3d8a7ea7c9e0cdbe61..1be74fb8ffc882a6b96fd68886dbedd7364cb50c 100644 --- a/src/main/scala/leon/invariant/engine/RefinementEngine.scala +++ b/src/main/scala/leon/invariant/engine/RefinementEngine.scala @@ -151,7 +151,7 @@ class RefinementEngine(ctx: InferenceContext, prog: Program, ctrTracker: Constra if (shouldCreateVC(recFun, calldata.inSpec)) { reporter.info("Creating VC for " + recFun.id) // instantiate the body with new types - val tparamMap = (recFun.tparams zip recFunTyped.tps).toMap + val tparamMap = (recFun.typeArgs zip recFunTyped.tps).toMap val paramMap = recFun.params.map { pdef => pdef.id -> FreshIdentifier(pdef.id.name, instantiateType(pdef.id.getType, tparamMap)) }.toMap @@ -185,7 +185,7 @@ class RefinementEngine(ctx: InferenceContext, prog: Program, ctrTracker: Constra if (callee.isBodyVisible) { //here inline the body and conjoin it with the guard //Important: make sure we use a fresh body expression here, and freshenlocals - val tparamMap = (callee.tparams zip tfd.tps).toMap + val tparamMap = (callee.typeArgs zip tfd.tps).toMap val freshBody = instantiateType(replace(formalToActual(call), Equals(getFunctionReturnVariable(callee), freshenLocals(callee.body.get))), tparamMap, Map()) diff --git a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala index 4e6de1675701345d31b0b9d85829f50250678641..84dbcb2bac1b19311ad40ba227eeb951b4393d47 100644 --- a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala +++ b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala @@ -128,7 +128,7 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons rawpost } // instantiate the post - val tparamMap = (callee.tparams zip tfd.tps).toMap + val tparamMap = (callee.typeArgs zip tfd.tps).toMap val instSpec = instantiateType(replace(formalToActual(call), rawspec), tparamMap, Map()) val inlinedSpec = ExpressionTransformer.normalizeExpr(instSpec, ctx.multOp) Some(inlinedSpec) @@ -142,7 +142,7 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons val callee = tfd.fd if (callee.hasTemplate) { val argmap = formalToActual(call) - val tparamMap = (callee.tparams zip tfd.tps).toMap + val tparamMap = (callee.typeArgs zip tfd.tps).toMap val tempExpr = instantiateType(replace(argmap, freshenLocals(callee.getTemplate)), tparamMap, Map()) val template = if (callee.hasPrecondition) { val pre = instantiateType(replace(argmap, freshenLocals(callee.precondition.get)), tparamMap, Map()) diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 16b7932d6c5b0958fae5c2a68f2821398a1af532..e54cb95c83b485073d92c50d4ef322822ef0ae3f 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -122,7 +122,7 @@ object Constructors { val formalType = tupleTypeWrap(fd.params map { _.getType }) val actualType = tupleTypeWrap(args map { _.getType }) - canBeSubtypeOf(actualType, typeParamsOf(formalType).toSeq, formalType) match { + subtypingInstantiation(actualType, formalType) match { case Some(tmap) => FunctionInvocation(fd.typed(fd.tparams map { tpd => tmap.getOrElse(tpd.tp, tpd.tp) }), args) case None => throw LeonFatalError(s"$args:$actualType cannot be a subtype of $formalType!") diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 718a50e1561ca5aa85a475abc8240b25d4685016..a86700e5419d736b96f834606ad40aa3355431fa 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -332,6 +332,8 @@ object Definitions { lazy val definedFunctions : Seq[FunDef] = methods lazy val definedClasses = Seq(this) + def typeArgs = tparams map (_.tp) + def typed(tps: Seq[TypeTree]): ClassType def typed: ClassType } @@ -536,6 +538,8 @@ object Definitions { def paramIds = params map { _.id } + def typeArgs = tparams map (_.tp) + def applied(args: Seq[Expr]): FunctionInvocation = Constructors.functionInvocation(this, args) def applied = FunctionInvocation(this.typed, this.paramIds map Variable) } @@ -553,8 +557,8 @@ object Definitions { } } - private lazy val typesMap: Map[TypeParameterDef, TypeTree] = { - (fd.tparams zip tps).toMap.filter(tt => tt._1.tp != tt._2) + private lazy val typesMap: Map[TypeParameter, TypeTree] = { + (fd.typeArgs zip tps).toMap.filter(tt => tt._1 != tt._2) } def translated(t: TypeTree): TypeTree = instantiateType(t, typesMap) diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 8e778b4d9fd91c4fdaa9a9302ee3179f76babc04..96a5605d2f27646e21bf753408b779478a8ebc95 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -1310,15 +1310,6 @@ object ExprOps extends GenTreeOps[Expr] { b: Apriori => Option[Apriori]): Option[Apriori] = a.flatMap(b) - object Same { - def unapply(tt: (Expr, Expr)): Option[(Expr, Expr)] = { - if (tt._1.getClass == tt._2.getClass) { - Some(tt) - } else { - None - } - } - } implicit class AugmentedContext(c: Option[Apriori]) { def &&(other: Apriori => Option[Apriori]): Option[Apriori] = mergeContexts(c, other) def --(other: Seq[Identifier]) = diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 3b9e396b01e3fa5c0353e591aec14ec0d232eb3a..8b8f0e3002ceedbc12a92dcd533c106d80d6cdaa 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -183,9 +183,9 @@ object Expressions { val getType = { // We need ot instantiate the type based on the type of the function as well as receiver val fdret = tfd.returnType - val extraMap: Map[TypeParameterDef, TypeTree] = rec.getType match { + val extraMap: Map[TypeParameter, TypeTree] = rec.getType match { case ct: ClassType => - (cd.tparams zip ct.tps).toMap + (cd.typeArgs zip ct.tps).toMap case _ => Map() } diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index 1375cdf26fad10aa5c8bcf568086e1fdd69ba6c0..a98dbf64b8a52c9d049ca6b00260379940683945 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -6,10 +6,9 @@ package purescala import Definitions._ import Expressions._ import ExprOps._ -import Constructors._ import TypeOps.instantiateType import Common.Identifier -import Types.TypeParameter +import leon.purescala.Types.TypeParameter import utils.GraphOps._ object FunctionClosure extends TransformationPhase { @@ -152,7 +151,7 @@ object FunctionClosure extends TransformationPhase { val reqPC = pc.filterByIds(free.toSet) val tpFresh = outer.tparams map { _.freshen } - val tparamsMap = outer.tparams.zip(tpFresh map {_.tp}).toMap + val tparamsMap = outer.typeArgs.zip(tpFresh map {_.tp}).toMap val freshVals = (inner.paramIds ++ free).map{_.freshen}.map(instantiateType(_, tparamsMap)) val freeMap = (inner.paramIds ++ free).zip(freshVals).toMap @@ -185,7 +184,7 @@ object FunctionClosure extends TransformationPhase { newFd.fullBody = replaceFromIDs(freeMap.map(p => (p._1, p._2.toVariable)), newFd.fullBody) - FunSubst(newFd, freeMap, tparamsMap.map{ case (from, to) => from.tp -> to}) + FunSubst(newFd, freeMap, tparamsMap) } override def apply(ctx: LeonContext, program: Program): Program = { diff --git a/src/main/scala/leon/purescala/GenTreeOps.scala b/src/main/scala/leon/purescala/GenTreeOps.scala index 8a8645da8045d4f96b43f0d42144b654f3c4c4a0..a8ced270426bde209a0448e848e9cbe9f7a4280f 100644 --- a/src/main/scala/leon/purescala/GenTreeOps.scala +++ b/src/main/scala/leon/purescala/GenTreeOps.scala @@ -415,5 +415,14 @@ trait GenTreeOps[SubTree <: Tree] { fold[Int]{ (_, sub) => 1 + (0 +: sub).max }(e) } + object Same { + def unapply(tt: (SubTree, SubTree)): Option[(SubTree, SubTree)] = { + if (tt._1.getClass == tt._2.getClass) { + Some(tt) + } else { + None + } + } + } } diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 8aaf2eafe8596626eac04bdbe15c6a7a5d56279f..64f4cb5827f082e986b67a225e20507f4a70f1cf 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -161,7 +161,7 @@ object MethodLifting extends TransformationPhase { if c.ancestors.forall(!_.methods.map{_.id}.contains(fd.id)) } { val root = c.ancestors.last - val tMap = c.tparams.zip(root.tparams.map{_.tp}).toMap + val tMap = c.typeArgs.zip(root.typeArgs).toMap val tSubst: TypeTree => TypeTree = instantiateType(_, tMap) val fdParams = fd.params map { vd => @@ -187,7 +187,7 @@ object MethodLifting extends TransformationPhase { for { cd <- u.classHierarchyRoots; fd <- cd.methods } { // We import class type params and freshen them val ctParams = cd.tparams map { _.freshen } - val tparamsMap = cd.tparams.zip(ctParams map { _.tp }).toMap + val tparamsMap = cd.typeArgs.zip(ctParams map { _.tp }).toMap val id = fd.id.freshen val recType = cd.typed(ctParams.map(_.tp)) @@ -239,13 +239,13 @@ object MethodLifting extends TransformationPhase { val classParamsMap = (for { c <- cd.knownDescendants :+ cd (from, to) <- c.tparams zip ctParams - } yield (from, to.tp)).toMap + } yield (from.tp, to.tp)).toMap val methodParamsMap = (for { c <- cd.knownDescendants :+ cd m <- c.methods if m.id == fd.id (from,to) <- m.tparams zip fd.tparams - } yield (from, to.tp)).toMap + } yield (from.tp, to.tp)).toMap def inst(cs: Seq[MatchCase]) = instantiateType( matchExpr(Variable(receiver), cs).setPos(fd), diff --git a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala index 03aea55aae10911a9c5e6d9fa87cecc3d06b2c5d..f0817f3bc732995036113f3d2d81fedec81a5c80 100644 --- a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala +++ b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala @@ -64,10 +64,9 @@ trait PrettyPrinterFinder[T, U >: T] { def buildLambda(inputType: TypeTree, fd: FunDef, slu: Stream[List[U]]): Stream[T] def prettyPrinterFromCandidate(fd: FunDef, inputType: TypeTree)(implicit ctx: LeonContext, program: Program): Stream[T] = { - TypeOps.canBeSubtypeOf(inputType, fd.tparams.map(_.tp), fd.params.head.getType) match { + TypeOps.subtypingInstantiation(inputType, fd.params.head.getType) match { case Some(genericTypeMap) => //println("Found a mapping") - val defGenericTypeMap = genericTypeMap.map{ case (k, v) => (Definitions.TypeParameterDef(k), v) } def gatherPrettyPrinters(funIds: List[Identifier], acc: ListBuffer[Stream[U]] = ListBuffer[Stream[U]]()): Option[Stream[List[U]]] = funIds match { case Nil => Some(StreamUtils.cartesianProduct(acc.toList)) case funId::tail => // For each function, find an expression which could be provided if it exists. @@ -77,7 +76,7 @@ trait PrettyPrinterFinder[T, U >: T] { None } } - val funIds = fd.params.tail.map(x => TypeOps.instantiateType(x.id, defGenericTypeMap)).toList + val funIds = fd.params.tail.map(x => TypeOps.instantiateType(x.id, genericTypeMap)).toList gatherPrettyPrinters(funIds) match { case Some(l) => buildLambda(inputType, fd, l) case None => Stream.empty diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index ac28c93cb63879879ef395387277021c550b2eec..600127cc495fb45450d6e1c6f97d279e47e27c07 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -22,155 +22,126 @@ object TypeOps extends GenTreeOps[TypeTree] { subs.flatMap(typeParamsOf).toSet } - def canBeSubtypeOf( - tpe: TypeTree, - freeParams: Seq[TypeParameter], - stpe: TypeTree, - lhsFixed: Boolean = false, - rhsFixed: Boolean = false - ): Option[Map[TypeParameter, TypeTree]] = { - - def unify(res: Seq[Option[Map[TypeParameter, TypeTree]]]): Option[Map[TypeParameter, TypeTree]] = { - if (res.forall(_.isDefined)) { - var result = Map[TypeParameter, TypeTree]() - - for (Some(m) <- res; (f, t) <- m) { - result.get(f) match { - case Some(t2) if t != t2 => return None - case _ => result += (f -> t) - } - } - Some(result) - } else { - None - } + /** Generic type bounds between two types. Serves as a base for a set of subtyping/unification functions. + * It will allow subtyping between classes (but type parameters are invariant). + * It will also allow a set of free parameters to be unified if needed. + * + * @param allowSub Allow subtyping for class types + * @param freeParams The unifiable type parameters + * @param isLub Whether the bound calculated is the LUB + * @return An optional pair of (type, map) where type the resulting type bound + * (with type parameters instantiated as needed) + * and map the assignment of type variables. + * Result is empty if types are incompatible. + * @see [[leastUpperBound]], [[greatestLowerBound]], [[isSubtypeOf]], [[typesCompatible]], [[unify]] + */ + def typeBound(t1: TypeTree, t2: TypeTree, isLub: Boolean, allowSub: Boolean) + (implicit freeParams: Seq[TypeParameter]): Option[(TypeTree, Map[TypeParameter, TypeTree])] = { + + def flatten(res: Seq[Option[(TypeTree, Map[TypeParameter, TypeTree])]]): Option[(Seq[TypeTree], Map[TypeParameter, TypeTree])] = { + val (tps, subst) = res.map(_.getOrElse(return None)).unzip + val flat = subst.flatMap(_.toSeq).groupBy(_._1) + Some((tps, flat.mapValues { vs => + vs.map(_._2).distinct match { + case Seq(unique) => unique + case _ => return None + } + })) } - if (freeParams.isEmpty) { - if (isSubtypeOf(tpe, stpe)) { - Some(Map()) - } else { - None - } - } else { - (tpe, stpe) match { - case (t1, t2) if t1 == t2 => - Some(Map()) + (t1, t2) match { + case (_: TypeParameter, _: TypeParameter) if t1 == t2 => + Some((t1, Map())) - case (t, tp1: TypeParameter) if (freeParams contains tp1) && (!rhsFixed) && !(typeParamsOf(t) contains tp1) => - Some(Map(tp1 -> t)) + case (t, tp1: TypeParameter) if (freeParams contains tp1) && !(typeParamsOf(t) contains tp1) => + Some((t, Map(tp1 -> t))) - case (tp1: TypeParameter, t) if (freeParams contains tp1) && (!lhsFixed) && !(typeParamsOf(t) contains tp1) => - Some(Map(tp1 -> t)) + case (tp1: TypeParameter, t) if (freeParams contains tp1) && !(typeParamsOf(t) contains tp1) => + Some((t, Map(tp1 -> t))) - case (ct1: ClassType, ct2: ClassType) => - val rt1 = ct1.root - val rt2 = ct2.root + case (_: TypeParameter, _) => + None + case (_, _: TypeParameter) => + None - if (rt1.classDef == rt2.classDef) { - unify((rt1.tps zip rt2.tps).map { case (tp1, tp2) => - canBeSubtypeOf(tp1, freeParams, tp2, lhsFixed, rhsFixed) - }) + case (ct1: ClassType, ct2: ClassType) => + val cd1 = ct1.classDef + val cd2 = ct2.classDef + val bound: Option[ClassDef] = if (allowSub) { + val an1 = cd1 +: cd1.ancestors + val an2 = cd2 +: cd2.ancestors + if (isLub) { + (an1.reverse zip an2.reverse) + .takeWhile(((_: ClassDef) == (_: ClassDef)).tupled) + .lastOption.map(_._1) } else { - None + // Lower bound + if(an1.contains(cd2)) Some(cd1) + else if (an2.contains(cd1)) Some(cd2) + else None } + } else { + (cd1 == cd2).option(cd1) + } + for { + cd <- bound + (subs, map) <- flatten((ct1.tps zip ct2.tps).map { case (tp1, tp2) => + // Class types are invariant! + typeBound(tp1, tp2, isLub, allowSub = false) + }) + } yield (cd.typed(subs), map) - case (_: TupleType, _: TupleType) | - (_: SetType, _: SetType) | - (_: MapType, _: MapType) | - (_: BagType, _: BagType) | - (_: FunctionType, _: FunctionType) => - - val NAryType(ts1, _) = tpe - val NAryType(ts2, _) = stpe - - if (ts1.size == ts2.size) { - unify((ts1 zip ts2).map { case (tp1, tp2) => - canBeSubtypeOf(tp1, freeParams, tp2, lhsFixed, rhsFixed) - }) - } else { - None + case (FunctionType(from1, to1), FunctionType(from2, to2)) => + if (from1.size != from2.size) None + else { + val in = (from1 zip from2).map { case (tp1, tp2) => + typeBound(tp1, tp2, !isLub, allowSub) // Contravariant args } + val out = typeBound(to1, to2, isLub, allowSub) // Covariant result + flatten(out +: in) map { + case (Seq(newTo, newFrom@_*), map) => + (FunctionType(newFrom, newTo), map) + } + } - case (t1, t2) => - None - } - } - } - - def bestRealType(t: TypeTree) : TypeTree = t match { - case (c: ClassType) => c.root - case NAryType(tps, builder) => builder(tps.map(bestRealType)) - } - - def leastUpperBound(t1: TypeTree, t2: TypeTree): Option[TypeTree] = (t1,t2) match { - case (c1: ClassType, c2: ClassType) => - - def computeChain(ct: ClassType): List[ClassType] = ct.parent match { - case Some(pct) => - computeChain(pct) ::: List(ct) - case None => - List(ct) - } - - val chain1 = computeChain(c1) - val chain2 = computeChain(c2) - - val prefix = (chain1 zip chain2).takeWhile { case (ct1, ct2) => ct1 == ct2 }.map(_._1) - - prefix.lastOption - - case (TupleType(args1), TupleType(args2)) if args1.size == args2.size => - val args = (args1 zip args2).map(p => leastUpperBound(p._1, p._2)) - if (args.forall(_.isDefined)) Some(TupleType(args.map(_.get))) else None - - case (FunctionType(from1, to1), FunctionType(from2, to2)) => - val args = (from1 zip from2).map(p => greatestLowerBound(p._1, p._2)) - if (args.forall(_.isDefined)) { - leastUpperBound(to1, to2) map { FunctionType(args.map(_.get), _) } - } else { + case Same(t1, t2) => + // Only tuples are covariant + def allowVariance = t1 match { + case _ : TupleType => true + case _ => false + } + val NAryType(ts1, recon) = t1 + val NAryType(ts2, _) = t2 + if (ts1.size == ts2.size) { + flatten((ts1 zip ts2).map { case (tp1, tp2) => + typeBound(tp1, tp2, isLub, allowSub = allowVariance) + }).map { case (subs, map) => (recon(subs), map) } + } else None + + case _ => None - } - - case (o1, o2) if o1 == o2 => Some(o1) - case _ => None + } } - def greatestLowerBound(t1: TypeTree, t2: TypeTree): Option[TypeTree] = (t1,t2) match { - case (c1: ClassType, c2: ClassType) => + def unify(tp1: TypeTree, tp2: TypeTree, freeParams: Seq[TypeParameter]) = + typeBound(tp1, tp2, isLub = true, allowSub = false)(freeParams).map(_._2) - def computeChains(ct: ClassType): Set[ClassType] = ct.parent match { - case Some(pct) => - computeChains(pct) + ct - case None => - Set(ct) - } - - if (computeChains(c1)(c2)) { - Some(c2) - } else if (computeChains(c2)(c1)) { - Some(c1) - } else { - None - } - - case (TupleType(args1), TupleType(args2)) => - val args = (args1 zip args2).map(p => greatestLowerBound(p._1, p._2)) - if (args.forall(_.isDefined)) Some(TupleType(args.map(_.get))) else None + /** Will try to instantiate superT so that subT <: superT + * + * @return Mapping of instantiations + */ + def subtypingInstantiation(subT: TypeTree, superT: TypeTree) = + typeBound(subT, superT, isLub = true, allowSub = true)(typeParamsOf(superT).toSeq) collect { + case (tp, map) if instantiateType(superT, map) == tp => map + } - case (FunctionType(from1, to1), FunctionType(from2, to2)) => - val args = (from1 zip from2).map(p => leastUpperBound(p._1, p._2)) - if (args.forall(_.isDefined)) { - greatestLowerBound(to1, to2).map { FunctionType(args.map(_.get), _) } - } else { - None - } + def leastUpperBound(tp1: TypeTree, tp2: TypeTree): Option[TypeTree] = + typeBound(tp1, tp2, isLub = true, allowSub = true)(Seq()).map(_._1) - case (o1, o2) if o1 == o2 => Some(o1) - case _ => None - } + def greatestLowerBound(tp1: TypeTree, tp2: TypeTree): Option[TypeTree] = + typeBound(tp1, tp2, isLub = false, allowSub = true)(Seq()).map(_._1) def leastUpperBound(ts: Seq[TypeTree]): Option[TypeTree] = { def olub(ot1: Option[TypeTree], t2: Option[TypeTree]): Option[TypeTree] = ot1 match { @@ -201,6 +172,11 @@ object TypeOps extends GenTreeOps[TypeTree] { } } + def bestRealType(t: TypeTree) : TypeTree = t match { + case (c: ClassType) => c.root + case NAryType(tps, builder) => builder(tps.map(bestRealType)) + } + def isParametricType(tpe: TypeTree): Boolean = tpe match { case (tp: TypeParameter) => true case NAryType(tps, builder) => tps.exists(isParametricType) @@ -220,26 +196,26 @@ object TypeOps extends GenTreeOps[TypeTree] { } } - def instantiateType(id: Identifier, tps: Map[TypeParameterDef, TypeTree]): Identifier = { - freshId(id, typeParamSubst(tps map { case (tpd, tp) => tpd.tp -> tp })(id.getType)) + def instantiateType(id: Identifier, tps: Map[TypeParameter, TypeTree]): Identifier = { + freshId(id, typeParamSubst(tps)(id.getType)) } - def instantiateType(tpe: TypeTree, tps: Map[TypeParameterDef, TypeTree]): TypeTree = { + def instantiateType(tpe: TypeTree, tps: Map[TypeParameter, TypeTree]): TypeTree = { if (tps.isEmpty) { tpe } else { - typeParamSubst(tps.map { case (tpd, tp) => tpd.tp -> tp })(tpe) + typeParamSubst(tps)(tpe) } } - def instantiateType(e: Expr, tps: Map[TypeParameterDef, TypeTree], ids: Map[Identifier, Identifier]): Expr = { + def instantiateType(e: Expr, tps: Map[TypeParameter, TypeTree], ids: Map[Identifier, Identifier]): Expr = { if (tps.isEmpty && ids.isEmpty) { e } else { val tpeSub = if (tps.isEmpty) { { (tpe: TypeTree) => tpe } } else { - typeParamSubst(tps.map { case (tpd, tp) => tpd.tp -> tp }) _ + typeParamSubst(tps) _ } val transformer = new TreeTransformer { @@ -255,6 +231,8 @@ object TypeOps extends GenTreeOps[TypeTree] { case Untyped => Some(0) case BooleanType => Some(2) case UnitType => Some(1) + case TupleType(tps) => + Some(tps.map(typeCardinality).map(_.getOrElse(return None)).product) case SetType(base) => typeCardinality(base).map(b => Math.pow(2, b).toInt) case FunctionType(from, to) => diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index f5976191343c7fe08f3c4b9164c5f2815727c9c3..2ae80beacdfd6d4a9fb88a4b25b9543f461f0e47 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -98,7 +98,7 @@ object Types { assert(classDef.tparams.size == tps.size) lazy val fields = { - val tmap = (classDef.tparams zip tps).toMap + val tmap = (classDef.typeArgs zip tps).toMap if (tmap.isEmpty) { classDef.fields } else { @@ -124,11 +124,12 @@ object Types { lazy val root: ClassType = parent.map{ _.root }.getOrElse(this) lazy val parent = classDef.parent.map { pct => - instantiateType(pct, (classDef.tparams zip tps).toMap) match { + instantiateType(pct, (classDef.typeArgs zip tps).toMap) match { case act: AbstractClassType => act case t => throw LeonFatalError("Unexpected translated parent type: "+t) } } + } case class AbstractClassType(classDef: AbstractClassDef, tps: Seq[TypeTree]) extends ClassType diff --git a/src/main/scala/leon/synthesis/rules/Abduction.scala b/src/main/scala/leon/synthesis/rules/Abduction.scala index da32100f30ad7a14e71815c278d8c883fa6edc67..f8991f2d7953af4d34ca8956fff86f4365cff966 100644 --- a/src/main/scala/leon/synthesis/rules/Abduction.scala +++ b/src/main/scala/leon/synthesis/rules/Abduction.scala @@ -7,7 +7,7 @@ package rules import purescala.Common._ import purescala.DefOps._ import purescala.Expressions._ -import purescala.TypeOps.canBeSubtypeOf +import purescala.TypeOps.unify import purescala.Constructors._ import purescala.ExprOps._ import purescala.Definitions._ @@ -17,88 +17,7 @@ import leon.utils.Simplifiers object Abduction extends Rule("Abduction") { override def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - - // Let ⟦ as ⟨ ws && pc | phi ⟩ xs ⟧ be the original problem - def processFd(tfd: TypedFunDef): Option[RuleInstantiation] = { - // We assign xs = tfd(newXs), newXs fresh - val newXs = tfd.paramIds map (_.freshen) - val args = newXs map Variable - val call = FunctionInvocation(tfd, args) - // prec. of tfd(newXs) (newXs have to satisfy it) - val pre = replaceFromIDs(tfd.paramIds.zip(args).toMap, tfd.precOrTrue) - - // Fail if pre is not SAT - val solver = SimpleSolverAPI(SolverFactory.getFromSettings(hctx, hctx.program)) - if (!solver.solveSAT(p.pc and pre)._1.contains(true)) return None - - // postc. of tfd(newXs) will hold for xs - val post = application( - replaceFromIDs(tfd.paramIds.zip(args).toMap, tfd.postOrTrue), - Seq(tupleWrap(p.xs map Variable)) - ) - - // Conceptually, the new problem is - // ⟦ as ⟨ ws && pc && xs <- tfd(newXs) && post | pre && phi ⟩ newXs ⟧ - // But we will only accept this problem if xs is simplifiable in phi under the new assumptions - - def containsXs(e: Expr) = (variablesOf(e) & p.xs.toSet).nonEmpty - - val newPhi = { - - val newPhi0 = { - // Try to eliminate xs in trivial cases - val TopLevelAnds(newPhis) = and(pre, post) - val equalities = newPhis.collect { - case Equals(e1, e2) if containsXs(e1) && !containsXs(e2) => e1 -> e2 - case Equals(e1, e2) if containsXs(e2) && !containsXs(e1) => e2 -> e1 - }.toMap - - replace(equalities, p.phi) - } - - val bigX = FreshIdentifier("x", p.outType, alwaysShowUniqueID = true) - - val projections = unwrapTuple(call, p.xs.size).zipWithIndex.map{ - case (e, ind) => tupleSelect(e, ind + 1, p.xs.size) - } - - val pc = p.pc - .withCond(pre) - .withBinding(bigX, call) - .withBindings(newXs zip projections) - .withCond(post) - - Simplifiers.bestEffort(hctx, hctx.program)(newPhi0, pc) - } - - if (containsXs(newPhi)) { - None - } else { - // We do not need xs anymore, so we accept the decomposition. - // We can remove xs-related cluases from the PC. - // Notice however we need to synthesize newXs such that pre is satisfied as well. - // Final problem is: - // ⟦ as ⟨ ws && pc | pre && phi ⟩ newXs ⟧ - - val newP = p.copy(phi = newPhi, xs = newXs.toList) - - val onSuccess = forwardMap(letTuple(p.xs, call, _)) - - Some(decomp(List(newP), onSuccess, "Blah")) - } - } - - val filter = (fd: FunDef) => fd.isSynthetic || fd.isInner - - val funcs = visibleFunDefsFromMain(hctx.program).toSeq.sortBy(_.id).filterNot(filter) - - // For now, let's handle all outputs at once only - for { - fd <- funcs - inst <- canBeSubtypeOf(p.outType, fd.tparams.map(_.tp), fd.returnType) - decomp <- processFd(fd.typed(fd.tparams map (tpar => inst(tpar.tp)))) - } yield decomp - + Nil } } diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala index b0bb9d50586264ce71665cf1e56da1fee0756100..dfb6da0752c5e788ad89377fc3c6e087f185c798 100644 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ b/src/main/scala/leon/synthesis/rules/StringRender.scala @@ -361,7 +361,7 @@ case object StringRender extends Rule("StringRender") { def extractCaseVariants(cct: CaseClassType, ctx: StringSynthesisContext) : (Stream[WithIds[MatchCase]], StringSynthesisResult) = cct match { case CaseClassType(ccd: CaseClassDef, tparams2) => - val typeMap = ccd.tparams.zip(tparams2).toMap + val typeMap = ccd.typeArgs.zip(tparams2).toMap val fields = ccd.fields.map(vd => TypeOps.instantiateType(vd.id, typeMap) ) val pattern = CaseClassPattern(None, ccd.typed(tparams2), fields.map(k => WildcardPattern(Some(k)))) val (rhs, result) = createFunDefsTemplates(ctx.copy(currentCaseClassParent=Some(cct)), fields.map(Variable)) // Invoke functions for each of the fields. @@ -387,7 +387,7 @@ case object StringRender extends Rule("StringRender") { def constantPatternMatching(fd: FunDef, act: AbstractClassType): WithIds[MatchExpr] = { val cases = (ListBuffer[WithIds[MatchCase]]() /: act.knownCCDescendants) { case (acc, cct @ CaseClassType(ccd, tparams2)) => - val typeMap = ccd.tparams.zip(tparams2).toMap + val typeMap = ccd.typeArgs.zip(tparams2).toMap val fields = ccd.fields.map(vd => TypeOps.instantiateType(vd.id, typeMap) ) val pattern = CaseClassPattern(None, ccd.typed(tparams2), fields.map(k => WildcardPattern(Some(k)))) val rhs = StringLiteral(ccd.id.asString) diff --git a/src/main/scala/leon/synthesis/utils/Helpers.scala b/src/main/scala/leon/synthesis/utils/Helpers.scala index 0a76050a1e99efb4a86a544af6e62ac036939b40..3bbefca1535aa51f16921d3b2edf4823c3d16de4 100644 --- a/src/main/scala/leon/synthesis/utils/Helpers.scala +++ b/src/main/scala/leon/synthesis/utils/Helpers.scala @@ -25,10 +25,9 @@ object Helpers { */ def functionsReturning(fds: Set[FunDef], tpe: TypeTree): Set[TypedFunDef] = { fds.flatMap { fd => - val free = fd.tparams.map(_.tp) - canBeSubtypeOf(fd.returnType, free, tpe) match { + subtypingInstantiation(tpe, fd.returnType) match { case Some(tpsMap) => - Some(fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp)))) + Some(fd.typed(fd.typeArgs.map(tp => tpsMap.getOrElse(tp, tp)))) case None => None } diff --git a/src/main/scala/leon/verification/TraceInductionTactic.scala b/src/main/scala/leon/verification/TraceInductionTactic.scala index f1c7b455d359e61708a9c5e6e773031bd3756b86..73e9f464bf379c742a9f89346a5d7c6eff44d7d4 100644 --- a/src/main/scala/leon/verification/TraceInductionTactic.scala +++ b/src/main/scala/leon/verification/TraceInductionTactic.scala @@ -144,7 +144,7 @@ class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { } } val argsMap = callee.params.map(_.id).zip(finv.args).toMap - val tparamMap = callee.tparams.zip(finv.tfd.tps).toMap + val tparamMap = callee.typeArgs.zip(finv.tfd.tps).toMap val inlinedBody = instantiateType(replaceFromIDs(argsMap, callee.body.get), tparamMap, Map()) val inductScheme = inductPattern(inlinedBody) // add body, pre and post for the tactFun diff --git a/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala b/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala index c6158d66a71bb8796da2ff6aa52a8f9ab4544c6d..29c47be9410d3b2a122258cdc93246ae9171b845 100644 --- a/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala +++ b/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala @@ -7,12 +7,11 @@ import leon.purescala.Common._ import leon.purescala.Expressions._ import leon.purescala.Definitions._ import leon.purescala.Types._ -import leon.purescala.ExprOps._ import leon.purescala.TypeOps._ class TypeOpsSuite extends LeonTestSuite with helpers.WithLikelyEq with helpers.ExpressionsDSL { - test("canBeSubtypeOf 1") { ctx => + test("type bounds") { ctx => val tp = TypeParameter.fresh("T") val tpD = new TypeParameterDef(tp) @@ -28,47 +27,66 @@ class TypeOpsSuite extends LeonTestSuite with helpers.WithLikelyEq with helpers. val consD = new CaseClassDef(FreshIdentifier("Cons"), Seq(tpD), Some(listT), false) val consT = consD.typed - // Simple tests for fixed types - assert(canBeSubtypeOf(tp, Seq(), tp).isDefined, "Same types can be subtypes") - assert(canBeSubtypeOf(listT, Seq(), listT).isDefined, "Different types are not subtypes") - - assert(canBeSubtypeOf(tp2, Seq(), tp3).isEmpty, "Different types are not subtypes") - assert(canBeSubtypeOf(BooleanType, Seq(), tp3).isEmpty, "Different types are not subtypes") - assert(canBeSubtypeOf(tp2, Seq(), BooleanType).isEmpty, "Different types are not subtypes") - assert(canBeSubtypeOf(IntegerType, Seq(), Int32Type).isEmpty, "Different types are not subtypes") + assert(isSubtypeOf(tp, tp), "T <: T") + assert(isSubtypeOf(listT, listT), "List[T] <: List[T]") + assert(isSubtypeOf(listD.typed, listD.typed), "List[T] <: List[T]") - assert(canBeSubtypeOf(nilT, Seq(), listT).isDefined, "Subtypes are subtypes") - assert(canBeSubtypeOf(consT, Seq(), listT).isDefined, "Subtypes are subtypes") + assert(isSubtypeOf(nilT, listT), "Subtypes are subtypes") + assert(isSubtypeOf(consT, listT), "Subtypes are subtypes") - assert(canBeSubtypeOf(listT, Seq(), nilT).isEmpty, "Supertypes are not subtypes") - assert(canBeSubtypeOf(listT, Seq(), consT).isEmpty, "Supertypes are not subtypes") + assert(!isSubtypeOf(listT, nilT ), "Supertypes are not subtypes") + assert(!isSubtypeOf(listT, consT), "Supertypes are not subtypes") - // Type parameters - assert(canBeSubtypeOf(tp, Seq(tp), tp2).isDefined, "T <: A with T free") - assert(canBeSubtypeOf(tp, Seq(tp2), tp2).isDefined, "T <: A with A free") + assert(!isSubtypeOf(nilD.typed(Seq(tp2)), listD.typed(Seq(tp3))), "Types are not subtypes with incompatible params") + assert(!isSubtypeOf(nilD.typed(Seq(tp2)), listD.typed(Seq(IntegerType))), "Types are not subtypes with incompatible params") + assert(!isSubtypeOf(SetType(tp2), SetType(tp3)), "Types are not subtypes with incompatible params") - assert(canBeSubtypeOf(listT, Seq(tp), listD.typed(Seq(tp2))).isDefined, "List[T] <: List[A] with T free") - assert(canBeSubtypeOf(listT, Seq(tp2), listD.typed(Seq(tp2))).isDefined, "List[T] <: List[A] with A free") + assert(!isSubtypeOf(nilD.typed(Seq(nilT)), listD.typed(Seq(listT))), "Classes are invariant") + assert(!isSubtypeOf(SetType(nilT), SetType(listT)), "Sets are invariant") - // Type parameters with fixed sides - assert(canBeSubtypeOf(tp, Seq(tp), tp2, lhsFixed = true).isEmpty, "T </: A with T free when lhs is fixed") - assert(canBeSubtypeOf(tp, Seq(tp), tp2, rhsFixed = true).isDefined, "T <: A with T free but rhs is fixed") - assert(canBeSubtypeOf(tp, Seq(tp2), tp2, lhsFixed = true).isDefined, "T <: A with A free when lhs is fixed") - assert(canBeSubtypeOf(tp, Seq(tp2), tp2, rhsFixed = true).isEmpty, "T </: A with A free but lhs is fixed") + assert(isSubtypeOf(FunctionType(Seq(listT), nilT), FunctionType(Seq(nilT), listT)), "Functions have contravariant params/ covariant result") - assert(canBeSubtypeOf(listT, Seq(tp), listD.typed(Seq(tp2)), rhsFixed = true).isDefined, "List[T] <: List[A] with T free and rhs fixed") + assert(!typesCompatible(tp2, tp3), "Different types should be incompatible") + assert(!typesCompatible(BooleanType, tp3), "Different types should be incompatible") + assert(!typesCompatible(tp2, BooleanType), "Different types should be incompatible") + assert(!typesCompatible(IntegerType, Int32Type), "Different types should be incompatible") - assert(isSubtypeOf(listD.typed, listD.typed), "List[T] <: List[T]") + // Type parameters + assert(unify(tp, tp2, Seq(tp) ).isDefined, "T <: A with T free") + assert(unify(tp, tp2, Seq(tp2)).isDefined, "T <: A with A free") + + assert(unify(listT, listD.typed(Seq(tp2)), Seq(tp) ).isDefined, "List[T] <: List[A] with T free") + assert(unify(listT, listD.typed(Seq(tp2)), Seq(tp2)).isDefined, "List[T] <: List[A] with A free") + assert(unify(listT, listD.typed(Seq(tp2)), Seq() ).isEmpty, "List[T] !<: List[A] with A,T not free") + assert(unify(listT, nilT, Seq(tp) ).isEmpty, "Subtypes not unifiable") + + assert( + unify(MapType(IntegerType, tp), MapType(tp2, IntegerType), Seq(tp, tp2)).contains(Map(tp -> IntegerType, tp2 -> IntegerType)), + "MapType unifiable" + ) + + assert( + subtypingInstantiation(consD.typed(Seq(tp)), listD.typed(Seq(tp2))) contains Map(tp2 -> tp), + "Cons[T] <: List[A] under A -> T" + ) + + assert( + subtypingInstantiation(consD.typed(Seq(IntegerType)), listD.typed(Seq(tp2))) contains Map(tp2 -> IntegerType), + "Cons[BigInt] <: List[A] under A -> BigInt" + ) + + assert( + subtypingInstantiation(consD.typed(Seq(tp)), listD.typed(Seq(IntegerType))).isEmpty, + "List[BigInt] cannot be instantiated such that Cons[T] <: List[BigInt]" + ) } test("instantiateType Hole") { ctx => val tp1 = TypeParameter.fresh("a") val tp2 = TypeParameter.fresh("b") - val tpd = TypeParameterDef(tp1) - val e1 = Hole(tp1, Nil) - val e2 = instantiateType(e1, Map(tpd -> tp2), Map()) + val e2 = instantiateType(e1, Map(tp1 -> tp2), Map()) e2 match { case Hole(tp, _) => assert(tp == tp2, "Type should have been substituted")