From 2d101d0438b84d2feb9c493075e772f721790906 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Tue, 11 Dec 2012 16:51:13 +0100
Subject: [PATCH] fix a bug in the nary extractor of letdef: need to apply it
 even without body

---
 src/main/scala/leon/xlang/Trees.scala | 301 +++++++++++++++++++++++---
 1 file changed, 274 insertions(+), 27 deletions(-)

diff --git a/src/main/scala/leon/xlang/Trees.scala b/src/main/scala/leon/xlang/Trees.scala
index e04fcbb7e..8995aba73 100644
--- a/src/main/scala/leon/xlang/Trees.scala
+++ b/src/main/scala/leon/xlang/Trees.scala
@@ -7,6 +7,7 @@ import leon.purescala.Trees._
 import leon.purescala.Definitions._
 import leon.purescala.Extractors._
 import leon.purescala.PrettyPrinter._
+import leon.purescala.ScalaPrinter._
 
 object Trees {
 
@@ -15,7 +16,7 @@ object Trees {
     sb
   }
 
-  case class Block(exprs: Seq[Expr], last: Expr) extends Expr with NAryExtractable with PrettyPrintable {
+  case class Block(exprs: Seq[Expr], last: Expr) extends Expr with NAryExtractable with PrettyPrintable with ScalaPrintable {
     //val t = last.getType
     //if(t != Untyped)
      // setType(t)
@@ -25,11 +26,15 @@ object Trees {
       Some((args :+ rest, exprs => Block(exprs.init, exprs.last)))
     }
 
-    def pp(sb: StringBuffer, lvl: Int, rp: (Expr, StringBuffer, Int) => StringBuffer): StringBuffer = {
+    def pp(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => StringBuffer, 
+      tp: (TypeTree, StringBuffer, Int) => StringBuffer,
+      dp: (Definition, StringBuffer, Int) => StringBuffer
+    ): StringBuffer = {
       sb.append("{\n")
       (exprs :+ last).foreach(e => {
         ind(sb, lvl+1)
-        rp(e, sb, lvl+1)
+        ep(e, sb, lvl+1)
         sb.append("\n")
       })
       ind(sb, lvl)
@@ -37,25 +42,58 @@ object Trees {
       sb
     }
 
+    def ppScala(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => Unit, 
+      tp: (TypeTree, StringBuffer, Int) => Unit,
+      dp: (Definition, StringBuffer, Int) => Unit
+    ): StringBuffer = {
+      sb.append("{\n")
+      (exprs :+ last).foreach(e => {
+        ind(sb, lvl+1)
+        ep(e, sb, lvl+1)
+        sb.append("\n")
+      })
+      ind(sb, lvl)
+      sb.append("}\n")
+      sb
+    }
   }
 
-  case class Assignment(varId: Identifier, expr: Expr) extends Expr with FixedType with UnaryExtractable with PrettyPrintable {
+  case class Assignment(varId: Identifier, expr: Expr) extends Expr with FixedType with UnaryExtractable with PrettyPrintable with ScalaPrintable {
     val fixedType = UnitType
 
     def extract: Option[(Expr, (Expr)=>Expr)] = {
       Some((expr, Assignment(varId, _)))
     }
-    def pp(sb: StringBuffer, lvl: Int, rp: (Expr, StringBuffer, Int) => StringBuffer): StringBuffer = {
+    def pp(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => StringBuffer, 
+      tp: (TypeTree, StringBuffer, Int) => StringBuffer,
+      dp: (Definition, StringBuffer, Int) => StringBuffer
+    ): StringBuffer = {
+      var nsb: StringBuffer = sb
+      nsb.append("(")
+      nsb.append(varId.name)
+      nsb.append(" = ")
+      nsb = ep(expr, nsb, lvl)
+      nsb.append(")")
+      nsb
+    }
+
+    def ppScala(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => Unit, 
+      tp: (TypeTree, StringBuffer, Int) => Unit,
+      dp: (Definition, StringBuffer, Int) => Unit
+    ): StringBuffer = {
       var nsb: StringBuffer = sb
       nsb.append("(")
       nsb.append(varId.name)
       nsb.append(" = ")
-      nsb = rp(expr, nsb, lvl)
+      ep(expr, nsb, lvl)
       nsb.append(")")
       nsb
     }
   }
-  case class While(cond: Expr, body: Expr) extends Expr with FixedType with ScalacPositional with BinaryExtractable with PrettyPrintable {
+  case class While(cond: Expr, body: Expr) extends Expr with FixedType with ScalacPositional with BinaryExtractable with PrettyPrintable with ScalaPrintable {
     val fixedType = UnitType
     var invariant: Option[Expr] = None
 
@@ -67,51 +105,105 @@ object Trees {
       Some((cond, body, (t1, t2) => While(t1, t2).setInvariant(this.invariant).setPosInfo(this)))
     }
 
-    def pp(sb: StringBuffer, lvl: Int, rp: (Expr, StringBuffer, Int) => StringBuffer): StringBuffer = {
+    def pp(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => StringBuffer, 
+      tp: (TypeTree, StringBuffer, Int) => StringBuffer,
+      dp: (Definition, StringBuffer, Int) => StringBuffer
+    ): StringBuffer = {
       invariant match {
         case Some(inv) => {
           sb.append("\n")
           ind(sb, lvl)
           sb.append("@invariant: ")
-          rp(inv, sb, lvl)
+          ep(inv, sb, lvl)
           sb.append("\n")
           ind(sb, lvl)
         }
         case None =>
       }
       sb.append("while(")
-      rp(cond, sb, lvl)
+      ep(cond, sb, lvl)
       sb.append(")\n")
       ind(sb, lvl+1)
-      rp(body, sb, lvl+1)
+      ep(body, sb, lvl+1)
+      sb.append("\n")
+    }
+
+    def ppScala(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => Unit, 
+      tp: (TypeTree, StringBuffer, Int) => Unit,
+      dp: (Definition, StringBuffer, Int) => Unit
+    ): StringBuffer = {
+      invariant match {
+        case Some(inv) => {
+          sb.append("\n")
+          ind(sb, lvl)
+          sb.append("@invariant: ")
+          ep(inv, sb, lvl)
+          sb.append("\n")
+          ind(sb, lvl)
+        }
+        case None =>
+      }
+      sb.append("while(")
+      ep(cond, sb, lvl)
+      sb.append(")\n")
+      ind(sb, lvl+1)
+      ep(body, sb, lvl+1)
       sb.append("\n")
     }
 
   }
 
-  case class Epsilon(pred: Expr) extends Expr with ScalacPositional with UnaryExtractable with PrettyPrintable {
+  case class Epsilon(pred: Expr) extends Expr with ScalacPositional with UnaryExtractable with PrettyPrintable with ScalaPrintable {
     def extract: Option[(Expr, (Expr)=>Expr)] = {
       Some((pred, (expr: Expr) => Epsilon(expr).setType(this.getType).setPosInfo(this)))
     }
 
-    def pp(sb: StringBuffer, lvl: Int, rp: (Expr, StringBuffer, Int) => StringBuffer): StringBuffer = {
+    def pp(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => StringBuffer, 
+      tp: (TypeTree, StringBuffer, Int) => StringBuffer,
+      dp: (Definition, StringBuffer, Int) => StringBuffer
+    ): StringBuffer = {
       var nsb = sb
       nsb.append("epsilon(x" + this.posIntInfo._1 + "_" + this.posIntInfo._2 + ". ")
-      nsb = rp(pred, nsb, lvl)
+      nsb = ep(pred, nsb, lvl)
       nsb.append(")")
       nsb
     }
 
+    def ppScala(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => Unit, 
+      tp: (TypeTree, StringBuffer, Int) => Unit,
+      dp: (Definition, StringBuffer, Int) => Unit
+    ): StringBuffer = {
+      sys.error("Not Scala Code")
+    }
+
   }
-  case class EpsilonVariable(pos: (Int, Int)) extends Expr with Terminal with PrettyPrintable {
-    def pp(sb: StringBuffer, lvl: Int, rp: (Expr, StringBuffer, Int) => StringBuffer): StringBuffer = {
+  case class EpsilonVariable(pos: (Int, Int)) extends Expr with Terminal with PrettyPrintable with ScalaPrintable {
+
+    def pp(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => StringBuffer, 
+      tp: (TypeTree, StringBuffer, Int) => StringBuffer,
+      dp: (Definition, StringBuffer, Int) => StringBuffer
+    ): StringBuffer = {
+      val (row, col) = pos
+      sb.append("x" + row + "_" + col)
+    }
+
+    def ppScala(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => Unit, 
+      tp: (TypeTree, StringBuffer, Int) => Unit,
+      dp: (Definition, StringBuffer, Int) => Unit
+    ): StringBuffer = {
       val (row, col) = pos
       sb.append("x" + row + "_" + col)
     }
   }
 
   //same as let, buf for mutable variable declaration
-  case class LetVar(binder: Identifier, value: Expr, body: Expr) extends Expr with BinaryExtractable with PrettyPrintable {
+  case class LetVar(binder: Identifier, value: Expr, body: Expr) extends Expr with BinaryExtractable with PrettyPrintable with ScalaPrintable {
     binder.markAsLetBinder
     val et = body.getType
     if(et != Untyped)
@@ -122,28 +214,178 @@ object Trees {
       Some((expr, body, (e: Expr, b: Expr) => LetVar(binders, e, b)))
     }
 
-    def pp(sb: StringBuffer, lvl: Int, rp: (Expr, StringBuffer, Int) => StringBuffer): StringBuffer = {
+    def pp(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => StringBuffer, 
+      tp: (TypeTree, StringBuffer, Int) => StringBuffer,
+      dp: (Definition, StringBuffer, Int) => StringBuffer
+    ): StringBuffer = {
       val LetVar(b,d,e) = this
       sb.append("(letvar (" + b + " := ");
-      rp(d, sb, lvl)
+      ep(d, sb, lvl)
       sb.append(") in\n")
       ind(sb, lvl+1)
-      rp(e, sb, lvl+1)
+      ep(e, sb, lvl+1)
       sb.append(")")
       sb
     }
+
+    def ppScala(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => Unit, 
+      tp: (TypeTree, StringBuffer, Int) => Unit,
+      dp: (Definition, StringBuffer, Int) => Unit
+    ): StringBuffer = {
+      val LetVar(b,d,e) = this
+      sb.append("locally {\n")
+      ind(sb, lvl+1)
+      sb.append("var " + b + " = ")
+      ep(d, sb, lvl+1)
+      sb.append("\n")
+      ind(sb, lvl+1)
+      ep(e, sb, lvl+1)
+      sb.append("\n")
+      ind(sb, lvl)
+      sb.append("}\n")
+      ind(sb, lvl)
+    }
+  }
+
+  case class LetDef(fd: FunDef, body: Expr) extends Expr with NAryExtractable with PrettyPrintable with ScalaPrintable {
+    val et = body.getType
+    if(et != Untyped)
+      setType(et)
+
+    def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = {
+      fd.body match {
+        case Some(b) =>
+          (fd.precondition, fd.postcondition) match {
+            case (None, None) =>
+                Some((Seq(b, body), (as: Seq[Expr]) => {
+                  val nfd = new FunDef(fd.id, fd.returnType, fd.args)
+                  nfd.fromLoop = fd.fromLoop
+                  nfd.parent = fd.parent
+                  nfd.body = Some(as(0))
+                  LetDef(nfd, as(1))
+                }))
+            case (Some(pre), None) =>
+                Some((Seq(b, body, pre), (as: Seq[Expr]) => {
+                  val nfd = new FunDef(fd.id, fd.returnType, fd.args)
+                  nfd.fromLoop = fd.fromLoop
+                  nfd.parent = fd.parent
+                  nfd.body = Some(as(0))
+                  nfd.precondition = Some(as(2))
+                  LetDef(nfd, as(1))
+                }))
+            case (None, Some(post)) =>
+                Some((Seq(b, body, post), (as: Seq[Expr]) => {
+                  val nfd = new FunDef(fd.id, fd.returnType, fd.args)
+                  nfd.fromLoop = fd.fromLoop
+                  nfd.parent = fd.parent
+                  nfd.body = Some(as(0))
+                  nfd.postcondition = Some(as(2))
+                  LetDef(nfd, as(1))
+                }))
+            case (Some(pre), Some(post)) =>
+                Some((Seq(b, body, pre, post), (as: Seq[Expr]) => {
+                  val nfd = new FunDef(fd.id, fd.returnType, fd.args)
+                  nfd.fromLoop = fd.fromLoop
+                  nfd.parent = fd.parent
+                  nfd.body = Some(as(0))
+                  nfd.precondition = Some(as(2))
+                  nfd.postcondition = Some(as(3))
+                  LetDef(nfd, as(1))
+                }))
+          }
+            
+        case None => //case no body, we still need to handle remaining cases
+          (fd.precondition, fd.postcondition) match {
+            case (None, None) =>
+                Some((Seq(body), (as: Seq[Expr]) => {
+                  val nfd = new FunDef(fd.id, fd.returnType, fd.args)
+                  nfd.fromLoop = fd.fromLoop
+                  nfd.parent = fd.parent
+                  LetDef(nfd, as(0))
+                }))
+            case (Some(pre), None) =>
+                Some((Seq(body, pre), (as: Seq[Expr]) => {
+                  val nfd = new FunDef(fd.id, fd.returnType, fd.args)
+                  nfd.fromLoop = fd.fromLoop
+                  nfd.parent = fd.parent
+                  nfd.precondition = Some(as(1))
+                  LetDef(nfd, as(0))
+                }))
+            case (None, Some(post)) =>
+                Some((Seq(body, post), (as: Seq[Expr]) => {
+                  val nfd = new FunDef(fd.id, fd.returnType, fd.args)
+                  nfd.fromLoop = fd.fromLoop
+                  nfd.parent = fd.parent
+                  nfd.postcondition = Some(as(1))
+                  LetDef(nfd, as(0))
+                }))
+            case (Some(pre), Some(post)) =>
+                Some((Seq(body, pre, post), (as: Seq[Expr]) => {
+                  val nfd = new FunDef(fd.id, fd.returnType, fd.args)
+                  nfd.fromLoop = fd.fromLoop
+                  nfd.parent = fd.parent
+                  nfd.precondition = Some(as(1))
+                  nfd.postcondition = Some(as(2))
+                  LetDef(nfd, as(0))
+                }))
+          }
+      }
+    }
+
+    def pp(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => StringBuffer, 
+      tp: (TypeTree, StringBuffer, Int) => StringBuffer,
+      dp: (Definition, StringBuffer, Int) => StringBuffer
+    ): StringBuffer = {
+      sb.append("\n")
+      dp(fd, sb, lvl+1)
+      sb.append("\n")
+      sb.append("\n")
+      ind(sb, lvl)
+      ep(body, sb, lvl)
+      sb
+    }
+
+    def ppScala(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => Unit, 
+      tp: (TypeTree, StringBuffer, Int) => Unit,
+      dp: (Definition, StringBuffer, Int) => Unit
+    ): StringBuffer = {
+      sb.append("\n")
+      dp(fd, sb, lvl+1)
+      sb.append("\n")
+      sb.append("\n")
+      ind(sb, lvl)
+      ep(body, sb, lvl)
+      sb
+    }
+
   }
 
-  case class Waypoint(i: Int, expr: Expr) extends Expr with UnaryExtractable with PrettyPrintable {
+  case class Waypoint(i: Int, expr: Expr) extends Expr with UnaryExtractable with PrettyPrintable with ScalaPrintable {
     def extract: Option[(Expr, (Expr)=>Expr)] = {
       Some((expr, (e: Expr) => Waypoint(i, e)))
     }
 
-    def pp(sb: StringBuffer, lvl: Int, rp: (Expr, StringBuffer, Int) => StringBuffer): StringBuffer = {
+    def pp(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => StringBuffer, 
+      tp: (TypeTree, StringBuffer, Int) => StringBuffer,
+      dp: (Definition, StringBuffer, Int) => StringBuffer
+    ): StringBuffer = {
       sb.append("waypoint_" + i + "(")
-      rp(expr, sb, lvl)
+      ep(expr, sb, lvl)
       sb.append(")")
     }
+
+    def ppScala(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => Unit, 
+      tp: (TypeTree, StringBuffer, Int) => Unit,
+      dp: (Definition, StringBuffer, Int) => Unit
+    ): StringBuffer = {
+      sys.error("Not Scala Code")
+    }
   }
 
   //the difference between ArrayUpdate and ArrayUpdated is that the former has a side effect while the latter is the functional version
@@ -156,12 +398,17 @@ object Trees {
       Some((Seq(t1,t2,t3), (as: Seq[Expr]) => ArrayUpdate(as(0), as(1), as(2))))
     }
 
-    def pp(sb: StringBuffer, lvl: Int, rp: (Expr, StringBuffer, Int) => StringBuffer): StringBuffer = {
-      rp(array, sb, lvl)
+    def pp(sb: StringBuffer, lvl: Int, 
+      ep: (Expr, StringBuffer, Int) => StringBuffer, 
+      tp: (TypeTree, StringBuffer, Int) => StringBuffer,
+      dp: (Definition, StringBuffer, Int) => StringBuffer
+    ): StringBuffer = {
+      ep(array, sb, lvl)
       sb.append("(")
-      rp(index, sb, lvl)
+      ep(index, sb, lvl)
       sb.append(") = ")
-      rp(newValue, sb, lvl)
+      ep(newValue, sb, lvl)
     }
   }
+
 }
-- 
GitLab