From bb66077a19400c55f7537839239ee5f9749737b2 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Thu, 24 Sep 2015 14:48:19 +0200
Subject: [PATCH] Enrich/fix instantiateType

Add some more helper functions/restructure
Handle LetDef in instantiateType
---
 src/main/scala/leon/purescala/TypeOps.scala | 64 ++++++++++++++++-----
 1 file changed, 51 insertions(+), 13 deletions(-)

diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala
index 4b458b5d8..2d79403a0 100644
--- a/src/main/scala/leon/purescala/TypeOps.scala
+++ b/src/main/scala/leon/purescala/TypeOps.scala
@@ -9,6 +9,7 @@ import Common._
 import Expressions._
 import Extractors._
 import Constructors._
+import ExprOps.preMap
 
 object TypeOps {
   def typeDepth(t: TypeTree): Int = t match {
@@ -166,6 +167,32 @@ object TypeOps {
     }
   }
 
+  // Helpers for instantiateType
+  private def typeParamSubst(map: Map[TypeParameter, TypeTree])(tpe: TypeTree): TypeTree = tpe match {
+    case (tp: TypeParameter) => map.getOrElse(tp, tp)
+    case NAryType(tps, builder) => builder(tps.map(typeParamSubst(map)))
+  }
+
+  private def freshId(id: Identifier, newTpe: TypeTree) = {
+    if (id.getType != newTpe) {
+      FreshIdentifier(id.name, newTpe).copiedFrom(id)
+    } else {
+      id
+    }
+  }
+  
+  def instantiateType(id: Identifier, tps: Map[TypeParameterDef, TypeTree]): Identifier = {
+    freshId(id, typeParamSubst(tps map { case (tpd, tp) => tpd.tp -> tp })(id.getType))
+  }
+
+  def instantiateType(vd: ValDef, tps: Map[TypeParameterDef, TypeTree]): ValDef = {
+    val ValDef(id, forcedType) = vd
+    ValDef(
+      freshId(id, instantiateType(id.getType, tps)),
+      forcedType map ((tp: TypeTree) => instantiateType(tp, tps))
+    )
+  }
+
   def instantiateType(tpe: TypeTree, tps: Map[TypeParameterDef, TypeTree]): TypeTree = {
     if (tps.isEmpty) {
       tpe
@@ -174,11 +201,6 @@ object TypeOps {
     }
   }
 
-  private def typeParamSubst(map: Map[TypeParameter, TypeTree])(tpe: TypeTree): TypeTree = tpe match {
-    case (tp: TypeParameter) => map.getOrElse(tp, tp)
-    case NAryType(tps, builder) => builder(tps.map(typeParamSubst(map)))
-  }
-
   def instantiateType(e: Expr, tps: Map[TypeParameterDef, TypeTree], ids: Map[Identifier, Identifier]): Expr = {
     if (tps.isEmpty && ids.isEmpty) {
       e
@@ -190,13 +212,6 @@ object TypeOps {
       }
 
       def rec(idsMap: Map[Identifier, Identifier])(e: Expr): Expr = {
-        def freshId(id: Identifier, newTpe: TypeTree) = {
-          if (id.getType != newTpe) {
-            FreshIdentifier(id.name, newTpe).copiedFrom(id)
-          } else {
-            id
-          }
-        }
 
         // Simple rec without affecting map
         val srec = rec(idsMap) _
@@ -292,6 +307,29 @@ object TypeOps {
             val newId = freshId(id, tpeSub(id.getType))
             Let(newId, srec(value), rec(idsMap + (id -> newId))(body)).copiedFrom(l)
 
+          case l @ LetDef(fd, bd) =>
+            val id = fd.id.freshen
+            val tparams = fd.tparams map { p => 
+              TypeParameterDef(tpeSub(p.tp).asInstanceOf[TypeParameter])
+            }
+            val returnType = tpeSub(fd.returnType)
+            val params = fd.params map (instantiateType(_, tps))
+            val newFd = new FunDef(id, tparams, returnType, params).copiedFrom(fd)
+            newFd.copyContentFrom(fd)
+
+            val subCalls = preMap {
+              case fi @ FunctionInvocation(tfd, args) if tfd.fd == fd =>
+                Some(FunctionInvocation(newFd.typed(tfd.tps), args).copiedFrom(fi))
+              case _ => 
+                None
+            } _
+            val fullBody = rec(idsMap ++ fd.paramIds.zip(newFd.paramIds))(subCalls(fd.fullBody))
+            newFd.fullBody = fullBody
+
+            val newBd = srec(subCalls(bd)).copiedFrom(bd)
+
+            LetDef(newFd, newBd).copiedFrom(l)
+
           case l @ Lambda(args, body) =>
             val newArgs = args.map { arg =>
               val tpe = tpeSub(arg.getType)
@@ -327,7 +365,7 @@ object TypeOps {
               case newTpar : TypeParameter => 
                 GenericValue(newTpar, id).copiedFrom(g)
               case other => // FIXME any better ideas?
-                sys.error(s"Tried to substitute $tpar with $other within GenericValue $g")
+                throw LeonFatalError(Some(s"Tried to substitute $tpar with $other within GenericValue $g"))
             }
 
           case s @ FiniteSet(elems, tpe) =>
-- 
GitLab