diff --git a/src/main/scala/leon/xlang/AntiAliasingPhase.scala b/src/main/scala/leon/xlang/AntiAliasingPhase.scala index 2b7176fb3dffeeaa4f2c42b8d319176b962265d2..baa8a0d2f9c05791e13a8c3639293688feeaeb1c 100644 --- a/src/main/scala/leon/xlang/AntiAliasingPhase.scala +++ b/src/main/scala/leon/xlang/AntiAliasingPhase.scala @@ -15,6 +15,7 @@ import leon.purescala.DependencyFinder import leon.purescala.DefinitionTransformer import leon.utils.Bijection import leon.xlang.Expressions._ +import leon.xlang.ExprOps._ object AntiAliasingPhase extends TransformationPhase { @@ -210,7 +211,7 @@ object AntiAliasingPhase extends TransformationPhase { val newBody = fd.body.map(body => { - val freshBody = replaceFromIDs(rewritingMap.map(p => (p._1, p._2.toVariable)), body) + val freshBody = rewriteIDs(rewritingMap, body) val explicitBody = makeSideEffectsExplicit(freshBody, fd, freshLocalVars, effects, updatedFunDefs, varsInScope)(ctx) //only now we rewrite function parameters that changed names when the new function was introduced @@ -352,7 +353,7 @@ object AntiAliasingPhase extends TransformationPhase { } //we need to replace local fundef by the new updated fun defs. - case l@LetDef(fds, body) => { + case l@LetDef(fds, body) => { //this might be traversed several time in case of doubly nested fundef, //so we need to ignore the second times by checking if updatedFunDefs //contains a mapping or not diff --git a/src/main/scala/leon/xlang/ExprOps.scala b/src/main/scala/leon/xlang/ExprOps.scala index e4680286f9f52c1462f7510433606b93c862f8e9..35d7914e241d5dbf2bdb26e10bb2b249dfb0b336 100644 --- a/src/main/scala/leon/xlang/ExprOps.scala +++ b/src/main/scala/leon/xlang/ExprOps.scala @@ -6,6 +6,7 @@ package xlang import purescala.Expressions._ import xlang.Expressions._ import purescala.ExprOps._ +import purescala.Common._ object ExprOps { @@ -38,5 +39,14 @@ object ExprOps { None })(expr) } + + def rewriteIDs(substs: Map[Identifier, Identifier], expr: Expr) : Expr = { + postMap({ + case Assignment(i, v) => substs.get(i).map(ni => Assignment(ni, v)) + case FieldAssignment(o, i, v) => substs.get(i).map(ni => FieldAssignment(o, ni, v)) + case Variable(i) => substs.get(i).map(ni => Variable(ni)) + case _ => None + })(expr) + } } diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation3.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation3.scala new file mode 100644 index 0000000000000000000000000000000000000000..e4f149333b03687a0532ab0d5d88e51b0cd8d5bf --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation3.scala @@ -0,0 +1,20 @@ +object NestedFunParamsMutation3 { + + case class Counter(var i: BigInt) { + def reset() = { + i = 0 + } + } + + + def main(c: Counter): Unit = { + + def sub(): Unit = { + c.reset() + } + sub() + assert(c.i == 0) + } + +} +