From 21f88f68438c0c57ef03ebee4c1e2c795f978f86 Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Wed, 10 Jun 2009 15:32:58 +0000
Subject: [PATCH] progress on extraction

---
 src/funcheck/AnalysisComponent.scala       | 18 ++---
 src/funcheck/CodeExtraction.scala          | 81 ++++++++++++++--------
 src/funcheck/Extractors.scala              |  4 +-
 src/funcheck/purescala/Definitions.scala   | 14 ++++
 src/funcheck/purescala/PrettyPrinter.scala | 36 +++++++++-
 5 files changed, 113 insertions(+), 40 deletions(-)

diff --git a/src/funcheck/AnalysisComponent.scala b/src/funcheck/AnalysisComponent.scala
index 286d6f716..44bb5ba95 100644
--- a/src/funcheck/AnalysisComponent.scala
+++ b/src/funcheck/AnalysisComponent.scala
@@ -4,7 +4,6 @@ import scala.tools.nsc._
 import scala.tools.nsc.plugins._
 
 class AnalysisComponent(val global: Global, val pluginInstance: FunCheckPlugin) extends PluginComponent
-  with Extractors
   with CodeExtraction
   with ForallInjection
 {
@@ -16,7 +15,7 @@ class AnalysisComponent(val global: Global, val pluginInstance: FunCheckPlugin)
 
   val phaseName = pluginInstance.name
 
-  private def stopIfErrors: Unit = {
+  protected def stopIfErrors: Unit = {
     if(reporter.hasErrors) {
       println("There were errors.")
       exit(0)
@@ -27,12 +26,17 @@ class AnalysisComponent(val global: Global, val pluginInstance: FunCheckPlugin)
 
   class AnalysisPhase(prev: Phase) extends StdPhase(prev) {
     def apply(unit: CompilationUnit): Unit = {
+      // That filter just helps getting meaningful errors before the attempt to
+      // extract the code, but it's really optional.
       (new ForeachTreeTraverser(firstFilter(unit))).traverse(unit.body)
       stopIfErrors
-      // (new ForeachTreeTraverser(findContracts)).traverse(unit.body)
-      // stopIfErrors
 
-      extractCode(unit)
+      val prog: purescala.Definitions.Program = extractCode(unit)
+      println("Extracted program for " + unit + ": ")
+      println(prog)
+
+      // Mirco your component can do its job here, as I leave the trees
+      // unmodified.
 
       if(pluginInstance.stopAfterAnalysis) {
         println("Analysis complete. Now terminating the compiler process.")
@@ -40,7 +44,7 @@ class AnalysisComponent(val global: Global, val pluginInstance: FunCheckPlugin)
       }
     }
 
-    /** Weeds out programs containing unsupported features. */
+    /** Weeds out some programs containing unsupported features. */
     def firstFilter(unit: CompilationUnit)(tree: Tree): Unit = {
       def unsup(s: String): String = "FunCheck: Unsupported construct: " + s
 
@@ -51,8 +55,6 @@ class AnalysisComponent(val global: Global, val pluginInstance: FunCheckPlugin)
         case Assign(lhs, rhs) => unit.error(tree.pos, unsup("assignment to mutable variable/field."))
         case Return(expr) => unit.error(tree.pos, unsup("return statement."))
         case Try(block, catches, finalizer) => unit.error(tree.pos, unsup("try block."))
-        // case Throw(expr) => unit.error(tree.pos, unsup("throw statement."))
-        // case New(tpt) => unit.error(tree.pos, unsup("'new' operator."))
         case SingletonTypeTree(ref) => unit.error(tree.pos, unsup("singleton type."))
         case SelectFromTypeTree(qualifier, selector) => unit.error(tree.pos, unsup("path-dependent type."))
         case CompoundTypeTree(templ: Template) => unit.error(tree.pos, unsup("compound/refinement type."))
diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala
index d371cae6c..432e2d08d 100644
--- a/src/funcheck/CodeExtraction.scala
+++ b/src/funcheck/CodeExtraction.scala
@@ -8,39 +8,14 @@ import purescala.Trees._
 import purescala.TypeTrees._
 import purescala.Common._
 
-trait CodeExtraction {
+trait CodeExtraction extends Extractors {
   self: AnalysisComponent =>
 
   import global._
   import StructuralExtractors._
   import ExpressionExtractors._
 
-  def findContracts(tree: Tree): Unit = tree match {
-    case DefDef(/*mods*/ _, name, /*tparams*/ _, /*vparamss*/ _, /*tpt*/ _, body) => {
-      var realBody = body
-      var reqCont: Option[Tree] = None
-      var ensCont: Option[Function] = None
-
-      body match {
-        case EnsuredExpression(body2, contract) => realBody = body2; ensCont = Some(contract)
-        case _ => ;
-      }
-
-      realBody match {
-        case RequiredExpression(body3, contract) => realBody = body3; reqCont = Some(contract)
-        case _ => ;
-      }
-
-      println("In: " + name) 
-      println("  Requires clause: " + reqCont)
-      println("  Ensures  clause: " + ensCont)
-      println("  Body:            " + realBody)
-    }
-
-    case _ => ;
-  }
-
-  def extractCode(unit: CompilationUnit): Unit = { 
+  def extractCode(unit: CompilationUnit): Program = { 
     def trav(tree: Tree): Unit = tree match {
       case d @ DefDef(mods, name, tparams, vparamss, tpt, body) if !d.symbol.isConstructor => {
         println("In: " + name)
@@ -55,8 +30,32 @@ trait CodeExtraction {
       case _ => ;
     }
 
-    (new ForeachTreeTraverser(trav)).traverse(unit.body)
+    // (new ForeachTreeTraverser(trav)).traverse(unit.body)
+
+    val program = unit.body match {
+      case p @ PackageDef(name, lst) if lst.size == 0 => {
+        unit.error(p.pos, "No top-level definition found.")
+        None
+      }
+
+      case PackageDef(name, lst) if lst.size > 1 => {
+        unit.error(lst(1).pos, "Too many top-level definitions.")
+        None
+      }
+
+      case PackageDef(name, lst) => {
+        assert(lst.size == 1)
+        lst(0) match {
+          case ExObjectDef(n, templ) => Some(Program(name.toString, ObjectDef(n.toString, Nil, Nil)))
+          case other @ _ => unit.error(other.pos, "Expected: top-level single object.")
+          None
+        }
+      }
+    }
+
+    stopIfErrors
 
+    program.get
   }
 
   /** An exception thrown when non-purescala compatible code is encountered. */
@@ -88,4 +87,30 @@ trait CodeExtraction {
       }
     }
   }
+
+//  def findContracts(tree: Tree): Unit = tree match {
+//    case DefDef(/*mods*/ _, name, /*tparams*/ _, /*vparamss*/ _, /*tpt*/ _, body) => {
+//      var realBody = body
+//      var reqCont: Option[Tree] = None
+//      var ensCont: Option[Function] = None
+//
+//      body match {
+//        case EnsuredExpression(body2, contract) => realBody = body2; ensCont = Some(contract)
+//        case _ => ;
+//      }
+//
+//      realBody match {
+//        case RequiredExpression(body3, contract) => realBody = body3; reqCont = Some(contract)
+//        case _ => ;
+//      }
+//
+//      println("In: " + name) 
+//      println("  Requires clause: " + reqCont)
+//      println("  Ensures  clause: " + ensCont)
+//      println("  Body:            " + realBody)
+//    }
+//
+//    case _ => ;
+//  }
+
 }
diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala
index 838442b8a..cd44ac23c 100644
--- a/src/funcheck/Extractors.scala
+++ b/src/funcheck/Extractors.scala
@@ -49,7 +49,7 @@ trait Extractors {
       }
     }
 
-    object ObjectDefn {
+    object ExObjectDef {
       /** Matches an object with no type parameters, and regardless of its
        * visibility. */
       def unapply(cd: ClassDef): Option[(String,Template)] = cd match {
@@ -58,7 +58,7 @@ trait Extractors {
       }
     }
 
-    object FunctionDefn {
+    object ExFunctionDef {
       /** Matches a function with a single list of arguments, no type
        * parameters and regardless of its visibility. */
       def unapply(dd: DefDef): Option[(String,Seq[ValDef],Tree,Tree)] = dd match {
diff --git a/src/funcheck/purescala/Definitions.scala b/src/funcheck/purescala/Definitions.scala
index f2ad9e0b4..3debcbc32 100644
--- a/src/funcheck/purescala/Definitions.scala
+++ b/src/funcheck/purescala/Definitions.scala
@@ -29,6 +29,20 @@ object Definitions {
   final case class VarDecl(id: Identifier, tpe: TypeTree)
   type VarDecls = Seq[VarDecl]
 
+  /** A wrapper for a program. For now a program is simply a single object. The
+   * name is meaningless and we can just the package name. */
+  final case class Program(id: Identifier, mainObject: ObjectDef) extends Definition {
+    override val parentScope = None
+
+    override def lookupObject(id: Identifier) = {
+      if(id == mainObject.id) {
+        Some(mainObject)
+      } else {
+        None
+      }
+    }
+  }
+
   /** Objects work as containers for class definitions, functions (def's) and
    * val's. */
   case class ObjectDef(id: Identifier, defs : Seq[Definition], invariants: Seq[Expr]) extends Definition
diff --git a/src/funcheck/purescala/PrettyPrinter.scala b/src/funcheck/purescala/PrettyPrinter.scala
index 9338f6b84..855da6bd1 100644
--- a/src/funcheck/purescala/PrettyPrinter.scala
+++ b/src/funcheck/purescala/PrettyPrinter.scala
@@ -27,6 +27,7 @@ object PrettyPrinter {
   }
 
   // EXPRESSIONS
+  // all expressions are printed in-line
   private def ppUnary(sb: StringBuffer, expr: Expr, op: String): StringBuffer = {
     var nsb: StringBuffer = sb
     nsb.append(op)
@@ -84,12 +85,43 @@ object PrettyPrinter {
   }
 
   // TYPE TREES
+  // all type trees are printed in-line
   private def pp(tpe: TypeTree, sb: StringBuffer): StringBuffer = tpe match {
     case _ => sb.append("Type?")
   }
 
   // DEFINITIONS
-  private def pp(defn: Definition, sb: StringBuffer, lvl: Int): StringBuffer = defn match {
-    case _ => sb.append("Defn?")
+  // all definitions are printed with an end-of-line
+  private def pp(defn: Definition, sb: StringBuffer, lvl: Int): StringBuffer = {
+    def ind(sb: StringBuffer): Unit = { sb.append("  " * lvl) }
+
+    defn match {
+      case Program(id, mainObj) => {
+        assert(lvl == 0)
+        sb.append("package ")
+        sb.append(id)
+        sb.append(" {\n")
+        pp(mainObj, sb, lvl+1).append("}\n")
+      }
+
+      case ObjectDef(id, defs, invs) => {
+        var nsb = sb
+        ind(nsb)
+        nsb.append("object ")
+        nsb.append(id)
+        nsb.append(" {\n")
+
+        val sz = defs.size
+        var c = 0
+
+        defs.foreach(df => {
+          nsb = pp(df, nsb, lvl+1) 
+        })
+
+        ind(nsb); nsb.append("}\n")
+      }
+
+      case _ => sb.append("Defn?")
+    }
   }
 }
-- 
GitLab