From 1c831e128d836996c32ac4fd004ea94b4357bfc4 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Sat, 13 Jul 2013 19:17:49 +0200
Subject: [PATCH] Promote LetDef to purescala

---
 src/main/scala/leon/Main.scala                |   2 +-
 .../scala/leon/purescala/Extractors.scala     |  54 +++++++++
 .../FunctionClosure.scala                     |  19 ++-
 .../scala/leon/purescala/PrettyPrinter.scala  |  10 +-
 .../scala/leon/purescala/ScalaPrinter.scala   |  10 ++
 src/main/scala/leon/purescala/TreeOps.scala   |  35 +++++-
 src/main/scala/leon/purescala/Trees.scala     |   8 ++
 src/main/scala/leon/synthesis/CostModel.scala |   1 -
 src/main/scala/leon/synthesis/Solution.scala  |   4 +-
 src/main/scala/leon/xlang/TreeOps.scala       |  42 -------
 src/main/scala/leon/xlang/Trees.scala         | 112 +-----------------
 .../scala/leon/xlang/XlangAnalysisPhase.scala |   2 +-
 12 files changed, 126 insertions(+), 173 deletions(-)
 rename src/main/scala/leon/{xlang => purescala}/FunctionClosure.scala (96%)

diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala
index bd9ee762e..6bec43bb3 100644
--- a/src/main/scala/leon/Main.scala
+++ b/src/main/scala/leon/Main.scala
@@ -11,7 +11,7 @@ object Main {
       xlang.ArrayTransformation,
       xlang.EpsilonElimination,
       xlang.ImperativeCodeElimination,
-      xlang.FunctionClosure,
+      purescala.FunctionClosure,
       xlang.XlangAnalysisPhase,
       synthesis.SynthesisPhase,
       termination.TerminationPhase,
diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index 5c2571719..02e0e7c29 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -128,6 +128,60 @@ object Extractors {
 
            MatchExpr(es(0), newcases)
            }))
+      case LetDef(fd, body) =>
+        fd.body match {
+          case Some(b) =>
+            (fd.precondition, fd.postcondition) match {
+              case (None, None) =>
+                  Some((Seq(b, body), (as: Seq[Expr]) => {
+                    fd.body = Some(as(0))
+                    LetDef(fd, as(1))
+                  }))
+              case (Some(pre), None) =>
+                  Some((Seq(b, body, pre), (as: Seq[Expr]) => {
+                    fd.body = Some(as(0))
+                    fd.precondition = Some(as(2))
+                    LetDef(fd, as(1))
+                  }))
+              case (None, Some(post)) =>
+                  Some((Seq(b, body, post), (as: Seq[Expr]) => {
+                    fd.body = Some(as(0))
+                    fd.postcondition = Some(as(2))
+                    LetDef(fd, as(1))
+                  }))
+              case (Some(pre), Some(post)) =>
+                  Some((Seq(b, body, pre, post), (as: Seq[Expr]) => {
+                    fd.body = Some(as(0))
+                    fd.precondition = Some(as(2))
+                    fd.postcondition = Some(as(3))
+                    LetDef(fd, as(1))
+                  }))
+            }
+
+          case None => //case no body, we still need to handle remaining cases
+            (fd.precondition, fd.postcondition) match {
+              case (None, None) =>
+                  Some((Seq(body), (as: Seq[Expr]) => {
+                    LetDef(fd, as(0))
+                  }))
+              case (Some(pre), None) =>
+                  Some((Seq(body, pre), (as: Seq[Expr]) => {
+                    fd.precondition = Some(as(1))
+                    LetDef(fd, as(0))
+                  }))
+              case (None, Some(post)) =>
+                  Some((Seq(body, post), (as: Seq[Expr]) => {
+                    fd.postcondition = Some(as(1))
+                    LetDef(fd, as(0))
+                  }))
+              case (Some(pre), Some(post)) =>
+                  Some((Seq(body, pre, post), (as: Seq[Expr]) => {
+                    fd.precondition = Some(as(1))
+                    fd.postcondition = Some(as(2))
+                    LetDef(fd, as(0))
+                  }))
+            }
+        }
       case (ex: NAryExtractable) => ex.extract
       case _ => None
     }
diff --git a/src/main/scala/leon/xlang/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala
similarity index 96%
rename from src/main/scala/leon/xlang/FunctionClosure.scala
rename to src/main/scala/leon/purescala/FunctionClosure.scala
index 3b6a74394..e4d006e2e 100644
--- a/src/main/scala/leon/xlang/FunctionClosure.scala
+++ b/src/main/scala/leon/purescala/FunctionClosure.scala
@@ -1,17 +1,14 @@
 /* Copyright 2009-2013 EPFL, Lausanne */
 
 package leon
-package xlang
-
-import leon.TransformationPhase
-import leon.LeonContext
-import leon.purescala.Common._
-import leon.purescala.Definitions._
-import leon.purescala.Trees._
-import leon.purescala.Extractors._
-import leon.purescala.TreeOps._
-import leon.purescala.TypeTrees._
-import leon.xlang.Trees._
+package purescala
+
+import Common._
+import Definitions._
+import Trees._
+import Extractors._
+import TreeOps._
+import TypeTrees._
 
 object FunctionClosure extends TransformationPhase {
 
diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index d13fd83e5..147c4635b 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -57,7 +57,6 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) {
     case Variable(id) => sb.append(idToString(id))
     case DeBruijnIndex(idx) => sb.append("_" + idx)
     case LetTuple(bs,d,e) =>
-        //pp(e, pp(d, sb.append("(let (" + b + " := "), lvl).append(") in "), lvl).append(")")
       sb.append("(let (" + bs.map(idToString _).mkString(",") + " := ");
       pp(d, lvl)
       sb.append(") in\n")
@@ -66,7 +65,6 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) {
       sb.append(")")
 
     case Let(b,d,e) =>
-        //pp(e, pp(d, sb.append("(let (" + b + " := "), lvl).append(") in "), lvl).append(")")
       sb.append("(let (" + idToString(b) + " := ");
       pp(d, lvl)
       sb.append(") in\n")
@@ -74,6 +72,14 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) {
       pp(e, lvl+1)
       sb.append(")")
 
+    case LetDef(fd,body) =>
+      sb.append("\n")
+      pp(fd, lvl+1)
+      sb.append("\n")
+      sb.append("\n")
+      ind(lvl)
+      pp(body, lvl)
+
     case And(exprs) => ppNary(exprs, "(", " \u2227 ", ")", lvl)            // \land
     case Or(exprs) => ppNary(exprs, "(", " \u2228 ", ")", lvl)             // \lor
     case Not(Equals(l, r)) => ppBinary(l, r, " \u2260 ", lvl)    // \neq
diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala
index c0dbb34cf..b1d82854d 100644
--- a/src/main/scala/leon/purescala/ScalaPrinter.scala
+++ b/src/main/scala/leon/purescala/ScalaPrinter.scala
@@ -55,11 +55,21 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb
       sb.append("}\n")
       ind(lvl)
 
+    case LetDef(fd, body) =>
+      sb.append("{\n")
+      pp(fd, lvl+1)
+      sb.append("\n")
+      sb.append("\n")
+      ind(lvl)
+      pp(body, lvl)
+      sb.append("}\n")
+
     case And(exprs) => ppNary(exprs, "(", " && ", ")", lvl)            // \land
     case Or(exprs) => ppNary(exprs, "(", " || ", ")", lvl)             // \lor
     case Not(Equals(l, r)) => ppBinary(l, r, " != ", lvl)    // \neq
     case UMinus(expr) => ppUnary(expr, "-(", ")", lvl)
     case Equals(l,r) => ppBinary(l, r, " == ", lvl)
+
     case IntLiteral(v) => sb.append(v)
     case BooleanLiteral(v) => sb.append(v)
     case StringLiteral(s) => sb.append("\"" + s + "\"")
diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index 77ded68f1..45a5ba008 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -11,7 +11,6 @@ object TreeOps {
   import Common._
   import TypeTrees._
   import Definitions._
-  import xlang.Trees.LetDef
   import Trees._
   import Extractors._
 
@@ -1441,6 +1440,27 @@ object TreeOps {
         val sb = rec(b, scope.register(i -> si))
         Let(si, se, sb)
 
+      case LetDef(fd: FunDef, body: Expr) =>
+        val newId    = genId(fd.id, scope)
+        var newScope = scope.register(fd.id -> newId)
+
+        val newArgs = for(VarDecl(id, tpe) <- fd.args) yield {
+          val newArg = genId(id, newScope)
+          newScope = newScope.register(id -> newArg)
+          VarDecl(newArg, tpe)
+        }
+
+        val newFd = new FunDef(newId, fd.returnType, newArgs)
+
+        newScope = newScope.registerFunDef(fd -> newFd)
+
+        newFd.body          = fd.body.map(b => rec(b, newScope))
+        newFd.precondition  = fd.precondition.map(pre => rec(pre, newScope))
+        newFd.postcondition = fd.postcondition.map(post => rec(post, newScope))
+
+
+        LetDef(newFd, rec(body, newScope))
+
       case LetTuple(is, e, b) =>
         var newScope = scope
         val sis = for (i <- is) yield {
@@ -1802,4 +1822,17 @@ object TreeOps {
       false
   }
 
+  def containsLetDef(expr: Expr): Boolean = {
+    def convert(t : Expr) : Boolean = t match {
+      case (l : LetDef) => true
+      case _ => false
+    }
+    def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2
+    def compute(t : Expr, c : Boolean) = t match {
+      case (l : LetDef) => true
+      case _ => c
+    }
+    treeCatamorphism(convert, combine, compute, expr)
+  }
+
 }
diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index b0a7d3220..f4f18b1b7 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -51,6 +51,14 @@ object Trees {
     val fixedType = body.getType
   }
 
+  case class LetDef(fd: FunDef, body: Expr) extends Expr {
+    val et = body.getType
+    if(et != Untyped)
+      setType(et)
+
+  }
+
+
   /* Control flow */
   case class FunctionInvocation(funDef: FunDef, args: Seq[Expr]) extends Expr with FixedType with ScalacPositional {
     val fixedType = funDef.returnType
diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala
index 128529c25..b4e42a0b1 100644
--- a/src/main/scala/leon/synthesis/CostModel.scala
+++ b/src/main/scala/leon/synthesis/CostModel.scala
@@ -5,7 +5,6 @@ package synthesis
 
 import purescala.Trees._
 import purescala.TreeOps._
-import leon.xlang.Trees.LetDef
 
 import synthesis.search.Cost
 
diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala
index 61a60a2d8..e6a5620cf 100644
--- a/src/main/scala/leon/synthesis/Solution.scala
+++ b/src/main/scala/leon/synthesis/Solution.scala
@@ -7,8 +7,6 @@ import purescala.Trees._
 import purescala.TypeTrees.{TypeTree,TupleType}
 import purescala.Definitions._
 import purescala.TreeOps._
-import xlang.Trees.LetDef
-import xlang.TreeOps.{ScopeSimplifier => XLangScopeSimplifier}
 import solvers.z3.UninterpretedZ3Solver
 
 // Defines a synthesis solution of the form:
@@ -45,7 +43,7 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) {
       simplifyTautologies(uninterpretedZ3)(_),
       simplifyLets _,
       rewriteTuples _,
-      (new ScopeSimplifier with XLangScopeSimplifier).transform _
+      (new ScopeSimplifier).transform _
     )
 
     simplifiers.foldLeft(toExpr){ (x, sim) => sim(x) }
diff --git a/src/main/scala/leon/xlang/TreeOps.scala b/src/main/scala/leon/xlang/TreeOps.scala
index 0e8c362aa..4c01aa639 100644
--- a/src/main/scala/leon/xlang/TreeOps.scala
+++ b/src/main/scala/leon/xlang/TreeOps.scala
@@ -76,46 +76,4 @@ object TreeOps {
     }
     searchAndReplaceDFS(applyToTree)(expr)
   }
-
-  def containsLetDef(expr: Expr): Boolean = {
-    def convert(t : Expr) : Boolean = t match {
-      case (l : LetDef) => true
-      case _ => false
-    }
-    def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2
-    def compute(t : Expr, c : Boolean) = t match {
-      case (l : LetDef) => true
-      case _ => c
-    }
-    treeCatamorphism(convert, combine, compute, expr)
-  }
-
-  trait ScopeSimplifier extends purescala.TreeOps.ScopeSimplifier {
-    override def rec(e: Expr, scope: Scope) = e match { 
-      case LetDef(fd: FunDef, body: Expr) =>
-        val newId    = genId(fd.id, scope)
-        var newScope = scope.register(fd.id -> newId)
-
-        val newArgs = for(VarDecl(id, tpe) <- fd.args) yield {
-          val newArg = genId(id, newScope)
-          newScope = newScope.register(id -> newArg)
-          VarDecl(newArg, tpe)
-        }
-
-        val newFd = new FunDef(newId, fd.returnType, newArgs)
-
-        newScope = newScope.registerFunDef(fd -> newFd)
-
-        newFd.body          = fd.body.map(b => rec(b, newScope))
-        newFd.precondition  = fd.precondition.map(pre => rec(pre, newScope))
-        newFd.postcondition = fd.postcondition.map(post => rec(post, newScope))
-
-
-        LetDef(newFd, rec(body, newScope))
-
-      case _ =>
-        super.rec(e, scope)
-    }
-  }
-
 }
diff --git a/src/main/scala/leon/xlang/Trees.scala b/src/main/scala/leon/xlang/Trees.scala
index 22cb9e639..866cfa5d9 100644
--- a/src/main/scala/leon/xlang/Trees.scala
+++ b/src/main/scala/leon/xlang/Trees.scala
@@ -152,117 +152,7 @@ object Trees {
     }
   }
 
-  case class LetDef(fd: FunDef, body: Expr) extends Expr with NAryExtractable with PrettyPrintable {
-    val et = body.getType
-    if(et != Untyped)
-      setType(et)
-
-    def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = {
-      fd.body match {
-        case Some(b) =>
-          (fd.precondition, fd.postcondition) match {
-            case (None, None) =>
-                Some((Seq(b, body), (as: Seq[Expr]) => {
-                  //val nfd = new FunDef(fd.id, fd.returnType, fd.args)
-                  //nfd.body = Some(as(0))
-                  //LetDef(nfd, as(1))
-                  fd.body = Some(as(0))
-                  LetDef(fd, as(1))
-                }))
-            case (Some(pre), None) =>
-                Some((Seq(b, body, pre), (as: Seq[Expr]) => {
-                  //val nfd = new FunDef(fd.id, fd.returnType, fd.args)
-                  //nfd.body = Some(as(0))
-                  //nfd.precondition = Some(as(2))
-                  //LetDef(nfd, as(1))
-                  fd.body = Some(as(0))
-                  fd.precondition = Some(as(2))
-                  LetDef(fd, as(1))
-                }))
-            case (None, Some(post)) =>
-                Some((Seq(b, body, post), (as: Seq[Expr]) => {
-                  //val nfd = new FunDef(fd.id, fd.returnType, fd.args)
-                  //nfd.body = Some(as(0))
-                  //nfd.postcondition = Some(as(2))
-                  //LetDef(nfd, as(1))
-                  fd.body = Some(as(0))
-                  fd.postcondition = Some(as(2))
-                  LetDef(fd, as(1))
-                }))
-            case (Some(pre), Some(post)) =>
-                Some((Seq(b, body, pre, post), (as: Seq[Expr]) => {
-                  //val nfd = new FunDef(fd.id, fd.returnType, fd.args)
-                  //nfd.body = Some(as(0))
-                  //nfd.precondition = Some(as(2))
-                  //nfd.postcondition = Some(as(3))
-                  //LetDef(nfd, as(1))
-                  fd.body = Some(as(0))
-                  fd.precondition = Some(as(2))
-                  fd.postcondition = Some(as(3))
-                  LetDef(fd, as(1))
-                }))
-          }
-            
-        case None => //case no body, we still need to handle remaining cases
-          (fd.precondition, fd.postcondition) match {
-            case (None, None) =>
-                Some((Seq(body), (as: Seq[Expr]) => {
-                  //val nfd = new FunDef(fd.id, fd.returnType, fd.args)
-                  //LetDef(nfd, as(0))
-                  LetDef(fd, as(0))
-                }))
-            case (Some(pre), None) =>
-                Some((Seq(body, pre), (as: Seq[Expr]) => {
-                  //val nfd = new FunDef(fd.id, fd.returnType, fd.args)
-                  //nfd.precondition = Some(as(1))
-                  //LetDef(nfd, as(0))
-                  fd.precondition = Some(as(1))
-                  LetDef(fd, as(0))
-                }))
-            case (None, Some(post)) =>
-                Some((Seq(body, post), (as: Seq[Expr]) => {
-                  //val nfd = new FunDef(fd.id, fd.returnType, fd.args)
-                  //nfd.postcondition = Some(as(1))
-                  //LetDef(nfd, as(0))
-                  fd.postcondition = Some(as(1))
-                  LetDef(fd, as(0))
-                }))
-            case (Some(pre), Some(post)) =>
-                Some((Seq(body, pre, post), (as: Seq[Expr]) => {
-                  //val nfd = new FunDef(fd.id, fd.returnType, fd.args)
-                  //nfd.precondition = Some(as(1))
-                  //nfd.postcondition = Some(as(2))
-                  //LetDef(nfd, as(0))
-                  fd.precondition = Some(as(1))
-                  fd.postcondition = Some(as(2))
-                  LetDef(fd, as(0))
-                }))
-          }
-      }
-    }
-
-    def printWith(lvl: Int, printer: PrettyPrinter) {
-      printer match {
-        case _: ScalaPrinter =>
-          printer.append("{\n")
-          printer.pp(fd, lvl+1)
-          printer.append("\n")
-          printer.append("\n")
-          printer.ind(lvl)
-          printer.pp(body, lvl)
-          printer.append("}\n")
-        case _ =>
-          printer.append("\n")
-          printer.pp(fd, lvl+1)
-          printer.append("\n")
-          printer.append("\n")
-          printer.ind(lvl)
-          printer.pp(body, lvl)
-      }
-    }
-  }
-
-  case class Waypoint(i: Int, expr: Expr) extends Expr with UnaryExtractable with PrettyPrintable{
+  case class Waypoint(i: Int, expr: Expr) extends Expr with UnaryExtractable with PrettyPrintable {
     def extract: Option[(Expr, (Expr)=>Expr)] = {
       Some((expr, (e: Expr) => Waypoint(i, e)))
     }
diff --git a/src/main/scala/leon/xlang/XlangAnalysisPhase.scala b/src/main/scala/leon/xlang/XlangAnalysisPhase.scala
index 4fb3c98bc..3ec72ee09 100644
--- a/src/main/scala/leon/xlang/XlangAnalysisPhase.scala
+++ b/src/main/scala/leon/xlang/XlangAnalysisPhase.scala
@@ -16,7 +16,7 @@ object XlangAnalysisPhase extends LeonPhase[Program, VerificationReport] {
     val pgm1 = ArrayTransformation(ctx, pgm)
     val pgm2 = EpsilonElimination(ctx, pgm1)
     val (pgm3, wasLoop) = ImperativeCodeElimination.run(ctx)(pgm2)
-    val pgm4 = FunctionClosure.run(ctx)(pgm3)
+    val pgm4 = purescala.FunctionClosure.run(ctx)(pgm3)
 
     def functionWasLoop(fd: FunDef): Boolean = fd.orig match {
       case None => false //meaning, this was a top level function
-- 
GitLab