From 11a91e68bcc28384b6c6aa34260164448eee0e4b Mon Sep 17 00:00:00 2001
From: "Emmanouil (Manos) Koukoutos" <emmanouil.koukoutos@epfl.ch>
Date: Tue, 25 Nov 2014 18:20:12 +0100
Subject: [PATCH] Automatically add type annotations to type ensuring

---
 .../frontends/scalac/AddTypeAnnotations.scala | 52 +++++++++++++++++++
 .../leon/frontends/scalac/ScalaCompiler.scala |  8 +++
 2 files changed, 60 insertions(+)
 create mode 100644 src/main/scala/leon/frontends/scalac/AddTypeAnnotations.scala

diff --git a/src/main/scala/leon/frontends/scalac/AddTypeAnnotations.scala b/src/main/scala/leon/frontends/scalac/AddTypeAnnotations.scala
new file mode 100644
index 000000000..3bfe34f49
--- /dev/null
+++ b/src/main/scala/leon/frontends/scalac/AddTypeAnnotations.scala
@@ -0,0 +1,52 @@
+/* Copyright 2009-2014 EPFL, Lausanne */
+
+package leon
+package frontends.scalac
+
+import scala.tools.nsc._
+import scala.tools.nsc.plugins._
+
+trait AddTypeAnnotations extends SubComponent with ASTExtractors {
+  import global._
+  import global.definitions._
+  import ExtractorHelpers._
+
+  val phaseName = "addtypeannotations"
+
+  val ctx: LeonContext
+
+  var imports : Map[RefTree,List[Import]] = Map()
+  
+  def newPhase(prev: scala.tools.nsc.Phase): StdPhase = new Phase(prev)
+
+  class Phase(prev: scala.tools.nsc.Phase) extends StdPhase(prev) {
+    def apply(unit: CompilationUnit): Unit = {
+      val transformer = new Transformer { 
+        override def transform(tree : Tree) : Tree = tree match { 
+          case d@DefDef(_,_,_,_,tpt,
+            a@Apply(
+              s@Select(
+                bd, 
+                nameEns@ExNamed("ensuring")
+              ),
+              f
+            )
+          ) => (bd,tpt.symbol) match {
+            case (Typed(_,_), _) | (_, null) => d
+            case _ => d.copy(rhs = 
+              Apply(
+                Select(
+                  Typed(bd,tpt.duplicate.setPos(bd.pos.focus)).setPos(bd.pos),
+                  nameEns
+                ).setPos(s.pos),
+                f         
+              ).setPos(a.pos)
+            ).setPos(d.pos)
+          }
+          case other => super.transform(other)
+        }
+      }
+      unit.body = transformer.transform(unit.body)
+    }
+  }
+}
diff --git a/src/main/scala/leon/frontends/scalac/ScalaCompiler.scala b/src/main/scala/leon/frontends/scalac/ScalaCompiler.scala
index 3113cbc1e..6fee922c2 100644
--- a/src/main/scala/leon/frontends/scalac/ScalaCompiler.scala
+++ b/src/main/scala/leon/frontends/scalac/ScalaCompiler.scala
@@ -22,9 +22,17 @@ class ScalaCompiler(settings : NSCSettings, ctx: LeonContext) extends Global(set
     val ctx = ScalaCompiler.this.ctx
   } with SaveImports
   
+  object addTypeAnnotations extends {
+    val global: ScalaCompiler.this.type = ScalaCompiler.this
+    val runsAfter = List[String]()
+    val runsRightAfter = Some("parser")
+    val ctx = ScalaCompiler.this.ctx
+  } with AddTypeAnnotations
+
   override protected def computeInternalPhases() : Unit = {
     val phs = List(
       syntaxAnalyzer          -> "parse source into ASTs, perform simple desugaring",
+      addTypeAnnotations      -> "add type annotations useful for Leon",
       analyzer.namerFactory   -> "resolve names, attach symbols to named trees",
       analyzer.packageObjects -> "load package objects",
       analyzer.typerFactory   -> "the meat and potatoes: type the trees",
-- 
GitLab