From 0e06d5ac4de4eabdfc4969a852a3a0f8c7c696ba Mon Sep 17 00:00:00 2001
From: "Emmanouil (Manos) Koukoutos" <emmanouil.koukoutos@epfl.ch>
Date: Wed, 21 Jan 2015 16:47:09 +0100
Subject: [PATCH] canBeSubtypeOf allows to fix either side's tparams.
 instantiateType handles Passes

---
 .../scala/leon/purescala/TypeTreeOps.scala    | 138 ++++++++++--------
 .../synthesis/utils/ExpressionGrammar.scala   |   2 +-
 2 files changed, 80 insertions(+), 60 deletions(-)

diff --git a/src/main/scala/leon/purescala/TypeTreeOps.scala b/src/main/scala/leon/purescala/TypeTreeOps.scala
index 0cd0f7ba2..9a1c0559f 100644
--- a/src/main/scala/leon/purescala/TypeTreeOps.scala
+++ b/src/main/scala/leon/purescala/TypeTreeOps.scala
@@ -19,7 +19,13 @@ object TypeTreeOps {
       subs.map(typeParamsOf).foldLeft(Set[TypeParameter]())(_ ++ _)
   }
 
-  def canBeSubtypeOf(tpe: TypeTree, freeParams: Seq[TypeParameter], stpe: TypeTree): Option[Map[TypeParameter, TypeTree]] = {
+  def canBeSubtypeOf(
+    tpe: TypeTree,
+    freeParams: Seq[TypeParameter], 
+    stpe: TypeTree,
+    lhsFixed: Boolean = false, 
+    rhsFixed: Boolean = false
+  ): Option[Map[TypeParameter, TypeTree]] = {
 
     def unify(res: Seq[Option[Map[TypeParameter, TypeTree]]]): Option[Map[TypeParameter, TypeTree]] = {
       if (res.forall(_.isDefined)) {
@@ -47,7 +53,7 @@ object TypeTreeOps {
     } else {
       (tpe, stpe) match {
         case (t, tp1: TypeParameter) =>
-          if ((freeParams contains tp1) && !(typeParamsOf(t) contains tp1)) {
+          if ((freeParams contains tp1) && (!rhsFixed) && !(typeParamsOf(t) contains tp1)) {
             Some(Map(tp1 -> t))
           } else if (tp1 == t) {
             Some(Map())
@@ -56,7 +62,7 @@ object TypeTreeOps {
           }
 
         case (tp1: TypeParameter, t) =>
-          if ((freeParams contains tp1) && !(typeParamsOf(t) contains tp1)) {
+          if ((freeParams contains tp1) && (!lhsFixed) && !(typeParamsOf(t) contains tp1)) {
             Some(Map(tp1 -> t))
           } else if (tp1 == t) {
             Some(Map())
@@ -71,7 +77,7 @@ object TypeTreeOps {
 
           if (rt1.classDef == rt2.classDef) {
             unify((rt1.tps zip rt2.tps).map { case (tp1, tp2) =>
-              canBeSubtypeOf(tp1, freeParams, tp2)
+              canBeSubtypeOf(tp1, freeParams, tp2, lhsFixed, rhsFixed)
             })
           } else {
             None
@@ -87,7 +93,7 @@ object TypeTreeOps {
 
           if (ts1.size == ts2.size) {
             unify((ts1 zip ts2).map { case (tp1, tp2) =>
-              canBeSubtypeOf(tp1, freeParams, tp2)
+              canBeSubtypeOf(tp1, freeParams, tp2, lhsFixed, rhsFixed)
             })
           } else {
             None
@@ -186,7 +192,7 @@ object TypeTreeOps {
       def rec(idsMap: Map[Identifier, Identifier])(e: Expr): Expr = {
         def freshId(id: Identifier, newTpe: TypeTree) = {
           if (id.getType != newTpe) {
-            FreshIdentifier(id.name, true).setType(newTpe).copiedFrom(id)
+            FreshIdentifier(id.name).setType(newTpe).copiedFrom(id)
           } else {
             id
           }
@@ -195,6 +201,63 @@ object TypeTreeOps {
         // Simple rec without affecting map
         val srec = rec(idsMap) _
 
+        def onMatchLike(e: Expr, cases : Seq[MatchCase]) = {
+        
+          val newTpe = tpeSub(e.getType)
+         
+          def mapsUnion(maps: Seq[Map[Identifier, Identifier]]): Map[Identifier, Identifier] = {
+            maps.foldLeft(Map[Identifier, Identifier]())(_ ++ _)
+          }
+
+          def trCase(c: MatchCase) = c match {
+            case SimpleCase(p, b) => 
+              val (newP, newIds) = trPattern(p, newTpe)
+              SimpleCase(newP, rec(idsMap ++ newIds)(b))
+
+            case GuardedCase(p, g, b) => 
+              val (newP, newIds) = trPattern(p, newTpe)
+              GuardedCase(newP, rec(idsMap ++ newIds)(g), rec(idsMap ++ newIds)(b))
+          }
+
+          def trPattern(p: Pattern, expType: TypeTree): (Pattern, Map[Identifier, Identifier]) = (p, expType) match {
+            case (InstanceOfPattern(ob, ct), _) =>
+              val newCt = tpeSub(ct).asInstanceOf[ClassType]
+              val newOb = ob.map(id => freshId(id, newCt))
+
+              (InstanceOfPattern(newOb, newCt), (ob zip newOb).toMap)
+
+            case (TuplePattern(ob, sps), tpt @ TupleType(stps)) =>
+              val newOb = ob.map(id => freshId(id, tpt))
+
+              val (newSps, newMaps) = (sps zip stps).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip
+
+              (TuplePattern(newOb, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps))
+
+            case (CaseClassPattern(ob, cct, sps), _) =>
+              val newCt = tpeSub(cct).asInstanceOf[CaseClassType]
+
+              val newOb = ob.map(id => freshId(id, newCt))
+
+              val (newSps, newMaps) = (sps zip newCt.fieldsTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip
+
+              (CaseClassPattern(newOb, newCt, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps))
+
+            case (WildcardPattern(ob), expTpe) =>
+              val newOb = ob.map(id => freshId(id, expTpe))
+
+              (WildcardPattern(newOb), (ob zip newOb).toMap)
+
+            case (LiteralPattern(ob, lit), expType) => 
+              val newOb = ob.map(id => freshId(id, expType))
+              (LiteralPattern(newOb,lit), (ob zip newOb).toMap)
+
+            case _ =>
+              sys.error("woot!?")
+          }
+
+          (srec(e), cases.map(trCase))//.copiedFrom(m)
+        }
+
         e match {
           case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) =>
             FunctionInvocation(TypedFunDef(fd, tps.map(tpeSub)), args.map(srec)).copiedFrom(fi)
@@ -228,63 +291,20 @@ object TypeTreeOps {
             val mapping = args.map(_.id) zip newArgs.map(_.id)
             Lambda(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(l)
 
+          case p @ Passes(in, out, cases) =>
+            val (newIn, newCases) = onMatchLike(in, cases)
+            passes(newIn, srec(out), newCases).copiedFrom(p)
+            
           case m @ MatchExpr(e, cases) =>
-            val newTpe = tpeSub(e.getType)
-
-            def mapsUnion(maps: Seq[Map[Identifier, Identifier]]): Map[Identifier, Identifier] = {
-              maps.foldLeft(Map[Identifier, Identifier]())(_ ++ _)
-            }
-
-            def trCase(c: MatchCase) = c match {
-              case SimpleCase(p, b) => 
-                val (newP, newIds) = trPattern(p, newTpe)
-                SimpleCase(newP, rec(idsMap ++ newIds)(b))
-
-              case GuardedCase(p, g, b) => 
-                val (newP, newIds) = trPattern(p, newTpe)
-                GuardedCase(newP, rec(idsMap ++ newIds)(g), rec(idsMap ++ newIds)(b))
-            }
-
-            def trPattern(p: Pattern, expType: TypeTree): (Pattern, Map[Identifier, Identifier]) = (p, expType) match {
-              case (InstanceOfPattern(ob, ct), _) =>
-                val newCt = tpeSub(ct).asInstanceOf[ClassType]
-                val newOb = ob.map(id => freshId(id, newCt))
-
-                (InstanceOfPattern(newOb, newCt), (ob zip newOb).toMap)
-
-              case (TuplePattern(ob, sps), tpt @ TupleType(stps)) =>
-                val newOb = ob.map(id => freshId(id, tpt))
-
-                val (newSps, newMaps) = (sps zip stps).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip
-
-                (TuplePattern(newOb, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps))
-
-              case (CaseClassPattern(ob, cct, sps), _) =>
-                val newCt = tpeSub(cct).asInstanceOf[CaseClassType]
-
-                val newOb = ob.map(id => freshId(id, newCt))
-
-                val (newSps, newMaps) = (sps zip newCt.fieldsTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip
-
-                (CaseClassPattern(newOb, newCt, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps))
-
-              case (WildcardPattern(ob), expTpe) =>
-                val newOb = ob.map(id => freshId(id, expTpe))
-
-                (WildcardPattern(newOb), (ob zip newOb).toMap)
-
-              case (LiteralPattern(ob, lit), expType) => 
-                val newOb = ob.map(id => freshId(id, expType))
-                (LiteralPattern(newOb,lit), (ob zip newOb).toMap)
-
-              case _ =>
-                sys.error("woot!?")
-            }
-
-            matchExpr(srec(e), cases.map(trCase)).copiedFrom(m)
+            val (newE, newCases) = onMatchLike(e, cases)
+            matchExpr(newE, newCases).copiedFrom(m)
 
           case Error(tpe, desc) =>
             Error(tpeSub(tpe), desc).copiedFrom(e)
+          
+          case ens @ Ensuring(body, id, pred) =>
+            val newId = freshId(id, tpeSub(id.getType))
+            Ensuring(srec(body), newId, rec(idsMap + (id -> newId))(pred)).copiedFrom(ens)
 
           case s @ FiniteSet(elements) if elements.isEmpty =>
             FiniteSet(Set()).setType(tpeSub(s.getType)).copiedFrom(s)
diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
index 55550d59f..8b3ff4d55 100644
--- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
+++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
@@ -332,7 +332,7 @@ object ExpressionGrammars {
 
        if (!isRecursiveCall && isDet) {
          val free = fd.tparams.map(_.tp)
-         canBeSubtypeOf(fd.returnType, free, t) match {
+         canBeSubtypeOf(fd.returnType, free, t, rhsFixed = true) match {
            case Some(tpsMap) =>
              val tfd = fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp)))
 
-- 
GitLab