Skip to content
Snippets Groups Projects
Commit c07e0c96 authored by Regis Blanc's avatar Regis Blanc
Browse files

Primitive support for aliasing in pattern matching

parent 44a031bb
No related branches found
No related tags found
No related merge requests found
......@@ -667,6 +667,9 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] {
case (e, p) => mapForPattern(e, p)
}.toMap
case InstanceOfPattern(b, ct) =>
bindIn(b, Some(ct))
case other =>
bindIn(other.binder)
}
......
......@@ -156,97 +156,116 @@ object AntiAliasingPhase extends TransformationPhase {
private def makeSideEffectsExplicit
(body: Expr, aliasedParams: Seq[Identifier], effects: Effects, updatedFunDefs: Map[FunDef, FunDef], varsInScope: Map[FunDef, Set[Identifier]])
(ctx: LeonContext): Expr = {
preMapWithContext[Set[Identifier]]((expr, bindings) => expr match {
preMapWithContext[(Set[Identifier], Map[Identifier, Expr])]((expr, context) => {
val bindings = context._1
val rewritings = context._2
expr match {
case l@Let(id, IsTyped(v, tpe), b) if isMutableType(tpe) => {
val varDecl = LetVar(id, v, b).setPos(l)
(Some(varDecl), (bindings + id, rewritings))
}
case up@ArrayUpdate(a, i, v) => {
val ra@Variable(id) = a
if(bindings.contains(id))
(Some(Assignment(id, ArrayUpdated(ra, i, v).setPos(up)).setPos(up)), bindings)
else
(None, bindings)
}
case l@LetVar(id, IsTyped(v, tpe), b) if isMutableType(tpe) => {
(None, (bindings + id, rewritings))
}
case as@FieldAssignment(o, id, v) => {
findReceiverId(o) match {
case None =>
ctx.reporter.fatalError(as.getPos, "Unsupported form of field assignment: " + as)
case Some(oid) => {
if(bindings.contains(oid))
(Some(Assignment(oid, deepCopy(o, id, v))), bindings)
else
(None, bindings)
}
case m@MatchExpr(scrut, cses) if isMutableType(scrut.getType) => {
val tmp: Map[Identifier, Expr] = cses.flatMap{ case MatchCase(pattern, guard, rhs) => {
mapForPattern(scrut, pattern)
//val binder = pattern.binder.get
//binder -> scrut
}}.toMap
(None, (bindings, rewritings ++ tmp))
}
}
case l@Let(id, IsTyped(v, tpe), b) if isMutableType(tpe) => {
val varDecl = LetVar(id, v, b).setPos(l)
(Some(varDecl), bindings + id)
}
case up@ArrayUpdate(a, i, v) => {
val ra@Variable(id) = a
if(bindings.contains(id))
(Some(Assignment(id, ArrayUpdated(ra, i, v).setPos(up)).setPos(up)), context)
else
(None, context)
}
case l@LetVar(id, IsTyped(v, tpe), b) if isMutableType(tpe) => {
(None, bindings + id)
}
case as@FieldAssignment(o, id, v) => {
val so = replaceFromIDs(rewritings, o)
findReceiverId(so) match {
case None =>
ctx.reporter.fatalError(as.getPos, "Unsupported form of field assignment: " + as)
case Some(oid) => {
if(bindings.contains(oid))
(Some(Assignment(oid, deepCopy(so, id, v))), context)
else
(None, context)
}
}
}
//we need to replace local fundef by the new updated fun defs.
case l@LetDef(fds, body) => {
//this might be traversed several time in case of doubly nested fundef,
//so we need to ignore the second times by checking if updatedFunDefs
//contains a mapping or not
val nfds = fds.map(fd => updatedFunDefs.get(fd).getOrElse(fd))
(Some(LetDef(nfds, body).copiedFrom(l)), bindings)
}
//we need to replace local fundef by the new updated fun defs.
case l@LetDef(fds, body) => {
//this might be traversed several time in case of doubly nested fundef,
//so we need to ignore the second times by checking if updatedFunDefs
//contains a mapping or not
val nfds = fds.map(fd => updatedFunDefs.get(fd).getOrElse(fd))
(Some(LetDef(nfds, body).copiedFrom(l)), context)
}
case fi@FunctionInvocation(fd, args) => {
val vis: Set[Identifier] = varsInScope.get(fd.fd).getOrElse(Set())
args.find({
case Variable(id) => vis.contains(id)
case _ => false
}).foreach(aliasedArg =>
ctx.reporter.fatalError(aliasedArg.getPos, "Illegal passing of aliased parameter: " + aliasedArg))
updatedFunDefs.get(fd.fd) match {
case None => (None, bindings)
case Some(nfd) => {
val nfi = FunctionInvocation(nfd.typed(fd.tps), args).copiedFrom(fi)
val fiEffects = effects.getOrElse(fd.fd, Set())
if(fiEffects.nonEmpty) {
val modifiedArgs: Seq[(Identifier, Expr)] =// functionInvocationEffects(fi, fiEffects)
args.zipWithIndex.filter{ case (arg, i) => fiEffects.contains(i) }
.map(arg => (findReceiverId(arg._1).get, arg._1))
val duplicatedParams = modifiedArgs.diff(modifiedArgs.distinct).distinct
if(duplicatedParams.nonEmpty)
ctx.reporter.fatalError(fi.getPos, "Illegal passing of aliased parameter: " + duplicatedParams.head)
val freshRes = FreshIdentifier("res", nfd.typed(fd.tps).returnType)
val extractResults = Block(
modifiedArgs.zipWithIndex.map{ case ((id, expr), index) => {
val resSelect = TupleSelect(freshRes.toVariable, index + 2)
expr match {
case CaseClassSelector(_, obj, mid) =>
Assignment(id, deepCopy(obj, mid, resSelect))
case _ =>
Assignment(id, resSelect)
}
}},
TupleSelect(freshRes.toVariable, 1))
val newExpr = Let(freshRes, nfi, extractResults)
(Some(newExpr), bindings)
} else {
(Some(nfi), bindings)
case fi@FunctionInvocation(fd, args) => {
val vis: Set[Identifier] = varsInScope.get(fd.fd).getOrElse(Set())
args.find({
case Variable(id) => vis.contains(id)
case _ => false
}).foreach(aliasedArg =>
ctx.reporter.fatalError(aliasedArg.getPos, "Illegal passing of aliased parameter: " + aliasedArg))
updatedFunDefs.get(fd.fd) match {
case None => (None, context)
case Some(nfd) => {
val nfi = FunctionInvocation(nfd.typed(fd.tps), args.map(arg => replaceFromIDs(rewritings, arg))).copiedFrom(fi)
val fiEffects = effects.getOrElse(fd.fd, Set())
if(fiEffects.nonEmpty) {
val modifiedArgs: Seq[(Identifier, Expr)] =
args.zipWithIndex.filter{ case (arg, i) => fiEffects.contains(i) }
.map(arg => {
val rArg = replaceFromIDs(rewritings, arg._1)
(findReceiverId(rArg).get, rArg)
})
val duplicatedParams = modifiedArgs.diff(modifiedArgs.distinct).distinct
if(duplicatedParams.nonEmpty)
ctx.reporter.fatalError(fi.getPos, "Illegal passing of aliased parameter: " + duplicatedParams.head)
val freshRes = FreshIdentifier("res", nfd.typed(fd.tps).returnType)
val extractResults = Block(
modifiedArgs.zipWithIndex.map{ case ((id, expr), index) => {
val resSelect = TupleSelect(freshRes.toVariable, index + 2)
expr match {
case CaseClassSelector(_, obj, mid) =>
Assignment(id, deepCopy(obj, mid, resSelect))
case _ =>
Assignment(id, resSelect)
}
}},
TupleSelect(freshRes.toVariable, 1))
val newExpr = Let(freshRes, nfi, extractResults)
(Some(newExpr), context)
} else {
(Some(nfi), context)
}
}
}
}
}
case _ => (None, bindings)
case _ => (None, context)
}
})(body, aliasedParams.toSet)
})(body, (aliasedParams.toSet, Map()))
}
//for each fundef, the set of modified params (by index)
......@@ -273,7 +292,8 @@ object AntiAliasingPhase extends TransformationPhase {
effects += (fd -> Set())
case Some(body) => {
val mutableParams = fd.params.filter(vd => isMutableType(vd.getType))
val mutatedParams = mutableParams.filter(vd => exists(expr => isMutationOf(expr, vd.id))(body))
val localAliases: Map[ValDef, Set[Identifier]] = mutableParams.map(vd => (vd, computeLocalAliases(vd.id, body))).toMap
val mutatedParams = mutableParams.filter(vd => exists(expr => localAliases(vd).exists(id => isMutationOf(expr, id)))(body))
val mutatedParamsIndices = fd.params.zipWithIndex.flatMap{
case (vd, i) if mutatedParams.contains(vd) => Some(i)
case _ => None
......@@ -326,6 +346,21 @@ object AntiAliasingPhase extends TransformationPhase {
effects
}
//for a given id, compute the identifiers that alias it or some part of the object refered by id
def computeLocalAliases(id: Identifier, body: Expr): Set[Identifier] = {
def pre(expr: Expr, ids: Set[Identifier]): Set[Identifier] = expr match {
case l@Let(i, Variable(v), _) if ids.contains(v) => ids + i
case m@MatchExpr(Variable(v), cses) if ids.contains(v) => {
val newIds = cses.flatMap(mc => mc.pattern.binders)
ids ++ newIds
}
case e => ids
}
def combiner(e: Expr, ctx: Set[Identifier], ids: Seq[Set[Identifier]]): Set[Identifier] = ctx ++ ids.toSet.flatten + id
val res = preFoldWithContext(pre, combiner)(body, Set(id))
res
}
def checkAliasing(fd: FunDef)(ctx: LeonContext): Unit = {
def checkReturnValue(body: Expr, bindings: Set[Identifier]): Unit = {
......@@ -400,13 +435,16 @@ object AntiAliasingPhase extends TransformationPhase {
private def findReceiverId(o: Expr): Option[Identifier] = o match {
case Variable(id) => Some(id)
case CaseClassSelector(_, e, _) => findReceiverId(e)
case AsInstanceOf(e, ct) => findReceiverId(e)
case _ => None
}
private def isMutableType(tpe: TypeTree): Boolean = tpe match {
private def isMutableType(tpe: TypeTree, abstractClasses: Set[ClassType] = Set()): Boolean = tpe match {
case (ct: ClassType) if abstractClasses.contains(ct) => false
case (arr: ArrayType) => true
case CaseClassType(ccd, _) if ccd.fields.exists(vd => vd.isVar || isMutableType(vd.getType)) => true
case CaseClassType(ccd, _) => ccd.fields.exists(vd => vd.isVar || isMutableType(vd.getType, abstractClasses))
case (ct: ClassType) => ct.knownDescendants.exists(c => isMutableType(c, abstractClasses + ct))
case _ => false
}
......
......@@ -3,6 +3,7 @@
package leon
package xlang
import utils._
import purescala.Definitions.Program
object XLangDesugaringPhase extends LeonPhase[Program, Program] {
......@@ -11,8 +12,13 @@ object XLangDesugaringPhase extends LeonPhase[Program, Program] {
val description = "Desugar xlang features into PureScala"
override def run(ctx: LeonContext, pgm: Program): (LeonContext, Program) = {
def debugTrees(title: String) =
PrintTreePhase(title).when(ctx.reporter.isDebugEnabled(DebugSectionTrees))
val phases =
AntiAliasingPhase andThen
debugTrees("Program after anti-aliasing") andThen
EpsilonElimination andThen
ImperativeCodeElimination
......
object PatternMatchingAliasingMutation1 {
abstract class A
case class B(var x: Int) extends A
case class C(var y: Int) extends A
def updateValue(a: A, newVal: Int): Unit = a match {
case (b: B) => b.x = newVal
case (c: C) => c.y = newVal
}
def f(): Int = {
val b = B(10)
updateValue(b, 15)
b.x
} ensuring(_ == 15)
}
object PatternMatchingAliasingMutation2 {
abstract class A
case class B(var x: Int) extends A
case class C(var y: Int) extends A
def updateValue(a: A, newVal: Int): Unit = a match {
case b@B(_) => b.x = newVal
case c@C(_) => c.y = newVal
}
def f(): Int = {
val b = B(10)
updateValue(b, 15)
b.x
} ensuring(_ == 15)
}
object PatternMatchingAliasingMutation3 {
case class MutableObject(var x: Int)
abstract class A
case class B(m: MutableObject) extends A
case class C(m: MutableObject) extends A
def updateValue(a: A, newVal: Int): Unit = a match {
case B(m) => m.x = newVal
case C(m) => m.x = newVal
}
def f(): Int = {
val b = B(MutableObject(10))
updateValue(b, 15)
b.m.x
} ensuring(_ == 15)
}
object PatternMatchingAliasingMutation4 {
case class A(var x: Int)
abstract class List
case class Cons(a: A, tail: List) extends List
case class Nil() extends List
def rec(l: List): Unit = (l match {
case Cons(a, as) =>
a.x = 0
rec(as)
case Nil() =>
()
}) ensuring(_ => allZero(l))
def allZero(l: List): Boolean = l match {
case Cons(a, tail) => a.x == 0 && allZero(tail)
case Nil() => true
}
def test(): List = {
val l = Cons(A(2), Cons(A(1), Cons(A(0), Nil())))
rec(l)
l
} ensuring(l => allZero(l))
}
object PatternMatchingAliasingMutation5 {
case class A(var x: Int)
abstract class List
case class Cons(a: A, tail: List) extends List
case class Nil() extends List
def rec(l: List, i: BigInt): Unit = {
require(allZero(l) && i >= 0)
l match {
case Cons(a, as) =>
if(i % 2 == 0)
a.x = 1
rec(as, i + 1)
case Nil() =>
()
}
} ensuring(_ => allZeroOrOne(l))
def allZeroOrOne(l: List): Boolean = l match {
case Cons(a, tail) => (a.x == 0 || a.x == 1) && allZeroOrOne(tail)
case Nil() => true
}
def allZero(l: List): Boolean = l match {
case Cons(a, tail) => a.x == 0 && allZero(tail)
case Nil() => true
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment