Skip to content
Snippets Groups Projects
Commit da6f47b3 authored by Régis Blanc's avatar Régis Blanc
Browse files

introduce a global state for true non-determinism

parent c85e1da2
No related branches found
No related tags found
No related merge requests found
/* Copyright 2009-2015 EPFL, Lausanne */
package leon.xlang
import leon.TransformationPhase
import leon.LeonContext
import leon.purescala.Common._
import leon.purescala.Definitions._
import leon.purescala.Expressions._
import leon.purescala.ExprOps._
import leon.purescala.DefOps._
import leon.purescala.Types._
import leon.purescala.Constructors._
import leon.purescala.Extractors._
import leon.xlang.Expressions._
object IntroduceGlobalStatePhase extends TransformationPhase {
val name = "introduce-global-state"
val description = "Introduce a global state passed around to all functions that depend on it"
private val globalStateCCD = new CaseClassDef(FreshIdentifier("GlobalState"), Seq(), None, false)
private val epsilonSeed = FreshIdentifier("epsilonSeed", IntegerType)
globalStateCCD.setFields(Seq(ValDef(epsilonSeed).setIsVar(true)))
override def apply(ctx: LeonContext, pgm: Program): Program = {
val fds = allFunDefs(pgm)
var updatedFunctions: Map[FunDef, FunDef] = Map()
val statefulFunDefs = funDefsNeedingState(pgm)
//println("Stateful fun def: " + statefulFunDefs.map(_.id))
/*
* The first pass will introduce all new function definitions,
* so that in the next pass we can update function invocations
*/
for {
fd <- fds if statefulFunDefs.contains(fd)
} {
updatedFunctions += (fd -> extendFunDefWithState(fd, globalStateCCD)(ctx))
}
for {
fd <- fds if statefulFunDefs.contains(fd)
} {
updateBody(fd, updatedFunctions)(ctx)
}
replaceDefsInProgram(pgm)(updatedFunctions)
}
private def extendFunDefWithState(fd: FunDef, stateCCD: CaseClassDef)(ctx: LeonContext): FunDef = {
val newParams = fd.params :+ ValDef(FreshIdentifier("state", stateCCD.typed))
val newFunDef = new FunDef(fd.id.freshen, fd.tparams, newParams, fd.returnType)
newFunDef.addFlags(fd.flags)
newFunDef.setPos(fd)
newFunDef
}
private def updateBody(fd: FunDef, updatedFunctions: Map[FunDef, FunDef])(ctx: LeonContext): FunDef = {
val nfd = updatedFunctions(fd)
val stateParam: ValDef = nfd.params.last
nfd.body = fd.body.map(body => postMap{
case fi@FunctionInvocation(efd, args) if updatedFunctions.contains(efd.fd) => {
Some(FunctionInvocation(updatedFunctions(efd.fd).typed(efd.tps), args :+ stateParam.id.toVariable))
}
case eps@Epsilon(pred, _) => {
val nextEpsilonSeed = Plus(
CaseClassSelector(globalStateCCD.typed, stateParam.id.toVariable, epsilonSeed),
InfiniteIntegerLiteral(1))
Some(Block(Seq(FieldAssignment(stateParam.id.toVariable, epsilonSeed, nextEpsilonSeed)), eps))
}
case _ => None
}(body))
nfd.precondition = fd.precondition
nfd.postcondition = fd.postcondition
nfd
}
def funDefsNeedingState(pgm: Program): Set[FunDef] = {
def compute(body: Expr): Boolean = exists{
case Epsilon(_, _) => true
case _ => false
}(body)
def combine(b1: Boolean, b2: Boolean) = b1 || b2
programFixpoint(pgm, compute, combine).filter(p => p._2).keySet
}
/*
* compute some A for each function in the program, including any nested
* functions (LetDef). The A value is transitive, combined with the A value
* of all called functions in the body. The combine function combines the current
* value computed with a new value from a function invocation.
*/
private def programFixpoint[A](pgm: Program, compute: (Expr) => A, combine: (A, A) => A): Map[FunDef, A] = {
//currently computed results (incremental)
var res: Map[FunDef, A] = Map()
//missing dependencies for a function
var missingDependencies: Map[FunDef, Set[FunctionInvocation]] = Map()
def fullyComputed(fd: FunDef): Boolean = !missingDependencies.isDefinedAt(fd)
for {
fd <- allFunDefs(pgm)
} {
fd.body match {
case None =>
() //TODO: maybe some default value? res += (fd -> Set())
case Some(body) => {
res = res + (fd -> compute(body))
val missingCalls: Set[FunctionInvocation] = functionCallsOf(body).filterNot(fi => fi.tfd.fd == fd)
if(missingCalls.nonEmpty)
missingDependencies += (fd -> missingCalls)
}
}
}
def rec(): Unit = {
val previousMissingDependencies = missingDependencies
for{ (fd, calls) <- missingDependencies } {
var newMissingCalls: Set[FunctionInvocation] = calls
for(fi <- calls) {
val newA = res.get(fi.tfd.fd).map(ra => combine(res(fd), ra)).getOrElse(res(fd))
res += (fd -> newA)
if(fullyComputed(fi.tfd.fd)) {
newMissingCalls -= fi
}
}
if(newMissingCalls.isEmpty)
missingDependencies = missingDependencies - fd
else
missingDependencies += (fd -> newMissingCalls)
}
if(missingDependencies != previousMissingDependencies) {
rec()
}
}
rec()
res
}
/*
* returns all fun def in the program, including local definitions inside
* other functions (LetDef).
*/
private def allFunDefs(pgm: Program): Seq[FunDef] =
pgm.definedFunctions.flatMap(fd =>
fd.body.toSet.flatMap((bd: Expr) =>
nestedFunDefsOf(bd)) + fd)
}
...@@ -17,6 +17,7 @@ object XLangDesugaringPhase extends LeonPhase[Program, Program] { ...@@ -17,6 +17,7 @@ object XLangDesugaringPhase extends LeonPhase[Program, Program] {
PrintTreePhase(title).when(ctx.reporter.isDebugEnabled(DebugSectionTrees)) PrintTreePhase(title).when(ctx.reporter.isDebugEnabled(DebugSectionTrees))
val phases = val phases =
IntroduceGlobalStatePhase andThen
AntiAliasingPhase andThen AntiAliasingPhase andThen
debugTrees("Program after anti-aliasing") andThen debugTrees("Program after anti-aliasing") andThen
EpsilonElimination andThen EpsilonElimination andThen
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment