From b166dee3b36832971d765690eb8670e7ff9472cb Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Mon, 11 Aug 2014 14:14:07 +0200
Subject: [PATCH] Base for canBeSubtypeOf

---
 .../scala/leon/purescala/TypeTreeOps.scala    | 72 +++++++++++++++----
 src/main/scala/leon/purescala/TypeTrees.scala |  2 +
 .../scala/leon/synthesis/rules/Cegis.scala    |  8 ++-
 3 files changed, 66 insertions(+), 16 deletions(-)

diff --git a/src/main/scala/leon/purescala/TypeTreeOps.scala b/src/main/scala/leon/purescala/TypeTreeOps.scala
index 6bdb5fa24..fc7374937 100644
--- a/src/main/scala/leon/purescala/TypeTreeOps.scala
+++ b/src/main/scala/leon/purescala/TypeTreeOps.scala
@@ -11,29 +11,75 @@ import Trees._
 import Extractors._
 
 object TypeTreeOps {
-  def canBeSubtypeOf(tpe: TypeTree, freeParams: Seq[TypeParameterDef], stpe: TypeTree): Option[Seq[TypeParameter]] = {
+  def canBeSubtypeOf(tpe: TypeTree, freeParams: Seq[TypeParameter], stpe: TypeTree): Option[Map[TypeParameter, TypeTree]] = {
+
+    def unify(res: Seq[Option[Map[TypeParameter, TypeTree]]]): Option[Map[TypeParameter, TypeTree]] = {
+      if (res.forall(_.isDefined)) {
+        var result = Map[TypeParameter, TypeTree]()
+
+        for (Some(m) <- res) {
+          result ++= m
+        }
+
+        Some(result)
+      } else {
+        None
+      }
+    }
+
     if (freeParams.isEmpty) {
       if (isSubtypeOf(tpe, stpe)) {
-        Some(Nil)
+        Some(Map())
       } else {
         None
       }
     } else {
-      // TODO
-      None
-    }
-  }
+      (tpe, stpe) match {
+        case (tp1: TypeParameter, t) =>
+          if (freeParams contains tp1) {
+            Some(Map(tp1 -> t))
+          } else if (tp1 == t) {
+            Some(Map())
+          } else {
+            None
+          }
 
-  def bestRealType(t: TypeTree) : TypeTree = t match {
-    case c: CaseClassType =>
-      c.classDef.parent match {
-        case None    =>
-          c
+        case (ct1: ClassType, ct2: ClassType) =>
+          val rt1 = ct1.root
+          val rt2 = ct2.root
+
+
+          if (rt1.classDef == rt2.classDef) {
+            unify((rt1.tps zip rt2.tps).map { case (tp1, tp2) =>
+              canBeSubtypeOf(tp1, freeParams, tp2)
+            })
+          } else {
+            None
+          }
+
+        case (_: TupleType, _: TupleType) |
+             (_: SetType, _: SetType) |
+             (_: MapType, _: MapType) |
+             (_: FunctionType, _: FunctionType) =>
+          val NAryType(ts1, _) = tpe
+          val NAryType(ts2, _) = stpe
 
-        case Some(p) =>
-          instantiateType(p, (c.classDef.tparams zip c.tps).toMap)
+          unify((ts1 zip ts2).map { case (tp1, tp2) =>
+            canBeSubtypeOf(tp1, freeParams, tp2)
+          })
+
+        case (t1, t2) =>
+          if (t1 == t2) {
+            Some(Map())
+          } else {
+            None
+          }
       }
+    }
+  }
 
+  def bestRealType(t: TypeTree) : TypeTree = t match {
+    case (c: CaseClassType) => c.root
     case NAryType(tps, builder) => builder(tps.map(bestRealType))
   }
 
diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala
index 7e13555d0..9b518db9d 100644
--- a/src/main/scala/leon/purescala/TypeTrees.scala
+++ b/src/main/scala/leon/purescala/TypeTrees.scala
@@ -151,6 +151,8 @@ object TypeTrees {
 
     lazy val fieldsTypes = fields.map(_.tpe)
 
+    lazy val root = parent.getOrElse(this)
+
     lazy val parent = classDef.parent.map {
       pct => instantiateType(pct, (classDef.tparams zip tps).toMap) match {
         case act: AbstractClassType  => act
diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala
index 37d19bb36..7a6c7f0b9 100644
--- a/src/main/scala/leon/synthesis/rules/Cegis.scala
+++ b/src/main/scala/leon/synthesis/rules/Cegis.scala
@@ -130,10 +130,12 @@ case object CEGIS extends Rule("CEGIS") {
               false
           }
 
+
           if (!isRecursiveCall && isNotSynthesizable) {
-            canBeSubtypeOf(fd.returnType, fd.tparams, t) match {
-              case Some(tps) =>
-                Some(fd.typed(tps))
+            val free = fd.tparams.map(_.tp)
+            canBeSubtypeOf(fd.returnType, free, t) match {
+              case Some(tpsMap) =>
+                Some(fd.typed(free.map(tp => tpsMap.getOrElse(tp, tp))))
               case None =>
                 None
             }
-- 
GitLab