From 88241dd79be337f2f9933a378e7e3147d655215b Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Wed, 28 Apr 2010 22:46:19 +0000
Subject: [PATCH] now extracting some contracts. This is getting somewhere
 faster than I thought.

---
 src/funcheck/CodeExtraction.scala          | 67 +++++++++++++---
 src/funcheck/Extractors.scala              | 88 ++++++++++++++++++++--
 src/funcheck/purescala/PrettyPrinter.scala | 33 +++++++-
 src/funcheck/purescala/Trees.scala         | 28 +++----
 4 files changed, 182 insertions(+), 34 deletions(-)

diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala
index 29a235101..aee9044bb 100644
--- a/src/funcheck/CodeExtraction.scala
+++ b/src/funcheck/CodeExtraction.scala
@@ -16,17 +16,21 @@ trait CodeExtraction extends Extractors {
   import StructuralExtractors._
   import ExpressionExtractors._
 
+  private val varSubsts: scala.collection.mutable.Map[Identifier,Function0[Expr]] = scala.collection.mutable.Map.empty[Identifier,Function0[Expr]]
+
   def extractCode(unit: CompilationUnit): Program = { 
     import scala.collection.mutable.HashMap
 
-    // register where the symbols where extracted from
-    // val symbolDefMap = new HashMap[purescala.Symbols.Symbol,Tree]
-
     def s2ps(tree: Tree): Expr = toPureScala(unit)(tree) match {
       case Some(ex) => ex
       case None => stopIfErrors; scala.Predef.error("unreachable error.")
     }
 
+    def st2ps(tree: Tree): funcheck.purescala.TypeTrees.TypeTree = toPureScalaType(unit)(tree) match {
+      case Some(tt) => tt
+      case None => stopIfErrors; scala.Predef.error("unreachable error.")
+    }
+
     def extractTopLevelDef: ObjectDef = {
       val top = unit.body match {
         case p @ PackageDef(name, lst) if lst.size == 0 => {
@@ -63,7 +67,7 @@ trait CodeExtraction extends Extractors {
       var funDefs: List[FunDef] = Nil
 
       tmpl.body.foreach(tree => {
-        println("[[[ " + tree + "]]]\n");
+        //println("[[[ " + tree + "]]]\n");
         tree match {
         case ExObjectDef(o2, t2) => { objectDefs = extractObjectDef(o2, t2) :: objectDefs }
         case ExAbstractClass(o2) => ;
@@ -82,7 +86,30 @@ trait CodeExtraction extends Extractors {
     }
 
     def extractFunDef(name: Identifier, params: Seq[ValDef], tpt: Tree, body: Tree) = {
-      FunDef(name, scalaType2PureScala(unit, false)(tpt), Nil, null, None, None)
+      var realBody = body
+      var reqCont: Option[Expr] = None
+      var ensCont: Option[Expr] = None
+
+      realBody match {
+        case ExEnsuredExpression(body2, resId, contract) => {
+          varSubsts(resId) = (() => ResultVariable())
+          val c1 = s2ps(contract)
+          varSubsts.remove(resId)
+          realBody = body2
+          ensCont = Some(c1)
+        }
+        case _ => ;
+      }
+
+      realBody match {
+        case ExRequiredExpression(body3, contract) => {
+          realBody = body3
+          reqCont = Some(s2ps(contract))
+        }
+        case _ => ;
+      }
+      
+      FunDef(name, st2ps(tpt), Nil, s2ps(realBody), reqCont, ensCont)
     }
 
     // THE EXTRACTION CODE STARTS HERE
@@ -95,10 +122,6 @@ trait CodeExtraction extends Extractors {
       case _ => "<program>"
     }
 
-    println("Top level sym:")
-    println(topLevelObjDef)
-
-
     //Program(programName, ObjectDef("Object", Nil, Nil))
     Program(programName, topLevelObjDef)
   }
@@ -115,22 +138,44 @@ trait CodeExtraction extends Extractors {
     }
   }
 
+  def toPureScalaType(unit: CompilationUnit)(typeTree: Tree): Option[funcheck.purescala.TypeTrees.TypeTree] = {
+    try {
+      Some(scalaType2PureScala(unit, false)(typeTree))
+    } catch {
+      case ImpureCodeEncounteredException(_) => None
+    }
+  }
+
   /** Forces conversion from scalac AST to purescala AST, throws an Exception
    * if impossible. If not in 'silent mode', non-pure AST nodes are reported as
    * errors. */
   private def scala2PureScala(unit: CompilationUnit, silent: Boolean)(tree: Tree): Expr = {
-    tree match {
+    def rec(tr: Tree): Expr = tr match {
       case ExInt32Literal(v) => IntLiteral(v)
       case ExBooleanLiteral(v) => BooleanLiteral(v)
-
+      case ExIntIdentifier(id) => varSubsts.get(id) match {
+        case Some(fun) => fun()
+        case None => Variable(id)
+      }
+      case ExAnd(l, r) => And(rec(l), rec(r))
+      case ExPlus(l, r) => Plus(rec(l), rec(r))
+      case ExEquals(l, r) => Equals(rec(l), rec(r))
+      case ExGreaterThan(l, r) => GreaterThan(rec(l), rec(r))
+      case ExGreaterEqThan(l, r) => GreaterEquals(rec(l), rec(r))
+      case ExLessThan(l, r) => LessThan(rec(l), rec(r))
+      case ExLessEqThan(l, r) => LessEquals(rec(l), rec(r))
+  
       // default behaviour is to complain :)
       case _ => {
         if(!silent) {
+          println(tr)
           unit.error(tree.pos, "Could not extract as PureScala.")
         }
         throw ImpureCodeEncounteredException(tree)
       }
     }
+
+    rec(tree)
   }
 
   private def scalaType2PureScala(unit: CompilationUnit, silent: Boolean)(tree: Tree): funcheck.purescala.TypeTrees.TypeTree = {
diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala
index 9ce171cdb..0e0d200dc 100644
--- a/src/funcheck/Extractors.scala
+++ b/src/funcheck/Extractors.scala
@@ -21,9 +21,9 @@ trait Extractors {
       }
     }
 
-    object EnsuredExpression {
+    object ExEnsuredExpression {
       /** Extracts the 'ensuring' contract from an expression. */
-      def unapply(tree: Tree): Option[(Tree,Function)] = tree match {
+      def unapply(tree: Tree): Option[(Tree,String,Tree)] = tree match {
         case Apply(
           Select(
             Apply(
@@ -34,12 +34,12 @@ trait Extractors {
             ensuringName),
           (anonymousFun @ Function(ValDef(_, resultName, resultType, EmptyTree) :: Nil,
             contractBody)) :: Nil)
-          if("ensuring".equals(ensuringName.toString)) => Some((body,anonymousFun))
+          if("ensuring".equals(ensuringName.toString)) => Some((body, resultName.toString, contractBody))
         case _ => None
       }
     }
 
-    object RequiredExpression {
+    object ExRequiredExpression {
       /** Extracts the 'require' contract from an expression (only if it's the
        * first call in the block). */
       def unapply(tree: Tree): Option[(Tree,Tree)] = tree match {
@@ -88,7 +88,6 @@ trait Extractors {
     object ExMainFunctionDef {
       def unapply(dd: DefDef): Boolean = dd match {
         case DefDef(_, name, tparams, vparamss, tpt, rhs) if(name.toString == "main" && tparams.isEmpty && vparamss.size == 1 && vparamss(0).size == 1) => {
-          println("Looks like main " + vparamss(0)(0).symbol.tpe);
           true
         }
         case _ => false
@@ -120,6 +119,85 @@ trait Extractors {
         case _ => None
       }
     }
+
+    object ExIntIdentifier {
+      def unapply(tree: Tree): Option[String] = tree match {
+        case i: Ident if i.symbol.tpe == IntClass.tpe => Some(i.symbol.name.toString)
+        case _ => None
+      }
+    }
+
+    object ExAnd {
+      def unapply(tree: Tree): Option[(Tree,Tree)] = tree match {
+        case Apply(s @ Select(lhs, _), List(rhs)) if (s.symbol == Boolean_and) =>
+          Some((lhs,rhs))
+        case _ => None
+      }
+    }
+  
+    object ExOr {
+      def unapply(tree: Tree): Option[(Tree,Tree)] = tree match {
+        case Apply(s @ Select(lhs, _), List(rhs)) if (s.symbol == Boolean_or) =>
+          Some((lhs,rhs))
+        case _ => None
+      }
+    }
+  
+    object ExNot {
+      def unapply(tree: Tree): Option[Tree] = tree match {
+        case Select(t, n) if (n == nme.UNARY_!) => Some(t)
+        case _ => None
+      }
+    }
+  
+    object ExEquals {
+      def unapply(tree: Tree): Option[(Tree,Tree)] = tree match {
+        case Apply(Select(lhs, n), List(rhs)) if (n == nme.EQ) => Some((lhs,rhs))
+        case _ => None
+      }
+    }
+  
+    object ExNotEquals {
+      def unapply(tree: Tree): Option[(Tree,Tree)] = tree match {
+        case Apply(Select(lhs, n), List(rhs)) if (n == nme.NE) => Some((lhs,rhs))
+        case _ => None
+      }
+    }
+  
+    object ExLessThan {
+      def unapply(tree: Tree): Option[(Tree,Tree)] = tree match {
+        case Apply(Select(lhs, n), List(rhs)) if (n == nme.LT) => Some((lhs,rhs))
+        case _ => None
+      }
+    }
+  
+    object ExLessEqThan {
+      def unapply(tree: Tree): Option[(Tree,Tree)] = tree match {
+        case Apply(Select(lhs, n), List(rhs)) if (n == nme.LE) => Some((lhs,rhs))
+        case _ => None
+      }
+    }
+  
+    object ExGreaterThan {
+      def unapply(tree: Tree): Option[(Tree,Tree)] = tree match {
+        case Apply(Select(lhs, n), List(rhs)) if (n == nme.GT) => Some((lhs,rhs))
+        case _ => None
+      }
+    }
+  
+    object ExGreaterEqThan {
+      def unapply(tree: Tree): Option[(Tree,Tree)] = tree match {
+        case Apply(Select(lhs, n), List(rhs)) if (n == nme.GE) => Some((lhs,rhs))
+        case _ => None
+      }
+    }
+
+    object ExPlus {
+      def unapply(tree: Tree): Option[(Tree,Tree)] = tree match {
+        case Apply(Select(lhs, n), List(rhs)) if (n == nme.ADD) => Some((lhs,rhs))
+        case _ => None
+      }
+    }
   }
 
   object TypeExtractors {
diff --git a/src/funcheck/purescala/PrettyPrinter.scala b/src/funcheck/purescala/PrettyPrinter.scala
index bf1f9f303..35391e4ed 100644
--- a/src/funcheck/purescala/PrettyPrinter.scala
+++ b/src/funcheck/purescala/PrettyPrinter.scala
@@ -62,6 +62,7 @@ object PrettyPrinter {
   }
 
   private def pp(tree: Expr, sb: StringBuffer): StringBuffer = tree match {
+    case Variable(id) => sb.append(id)
     case And(exprs) => ppNary(sb, exprs, " \u2227 ")            // \land
     case Or(exprs) => ppNary(sb, exprs, " \u2228 ")             // \lor
     case Not(Equals(l, r)) => ppBinary(sb, l, r, " \u2260 ")    // \neq
@@ -69,6 +70,15 @@ object PrettyPrinter {
     case Equals(l,r) => ppBinary(sb, l, r, " = ")
     case IntLiteral(v) => sb.append(v)
     case BooleanLiteral(v) => sb.append(v)
+    case Plus(l,r) => ppBinary(sb, l, r, " + ")
+    case Minus(l,r) => ppBinary(sb, l, r, " - ")
+    case Times(l,r) => ppBinary(sb, l, r, " * ")
+    case Division(l,r) => ppBinary(sb, l, r, " / ")
+    case LessThan(l,r) => ppBinary(sb, l, r, " < ")
+    case GreaterThan(l,r) => ppBinary(sb, l, r, " > ")
+    case LessEquals(l,r) => ppBinary(sb, l, r, " \u2264 ")      // \leq
+    case GreaterEquals(l,r) => ppBinary(sb, l, r, " \u2265 ")   // \geq
+    
     case IfExpr(c, t, e) => {
       var nsb = sb
       nsb.append("if (")
@@ -81,6 +91,8 @@ object PrettyPrinter {
       nsb
     }
 
+    case ResultVariable() => sb.append("<res>")
+
     case _ => sb.append("Expr?")
   }
 
@@ -94,7 +106,9 @@ object PrettyPrinter {
   // DEFINITIONS
   // all definitions are printed with an end-of-line
   private def pp(defn: Definition, sb: StringBuffer, lvl: Int): StringBuffer = {
-    def ind(sb: StringBuffer): Unit = { sb.append("  " * lvl) }
+    def ind(sb: StringBuffer, customLevel: Int = lvl) : Unit = {
+        sb.append("  " * customLevel)
+    }
 
     defn match {
       case Program(id, mainObj) => {
@@ -124,6 +138,21 @@ object PrettyPrinter {
 
       case FunDef(id, rt, args, body, pre, post) => {
         var nsb = sb
+
+        pre.foreach(prec => {
+          ind(nsb)
+          nsb.append("@pre : ")
+          nsb = pp(prec, nsb)
+          nsb.append("\n")
+        })
+
+        post.foreach(postc => {
+          ind(nsb)
+          nsb.append("@post: ")
+          nsb = pp(postc, nsb)
+          nsb.append("\n")
+        })
+
         ind(nsb)
         nsb.append("def ")
         nsb.append(id)
@@ -137,7 +166,7 @@ object PrettyPrinter {
         nsb = pp(rt, nsb)
         nsb.append(" = {\n")
 
-        ind(nsb)
+        ind(nsb, lvl+1)
         nsb = pp(body, nsb)
         nsb.append("\n")
 
diff --git a/src/funcheck/purescala/Trees.scala b/src/funcheck/purescala/Trees.scala
index f995ded93..1a14d5d4b 100644
--- a/src/funcheck/purescala/Trees.scala
+++ b/src/funcheck/purescala/Trees.scala
@@ -11,22 +11,6 @@ object Trees {
 
   sealed abstract class Expr extends Typed {
     override def toString: String = PrettyPrinter(this)
-
-    // private var _scope: Option[Scope] = None
-    //
-    // def scope: Scope =
-    //   if(_scope.isEmpty)
-    //     throw new Exception("Undefined scope.")
-    //   else
-    //     _scope.get
-
-    // def scope_=(s: Scope): Unit = {
-    //   if(_scope.isEmpty) {
-    //     _scope = Some(s)
-    //   } else {
-    //     throw new Exception("Redefining scope.")
-    //   }
-    // }
   }
 
   /* Control flow */
@@ -51,6 +35,15 @@ object Trees {
   // We don't handle Seq stars for now.
 
   /* Propositional logic */
+  case object And {
+    def apply(l: Expr, r: Expr): Expr = (l,r) match {
+      case (And(exs1), And(exs2)) => And(exs1 ++ exs2)
+      case (And(exs1), ex2) => And(exs1 :+ ex2)
+      case (ex1, And(exs2)) => And(exs2 :+ ex1)
+      case (ex1, ex2) => And(List(ex1, ex2))
+    }
+  }
+
   case class And(exprs: Seq[Expr]) extends Expr
   case class Or(exprs: Seq[Expr]) extends Expr
   case class Not(expr: Expr) extends Expr 
@@ -63,6 +56,9 @@ object Trees {
   // variable, which would also give us its type.
   case class Variable(id: Identifier) extends Expr
 
+  // represents the result in post-conditions
+  case class ResultVariable() extends Expr
+
   sealed abstract class Literal[T] extends Expr {
     val value: T
   }
-- 
GitLab