Skip to content
Snippets Groups Projects
Commit 0e06d5ac authored by Emmanouil (Manos) Koukoutos's avatar Emmanouil (Manos) Koukoutos Committed by Etienne Kneuss
Browse files

canBeSubtypeOf allows to fix either side's tparams. instantiateType handles Passes

parent 55931024
Branches
Tags
No related merge requests found
...@@ -19,7 +19,13 @@ object TypeTreeOps { ...@@ -19,7 +19,13 @@ object TypeTreeOps {
subs.map(typeParamsOf).foldLeft(Set[TypeParameter]())(_ ++ _) subs.map(typeParamsOf).foldLeft(Set[TypeParameter]())(_ ++ _)
} }
def canBeSubtypeOf(tpe: TypeTree, freeParams: Seq[TypeParameter], stpe: TypeTree): Option[Map[TypeParameter, TypeTree]] = { def canBeSubtypeOf(
tpe: TypeTree,
freeParams: Seq[TypeParameter],
stpe: TypeTree,
lhsFixed: Boolean = false,
rhsFixed: Boolean = false
): Option[Map[TypeParameter, TypeTree]] = {
def unify(res: Seq[Option[Map[TypeParameter, TypeTree]]]): Option[Map[TypeParameter, TypeTree]] = { def unify(res: Seq[Option[Map[TypeParameter, TypeTree]]]): Option[Map[TypeParameter, TypeTree]] = {
if (res.forall(_.isDefined)) { if (res.forall(_.isDefined)) {
...@@ -47,7 +53,7 @@ object TypeTreeOps { ...@@ -47,7 +53,7 @@ object TypeTreeOps {
} else { } else {
(tpe, stpe) match { (tpe, stpe) match {
case (t, tp1: TypeParameter) => case (t, tp1: TypeParameter) =>
if ((freeParams contains tp1) && !(typeParamsOf(t) contains tp1)) { if ((freeParams contains tp1) && (!rhsFixed) && !(typeParamsOf(t) contains tp1)) {
Some(Map(tp1 -> t)) Some(Map(tp1 -> t))
} else if (tp1 == t) { } else if (tp1 == t) {
Some(Map()) Some(Map())
...@@ -56,7 +62,7 @@ object TypeTreeOps { ...@@ -56,7 +62,7 @@ object TypeTreeOps {
} }
case (tp1: TypeParameter, t) => case (tp1: TypeParameter, t) =>
if ((freeParams contains tp1) && !(typeParamsOf(t) contains tp1)) { if ((freeParams contains tp1) && (!lhsFixed) && !(typeParamsOf(t) contains tp1)) {
Some(Map(tp1 -> t)) Some(Map(tp1 -> t))
} else if (tp1 == t) { } else if (tp1 == t) {
Some(Map()) Some(Map())
...@@ -71,7 +77,7 @@ object TypeTreeOps { ...@@ -71,7 +77,7 @@ object TypeTreeOps {
if (rt1.classDef == rt2.classDef) { if (rt1.classDef == rt2.classDef) {
unify((rt1.tps zip rt2.tps).map { case (tp1, tp2) => unify((rt1.tps zip rt2.tps).map { case (tp1, tp2) =>
canBeSubtypeOf(tp1, freeParams, tp2) canBeSubtypeOf(tp1, freeParams, tp2, lhsFixed, rhsFixed)
}) })
} else { } else {
None None
...@@ -87,7 +93,7 @@ object TypeTreeOps { ...@@ -87,7 +93,7 @@ object TypeTreeOps {
if (ts1.size == ts2.size) { if (ts1.size == ts2.size) {
unify((ts1 zip ts2).map { case (tp1, tp2) => unify((ts1 zip ts2).map { case (tp1, tp2) =>
canBeSubtypeOf(tp1, freeParams, tp2) canBeSubtypeOf(tp1, freeParams, tp2, lhsFixed, rhsFixed)
}) })
} else { } else {
None None
...@@ -186,7 +192,7 @@ object TypeTreeOps { ...@@ -186,7 +192,7 @@ object TypeTreeOps {
def rec(idsMap: Map[Identifier, Identifier])(e: Expr): Expr = { def rec(idsMap: Map[Identifier, Identifier])(e: Expr): Expr = {
def freshId(id: Identifier, newTpe: TypeTree) = { def freshId(id: Identifier, newTpe: TypeTree) = {
if (id.getType != newTpe) { if (id.getType != newTpe) {
FreshIdentifier(id.name, true).setType(newTpe).copiedFrom(id) FreshIdentifier(id.name).setType(newTpe).copiedFrom(id)
} else { } else {
id id
} }
...@@ -195,6 +201,63 @@ object TypeTreeOps { ...@@ -195,6 +201,63 @@ object TypeTreeOps {
// Simple rec without affecting map // Simple rec without affecting map
val srec = rec(idsMap) _ val srec = rec(idsMap) _
def onMatchLike(e: Expr, cases : Seq[MatchCase]) = {
val newTpe = tpeSub(e.getType)
def mapsUnion(maps: Seq[Map[Identifier, Identifier]]): Map[Identifier, Identifier] = {
maps.foldLeft(Map[Identifier, Identifier]())(_ ++ _)
}
def trCase(c: MatchCase) = c match {
case SimpleCase(p, b) =>
val (newP, newIds) = trPattern(p, newTpe)
SimpleCase(newP, rec(idsMap ++ newIds)(b))
case GuardedCase(p, g, b) =>
val (newP, newIds) = trPattern(p, newTpe)
GuardedCase(newP, rec(idsMap ++ newIds)(g), rec(idsMap ++ newIds)(b))
}
def trPattern(p: Pattern, expType: TypeTree): (Pattern, Map[Identifier, Identifier]) = (p, expType) match {
case (InstanceOfPattern(ob, ct), _) =>
val newCt = tpeSub(ct).asInstanceOf[ClassType]
val newOb = ob.map(id => freshId(id, newCt))
(InstanceOfPattern(newOb, newCt), (ob zip newOb).toMap)
case (TuplePattern(ob, sps), tpt @ TupleType(stps)) =>
val newOb = ob.map(id => freshId(id, tpt))
val (newSps, newMaps) = (sps zip stps).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip
(TuplePattern(newOb, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps))
case (CaseClassPattern(ob, cct, sps), _) =>
val newCt = tpeSub(cct).asInstanceOf[CaseClassType]
val newOb = ob.map(id => freshId(id, newCt))
val (newSps, newMaps) = (sps zip newCt.fieldsTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip
(CaseClassPattern(newOb, newCt, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps))
case (WildcardPattern(ob), expTpe) =>
val newOb = ob.map(id => freshId(id, expTpe))
(WildcardPattern(newOb), (ob zip newOb).toMap)
case (LiteralPattern(ob, lit), expType) =>
val newOb = ob.map(id => freshId(id, expType))
(LiteralPattern(newOb,lit), (ob zip newOb).toMap)
case _ =>
sys.error("woot!?")
}
(srec(e), cases.map(trCase))//.copiedFrom(m)
}
e match { e match {
case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) =>
FunctionInvocation(TypedFunDef(fd, tps.map(tpeSub)), args.map(srec)).copiedFrom(fi) FunctionInvocation(TypedFunDef(fd, tps.map(tpeSub)), args.map(srec)).copiedFrom(fi)
...@@ -228,63 +291,20 @@ object TypeTreeOps { ...@@ -228,63 +291,20 @@ object TypeTreeOps {
val mapping = args.map(_.id) zip newArgs.map(_.id) val mapping = args.map(_.id) zip newArgs.map(_.id)
Lambda(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(l) Lambda(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(l)
case p @ Passes(in, out, cases) =>
val (newIn, newCases) = onMatchLike(in, cases)
passes(newIn, srec(out), newCases).copiedFrom(p)
case m @ MatchExpr(e, cases) => case m @ MatchExpr(e, cases) =>
val newTpe = tpeSub(e.getType) val (newE, newCases) = onMatchLike(e, cases)
matchExpr(newE, newCases).copiedFrom(m)
def mapsUnion(maps: Seq[Map[Identifier, Identifier]]): Map[Identifier, Identifier] = {
maps.foldLeft(Map[Identifier, Identifier]())(_ ++ _)
}
def trCase(c: MatchCase) = c match {
case SimpleCase(p, b) =>
val (newP, newIds) = trPattern(p, newTpe)
SimpleCase(newP, rec(idsMap ++ newIds)(b))
case GuardedCase(p, g, b) =>
val (newP, newIds) = trPattern(p, newTpe)
GuardedCase(newP, rec(idsMap ++ newIds)(g), rec(idsMap ++ newIds)(b))
}
def trPattern(p: Pattern, expType: TypeTree): (Pattern, Map[Identifier, Identifier]) = (p, expType) match {
case (InstanceOfPattern(ob, ct), _) =>
val newCt = tpeSub(ct).asInstanceOf[ClassType]
val newOb = ob.map(id => freshId(id, newCt))
(InstanceOfPattern(newOb, newCt), (ob zip newOb).toMap)
case (TuplePattern(ob, sps), tpt @ TupleType(stps)) =>
val newOb = ob.map(id => freshId(id, tpt))
val (newSps, newMaps) = (sps zip stps).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip
(TuplePattern(newOb, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps))
case (CaseClassPattern(ob, cct, sps), _) =>
val newCt = tpeSub(cct).asInstanceOf[CaseClassType]
val newOb = ob.map(id => freshId(id, newCt))
val (newSps, newMaps) = (sps zip newCt.fieldsTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip
(CaseClassPattern(newOb, newCt, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps))
case (WildcardPattern(ob), expTpe) =>
val newOb = ob.map(id => freshId(id, expTpe))
(WildcardPattern(newOb), (ob zip newOb).toMap)
case (LiteralPattern(ob, lit), expType) =>
val newOb = ob.map(id => freshId(id, expType))
(LiteralPattern(newOb,lit), (ob zip newOb).toMap)
case _ =>
sys.error("woot!?")
}
matchExpr(srec(e), cases.map(trCase)).copiedFrom(m)
case Error(tpe, desc) => case Error(tpe, desc) =>
Error(tpeSub(tpe), desc).copiedFrom(e) Error(tpeSub(tpe), desc).copiedFrom(e)
case ens @ Ensuring(body, id, pred) =>
val newId = freshId(id, tpeSub(id.getType))
Ensuring(srec(body), newId, rec(idsMap + (id -> newId))(pred)).copiedFrom(ens)
case s @ FiniteSet(elements) if elements.isEmpty => case s @ FiniteSet(elements) if elements.isEmpty =>
FiniteSet(Set()).setType(tpeSub(s.getType)).copiedFrom(s) FiniteSet(Set()).setType(tpeSub(s.getType)).copiedFrom(s)
......
...@@ -332,7 +332,7 @@ object ExpressionGrammars { ...@@ -332,7 +332,7 @@ object ExpressionGrammars {
if (!isRecursiveCall && isDet) { if (!isRecursiveCall && isDet) {
val free = fd.tparams.map(_.tp) val free = fd.tparams.map(_.tp)
canBeSubtypeOf(fd.returnType, free, t) match { canBeSubtypeOf(fd.returnType, free, t, rhsFixed = true) match {
case Some(tpsMap) => case Some(tpsMap) =>
val tfd = fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp))) val tfd = fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment