From da6f47b31d113ae4648ac70b97bcb8a4514bbb0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com> Date: Thu, 7 Apr 2016 15:16:11 +0200 Subject: [PATCH] introduce a global state for true non-determinism --- .../xlang/IntroduceGlobalStatePhase.scala | 166 ++++++++++++++++++ .../leon/xlang/XLangDesugaringPhase.scala | 1 + 2 files changed, 167 insertions(+) create mode 100644 src/main/scala/leon/xlang/IntroduceGlobalStatePhase.scala diff --git a/src/main/scala/leon/xlang/IntroduceGlobalStatePhase.scala b/src/main/scala/leon/xlang/IntroduceGlobalStatePhase.scala new file mode 100644 index 000000000..7bc975309 --- /dev/null +++ b/src/main/scala/leon/xlang/IntroduceGlobalStatePhase.scala @@ -0,0 +1,166 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ +package leon.xlang + +import leon.TransformationPhase +import leon.LeonContext +import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.purescala.Expressions._ +import leon.purescala.ExprOps._ +import leon.purescala.DefOps._ +import leon.purescala.Types._ +import leon.purescala.Constructors._ +import leon.purescala.Extractors._ +import leon.xlang.Expressions._ + +object IntroduceGlobalStatePhase extends TransformationPhase { + + val name = "introduce-global-state" + val description = "Introduce a global state passed around to all functions that depend on it" + + + private val globalStateCCD = new CaseClassDef(FreshIdentifier("GlobalState"), Seq(), None, false) + private val epsilonSeed = FreshIdentifier("epsilonSeed", IntegerType) + globalStateCCD.setFields(Seq(ValDef(epsilonSeed).setIsVar(true))) + + override def apply(ctx: LeonContext, pgm: Program): Program = { + + val fds = allFunDefs(pgm) + var updatedFunctions: Map[FunDef, FunDef] = Map() + + val statefulFunDefs = funDefsNeedingState(pgm) + //println("Stateful fun def: " + statefulFunDefs.map(_.id)) + + + /* + * The first pass will introduce all new function definitions, + * so that in the next pass we can update function invocations + */ + for { + fd <- fds if statefulFunDefs.contains(fd) + } { + updatedFunctions += (fd -> extendFunDefWithState(fd, globalStateCCD)(ctx)) + } + + for { + fd <- fds if statefulFunDefs.contains(fd) + } { + updateBody(fd, updatedFunctions)(ctx) + } + + replaceDefsInProgram(pgm)(updatedFunctions) + } + + private def extendFunDefWithState(fd: FunDef, stateCCD: CaseClassDef)(ctx: LeonContext): FunDef = { + val newParams = fd.params :+ ValDef(FreshIdentifier("state", stateCCD.typed)) + val newFunDef = new FunDef(fd.id.freshen, fd.tparams, newParams, fd.returnType) + newFunDef.addFlags(fd.flags) + newFunDef.setPos(fd) + newFunDef + } + + private def updateBody(fd: FunDef, updatedFunctions: Map[FunDef, FunDef])(ctx: LeonContext): FunDef = { + val nfd = updatedFunctions(fd) + val stateParam: ValDef = nfd.params.last + + nfd.body = fd.body.map(body => postMap{ + case fi@FunctionInvocation(efd, args) if updatedFunctions.contains(efd.fd) => { + Some(FunctionInvocation(updatedFunctions(efd.fd).typed(efd.tps), args :+ stateParam.id.toVariable)) + } + case eps@Epsilon(pred, _) => { + val nextEpsilonSeed = Plus( + CaseClassSelector(globalStateCCD.typed, stateParam.id.toVariable, epsilonSeed), + InfiniteIntegerLiteral(1)) + Some(Block(Seq(FieldAssignment(stateParam.id.toVariable, epsilonSeed, nextEpsilonSeed)), eps)) + } + case _ => None + }(body)) + + nfd.precondition = fd.precondition + nfd.postcondition = fd.postcondition + + nfd + } + + def funDefsNeedingState(pgm: Program): Set[FunDef] = { + + def compute(body: Expr): Boolean = exists{ + case Epsilon(_, _) => true + case _ => false + }(body) + + def combine(b1: Boolean, b2: Boolean) = b1 || b2 + + programFixpoint(pgm, compute, combine).filter(p => p._2).keySet + } + + /* + * compute some A for each function in the program, including any nested + * functions (LetDef). The A value is transitive, combined with the A value + * of all called functions in the body. The combine function combines the current + * value computed with a new value from a function invocation. + */ + private def programFixpoint[A](pgm: Program, compute: (Expr) => A, combine: (A, A) => A): Map[FunDef, A] = { + + //currently computed results (incremental) + var res: Map[FunDef, A] = Map() + //missing dependencies for a function + var missingDependencies: Map[FunDef, Set[FunctionInvocation]] = Map() + + def fullyComputed(fd: FunDef): Boolean = !missingDependencies.isDefinedAt(fd) + + for { + fd <- allFunDefs(pgm) + } { + fd.body match { + case None => + () //TODO: maybe some default value? res += (fd -> Set()) + case Some(body) => { + res = res + (fd -> compute(body)) + val missingCalls: Set[FunctionInvocation] = functionCallsOf(body).filterNot(fi => fi.tfd.fd == fd) + if(missingCalls.nonEmpty) + missingDependencies += (fd -> missingCalls) + } + } + } + + def rec(): Unit = { + val previousMissingDependencies = missingDependencies + + for{ (fd, calls) <- missingDependencies } { + var newMissingCalls: Set[FunctionInvocation] = calls + for(fi <- calls) { + val newA = res.get(fi.tfd.fd).map(ra => combine(res(fd), ra)).getOrElse(res(fd)) + res += (fd -> newA) + + if(fullyComputed(fi.tfd.fd)) { + newMissingCalls -= fi + } + } + if(newMissingCalls.isEmpty) + missingDependencies = missingDependencies - fd + else + missingDependencies += (fd -> newMissingCalls) + } + + if(missingDependencies != previousMissingDependencies) { + rec() + } + } + + rec() + res + } + + + /* + * returns all fun def in the program, including local definitions inside + * other functions (LetDef). + */ + private def allFunDefs(pgm: Program): Seq[FunDef] = + pgm.definedFunctions.flatMap(fd => + fd.body.toSet.flatMap((bd: Expr) => + nestedFunDefsOf(bd)) + fd) + + +} diff --git a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala b/src/main/scala/leon/xlang/XLangDesugaringPhase.scala index e3325e3a9..b00935fcd 100644 --- a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala +++ b/src/main/scala/leon/xlang/XLangDesugaringPhase.scala @@ -17,6 +17,7 @@ object XLangDesugaringPhase extends LeonPhase[Program, Program] { PrintTreePhase(title).when(ctx.reporter.isDebugEnabled(DebugSectionTrees)) val phases = + IntroduceGlobalStatePhase andThen AntiAliasingPhase andThen debugTrees("Program after anti-aliasing") andThen EpsilonElimination andThen -- GitLab