Skip to content
Snippets Groups Projects
Commit 33d8bd6e authored by Regis Blanc's avatar Regis Blanc
Browse files

anti-aliasing handling of higher-order functions

parent c6ec9578
No related branches found
No related tags found
No related merge requests found
......@@ -63,6 +63,10 @@ object AntiAliasingPhase extends TransformationPhase {
* Adapt the signature to express its effects. In case the
* function has no effect, this will still return the original
* fundef.
*
* Also update FunctionType parameters that need to become explicit
* about the effect they could perform (returning any mutable type that
* they receive).
*/
private def updateFunDef(fd: FunDef, effects: Effects)(ctx: LeonContext): FunDef = {
......@@ -72,22 +76,21 @@ object AntiAliasingPhase extends TransformationPhase {
case _ => None
}.map(_.id)
val newParams = fd.params.map(vd => vd.getType match {
case (ft: FunctionType) => ValDef(vd.id.duplicate(tpe = makeFunctionTypeExplicit(ft)))
case _ => vd
})
fd.body.foreach(body => getReturnedExpr(body).foreach{
case v@Variable(id) if aliasedParams.contains(id) =>
ctx.reporter.fatalError(v.getPos, "Cannot return a shared reference to a mutable object")
case _ => ()
})
//val allBodies: Set[Expr] =
// fd.body.toSet.flatMap((bd: Expr) => nestedFunDefsOf(bd).flatMap(_.body)) ++ fd.body
//allBodies.foreach(body => getReturnedExpr(body).foreach{
// case v@Variable(id) if aliasedParams.contains(id) =>
// ctx.reporter.fatalError(v.getPos, "Cannot return a shared reference to a mutable object: "k+ v)
// case _ => ()
//})
if(aliasedParams.isEmpty) fd else {
if(aliasedParams.isEmpty && fd.params == newParams) fd else {
val newReturnType: TypeTree = tupleTypeWrap(fd.returnType +: aliasedParams.map(_.getType))
val newFunDef = new FunDef(fd.id.freshen, fd.tparams, fd.params, newReturnType)
val newFunDef = new FunDef(fd.id.freshen, fd.tparams, newParams, newReturnType)
newFunDef.addFlags(fd.flags)
newFunDef.setPos(fd)
newFunDef
......@@ -107,7 +110,7 @@ object AntiAliasingPhase extends TransformationPhase {
if(aliasedParams.isEmpty) {
val newBody = fd.body.map(body => {
makeSideEffectsExplicit(body, Seq(), effects, updatedFunDefs, varsInScope)(ctx)
makeSideEffectsExplicit(body, fd, Seq(), effects, updatedFunDefs, varsInScope)(ctx)
})
newFunDef.body = newBody
newFunDef.precondition = fd.precondition
......@@ -119,7 +122,7 @@ object AntiAliasingPhase extends TransformationPhase {
val newBody = fd.body.map(body => {
val freshBody = replaceFromIDs(rewritingMap.map(p => (p._1, p._2.toVariable)), body)
val explicitBody = makeSideEffectsExplicit(freshBody, freshLocalVars, effects, updatedFunDefs, varsInScope)(ctx)
val explicitBody = makeSideEffectsExplicit(freshBody, fd, freshLocalVars, effects, updatedFunDefs, varsInScope)(ctx)
//WARNING: only works if side effects in Tuples are extracted from left to right,
// in the ImperativeTransformation phase.
......@@ -154,8 +157,46 @@ object AntiAliasingPhase extends TransformationPhase {
//We turn all local val of mutable objects into vars and explicit side effects
//using assignments. We also make sure that no aliasing is being done.
private def makeSideEffectsExplicit
(body: Expr, aliasedParams: Seq[Identifier], effects: Effects, updatedFunDefs: Map[FunDef, FunDef], varsInScope: Map[FunDef, Set[Identifier]])
(body: Expr, originalFd: FunDef, aliasedParams: Seq[Identifier], effects: Effects, updatedFunDefs: Map[FunDef, FunDef], varsInScope: Map[FunDef, Set[Identifier]])
(ctx: LeonContext): Expr = {
val newFunDef = updatedFunDefs(originalFd)
def mapApplication(args: Seq[Expr], nfi: Expr, nfiType: TypeTree, fiEffects: Set[Int], rewritings: Map[Identifier, Expr]): Expr = {
if(fiEffects.nonEmpty) {
val modifiedArgs: Seq[(Identifier, Expr)] =
args.zipWithIndex.filter{ case (arg, i) => fiEffects.contains(i) }
.map(arg => {
val rArg = replaceFromIDs(rewritings, arg._1)
(findReceiverId(rArg).get, rArg)
})
val duplicatedParams = modifiedArgs.diff(modifiedArgs.distinct).distinct
if(duplicatedParams.nonEmpty)
ctx.reporter.fatalError(nfi.getPos, "Illegal passing of aliased parameter: " + duplicatedParams.head)
val freshRes = FreshIdentifier("res", nfiType)
val extractResults = Block(
modifiedArgs.zipWithIndex.map{ case ((id, expr), index) => {
val resSelect = TupleSelect(freshRes.toVariable, index + 2)
expr match {
case CaseClassSelector(_, obj, mid) =>
Assignment(id, deepCopy(obj, mid, resSelect))
case _ =>
Assignment(id, resSelect)
}
}},
TupleSelect(freshRes.toVariable, 1))
val newExpr = Let(freshRes, nfi, extractResults)
newExpr
} else {
nfi
}
}
preMapWithContext[(Set[Identifier], Map[Identifier, Expr])]((expr, context) => {
val bindings = context._1
val rewritings = context._2
......@@ -212,6 +253,36 @@ object AntiAliasingPhase extends TransformationPhase {
(Some(LetDef(nfds, body).copiedFrom(l)), context)
}
case lambda@Lambda(params, body) => {
val ft@FunctionType(_, _) = lambda.getType
val ownEffects = functionTypeEffects(ft)
val aliasedParams: Seq[Identifier] = params.zipWithIndex.flatMap{
case (vd, i) if ownEffects.contains(i) => Some(vd)
case _ => None
}.map(_.id)
if(aliasedParams.isEmpty) {
(None, context)
} else {
val freshLocalVars: Seq[Identifier] = aliasedParams.map(v => v.freshen)
val rewritingMap: Map[Identifier, Identifier] = aliasedParams.zip(freshLocalVars).toMap
val freshBody = replaceFromIDs(rewritingMap.map(p => (p._1, p._2.toVariable)), body)
val explicitBody = makeSideEffectsExplicit(freshBody, originalFd, freshLocalVars, effects, updatedFunDefs, varsInScope)(ctx)
//WARNING: only works if side effects in Tuples are extracted from left to right,
// in the ImperativeTransformation phase.
val finalBody: Expr = Tuple(explicitBody +: freshLocalVars.map(_.toVariable))
val wrappedBody: Expr = freshLocalVars.zip(aliasedParams).foldLeft(finalBody)((bd, vp) => {
LetVar(vp._1, Variable(vp._2), bd)
})
val finalLambda = Lambda(params, wrappedBody).copiedFrom(lambda)
(Some(finalLambda), context)
}
}
case fi@FunctionInvocation(fd, args) => {
val vis: Set[Identifier] = varsInScope.get(fd.fd).getOrElse(Set())
......@@ -221,47 +292,35 @@ object AntiAliasingPhase extends TransformationPhase {
}).foreach(aliasedArg =>
ctx.reporter.fatalError(aliasedArg.getPos, "Illegal passing of aliased parameter: " + aliasedArg))
updatedFunDefs.get(fd.fd) match {
case None => (None, context)
case Some(nfd) => {
val nfi = FunctionInvocation(nfd.typed(fd.tps), args.map(arg => replaceFromIDs(rewritings, arg))).copiedFrom(fi)
val fiEffects = effects.getOrElse(fd.fd, Set())
if(fiEffects.nonEmpty) {
val modifiedArgs: Seq[(Identifier, Expr)] =
args.zipWithIndex.filter{ case (arg, i) => fiEffects.contains(i) }
.map(arg => {
val rArg = replaceFromIDs(rewritings, arg._1)
(findReceiverId(rArg).get, rArg)
})
val duplicatedParams = modifiedArgs.diff(modifiedArgs.distinct).distinct
if(duplicatedParams.nonEmpty)
ctx.reporter.fatalError(fi.getPos, "Illegal passing of aliased parameter: " + duplicatedParams.head)
val freshRes = FreshIdentifier("res", nfd.typed(fd.tps).returnType)
val extractResults = Block(
modifiedArgs.zipWithIndex.map{ case ((id, expr), index) => {
val resSelect = TupleSelect(freshRes.toVariable, index + 2)
expr match {
case CaseClassSelector(_, obj, mid) =>
Assignment(id, deepCopy(obj, mid, resSelect))
case _ =>
Assignment(id, resSelect)
}
}},
TupleSelect(freshRes.toVariable, 1))
val newExpr = Let(freshRes, nfi, extractResults)
(Some(newExpr), context)
} else {
(Some(nfi), context)
}
(Some(mapApplication(args, nfi, nfd.typed(fd.tps).returnType, fiEffects, rewritings)), context)
}
}
}
case app@Application(callee@Variable(id), args) => {
originalFd.params.zip(newFunDef.params)
.find(p => p._1.id == id)
.map(p => p._2.id) match {
case Some(newId) =>
val ft@FunctionType(_, _) = callee.getType
val nft = makeFunctionTypeExplicit(ft)
if(ft == nft) (None, context) else {
val nfi = Application(Variable(newId).copiedFrom(callee), args.map(arg => replaceFromIDs(rewritings, arg))).copiedFrom(app)
val fiEffects = functionTypeEffects(ft)
(Some(mapApplication(args, nfi, nft.to, fiEffects, rewritings)), context)
}
case None => (None, context)
}
}
case _ => (None, context)
}
......@@ -346,6 +405,25 @@ object AntiAliasingPhase extends TransformationPhase {
effects
}
//convert a function type with mutable parameters, into a function type
//that returns the mutable parameters. This makes explicit all possible
//effects of the function. This should be used for higher order functions
//declared as parameters.
def makeFunctionTypeExplicit(tpe: FunctionType): FunctionType = {
val newReturnTypes = tpe.from.filter(t => isMutableType(t))
if(newReturnTypes.isEmpty)
tpe
else {
FunctionType(tpe.from, TupleType(tpe.to +: newReturnTypes))
}
}
def functionTypeEffects(ft: FunctionType): Set[Int] = {
ft.from.zipWithIndex.flatMap{ case (vd, i) =>
if(isMutableType(vd.getType)) Some(i) else None
}.toSet
}
//for a given id, compute the identifiers that alias it or some part of the object refered by id
def computeLocalAliases(id: Identifier, body: Expr): Set[Identifier] = {
def pre(expr: Expr, ids: Set[Identifier]): Set[Identifier] = expr match {
......@@ -455,6 +533,14 @@ object AntiAliasingPhase extends TransformationPhase {
private def isMutationOf(expr: Expr, id: Identifier): Boolean = expr match {
case ArrayUpdate(Variable(a), _, _) => a == id
case FieldAssignment(obj, _, _) => findReceiverId(obj).exists(_ == id)
case Application(callee, args) => {
val ft@FunctionType(_, _) = callee.getType
val effects = functionTypeEffects(ft)
args.zipWithIndex.exists{
case (Variable(argId), index) => argId == id && effects.contains(index)
case _ => false
}
}
case _ => false
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment