From 0e06d5ac4de4eabdfc4969a852a3a0f8c7c696ba Mon Sep 17 00:00:00 2001 From: "Emmanouil (Manos) Koukoutos" <emmanouil.koukoutos@epfl.ch> Date: Wed, 21 Jan 2015 16:47:09 +0100 Subject: [PATCH] canBeSubtypeOf allows to fix either side's tparams. instantiateType handles Passes --- .../scala/leon/purescala/TypeTreeOps.scala | 138 ++++++++++-------- .../synthesis/utils/ExpressionGrammar.scala | 2 +- 2 files changed, 80 insertions(+), 60 deletions(-) diff --git a/src/main/scala/leon/purescala/TypeTreeOps.scala b/src/main/scala/leon/purescala/TypeTreeOps.scala index 0cd0f7ba2..9a1c0559f 100644 --- a/src/main/scala/leon/purescala/TypeTreeOps.scala +++ b/src/main/scala/leon/purescala/TypeTreeOps.scala @@ -19,7 +19,13 @@ object TypeTreeOps { subs.map(typeParamsOf).foldLeft(Set[TypeParameter]())(_ ++ _) } - def canBeSubtypeOf(tpe: TypeTree, freeParams: Seq[TypeParameter], stpe: TypeTree): Option[Map[TypeParameter, TypeTree]] = { + def canBeSubtypeOf( + tpe: TypeTree, + freeParams: Seq[TypeParameter], + stpe: TypeTree, + lhsFixed: Boolean = false, + rhsFixed: Boolean = false + ): Option[Map[TypeParameter, TypeTree]] = { def unify(res: Seq[Option[Map[TypeParameter, TypeTree]]]): Option[Map[TypeParameter, TypeTree]] = { if (res.forall(_.isDefined)) { @@ -47,7 +53,7 @@ object TypeTreeOps { } else { (tpe, stpe) match { case (t, tp1: TypeParameter) => - if ((freeParams contains tp1) && !(typeParamsOf(t) contains tp1)) { + if ((freeParams contains tp1) && (!rhsFixed) && !(typeParamsOf(t) contains tp1)) { Some(Map(tp1 -> t)) } else if (tp1 == t) { Some(Map()) @@ -56,7 +62,7 @@ object TypeTreeOps { } case (tp1: TypeParameter, t) => - if ((freeParams contains tp1) && !(typeParamsOf(t) contains tp1)) { + if ((freeParams contains tp1) && (!lhsFixed) && !(typeParamsOf(t) contains tp1)) { Some(Map(tp1 -> t)) } else if (tp1 == t) { Some(Map()) @@ -71,7 +77,7 @@ object TypeTreeOps { if (rt1.classDef == rt2.classDef) { unify((rt1.tps zip rt2.tps).map { case (tp1, tp2) => - canBeSubtypeOf(tp1, freeParams, tp2) + canBeSubtypeOf(tp1, freeParams, tp2, lhsFixed, rhsFixed) }) } else { None @@ -87,7 +93,7 @@ object TypeTreeOps { if (ts1.size == ts2.size) { unify((ts1 zip ts2).map { case (tp1, tp2) => - canBeSubtypeOf(tp1, freeParams, tp2) + canBeSubtypeOf(tp1, freeParams, tp2, lhsFixed, rhsFixed) }) } else { None @@ -186,7 +192,7 @@ object TypeTreeOps { def rec(idsMap: Map[Identifier, Identifier])(e: Expr): Expr = { def freshId(id: Identifier, newTpe: TypeTree) = { if (id.getType != newTpe) { - FreshIdentifier(id.name, true).setType(newTpe).copiedFrom(id) + FreshIdentifier(id.name).setType(newTpe).copiedFrom(id) } else { id } @@ -195,6 +201,63 @@ object TypeTreeOps { // Simple rec without affecting map val srec = rec(idsMap) _ + def onMatchLike(e: Expr, cases : Seq[MatchCase]) = { + + val newTpe = tpeSub(e.getType) + + def mapsUnion(maps: Seq[Map[Identifier, Identifier]]): Map[Identifier, Identifier] = { + maps.foldLeft(Map[Identifier, Identifier]())(_ ++ _) + } + + def trCase(c: MatchCase) = c match { + case SimpleCase(p, b) => + val (newP, newIds) = trPattern(p, newTpe) + SimpleCase(newP, rec(idsMap ++ newIds)(b)) + + case GuardedCase(p, g, b) => + val (newP, newIds) = trPattern(p, newTpe) + GuardedCase(newP, rec(idsMap ++ newIds)(g), rec(idsMap ++ newIds)(b)) + } + + def trPattern(p: Pattern, expType: TypeTree): (Pattern, Map[Identifier, Identifier]) = (p, expType) match { + case (InstanceOfPattern(ob, ct), _) => + val newCt = tpeSub(ct).asInstanceOf[ClassType] + val newOb = ob.map(id => freshId(id, newCt)) + + (InstanceOfPattern(newOb, newCt), (ob zip newOb).toMap) + + case (TuplePattern(ob, sps), tpt @ TupleType(stps)) => + val newOb = ob.map(id => freshId(id, tpt)) + + val (newSps, newMaps) = (sps zip stps).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip + + (TuplePattern(newOb, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) + + case (CaseClassPattern(ob, cct, sps), _) => + val newCt = tpeSub(cct).asInstanceOf[CaseClassType] + + val newOb = ob.map(id => freshId(id, newCt)) + + val (newSps, newMaps) = (sps zip newCt.fieldsTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip + + (CaseClassPattern(newOb, newCt, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) + + case (WildcardPattern(ob), expTpe) => + val newOb = ob.map(id => freshId(id, expTpe)) + + (WildcardPattern(newOb), (ob zip newOb).toMap) + + case (LiteralPattern(ob, lit), expType) => + val newOb = ob.map(id => freshId(id, expType)) + (LiteralPattern(newOb,lit), (ob zip newOb).toMap) + + case _ => + sys.error("woot!?") + } + + (srec(e), cases.map(trCase))//.copiedFrom(m) + } + e match { case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => FunctionInvocation(TypedFunDef(fd, tps.map(tpeSub)), args.map(srec)).copiedFrom(fi) @@ -228,63 +291,20 @@ object TypeTreeOps { val mapping = args.map(_.id) zip newArgs.map(_.id) Lambda(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(l) + case p @ Passes(in, out, cases) => + val (newIn, newCases) = onMatchLike(in, cases) + passes(newIn, srec(out), newCases).copiedFrom(p) + case m @ MatchExpr(e, cases) => - val newTpe = tpeSub(e.getType) - - def mapsUnion(maps: Seq[Map[Identifier, Identifier]]): Map[Identifier, Identifier] = { - maps.foldLeft(Map[Identifier, Identifier]())(_ ++ _) - } - - def trCase(c: MatchCase) = c match { - case SimpleCase(p, b) => - val (newP, newIds) = trPattern(p, newTpe) - SimpleCase(newP, rec(idsMap ++ newIds)(b)) - - case GuardedCase(p, g, b) => - val (newP, newIds) = trPattern(p, newTpe) - GuardedCase(newP, rec(idsMap ++ newIds)(g), rec(idsMap ++ newIds)(b)) - } - - def trPattern(p: Pattern, expType: TypeTree): (Pattern, Map[Identifier, Identifier]) = (p, expType) match { - case (InstanceOfPattern(ob, ct), _) => - val newCt = tpeSub(ct).asInstanceOf[ClassType] - val newOb = ob.map(id => freshId(id, newCt)) - - (InstanceOfPattern(newOb, newCt), (ob zip newOb).toMap) - - case (TuplePattern(ob, sps), tpt @ TupleType(stps)) => - val newOb = ob.map(id => freshId(id, tpt)) - - val (newSps, newMaps) = (sps zip stps).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip - - (TuplePattern(newOb, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) - - case (CaseClassPattern(ob, cct, sps), _) => - val newCt = tpeSub(cct).asInstanceOf[CaseClassType] - - val newOb = ob.map(id => freshId(id, newCt)) - - val (newSps, newMaps) = (sps zip newCt.fieldsTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip - - (CaseClassPattern(newOb, newCt, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) - - case (WildcardPattern(ob), expTpe) => - val newOb = ob.map(id => freshId(id, expTpe)) - - (WildcardPattern(newOb), (ob zip newOb).toMap) - - case (LiteralPattern(ob, lit), expType) => - val newOb = ob.map(id => freshId(id, expType)) - (LiteralPattern(newOb,lit), (ob zip newOb).toMap) - - case _ => - sys.error("woot!?") - } - - matchExpr(srec(e), cases.map(trCase)).copiedFrom(m) + val (newE, newCases) = onMatchLike(e, cases) + matchExpr(newE, newCases).copiedFrom(m) case Error(tpe, desc) => Error(tpeSub(tpe), desc).copiedFrom(e) + + case ens @ Ensuring(body, id, pred) => + val newId = freshId(id, tpeSub(id.getType)) + Ensuring(srec(body), newId, rec(idsMap + (id -> newId))(pred)).copiedFrom(ens) case s @ FiniteSet(elements) if elements.isEmpty => FiniteSet(Set()).setType(tpeSub(s.getType)).copiedFrom(s) diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index 55550d59f..8b3ff4d55 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -332,7 +332,7 @@ object ExpressionGrammars { if (!isRecursiveCall && isDet) { val free = fd.tparams.map(_.tp) - canBeSubtypeOf(fd.returnType, free, t) match { + canBeSubtypeOf(fd.returnType, free, t, rhsFixed = true) match { case Some(tpsMap) => val tfd = fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp))) -- GitLab