From cbf033c50bc09189c7d85932afb79a417934e754 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Thu, 25 Jun 2015 14:08:29 +0200
Subject: [PATCH] Fix XLang to use flags

---
 .../xlang/ImperativeCodeElimination.scala     | 10 ++--
 .../scala/leon/xlang/XLangAnalysisPhase.scala | 48 +++++++++----------
 2 files changed, 26 insertions(+), 32 deletions(-)

diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
index 5abf808f4..9678009e6 100644
--- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
+++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
@@ -12,18 +12,16 @@ import leon.purescala.ExprOps._
 import leon.purescala.TypeOps._
 import leon.xlang.Expressions._
 
-object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef])] {
+object ImperativeCodeElimination extends TransformationPhase {
 
   val name = "Imperative Code Elimination"
   val description = "Transform imperative constructs into purely functional code"
 
   private var varInScope = Set[Identifier]()
   private var parent: FunDef = null //the enclosing fundef
-  private var wasLoop: Set[FunDef] = null //record FunDef that are the transformation of loops
 
-  def run(ctx: LeonContext)(pgm: Program): (Program, Set[FunDef]) = {
+  def apply(ctx: LeonContext, pgm: Program): Program = {
     varInScope = Set()
-    wasLoop = Set()
     parent = null
 
     val allFuns = pgm.definedFunctions
@@ -35,7 +33,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef
       val (res, scope, _) = toFunction(body)
       fd.body = Some(scope(res))
     }
-    (pgm, wasLoop)
+    pgm
   }
 
   //return a "scope" consisting of purely functional code that defines potentially needed 
@@ -153,7 +151,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef
           val whileFunValDefs = whileFunVars.map(ValDef(_))
           val whileFunReturnType = tupleTypeWrap(whileFunVars.map(_.getType))
           val whileFunDef = new FunDef(FreshIdentifier(parent.id.name), Nil, whileFunReturnType, whileFunValDefs).setPos(wh)
-          wasLoop += whileFunDef
+          whileFunDef.addFlag(IsLoop(parent))
           
           val whileFunCond = condScope(condRes)
           val whileFunRecursiveCall = replaceNames(condFun,
diff --git a/src/main/scala/leon/xlang/XLangAnalysisPhase.scala b/src/main/scala/leon/xlang/XLangAnalysisPhase.scala
index 1cd0961bc..5a0f24812 100644
--- a/src/main/scala/leon/xlang/XLangAnalysisPhase.scala
+++ b/src/main/scala/leon/xlang/XLangAnalysisPhase.scala
@@ -14,6 +14,8 @@ object XLangAnalysisPhase extends LeonPhase[Program, VerificationReport] {
   val description = "apply analysis on xlang"
 
   object VCXLangKinds {
+    // TODO: something of this sort should be included
+    // case object InvariantEntry extends VCKind("invariant init",           "inv. init.")
     case object InvariantPost extends VCKind("invariant postcondition", "inv. post.")
     case object InvariantInd  extends VCKind("invariant inductive",     "inv. ind.")
   }
@@ -22,19 +24,13 @@ object XLangAnalysisPhase extends LeonPhase[Program, VerificationReport] {
 
     ArrayTransformation(ctx, pgm) // In-place
     EpsilonElimination(ctx, pgm)  // In-place
-    val (pgm1, wasLoop) = ImperativeCodeElimination.run(ctx)(pgm)
+    val pgm1 = ImperativeCodeElimination.run(ctx)(pgm)
     val pgm2 = purescala.FunctionClosure.run(ctx)(pgm1)
 
     if (ctx.reporter.isDebugEnabled(DebugSectionTrees)) {
       PrintTreePhase("Program after xlang transformations").run(ctx)(pgm2)
     }
 
-    def functionWasLoop(fd: FunDef): Boolean = fd.flags.collectFirst{ case IsLoop(fd) => fd } match {
-      case Some(nested) => // could have been a LetDef originally
-        wasLoop.contains(nested)
-      case _ => false //meaning, this was a top level function
-    }
-
     val subFunctionsOf = Map[FunDef, Set[FunDef]]().withDefaultValue(Set())
 
     val newOptions = ctx.options map {
@@ -57,10 +53,10 @@ object XLangAnalysisPhase extends LeonPhase[Program, VerificationReport] {
     }
 
     val vr = AnalysisPhase.run(ctx.copy(options = newOptions))(pgm2)
-    completeVerificationReport(vr, functionWasLoop)
+    completeVerificationReport(vr)
   }
 
-  def completeVerificationReport(vr: VerificationReport, functionWasLoop: FunDef => Boolean): VerificationReport = {
+  def completeVerificationReport(vr: VerificationReport): VerificationReport = {
 
     //this is enough to convert invariant postcondition and inductive conditions. However the initial validity
     //of the invariant (before entering the loop) will still appear as a regular function precondition
@@ -69,24 +65,24 @@ object XLangAnalysisPhase extends LeonPhase[Program, VerificationReport] {
     //precondition and a function invocation precondition
 
     val newResults = for ((vc, ovr) <- vr.results) yield {
-      if(functionWasLoop(vc.fd)) {
-        val nvc = VC(vc.condition, 
-                     vc.fd,
-                     //vc.fd.owner match {
-                     //  case Some(fd: FunDef) => fd
-                     //  case _ => vc.fd
-                     //},
-                     vc.kind.underlying match {
-                       case VCKinds.Postcondition => VCXLangKinds.InvariantPost
-                       case VCKinds.Precondition => VCXLangKinds.InvariantInd
-                       case _ => vc.kind
-                     },
-                     vc.tactic).setPos(vc.getPos)
-
-        nvc -> ovr
-      } else {
-        vc -> ovr
+      val (vcKind, fd) = vc.fd.flags.collectFirst { case IsLoop(orig) => orig } match {
+        case None => (vc.kind, vc.fd)
+        case Some(owner) => (vc.kind.underlying match {
+          case VCKinds.Precondition => VCXLangKinds.InvariantInd
+          case VCKinds.Postcondition => VCXLangKinds.InvariantPost
+          case _ => vc.kind
+        }, owner)
       }
+
+      val nvc = VC(
+        vc.condition,
+        fd,
+        vcKind,
+        vc.tactic
+      ).setPos(vc.getPos)
+
+      nvc -> ovr
+
     }
 
     VerificationReport(newResults)
-- 
GitLab