From 28b229eb4aa78af4f67ddb814226b4677d4b964c Mon Sep 17 00:00:00 2001 From: Regis Blanc <regwblanc@gmail.com> Date: Wed, 30 Dec 2015 15:50:02 +0100 Subject: [PATCH] refactor while case in xlang --- .../scala/leon/purescala/Definitions.scala | 4 +- .../leon/purescala/FunctionClosure.scala | 4 +- .../xlang/ImperativeCodeElimination.scala | 93 ++++++------------- .../xlang/valid/WhileAsFun1.scala | 25 +++++ .../xlang/valid/WhileAsFun2.scala | 33 +++++++ 5 files changed, 92 insertions(+), 67 deletions(-) create mode 100644 src/test/resources/regression/verification/xlang/valid/WhileAsFun1.scala create mode 100644 src/test/resources/regression/verification/xlang/valid/WhileAsFun2.scala diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index a3d00ad98..d0d127fd8 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -204,8 +204,8 @@ object Definitions { // If this class was a method. owner is the original owner of the method case class IsMethod(owner: ClassDef) extends FunctionFlag // If this function represents a loop that was there before XLangElimination - // Contains a copy of the original looping function - case class IsLoop(orig: FunDef) extends FunctionFlag + // Contains a link to the FunDef where the loop was defined + case class IsLoop(owner: FunDef) extends FunctionFlag // If extraction fails of the function's body fais, it is marked as abstract case object IsAbstract extends FunctionFlag // Currently, the only synthetic functions are those that calculate default values of parameters diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index 65fd1de8a..92b3db79b 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -145,11 +145,11 @@ object FunctionClosure extends TransformationPhase { ) newFd.fullBody = preMap { - case FunctionInvocation(tfd, args) if tfd.fd == inner => + case fi@FunctionInvocation(tfd, args) if tfd.fd == inner => Some(FunctionInvocation( newFd.typed(tfd.tps ++ tpFresh.map{ _.tp }), args ++ freshVals.drop(args.length).map(Variable) - )) + ).setPos(fi)) case _ => None }(instBody) diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index e9802d500..210612e99 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -129,68 +129,22 @@ object ImperativeCodeElimination extends UnitPhase[Program] { (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap) case wh@While(cond, body) => - //TODO: rewrite by re-using the nested function transformation code - val (condRes, condScope, condFun) = toFunction(cond) - val (_, bodyScope, bodyFun) = toFunction(body) - val condBodyFun = condFun ++ bodyFun - - val modifiedVars: Seq[Identifier] = condBodyFun.keys.toSet.intersect(varsInScope).toSeq - - if(modifiedVars.isEmpty) - (UnitLiteral(), (b: Expr) => b, Map()) - else { - val whileFunVars = modifiedVars.map(id => FreshIdentifier(id.name, id.getType)) - val modifiedVars2WhileFunVars = modifiedVars.zip(whileFunVars).toMap - val whileFunValDefs = whileFunVars.map(ValDef(_)) - val whileFunReturnType = tupleTypeWrap(whileFunVars.map(_.getType)) - val whileFunDef = new FunDef(parent.id.freshen, Nil, whileFunValDefs, whileFunReturnType).setPos(wh) - whileFunDef.addFlag(IsLoop(parent)) - - val whileFunCond = condScope(condRes) - val whileFunRecursiveCall = replaceNames(condFun, - bodyScope(FunctionInvocation(whileFunDef.typed, modifiedVars.map(id => condBodyFun(id).toVariable)).setPos(wh))) - val whileFunBaseCase = - tupleWrap(modifiedVars.map(id => condFun.getOrElse(id, modifiedVars2WhileFunVars(id)).toVariable)) - val whileFunBody = replaceNames(modifiedVars2WhileFunVars, - condScope(IfExpr(whileFunCond, whileFunRecursiveCall, whileFunBaseCase))) - whileFunDef.body = Some(whileFunBody) - - val resVar = Variable(FreshIdentifier("res", whileFunReturnType)) - val whileFunVars2ResultVars: Map[Expr, Expr] = - whileFunVars.zipWithIndex.map{ case (v, i) => - (v.toVariable, tupleSelect(resVar, i+1, whileFunVars.size)) - }.toMap - val modifiedVars2ResultVars: Map[Expr, Expr] = modifiedVars.map(id => - (id.toVariable, whileFunVars2ResultVars(modifiedVars2WhileFunVars(id).toVariable))).toMap - - //the mapping of the trivial post condition variables depends on whether the condition has had some side effect - val trivialPostcondition: Option[Expr] = Some(Not(replace( - modifiedVars.map(id => (condFun.getOrElse(id, id).toVariable, modifiedVars2ResultVars(id.toVariable))).toMap, - whileFunCond))) - val invariantPrecondition: Option[Expr] = wh.invariant.map(expr => replaceNames(modifiedVars2WhileFunVars, expr)) - val invariantPostcondition: Option[Expr] = wh.invariant.map(expr => replace(modifiedVars2ResultVars, expr)) - whileFunDef.precondition = invariantPrecondition - whileFunDef.postcondition = trivialPostcondition.map( expr => - Lambda( - Seq(ValDef(resVar.id)), - and(expr, invariantPostcondition.getOrElse(BooleanLiteral(true))).setPos(wh) - ).setPos(wh) - ) - - val finalVars = modifiedVars.map(_.freshen) - val finalScope = (body: Expr) => { - val tupleId = FreshIdentifier("t", whileFunReturnType) - LetDef(whileFunDef, Let( - tupleId, - FunctionInvocation(whileFunDef.typed, modifiedVars.map(_.toVariable)).setPos(wh), - finalVars.zipWithIndex.foldLeft(body) { (b, id) => - Let(id._1, tupleSelect(tupleId.toVariable, id._2 + 1, finalVars.size), b) - } - )) - } + val whileFunDef = new FunDef(parent.id.freshen, Nil, Nil, UnitType).setPos(wh) + whileFunDef.addFlag(IsLoop(parent)) + whileFunDef.body = Some( + IfExpr(cond, + Block(Seq(body), FunctionInvocation(whileFunDef.typed, Seq()).setPos(wh)), + UnitLiteral())) + whileFunDef.precondition = wh.invariant + whileFunDef.postcondition = Some( + Lambda( + Seq(ValDef(FreshIdentifier("bodyRes", UnitType))), + and(Not(getFunctionalResult(cond)), wh.invariant.getOrElse(BooleanLiteral(true))).setPos(wh) + ).setPos(wh) + ) - (UnitLiteral(), finalScope, modifiedVars.zip(finalVars).toMap) - } + val newExpr = LetDef(whileFunDef, FunctionInvocation(whileFunDef.typed, Seq()).setPos(wh)).setPos(wh) + toFunction(newExpr) case Block(Seq(), expr) => toFunction(expr) @@ -279,6 +233,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] { val modifiedVars: List[Identifier] = collect[Identifier]({ case Assignment(v, _) => Set(v) + case FunctionInvocation(tfd, _) => state.funDefsMapping.get(tfd.fd).map(p => p._2.toSet).getOrElse(Set()) case _ => Set() })(bd).intersect(state.varsInScope).toList @@ -304,11 +259,14 @@ object ImperativeCodeElimination extends UnitPhase[Program] { val newReturnType = TupleType(fd.returnType :: modifiedVars.map(_.getType)) val newFd = new FunDef(fd.id.freshen, fd.tparams, newParams, newReturnType).setPos(fd) + newFd.addFlags(fd.flags) val (fdRes, fdScope, fdFun) = toFunction(wrappedBody)( - State(state.parent, Set(), - state.funDefsMapping + (fd -> ((newFd, freshVarDecls)))) + State(state.parent, + Set(), + state.funDefsMapping.map{case (fd, (nfd, mvs)) => (fd, (nfd, mvs.map(v => rewritingMap.getOrElse(v, v))))} + + (fd -> ((newFd, freshVarDecls)))) ) val newRes = Tuple(fdRes :: freshVarDecls.map(vd => fdFun(vd).toVariable)) val newBody = fdScope(newRes) @@ -367,4 +325,13 @@ object ImperativeCodeElimination extends UnitPhase[Program] { def replaceNames(fun: Map[Identifier, Identifier], expr: Expr) = replaceFromIDs(fun mapValues Variable, expr) + + /* Extract functional result value. Useful to remove side effect from conditions when moving it to post-condition */ + private def getFunctionalResult(expr: Expr): Expr = { + preMap({ + case Block(_, res) => Some(res) + case _ => None + })(expr) + } + } diff --git a/src/test/resources/regression/verification/xlang/valid/WhileAsFun1.scala b/src/test/resources/regression/verification/xlang/valid/WhileAsFun1.scala new file mode 100644 index 000000000..b81ea3859 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/WhileAsFun1.scala @@ -0,0 +1,25 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ +import leon.lang._ + +object WhileAsFun1 { + + + def counterN(n: Int): Int = { + require(n > 0) + + var i = 0 + def rec(): Unit = { + require(i >= 0 && i <= n) + if(i < n) { + i += 1 + rec() + } else { + () + } + } ensuring(_ => i >= 0 && i <= n && i >= n) + rec() + + i + } ensuring(_ == n) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/WhileAsFun2.scala b/src/test/resources/regression/verification/xlang/valid/WhileAsFun2.scala new file mode 100644 index 000000000..968aadfdb --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/WhileAsFun2.scala @@ -0,0 +1,33 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ +import leon.lang._ + +object WhileAsFun2 { + + + def counterN(n: Int): Int = { + require(n > 0) + + var counter = 0 + + def inc(): Unit = { + counter += 1 + } + + var i = 0 + def rec(): Unit = { + require(i >= 0 && counter == i && i <= n) + if(i < n) { + inc() + i += 1 + rec() + } else { + () + } + } ensuring(_ => i >= 0 && counter == i && i <= n && i >= n) + rec() + + + counter + } ensuring(_ == n) + +} -- GitLab