diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 9724d846cedbcc10b2c22b26cd095832b53e1b27..4f7bc819650c5316ab4dc3ff3bf1cf1203cb9a1d 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -667,6 +667,9 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { case (e, p) => mapForPattern(e, p) }.toMap + case InstanceOfPattern(b, ct) => + bindIn(b, Some(ct)) + case other => bindIn(other.binder) } diff --git a/src/main/scala/leon/xlang/AntiAliasingPhase.scala b/src/main/scala/leon/xlang/AntiAliasingPhase.scala index 0cfc68d50381276e7c3ed5c28c2566086f8cb315..0a7b0c56316365a91414615b9814bde7f8f4bca9 100644 --- a/src/main/scala/leon/xlang/AntiAliasingPhase.scala +++ b/src/main/scala/leon/xlang/AntiAliasingPhase.scala @@ -156,97 +156,116 @@ object AntiAliasingPhase extends TransformationPhase { private def makeSideEffectsExplicit (body: Expr, aliasedParams: Seq[Identifier], effects: Effects, updatedFunDefs: Map[FunDef, FunDef], varsInScope: Map[FunDef, Set[Identifier]]) (ctx: LeonContext): Expr = { - preMapWithContext[Set[Identifier]]((expr, bindings) => expr match { + preMapWithContext[(Set[Identifier], Map[Identifier, Expr])]((expr, context) => { + val bindings = context._1 + val rewritings = context._2 + expr match { + + case l@Let(id, IsTyped(v, tpe), b) if isMutableType(tpe) => { + val varDecl = LetVar(id, v, b).setPos(l) + (Some(varDecl), (bindings + id, rewritings)) + } - case up@ArrayUpdate(a, i, v) => { - val ra@Variable(id) = a - if(bindings.contains(id)) - (Some(Assignment(id, ArrayUpdated(ra, i, v).setPos(up)).setPos(up)), bindings) - else - (None, bindings) - } + case l@LetVar(id, IsTyped(v, tpe), b) if isMutableType(tpe) => { + (None, (bindings + id, rewritings)) + } - case as@FieldAssignment(o, id, v) => { - findReceiverId(o) match { - case None => - ctx.reporter.fatalError(as.getPos, "Unsupported form of field assignment: " + as) - case Some(oid) => { - if(bindings.contains(oid)) - (Some(Assignment(oid, deepCopy(o, id, v))), bindings) - else - (None, bindings) - } + case m@MatchExpr(scrut, cses) if isMutableType(scrut.getType) => { + + val tmp: Map[Identifier, Expr] = cses.flatMap{ case MatchCase(pattern, guard, rhs) => { + mapForPattern(scrut, pattern) + //val binder = pattern.binder.get + //binder -> scrut + }}.toMap + + (None, (bindings, rewritings ++ tmp)) } - } - case l@Let(id, IsTyped(v, tpe), b) if isMutableType(tpe) => { - val varDecl = LetVar(id, v, b).setPos(l) - (Some(varDecl), bindings + id) - } + case up@ArrayUpdate(a, i, v) => { + val ra@Variable(id) = a + if(bindings.contains(id)) + (Some(Assignment(id, ArrayUpdated(ra, i, v).setPos(up)).setPos(up)), context) + else + (None, context) + } - case l@LetVar(id, IsTyped(v, tpe), b) if isMutableType(tpe) => { - (None, bindings + id) - } + case as@FieldAssignment(o, id, v) => { + val so = replaceFromIDs(rewritings, o) + findReceiverId(so) match { + case None => + ctx.reporter.fatalError(as.getPos, "Unsupported form of field assignment: " + as) + case Some(oid) => { + if(bindings.contains(oid)) + (Some(Assignment(oid, deepCopy(so, id, v))), context) + else + (None, context) + } + } + } - //we need to replace local fundef by the new updated fun defs. - case l@LetDef(fds, body) => { - //this might be traversed several time in case of doubly nested fundef, - //so we need to ignore the second times by checking if updatedFunDefs - //contains a mapping or not - val nfds = fds.map(fd => updatedFunDefs.get(fd).getOrElse(fd)) - (Some(LetDef(nfds, body).copiedFrom(l)), bindings) - } + //we need to replace local fundef by the new updated fun defs. + case l@LetDef(fds, body) => { + //this might be traversed several time in case of doubly nested fundef, + //so we need to ignore the second times by checking if updatedFunDefs + //contains a mapping or not + val nfds = fds.map(fd => updatedFunDefs.get(fd).getOrElse(fd)) + (Some(LetDef(nfds, body).copiedFrom(l)), context) + } - case fi@FunctionInvocation(fd, args) => { - - val vis: Set[Identifier] = varsInScope.get(fd.fd).getOrElse(Set()) - args.find({ - case Variable(id) => vis.contains(id) - case _ => false - }).foreach(aliasedArg => - ctx.reporter.fatalError(aliasedArg.getPos, "Illegal passing of aliased parameter: " + aliasedArg)) - - updatedFunDefs.get(fd.fd) match { - case None => (None, bindings) - case Some(nfd) => { - val nfi = FunctionInvocation(nfd.typed(fd.tps), args).copiedFrom(fi) - val fiEffects = effects.getOrElse(fd.fd, Set()) - if(fiEffects.nonEmpty) { - val modifiedArgs: Seq[(Identifier, Expr)] =// functionInvocationEffects(fi, fiEffects) - args.zipWithIndex.filter{ case (arg, i) => fiEffects.contains(i) } - .map(arg => (findReceiverId(arg._1).get, arg._1)) - - val duplicatedParams = modifiedArgs.diff(modifiedArgs.distinct).distinct - if(duplicatedParams.nonEmpty) - ctx.reporter.fatalError(fi.getPos, "Illegal passing of aliased parameter: " + duplicatedParams.head) - - val freshRes = FreshIdentifier("res", nfd.typed(fd.tps).returnType) - - val extractResults = Block( - modifiedArgs.zipWithIndex.map{ case ((id, expr), index) => { - val resSelect = TupleSelect(freshRes.toVariable, index + 2) - expr match { - case CaseClassSelector(_, obj, mid) => - Assignment(id, deepCopy(obj, mid, resSelect)) - case _ => - Assignment(id, resSelect) - } - }}, - TupleSelect(freshRes.toVariable, 1)) - - - val newExpr = Let(freshRes, nfi, extractResults) - (Some(newExpr), bindings) - } else { - (Some(nfi), bindings) + case fi@FunctionInvocation(fd, args) => { + + val vis: Set[Identifier] = varsInScope.get(fd.fd).getOrElse(Set()) + args.find({ + case Variable(id) => vis.contains(id) + case _ => false + }).foreach(aliasedArg => + ctx.reporter.fatalError(aliasedArg.getPos, "Illegal passing of aliased parameter: " + aliasedArg)) + + updatedFunDefs.get(fd.fd) match { + case None => (None, context) + case Some(nfd) => { + val nfi = FunctionInvocation(nfd.typed(fd.tps), args.map(arg => replaceFromIDs(rewritings, arg))).copiedFrom(fi) + val fiEffects = effects.getOrElse(fd.fd, Set()) + if(fiEffects.nonEmpty) { + val modifiedArgs: Seq[(Identifier, Expr)] = + args.zipWithIndex.filter{ case (arg, i) => fiEffects.contains(i) } + .map(arg => { + val rArg = replaceFromIDs(rewritings, arg._1) + (findReceiverId(rArg).get, rArg) + }) + + val duplicatedParams = modifiedArgs.diff(modifiedArgs.distinct).distinct + if(duplicatedParams.nonEmpty) + ctx.reporter.fatalError(fi.getPos, "Illegal passing of aliased parameter: " + duplicatedParams.head) + + val freshRes = FreshIdentifier("res", nfd.typed(fd.tps).returnType) + + val extractResults = Block( + modifiedArgs.zipWithIndex.map{ case ((id, expr), index) => { + val resSelect = TupleSelect(freshRes.toVariable, index + 2) + expr match { + case CaseClassSelector(_, obj, mid) => + Assignment(id, deepCopy(obj, mid, resSelect)) + case _ => + Assignment(id, resSelect) + } + }}, + TupleSelect(freshRes.toVariable, 1)) + + + val newExpr = Let(freshRes, nfi, extractResults) + (Some(newExpr), context) + } else { + (Some(nfi), context) + } } } } - } - case _ => (None, bindings) + case _ => (None, context) + } - })(body, aliasedParams.toSet) + })(body, (aliasedParams.toSet, Map())) } //for each fundef, the set of modified params (by index) @@ -273,7 +292,8 @@ object AntiAliasingPhase extends TransformationPhase { effects += (fd -> Set()) case Some(body) => { val mutableParams = fd.params.filter(vd => isMutableType(vd.getType)) - val mutatedParams = mutableParams.filter(vd => exists(expr => isMutationOf(expr, vd.id))(body)) + val localAliases: Map[ValDef, Set[Identifier]] = mutableParams.map(vd => (vd, computeLocalAliases(vd.id, body))).toMap + val mutatedParams = mutableParams.filter(vd => exists(expr => localAliases(vd).exists(id => isMutationOf(expr, id)))(body)) val mutatedParamsIndices = fd.params.zipWithIndex.flatMap{ case (vd, i) if mutatedParams.contains(vd) => Some(i) case _ => None @@ -326,6 +346,21 @@ object AntiAliasingPhase extends TransformationPhase { effects } + //for a given id, compute the identifiers that alias it or some part of the object refered by id + def computeLocalAliases(id: Identifier, body: Expr): Set[Identifier] = { + def pre(expr: Expr, ids: Set[Identifier]): Set[Identifier] = expr match { + case l@Let(i, Variable(v), _) if ids.contains(v) => ids + i + case m@MatchExpr(Variable(v), cses) if ids.contains(v) => { + val newIds = cses.flatMap(mc => mc.pattern.binders) + ids ++ newIds + } + case e => ids + } + def combiner(e: Expr, ctx: Set[Identifier], ids: Seq[Set[Identifier]]): Set[Identifier] = ctx ++ ids.toSet.flatten + id + val res = preFoldWithContext(pre, combiner)(body, Set(id)) + res + } + def checkAliasing(fd: FunDef)(ctx: LeonContext): Unit = { def checkReturnValue(body: Expr, bindings: Set[Identifier]): Unit = { @@ -400,13 +435,16 @@ object AntiAliasingPhase extends TransformationPhase { private def findReceiverId(o: Expr): Option[Identifier] = o match { case Variable(id) => Some(id) case CaseClassSelector(_, e, _) => findReceiverId(e) + case AsInstanceOf(e, ct) => findReceiverId(e) case _ => None } - private def isMutableType(tpe: TypeTree): Boolean = tpe match { + private def isMutableType(tpe: TypeTree, abstractClasses: Set[ClassType] = Set()): Boolean = tpe match { + case (ct: ClassType) if abstractClasses.contains(ct) => false case (arr: ArrayType) => true - case CaseClassType(ccd, _) if ccd.fields.exists(vd => vd.isVar || isMutableType(vd.getType)) => true + case CaseClassType(ccd, _) => ccd.fields.exists(vd => vd.isVar || isMutableType(vd.getType, abstractClasses)) + case (ct: ClassType) => ct.knownDescendants.exists(c => isMutableType(c, abstractClasses + ct)) case _ => false } diff --git a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala b/src/main/scala/leon/xlang/XLangDesugaringPhase.scala index 54e298a226704a78fb99abc445e55df8895bc98f..e3325e3a922b1b7f7adb6919cb6c9fd797451da9 100644 --- a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala +++ b/src/main/scala/leon/xlang/XLangDesugaringPhase.scala @@ -3,6 +3,7 @@ package leon package xlang +import utils._ import purescala.Definitions.Program object XLangDesugaringPhase extends LeonPhase[Program, Program] { @@ -11,8 +12,13 @@ object XLangDesugaringPhase extends LeonPhase[Program, Program] { val description = "Desugar xlang features into PureScala" override def run(ctx: LeonContext, pgm: Program): (LeonContext, Program) = { + + def debugTrees(title: String) = + PrintTreePhase(title).when(ctx.reporter.isDebugEnabled(DebugSectionTrees)) + val phases = AntiAliasingPhase andThen + debugTrees("Program after anti-aliasing") andThen EpsilonElimination andThen ImperativeCodeElimination diff --git a/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation1.scala b/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation1.scala new file mode 100644 index 0000000000000000000000000000000000000000..136752a4e3d0aa84366489fa3553d83500f2f867 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation1.scala @@ -0,0 +1,17 @@ +object PatternMatchingAliasingMutation1 { + + abstract class A + case class B(var x: Int) extends A + case class C(var y: Int) extends A + + def updateValue(a: A, newVal: Int): Unit = a match { + case (b: B) => b.x = newVal + case (c: C) => c.y = newVal + } + + def f(): Int = { + val b = B(10) + updateValue(b, 15) + b.x + } ensuring(_ == 15) +} diff --git a/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation2.scala b/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation2.scala new file mode 100644 index 0000000000000000000000000000000000000000..fd8f16bd032660ffa5c484c120e81e8b7aba1280 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation2.scala @@ -0,0 +1,17 @@ +object PatternMatchingAliasingMutation2 { + + abstract class A + case class B(var x: Int) extends A + case class C(var y: Int) extends A + + def updateValue(a: A, newVal: Int): Unit = a match { + case b@B(_) => b.x = newVal + case c@C(_) => c.y = newVal + } + + def f(): Int = { + val b = B(10) + updateValue(b, 15) + b.x + } ensuring(_ == 15) +} diff --git a/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation3.scala b/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation3.scala new file mode 100644 index 0000000000000000000000000000000000000000..17cf06b14105289792afba610189f02966bee360 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation3.scala @@ -0,0 +1,19 @@ +object PatternMatchingAliasingMutation3 { + + case class MutableObject(var x: Int) + + abstract class A + case class B(m: MutableObject) extends A + case class C(m: MutableObject) extends A + + def updateValue(a: A, newVal: Int): Unit = a match { + case B(m) => m.x = newVal + case C(m) => m.x = newVal + } + + def f(): Int = { + val b = B(MutableObject(10)) + updateValue(b, 15) + b.m.x + } ensuring(_ == 15) +} diff --git a/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation4.scala b/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation4.scala new file mode 100644 index 0000000000000000000000000000000000000000..1829b5c51f1570ff528efb3cd05aafddcb2dee6f --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation4.scala @@ -0,0 +1,28 @@ +object PatternMatchingAliasingMutation4 { + + case class A(var x: Int) + + abstract class List + case class Cons(a: A, tail: List) extends List + case class Nil() extends List + + def rec(l: List): Unit = (l match { + case Cons(a, as) => + a.x = 0 + rec(as) + case Nil() => + () + }) ensuring(_ => allZero(l)) + + def allZero(l: List): Boolean = l match { + case Cons(a, tail) => a.x == 0 && allZero(tail) + case Nil() => true + } + + def test(): List = { + val l = Cons(A(2), Cons(A(1), Cons(A(0), Nil()))) + rec(l) + l + } ensuring(l => allZero(l)) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation5.scala b/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation5.scala new file mode 100644 index 0000000000000000000000000000000000000000..0d415a3fa4adcbd2d67781d6d6f1eab565e6161a --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/PatternMatchingAliasingMutation5.scala @@ -0,0 +1,31 @@ +object PatternMatchingAliasingMutation5 { + + case class A(var x: Int) + + abstract class List + case class Cons(a: A, tail: List) extends List + case class Nil() extends List + + def rec(l: List, i: BigInt): Unit = { + require(allZero(l) && i >= 0) + l match { + case Cons(a, as) => + if(i % 2 == 0) + a.x = 1 + rec(as, i + 1) + case Nil() => + () + } + } ensuring(_ => allZeroOrOne(l)) + + def allZeroOrOne(l: List): Boolean = l match { + case Cons(a, tail) => (a.x == 0 || a.x == 1) && allZeroOrOne(tail) + case Nil() => true + } + + def allZero(l: List): Boolean = l match { + case Cons(a, tail) => a.x == 0 && allZero(tail) + case Nil() => true + } + +}