From 28b229eb4aa78af4f67ddb814226b4677d4b964c Mon Sep 17 00:00:00 2001
From: Regis Blanc <regwblanc@gmail.com>
Date: Wed, 30 Dec 2015 15:50:02 +0100
Subject: [PATCH] refactor while case in xlang

---
 .../scala/leon/purescala/Definitions.scala    |  4 +-
 .../leon/purescala/FunctionClosure.scala      |  4 +-
 .../xlang/ImperativeCodeElimination.scala     | 93 ++++++-------------
 .../xlang/valid/WhileAsFun1.scala             | 25 +++++
 .../xlang/valid/WhileAsFun2.scala             | 33 +++++++
 5 files changed, 92 insertions(+), 67 deletions(-)
 create mode 100644 src/test/resources/regression/verification/xlang/valid/WhileAsFun1.scala
 create mode 100644 src/test/resources/regression/verification/xlang/valid/WhileAsFun2.scala

diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala
index a3d00ad98..d0d127fd8 100644
--- a/src/main/scala/leon/purescala/Definitions.scala
+++ b/src/main/scala/leon/purescala/Definitions.scala
@@ -204,8 +204,8 @@ object Definitions {
   // If this class was a method. owner is the original owner of the method
   case class IsMethod(owner: ClassDef) extends FunctionFlag
   // If this function represents a loop that was there before XLangElimination
-  // Contains a copy of the original looping function
-  case class IsLoop(orig: FunDef) extends FunctionFlag
+  // Contains a link to the FunDef where the loop was defined
+  case class IsLoop(owner: FunDef) extends FunctionFlag
   // If extraction fails of the function's body fais, it is marked as abstract
   case object IsAbstract extends FunctionFlag
   // Currently, the only synthetic functions are those that calculate default values of parameters
diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala
index 65fd1de8a..92b3db79b 100644
--- a/src/main/scala/leon/purescala/FunctionClosure.scala
+++ b/src/main/scala/leon/purescala/FunctionClosure.scala
@@ -145,11 +145,11 @@ object FunctionClosure extends TransformationPhase {
     )
 
     newFd.fullBody = preMap {
-      case FunctionInvocation(tfd, args) if tfd.fd == inner =>
+      case fi@FunctionInvocation(tfd, args) if tfd.fd == inner =>
         Some(FunctionInvocation(
           newFd.typed(tfd.tps ++ tpFresh.map{ _.tp }),
           args ++ freshVals.drop(args.length).map(Variable)
-        ))
+        ).setPos(fi))
       case _ => None
     }(instBody)
 
diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
index e9802d500..210612e99 100644
--- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
+++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
@@ -129,68 +129,22 @@ object ImperativeCodeElimination extends UnitPhase[Program] {
         (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap)
  
       case wh@While(cond, body) =>
-        //TODO: rewrite by re-using the nested function transformation code
-        val (condRes, condScope, condFun) = toFunction(cond)
-        val (_, bodyScope, bodyFun) = toFunction(body)
-        val condBodyFun = condFun ++ bodyFun
-
-        val modifiedVars: Seq[Identifier] = condBodyFun.keys.toSet.intersect(varsInScope).toSeq
-
-        if(modifiedVars.isEmpty)
-          (UnitLiteral(), (b: Expr) => b, Map())
-        else {
-          val whileFunVars = modifiedVars.map(id => FreshIdentifier(id.name, id.getType))
-          val modifiedVars2WhileFunVars = modifiedVars.zip(whileFunVars).toMap
-          val whileFunValDefs = whileFunVars.map(ValDef(_))
-          val whileFunReturnType = tupleTypeWrap(whileFunVars.map(_.getType))
-          val whileFunDef = new FunDef(parent.id.freshen, Nil, whileFunValDefs, whileFunReturnType).setPos(wh)
-          whileFunDef.addFlag(IsLoop(parent))
-          
-          val whileFunCond = condScope(condRes)
-          val whileFunRecursiveCall = replaceNames(condFun,
-            bodyScope(FunctionInvocation(whileFunDef.typed, modifiedVars.map(id => condBodyFun(id).toVariable)).setPos(wh)))
-          val whileFunBaseCase =
-            tupleWrap(modifiedVars.map(id => condFun.getOrElse(id, modifiedVars2WhileFunVars(id)).toVariable))
-          val whileFunBody = replaceNames(modifiedVars2WhileFunVars, 
-            condScope(IfExpr(whileFunCond, whileFunRecursiveCall, whileFunBaseCase)))
-          whileFunDef.body = Some(whileFunBody)
-
-          val resVar = Variable(FreshIdentifier("res", whileFunReturnType))
-          val whileFunVars2ResultVars: Map[Expr, Expr] = 
-            whileFunVars.zipWithIndex.map{ case (v, i) => 
-              (v.toVariable, tupleSelect(resVar, i+1, whileFunVars.size))
-            }.toMap
-          val modifiedVars2ResultVars: Map[Expr, Expr]  = modifiedVars.map(id => 
-            (id.toVariable, whileFunVars2ResultVars(modifiedVars2WhileFunVars(id).toVariable))).toMap
-
-          //the mapping of the trivial post condition variables depends on whether the condition has had some side effect
-          val trivialPostcondition: Option[Expr] = Some(Not(replace(
-            modifiedVars.map(id => (condFun.getOrElse(id, id).toVariable, modifiedVars2ResultVars(id.toVariable))).toMap,
-            whileFunCond)))
-          val invariantPrecondition: Option[Expr] = wh.invariant.map(expr => replaceNames(modifiedVars2WhileFunVars, expr))
-          val invariantPostcondition: Option[Expr] = wh.invariant.map(expr => replace(modifiedVars2ResultVars, expr))
-          whileFunDef.precondition = invariantPrecondition
-          whileFunDef.postcondition = trivialPostcondition.map( expr => 
-            Lambda(
-              Seq(ValDef(resVar.id)), 
-              and(expr, invariantPostcondition.getOrElse(BooleanLiteral(true))).setPos(wh)
-            ).setPos(wh)
-          )
-
-          val finalVars = modifiedVars.map(_.freshen)
-          val finalScope = (body: Expr) => {
-            val tupleId = FreshIdentifier("t", whileFunReturnType)
-            LetDef(whileFunDef, Let(
-              tupleId,
-              FunctionInvocation(whileFunDef.typed, modifiedVars.map(_.toVariable)).setPos(wh),
-              finalVars.zipWithIndex.foldLeft(body) { (b, id) =>
-                Let(id._1, tupleSelect(tupleId.toVariable, id._2 + 1, finalVars.size), b)
-              }
-            ))
-          }
+        val whileFunDef = new FunDef(parent.id.freshen, Nil, Nil, UnitType).setPos(wh)
+        whileFunDef.addFlag(IsLoop(parent))
+        whileFunDef.body = Some(
+          IfExpr(cond, 
+                 Block(Seq(body), FunctionInvocation(whileFunDef.typed, Seq()).setPos(wh)),
+                 UnitLiteral()))
+        whileFunDef.precondition = wh.invariant
+        whileFunDef.postcondition = Some(
+          Lambda(
+            Seq(ValDef(FreshIdentifier("bodyRes", UnitType))),
+            and(Not(getFunctionalResult(cond)), wh.invariant.getOrElse(BooleanLiteral(true))).setPos(wh)
+          ).setPos(wh)
+        )
 
-          (UnitLiteral(), finalScope, modifiedVars.zip(finalVars).toMap)
-        }
+        val newExpr = LetDef(whileFunDef, FunctionInvocation(whileFunDef.typed, Seq()).setPos(wh)).setPos(wh)
+        toFunction(newExpr)
 
       case Block(Seq(), expr) =>
         toFunction(expr)
@@ -279,6 +233,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] {
             val modifiedVars: List[Identifier] =
               collect[Identifier]({
                 case Assignment(v, _) => Set(v)
+                case FunctionInvocation(tfd, _) => state.funDefsMapping.get(tfd.fd).map(p => p._2.toSet).getOrElse(Set())
                 case _ => Set()
               })(bd).intersect(state.varsInScope).toList
 
@@ -304,11 +259,14 @@ object ImperativeCodeElimination extends UnitPhase[Program] {
               val newReturnType = TupleType(fd.returnType :: modifiedVars.map(_.getType))
 
               val newFd = new FunDef(fd.id.freshen, fd.tparams, newParams, newReturnType).setPos(fd)
+              newFd.addFlags(fd.flags)
 
               val (fdRes, fdScope, fdFun) = 
                 toFunction(wrappedBody)(
-                  State(state.parent, Set(), 
-                        state.funDefsMapping + (fd -> ((newFd, freshVarDecls))))
+                  State(state.parent, 
+                        Set(), 
+                        state.funDefsMapping.map{case (fd, (nfd, mvs)) => (fd, (nfd, mvs.map(v => rewritingMap.getOrElse(v, v))))} + 
+                               (fd -> ((newFd, freshVarDecls))))
                 )
               val newRes = Tuple(fdRes :: freshVarDecls.map(vd => fdFun(vd).toVariable))
               val newBody = fdScope(newRes)
@@ -367,4 +325,13 @@ object ImperativeCodeElimination extends UnitPhase[Program] {
 
   def replaceNames(fun: Map[Identifier, Identifier], expr: Expr) = replaceFromIDs(fun mapValues Variable, expr)
 
+  
+  /* Extract functional result value. Useful to remove side effect from conditions when moving it to post-condition */
+  private def getFunctionalResult(expr: Expr): Expr = {
+    preMap({
+      case Block(_, res) => Some(res)
+      case _ => None
+    })(expr)
+  }
+
 }
diff --git a/src/test/resources/regression/verification/xlang/valid/WhileAsFun1.scala b/src/test/resources/regression/verification/xlang/valid/WhileAsFun1.scala
new file mode 100644
index 000000000..b81ea3859
--- /dev/null
+++ b/src/test/resources/regression/verification/xlang/valid/WhileAsFun1.scala
@@ -0,0 +1,25 @@
+/* Copyright 2009-2015 EPFL, Lausanne */
+import leon.lang._
+
+object WhileAsFun1 {
+
+
+  def counterN(n: Int): Int = {
+    require(n > 0)
+
+    var i = 0
+    def rec(): Unit = {
+      require(i >= 0 && i <= n)
+      if(i < n) {
+        i += 1
+        rec()
+      } else {
+        ()
+      }
+    } ensuring(_ => i >= 0 && i <= n && i >= n)
+    rec()
+
+    i
+  } ensuring(_ == n)
+
+}
diff --git a/src/test/resources/regression/verification/xlang/valid/WhileAsFun2.scala b/src/test/resources/regression/verification/xlang/valid/WhileAsFun2.scala
new file mode 100644
index 000000000..968aadfdb
--- /dev/null
+++ b/src/test/resources/regression/verification/xlang/valid/WhileAsFun2.scala
@@ -0,0 +1,33 @@
+/* Copyright 2009-2015 EPFL, Lausanne */
+import leon.lang._
+
+object WhileAsFun2 {
+
+
+  def counterN(n: Int): Int = {
+    require(n > 0)
+
+    var counter = 0
+
+    def inc(): Unit = {
+      counter += 1
+    }
+
+    var i = 0
+    def rec(): Unit = {
+      require(i >= 0 && counter == i && i <= n)
+      if(i < n) {
+        inc()
+        i += 1
+        rec()
+      } else {
+        ()
+      }
+    } ensuring(_ => i >= 0 && counter == i && i <= n && i >= n)
+    rec()
+
+
+    counter
+  } ensuring(_ == n)
+
+}
-- 
GitLab