From da6f47b31d113ae4648ac70b97bcb8a4514bbb0a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Thu, 7 Apr 2016 15:16:11 +0200
Subject: [PATCH] introduce a global state for true non-determinism

---
 .../xlang/IntroduceGlobalStatePhase.scala     | 166 ++++++++++++++++++
 .../leon/xlang/XLangDesugaringPhase.scala     |   1 +
 2 files changed, 167 insertions(+)
 create mode 100644 src/main/scala/leon/xlang/IntroduceGlobalStatePhase.scala

diff --git a/src/main/scala/leon/xlang/IntroduceGlobalStatePhase.scala b/src/main/scala/leon/xlang/IntroduceGlobalStatePhase.scala
new file mode 100644
index 000000000..7bc975309
--- /dev/null
+++ b/src/main/scala/leon/xlang/IntroduceGlobalStatePhase.scala
@@ -0,0 +1,166 @@
+/* Copyright 2009-2015 EPFL, Lausanne */
+package leon.xlang
+
+import leon.TransformationPhase
+import leon.LeonContext
+import leon.purescala.Common._
+import leon.purescala.Definitions._
+import leon.purescala.Expressions._
+import leon.purescala.ExprOps._
+import leon.purescala.DefOps._
+import leon.purescala.Types._
+import leon.purescala.Constructors._
+import leon.purescala.Extractors._
+import leon.xlang.Expressions._
+
+object IntroduceGlobalStatePhase extends TransformationPhase {
+
+  val name = "introduce-global-state"
+  val description = "Introduce a global state passed around to all functions that depend on it"
+
+
+  private val globalStateCCD = new CaseClassDef(FreshIdentifier("GlobalState"), Seq(), None, false)
+  private val epsilonSeed = FreshIdentifier("epsilonSeed", IntegerType)
+  globalStateCCD.setFields(Seq(ValDef(epsilonSeed).setIsVar(true)))
+
+  override def apply(ctx: LeonContext, pgm: Program): Program = {
+
+    val fds = allFunDefs(pgm)
+    var updatedFunctions: Map[FunDef, FunDef] = Map()
+    
+    val statefulFunDefs = funDefsNeedingState(pgm)
+    //println("Stateful fun def: " + statefulFunDefs.map(_.id))
+
+
+    /*
+     * The first pass will introduce all new function definitions,
+     * so that in the next pass we can update function invocations
+     */
+    for {
+      fd <- fds if statefulFunDefs.contains(fd)
+    } {
+      updatedFunctions += (fd -> extendFunDefWithState(fd, globalStateCCD)(ctx))
+    }
+
+    for {
+      fd <- fds if statefulFunDefs.contains(fd)
+    } {
+      updateBody(fd, updatedFunctions)(ctx)
+    }
+
+    replaceDefsInProgram(pgm)(updatedFunctions)
+  }
+
+  private def extendFunDefWithState(fd: FunDef, stateCCD: CaseClassDef)(ctx: LeonContext): FunDef = {
+    val newParams = fd.params :+ ValDef(FreshIdentifier("state", stateCCD.typed))
+    val newFunDef = new FunDef(fd.id.freshen, fd.tparams, newParams, fd.returnType)
+    newFunDef.addFlags(fd.flags)
+    newFunDef.setPos(fd)
+    newFunDef
+  }
+
+  private def updateBody(fd: FunDef, updatedFunctions: Map[FunDef, FunDef])(ctx: LeonContext): FunDef = {
+    val nfd = updatedFunctions(fd)
+    val stateParam: ValDef = nfd.params.last
+
+    nfd.body = fd.body.map(body => postMap{
+      case fi@FunctionInvocation(efd, args) if updatedFunctions.contains(efd.fd) => {
+        Some(FunctionInvocation(updatedFunctions(efd.fd).typed(efd.tps), args :+ stateParam.id.toVariable))
+      }
+      case eps@Epsilon(pred, _) => {
+        val nextEpsilonSeed = Plus(
+                                CaseClassSelector(globalStateCCD.typed, stateParam.id.toVariable, epsilonSeed),
+                                InfiniteIntegerLiteral(1))
+        Some(Block(Seq(FieldAssignment(stateParam.id.toVariable, epsilonSeed, nextEpsilonSeed)), eps))
+      }
+      case _ => None
+    }(body))
+
+    nfd.precondition = fd.precondition
+    nfd.postcondition = fd.postcondition
+
+    nfd
+  }
+
+  def funDefsNeedingState(pgm: Program): Set[FunDef] = {
+
+    def compute(body: Expr): Boolean = exists{ 
+      case Epsilon(_, _) => true 
+      case _ => false
+    }(body)
+
+    def combine(b1: Boolean, b2: Boolean) = b1 || b2
+
+    programFixpoint(pgm, compute, combine).filter(p => p._2).keySet
+  }
+
+  /*
+   * compute some A for each function in the program, including any nested
+   * functions (LetDef). The A value is transitive, combined with the A value
+   * of all called functions in the body. The combine function combines the current
+   * value computed with a new value from a function invocation.
+   */
+  private def programFixpoint[A](pgm: Program, compute: (Expr) => A, combine: (A, A) => A): Map[FunDef, A] = {
+
+    //currently computed results (incremental)
+    var res: Map[FunDef, A] = Map()
+    //missing dependencies for a function
+    var missingDependencies: Map[FunDef, Set[FunctionInvocation]] = Map()
+
+    def fullyComputed(fd: FunDef): Boolean = !missingDependencies.isDefinedAt(fd)
+
+    for {
+      fd <- allFunDefs(pgm)
+    } {
+      fd.body match {
+        case None =>
+          () //TODO: maybe some default value?  res += (fd -> Set())
+        case Some(body) => {
+          res = res + (fd -> compute(body))
+          val missingCalls: Set[FunctionInvocation] = functionCallsOf(body).filterNot(fi => fi.tfd.fd == fd)
+          if(missingCalls.nonEmpty)
+            missingDependencies += (fd -> missingCalls)
+        }
+      }
+    }
+
+    def rec(): Unit = {
+      val previousMissingDependencies = missingDependencies
+
+      for{ (fd, calls) <- missingDependencies } {
+        var newMissingCalls: Set[FunctionInvocation] = calls
+        for(fi <- calls) {
+          val newA = res.get(fi.tfd.fd).map(ra => combine(res(fd), ra)).getOrElse(res(fd))
+          res += (fd -> newA)
+
+          if(fullyComputed(fi.tfd.fd)) {
+            newMissingCalls -= fi
+          }
+        }
+        if(newMissingCalls.isEmpty)
+          missingDependencies = missingDependencies - fd
+        else
+          missingDependencies += (fd -> newMissingCalls)
+      }
+
+      if(missingDependencies != previousMissingDependencies) {
+        rec()
+      }
+    }
+
+    rec()
+    res
+  }
+
+
+  /*
+   * returns all fun def in the program, including local definitions inside
+   * other functions (LetDef).
+   */
+  private def allFunDefs(pgm: Program): Seq[FunDef] =
+      pgm.definedFunctions.flatMap(fd => 
+        fd.body.toSet.flatMap((bd: Expr) =>
+          nestedFunDefsOf(bd)) + fd)
+
+
+}
diff --git a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala b/src/main/scala/leon/xlang/XLangDesugaringPhase.scala
index e3325e3a9..b00935fcd 100644
--- a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala
+++ b/src/main/scala/leon/xlang/XLangDesugaringPhase.scala
@@ -17,6 +17,7 @@ object XLangDesugaringPhase extends LeonPhase[Program, Program] {
       PrintTreePhase(title).when(ctx.reporter.isDebugEnabled(DebugSectionTrees))
 
     val phases =
+      IntroduceGlobalStatePhase andThen
       AntiAliasingPhase andThen
       debugTrees("Program after anti-aliasing") andThen
       EpsilonElimination andThen
-- 
GitLab