From 42e1a27f7c24fe9c68bf7fbdb9adaf1e379d1519 Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Mon, 28 Jun 2010 23:36:15 +0000
Subject: [PATCH] added extraction of vals, as well as Let trees in PureScala
 expressions.

---
 src/funcheck/CodeExtraction.scala | 10 ++++++++++
 src/funcheck/Extractors.scala     | 12 ++++++++++++
 src/purescala/Analysis.scala      | 32 +++++++++++++++++++++----------
 src/purescala/PrettyPrinter.scala |  3 +++
 src/purescala/Trees.scala         |  4 ++--
 testcases/IntOperations.scala     |  4 +++-
 6 files changed, 52 insertions(+), 13 deletions(-)

diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala
index 1e189924c..c8c887a13 100644
--- a/src/funcheck/CodeExtraction.scala
+++ b/src/funcheck/CodeExtraction.scala
@@ -313,6 +313,16 @@ trait CodeExtraction extends Extractors {
     }
 
     def rec(tr: Tree): Expr = tr match {
+      case ExValDef(vs, tpt, bdy, rst) => {
+        val binderTpe = scalaType2PureScala(unit, silent)(tpt.tpe)
+        val newID = FreshIdentifier(vs.name.toString).setType(binderTpe)
+        val oldSubsts = varSubsts
+        val valTree = rec(bdy)
+        varSubsts(vs) = (() => Variable(newID))
+        val restTree = rec(rst)
+        varSubsts.remove(vs)
+        Let(newID, valTree, restTree)
+      }
       case ExInt32Literal(v) => IntLiteral(v).setType(Int32Type)
       case ExBooleanLiteral(v) => BooleanLiteral(v).setType(BooleanType)
       case ExIdentifier(sym,tpt) => varSubsts.get(sym) match {
diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala
index 2d108abdc..da04302b6 100644
--- a/src/funcheck/Extractors.scala
+++ b/src/funcheck/Extractors.scala
@@ -53,6 +53,18 @@ trait Extractors {
       }
     }
 
+    object ExValDef {
+      /** Extracts val's in the head of blocks. */
+      def unapply(tree: Block): Option[(Symbol,Tree,Tree,Tree)] = tree match {
+        case Block((vd @ ValDef(_, _, tpt, rhs)) :: rest, expr) => 
+          if(rest.isEmpty)
+            Some((vd.symbol, tpt, rhs, expr))
+          else
+            Some((vd.symbol, tpt, rhs, Block(rest, expr)))
+        case _ => None
+      }
+    }
+
     object ExObjectDef {
       /** Matches an object with no type parameters, and regardless of its
        * visibility. Does not match on the automatically generated companion
diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala
index 9018e7565..20defdec3 100644
--- a/src/purescala/Analysis.scala
+++ b/src/purescala/Analysis.scala
@@ -41,7 +41,7 @@ class Analysis(val program: Program) {
                   reporter.info(vc)
 
                   if(Settings.runDefaultExtensions) {
-                    val (z3f,stupidMap) = toZ3Formula(z3, vc)
+                    val z3f = toZ3Formula(z3, vc)
                     z3.assertCnstr(z3.mkNot(z3f))
                     //z3.print
                     z3.checkAndGetModel() match {
@@ -133,17 +133,30 @@ class Analysis(val program: Program) {
         rec(expr)
     }
 
-    def toZ3Formula(z3: Z3Context, expr: Expr) : (Z3AST,Map[Identifier,Z3AST]) = {
-        val intSort = z3.mkIntSort()
-        var varMap: Map[Identifier,Z3AST] = Map.empty
+    def toZ3Formula(z3: Z3Context, expr: Expr) : (Z3AST) = {
+        lazy val intSort  = z3.mkIntSort()
+        lazy val boolSort = z3.mkBoolSort()
+
+        // because we create identifiers the first time we see them, this is
+        // convenient.
+        var z3Vars: Map[Identifier,Z3AST] = Map.empty
 
         def rec(ex: Expr) : Z3AST = ex match {
-            case v @ Variable(id) => varMap.get(id) match {
+            case Let(i,e,b) => {
+              z3Vars = z3Vars + (i -> rec(e))
+              rec(b)
+            }
+            case v @ Variable(id) => z3Vars.get(id) match {
                 case Some(ast) => ast
                 case None => {
-                    assert(v.getType == Int32Type)
-                    val newAST = z3.mkConst(z3.mkStringSymbol(id.name), intSort)
-                    varMap = varMap + (id -> newAST)
+                    val newAST = if(v.getType == Int32Type) {
+                      z3.mkConst(z3.mkStringSymbol(id.name), intSort)
+                    } else if(v.getType == BooleanType) {
+                      z3.mkConst(z3.mkStringSymbol(id.name), boolSort)
+                    } else {
+                      reporter.fatalError("Unsupported type in Z3 transformation: " + v.getType)
+                    }
+                    z3Vars = z3Vars + (id -> newAST)
                     newAST
                 }
             } 
@@ -168,7 +181,6 @@ class Analysis(val program: Program) {
             case _ => scala.Predef.error("Can't handle this in translation to Z3: " + ex)
         }
 
-        val res = rec(expr)
-        (res,varMap)
+        rec(expr)
     }
 }
diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala
index ad018307c..4e312210a 100644
--- a/src/purescala/PrettyPrinter.scala
+++ b/src/purescala/PrettyPrinter.scala
@@ -67,6 +67,9 @@ object PrettyPrinter {
 
   private def pp(tree: Expr, sb: StringBuffer, lvl: Int): StringBuffer = tree match {
     case Variable(id) => sb.append(id)
+    case Let(b,d,e) => {
+        pp(e, pp(d, sb.append("(let (" + b + " = "), lvl).append(") in "), lvl).append(")")
+    }
     case And(exprs) => ppNary(sb, exprs, "(", " \u2227 ", ")", lvl)            // \land
     case Or(exprs) => ppNary(sb, exprs, "(", " \u2228 ", ")", lvl)             // \lor
     case Not(Equals(l, r)) => ppBinary(sb, l, r, " \u2260 ", lvl)    // \neq
diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala
index c0a61240a..2f58fd3cb 100644
--- a/src/purescala/Trees.scala
+++ b/src/purescala/Trees.scala
@@ -13,8 +13,8 @@ object Trees {
   }
 
   /* Like vals */
-  case class Let(binder: Identifier, expression: Expr) extends Expr {
-    val et = expression.getType
+  case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr {
+    val et = body.getType
     if(et != NoType)
       setType(et)
   }
diff --git a/testcases/IntOperations.scala b/testcases/IntOperations.scala
index d38896f80..98a757434 100644
--- a/testcases/IntOperations.scala
+++ b/testcases/IntOperations.scala
@@ -1,7 +1,9 @@
 object IntOperations {
     def sum(a: Int, b: Int) : Int = {
         require(b >= 0)
-        a + b
+        val b2 = b - 1
+        val b3 = b2 + 1
+        a + b3
     } ensuring(_ >= a)
 
 
-- 
GitLab