diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index be000b03b3d31d4f2511a5a4f82d70b36c0cf68f..99ca8aa072982fe75065ca5dbcd1671442092ad8 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1086,7 +1086,7 @@ trait CodeExtraction extends ASTExtractors { extractType(up.tpe), tupleTypeWrap(args map { tr => extractType(tr.tpe)}) )) - val newTps = subtypingInstantiation(realTypes, formalTypes) match { + val newTps = canBeSupertypeOf(formalTypes, realTypes) match { case Some(tmap) => fd.tparams map { tpd => tmap.getOrElse(tpd.tp, tpd.tp) } case None => diff --git a/src/main/scala/leon/grammars/FunctionCalls.scala b/src/main/scala/leon/grammars/FunctionCalls.scala index f3b143d03d92aade5b746d892cfc0391b0a6cd4e..0c235d66836c3c7da55d44f0ff1e300d75349c8b 100644 --- a/src/main/scala/leon/grammars/FunctionCalls.scala +++ b/src/main/scala/leon/grammars/FunctionCalls.scala @@ -28,7 +28,7 @@ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[Type val isDet = fd.body.exists(isDeterministic) if (!isRecursiveCall && isDet) { - subtypingInstantiation(t, fd.returnType) match { + canBeSubtypeOf(fd.returnType, t) match { case Some(tpsMap) => val free = fd.tparams.map(_.tp) val tfd = fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp))) diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index e54cb95c83b485073d92c50d4ef322822ef0ae3f..e82595f30c57768db3aa908264708c3b8640932a 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -122,7 +122,7 @@ object Constructors { val formalType = tupleTypeWrap(fd.params map { _.getType }) val actualType = tupleTypeWrap(args map { _.getType }) - subtypingInstantiation(actualType, formalType) match { + canBeSupertypeOf(formalType, actualType) match { case Some(tmap) => FunctionInvocation(fd.typed(fd.tparams map { tpd => tmap.getOrElse(tpd.tp, tpd.tp) }), args) case None => throw LeonFatalError(s"$args:$actualType cannot be a subtype of $formalType!") diff --git a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala index 71c91c551c42e8848f2d025439cfd0970c3f9d94..08ecdbbe20aafa203835c48abb52aff33a71005d 100644 --- a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala +++ b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala @@ -53,7 +53,7 @@ trait PrettyPrinterFinder[T, U >: T] { def buildLambda(inputType: TypeTree, fd: FunDef, slu: Stream[List[U]]): Stream[T] def prettyPrinterFromCandidate(fd: FunDef, inputType: TypeTree)(implicit ctx: LeonContext, program: Program): Stream[T] = { - TypeOps.subtypingInstantiation(inputType, fd.params.head.getType) match { + TypeOps.canBeSupertypeOf(fd.params.head.getType, inputType) match { case Some(genericTypeMap) => //println("Found a mapping") def gatherPrettyPrinters(funIds: List[Identifier], acc: ListBuffer[Stream[U]] = ListBuffer[Stream[U]]()): Option[Stream[List[U]]] = funIds match { diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 600127cc495fb45450d6e1c6f97d279e47e27c07..3fd35f7190d8cdafb617bac9375eae9d1d08ffc4 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -128,15 +128,23 @@ object TypeOps extends GenTreeOps[TypeTree] { def unify(tp1: TypeTree, tp2: TypeTree, freeParams: Seq[TypeParameter]) = typeBound(tp1, tp2, isLub = true, allowSub = false)(freeParams).map(_._2) - /** Will try to instantiate superT so that subT <: superT + /** Will try to instantiate subT and superT so that subT <: superT * * @return Mapping of instantiations */ - def subtypingInstantiation(subT: TypeTree, superT: TypeTree) = - typeBound(subT, superT, isLub = true, allowSub = true)(typeParamsOf(superT).toSeq) collect { + private def subtypingInstantiation(subT: TypeTree, superT: TypeTree, free: Seq[TypeParameter]) = + typeBound(subT, superT, isLub = true, allowSub = true)(free) collect { case (tp, map) if instantiateType(superT, map) == tp => map } + def canBeSubtypeOf(subT: TypeTree, superT: TypeTree) = { + subtypingInstantiation(subT, superT, (typeParamsOf(subT) -- typeParamsOf(superT)).toSeq) + } + + def canBeSupertypeOf(superT: TypeTree, subT: TypeTree) = { + subtypingInstantiation(subT, superT, (typeParamsOf(superT) -- typeParamsOf(subT)).toSeq) + } + def leastUpperBound(tp1: TypeTree, tp2: TypeTree): Option[TypeTree] = typeBound(tp1, tp2, isLub = true, allowSub = true)(Seq()).map(_._1) diff --git a/src/main/scala/leon/synthesis/rules/Abduction.scala b/src/main/scala/leon/synthesis/rules/Abduction.scala index f8991f2d7953af4d34ca8956fff86f4365cff966..4fe724e162d8fc030c7f39cceff8b2edc4498c3a 100644 --- a/src/main/scala/leon/synthesis/rules/Abduction.scala +++ b/src/main/scala/leon/synthesis/rules/Abduction.scala @@ -8,6 +8,7 @@ import purescala.Common._ import purescala.DefOps._ import purescala.Expressions._ import purescala.TypeOps.unify +import purescala.TypeOps.canBeSubtypeOf import purescala.Constructors._ import purescala.ExprOps._ import purescala.Definitions._ @@ -19,5 +20,4 @@ object Abduction extends Rule("Abduction") { override def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { Nil } - } diff --git a/src/main/scala/leon/synthesis/utils/Helpers.scala b/src/main/scala/leon/synthesis/utils/Helpers.scala index 3bbefca1535aa51f16921d3b2edf4823c3d16de4..767b703c08e2dbb175712e016cf36a3d96050dd5 100644 --- a/src/main/scala/leon/synthesis/utils/Helpers.scala +++ b/src/main/scala/leon/synthesis/utils/Helpers.scala @@ -25,7 +25,7 @@ object Helpers { */ def functionsReturning(fds: Set[FunDef], tpe: TypeTree): Set[TypedFunDef] = { fds.flatMap { fd => - subtypingInstantiation(tpe, fd.returnType) match { + canBeSubtypeOf(fd.returnType, tpe) match { case Some(tpsMap) => Some(fd.typed(fd.typeArgs.map(tp => tpsMap.getOrElse(tp, tp)))) case None => diff --git a/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala b/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala index 29c47be9410d3b2a122258cdc93246ae9171b845..0ef77ddc5fffe600cc70beeea644ad8663ce289d 100644 --- a/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala +++ b/src/test/scala/leon/unit/purescala/TypeOpsSuite.scala @@ -66,19 +66,25 @@ class TypeOpsSuite extends LeonTestSuite with helpers.WithLikelyEq with helpers. ) assert( - subtypingInstantiation(consD.typed(Seq(tp)), listD.typed(Seq(tp2))) contains Map(tp2 -> tp), - "Cons[T] <: List[A] under A -> T" + canBeSupertypeOf(listD.typed(Seq(tp2)), consD.typed(Seq(tp))) contains Map(tp2 -> tp), + "List[A] >: Cons[T] under A -> T" ) assert( - subtypingInstantiation(consD.typed(Seq(IntegerType)), listD.typed(Seq(tp2))) contains Map(tp2 -> IntegerType), - "Cons[BigInt] <: List[A] under A -> BigInt" + canBeSubtypeOf(consD.typed(Seq(tp)), listD.typed(Seq(tp2))) contains Map(tp -> tp2), + "Cons[T] <: List[A] under T -> A" ) assert( - subtypingInstantiation(consD.typed(Seq(tp)), listD.typed(Seq(IntegerType))).isEmpty, - "List[BigInt] cannot be instantiated such that Cons[T] <: List[BigInt]" + canBeSubtypeOf(consD.typed(Seq(IntegerType)), listD.typed(Seq(tp2))).isEmpty, + "Cons[BigInt] cannot be instantiated so that it is <: List[A]" ) + + assert( + canBeSupertypeOf(listD.typed(Seq(tp2)), consD.typed(Seq(IntegerType))) contains Map(tp2 -> IntegerType), + "List[A] >: Cons[BigInt] under A -> BigInt" + ) + } test("instantiateType Hole") { ctx =>