Skip to content
Snippets Groups Projects
Commit d7f386a0 authored by Régis Blanc's avatar Régis Blanc
Browse files

block can be used anywhere

parent c3169f91
No related branches found
No related tags found
No related merge requests found
...@@ -3,11 +3,11 @@ object IfExpr1 { ...@@ -3,11 +3,11 @@ object IfExpr1 {
def foo(): Int = { def foo(): Int = {
var a = 1 var a = 1
var b = 2 var b = 2
if(a == b) if({a = a + 1; a != b})
a = a + 3 a = a + 3
else else
b = a + b b = a + b
a a
} ensuring(_ == 1) } ensuring(_ == 2)
} }
...@@ -3,9 +3,11 @@ object Plus { ...@@ -3,9 +3,11 @@ object Plus {
def foo(): Int = ({ def foo(): Int = ({
var a = 2 var a = 2
var b = 1
{a = a + 1; a} + {a = 5 - a; a} a = {b = b + 2; a = a + 1; a} + {a = 5 - a; a}
}) ensuring(_ == 5) a + b
}) ensuring(_ == 8)
} }
object ValSideEffect {
def foo(): Int = ({
var a = 2
var a2 = 1
val b = {a = a + 1; a2 = a2 + 1; a} + {a = 5 - a; a}
a = a + 1
a2 = a2 + 3
a + a2 + b
}) ensuring(_ == 13)
}
// vim: set ts=4 sw=4 et:
object While1 {
def foo(): Int = {
var a = 0
var i = 0
while(i < 10) {
a = a + i
i = i + 1
}
a
}
}
object While1 {
def foo(): Int = {
var a = 0
var i = 0
while({i = i+2; i <= 10}) {
a = a + i
i = i - 1
}
a
} ensuring(_ == 54)
}
// vim: set ts=4 sw=4 et:
...@@ -20,15 +20,19 @@ object ImperativeCodeElimination extends Pass { ...@@ -20,15 +20,19 @@ object ImperativeCodeElimination extends Pass {
//return a "scope" consisting of purely functional code that defines potentially needed //return a "scope" consisting of purely functional code that defines potentially needed
//new variables (val, not var) and a mapping for each modified variable (var, not val :) ) //new variables (val, not var) and a mapping for each modified variable (var, not val :) )
//to their new name defined in the scope //to their new name defined in the scope. The first returned valued is the value of the expression
//that should be introduced as such in the returned scope (the val already refers to the new names)
private def toFunction(expr: Expr): (Expr, Expr => Expr, Map[Identifier, Identifier]) = { private def toFunction(expr: Expr): (Expr, Expr => Expr, Map[Identifier, Identifier]) = {
val res = expr match { val res = expr match {
case Assignment(id, e) => { case Assignment(id, e) => {
val newId = FreshIdentifier(id.name).setType(id.getType) val newId = FreshIdentifier(id.name).setType(id.getType)
val scope = ((body: Expr) => Let(newId, e, body)) val (rhsVal, rhsScope, rhsFun) = toFunction(e)
(UnitLiteral, scope, Map(id -> newId)) val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, body))
(UnitLiteral, scope, rhsFun + (id -> newId))
} }
case ite@IfExpr(cond, tExpr, eExpr) => { case ite@IfExpr(cond, tExpr, eExpr) => {
val (cRes, cScope, cFun) = toFunction(cond)
val (tRes, tScope, tFun) = toFunction(tExpr) val (tRes, tScope, tFun) = toFunction(tExpr)
val (eRes, eScope, eFun) = toFunction(eExpr) val (eRes, eScope, eFun) = toFunction(eExpr)
...@@ -49,11 +53,12 @@ object ImperativeCodeElimination extends Pass { ...@@ -49,11 +53,12 @@ object ImperativeCodeElimination extends Pass {
})) }))
elseVal.setType(iteType) elseVal.setType(iteType)
val iteExpr = IfExpr(cond, tScope(thenVal), eScope(elseVal)).setType(iteType) val iteExpr = IfExpr(cRes, replaceNames(cFun, tScope(thenVal)), replaceNames(cFun, eScope(elseVal))).setType(iteType)
val scope = ((body: Expr) => { val scope = ((body: Expr) => {
val tupleId = FreshIdentifier("t").setType(iteType) val tupleId = FreshIdentifier("t").setType(iteType)
Let(tupleId, iteExpr, cScope(
Let(tupleId, iteExpr,
if(freshIds.isEmpty) if(freshIds.isEmpty)
Let(resId, tupleId.toVariable, body) Let(resId, tupleId.toVariable, body)
else else
...@@ -61,28 +66,39 @@ object ImperativeCodeElimination extends Pass { ...@@ -61,28 +66,39 @@ object ImperativeCodeElimination extends Pass {
freshIds.zipWithIndex.foldLeft(body)((b, id) => freshIds.zipWithIndex.foldLeft(body)((b, id) =>
Let(id._1, Let(id._1,
TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType),
b)))) b)))))
}) })
(resId.toVariable, scope, modifiedVars.zip(freshIds).toMap) (resId.toVariable, scope, cFun ++ modifiedVars.zip(freshIds).toMap)
} }
case wh@While(cond, body) => { case wh@While(cond, body) => {
val (condRes, condScope, condFun) = toFunction(cond)
val (_, bodyScope, bodyFun) = toFunction(body) val (_, bodyScope, bodyFun) = toFunction(body)
val modifiedVars: Seq[Identifier] = bodyFun.keys.toSeq val condBodyFun = condFun ++ bodyFun
val modifiedVars: Seq[Identifier] = condBodyFun.keys.toSeq
if(modifiedVars.isEmpty) if(modifiedVars.isEmpty)
(UnitLiteral, (b: Expr) => b, Map()) (UnitLiteral, (b: Expr) => b, Map())
else { else {
val whileFunVars = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) val whileFunVars = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType))
val modifiedVars2WhileFunVars = modifiedVars.zip(whileFunVars).toMap
val whileFunVarDecls = whileFunVars.map(id => VarDecl(id, id.getType)) val whileFunVarDecls = whileFunVars.map(id => VarDecl(id, id.getType))
val whileFunReturnType = if(whileFunVars.size == 1) whileFunVars.head.getType else TupleType(whileFunVars.map(_.getType)) val whileFunReturnType = if(whileFunVars.size == 1) whileFunVars.head.getType else TupleType(whileFunVars.map(_.getType))
val whileFunDef = new FunDef(FreshIdentifier("while"), whileFunReturnType, whileFunVarDecls).setPosInfo(wh) val whileFunDef = new FunDef(FreshIdentifier("while"), whileFunReturnType, whileFunVarDecls).setPosInfo(wh)
val modifiedVars2WhileFunVars: Map[Expr, Expr] = modifiedVars.zip(whileFunVars).map(p => (p._1.toVariable, p._2.toVariable)).toMap val whileFunCond = condRes
val whileFunCond = replace(modifiedVars2WhileFunVars, cond) val whileFunRecursiveCall = replaceNames(condFun,
val whileFunRecursiveCall = replace(modifiedVars2WhileFunVars, bodyScope(FunctionInvocation(whileFunDef, modifiedVars.map(id => bodyFun(id).toVariable)))) bodyScope(FunctionInvocation(whileFunDef, modifiedVars.map(id => condBodyFun(id).toVariable))))
val whileFunBaseCase = (if(whileFunVars.size == 1) whileFunVars.head.toVariable else Tuple(whileFunVars.map(_.toVariable))).setType(whileFunReturnType) val whileFunBaseCase =
val whileFunBody = IfExpr(whileFunCond, whileFunRecursiveCall, whileFunBaseCase).setType(whileFunReturnType) (if(whileFunVars.size == 1)
condFun.get(modifiedVars.head).getOrElse(whileFunVars.head).toVariable
else
Tuple(modifiedVars.map(id => condFun.get(id).getOrElse(modifiedVars2WhileFunVars(id)).toVariable))
).setType(whileFunReturnType)
val whileFunBody = replaceNames(modifiedVars2WhileFunVars,
condScope(IfExpr(whileFunCond, whileFunRecursiveCall, whileFunBaseCase).setType(whileFunReturnType)))
whileFunDef.body = Some(whileFunBody) whileFunDef.body = Some(whileFunBody)
val resVar = ResultVariable().setType(whileFunReturnType) val resVar = ResultVariable().setType(whileFunReturnType)
...@@ -91,10 +107,14 @@ object ImperativeCodeElimination extends Pass { ...@@ -91,10 +107,14 @@ object ImperativeCodeElimination extends Pass {
Map(whileFunVars.head.toVariable -> resVar) Map(whileFunVars.head.toVariable -> resVar)
else else
whileFunVars.zipWithIndex.map{ case (v, i) => (v.toVariable, TupleSelect(resVar, i+1).setType(v.getType)) }.toMap whileFunVars.zipWithIndex.map{ case (v, i) => (v.toVariable, TupleSelect(resVar, i+1).setType(v.getType)) }.toMap
val modifiedVars2ResultVars: Map[Expr, Expr] = modifiedVars.map(v => (v.toVariable, whileFunVars2ResultVars(modifiedVars2WhileFunVars(v.toVariable)))).toMap val modifiedVars2ResultVars: Map[Expr, Expr] = modifiedVars.map(id =>
(id.toVariable, whileFunVars2ResultVars(modifiedVars2WhileFunVars(id).toVariable))).toMap
val trivialPostcondition: Option[Expr] = Some(Not(replace(whileFunVars2ResultVars, whileFunCond)))
val invariantPrecondition: Option[Expr] = wh.invariant.map(expr => replace(modifiedVars2WhileFunVars, expr)) //the mapping of the trivial post condition variables depends on whether the condition has had some side effect
val trivialPostcondition: Option[Expr] = Some(Not(replace(
modifiedVars.map(id => (condFun.get(id).getOrElse(id).toVariable, modifiedVars2ResultVars(id.toVariable))).toMap,
whileFunCond)))
val invariantPrecondition: Option[Expr] = wh.invariant.map(expr => replaceNames(modifiedVars2WhileFunVars, expr))
val invariantPostcondition: Option[Expr] = wh.invariant.map(expr => replace(modifiedVars2ResultVars, expr)) val invariantPostcondition: Option[Expr] = wh.invariant.map(expr => replace(modifiedVars2ResultVars, expr))
whileFunDef.precondition = invariantPrecondition whileFunDef.precondition = invariantPrecondition
whileFunDef.postcondition = trivialPostcondition.map(expr => whileFunDef.postcondition = trivialPostcondition.map(expr =>
...@@ -122,37 +142,42 @@ object ImperativeCodeElimination extends Pass { ...@@ -122,37 +142,42 @@ object ImperativeCodeElimination extends Pass {
(UnitLiteral, finalScope, modifiedVars.zip(finalVars).toMap) (UnitLiteral, finalScope, modifiedVars.zip(finalVars).toMap)
} }
} }
case Block(head::exprs, expr) => {
val (_, headScope, headFun) = toFunction(head) case Block(Seq(), expr) => toFunction(expr)
val (scope, fun) = exprs.foldLeft((headScope, headFun))((acc, e) => { case Block(exprs, expr) => {
val (scope, fun) = exprs.foldRight((body: Expr) => body, Map[Identifier, Identifier]())((e, acc) => {
val (accScope, accFun) = acc val (accScope, accFun) = acc
val (_, rScope, rFun) = toFunction(e) val (_, rScope, rFun) = toFunction(e)
val scope = ((body: Expr) => val scope = (body: Expr) => rScope(replaceNames(rFun, accScope(body)))
accScope(replace(accFun.map{case (i1, i2) => (i1.toVariable, i2.toVariable)}, rScope(body)))) (scope, rFun ++ accFun)
(scope, accFun ++ rFun)
}) })
val (lastRes, lastScope, lastFun) = toFunction(expr) val (lastRes, lastScope, lastFun) = toFunction(expr)
(lastRes, val finalFun = fun ++ lastFun
(body: Expr) => scope(replace(fun.map{ case (i1, i2) => (i1.toVariable, i2.toVariable) }, lastScope(body))), (replaceNames(finalFun, lastRes),
fun ++ lastFun) (body: Expr) => scope(replaceNames(fun, lastScope(body))),
finalFun)
} }
//pure expression (that could still contain side effects as a subexpression) (evaluation order is from left to right) //pure expression (that could still contain side effects as a subexpression) (evaluation order is from left to right)
case Let(id, e, b) => { case Let(id, e, b) => {
val (bindRes, bindScope, bindFun) = toFunction(e)
val (bodyRes, bodyScope, bodyFun) = toFunction(b) val (bodyRes, bodyScope, bodyFun) = toFunction(b)
(bodyRes, (b: Expr) => Let(id, e, bodyScope(b)), bodyFun) (bodyRes,
(b2: Expr) => bindScope(Let(id, replaceNames(bindFun, bindRes), bodyScope(b2))),
bindFun ++ bodyFun)
} }
case LetDef(fd, b) => { case LetDef(fd, b) => {
//Recall that here the nested function should not access mutable variables from an outside scope
val (bodyRes, bodyScope, bodyFun) = toFunction(b) val (bodyRes, bodyScope, bodyFun) = toFunction(b)
(bodyRes, (b: Expr) => LetDef(fd, bodyScope(b)), bodyFun) (bodyRes, (b2: Expr) => LetDef(fd, bodyScope(b2)), bodyFun)
} }
case n @ NAryOperator(Seq(), recons) => (n, (body: Expr) => body, Map())
case n @ NAryOperator(args, recons) => { case n @ NAryOperator(args, recons) => {
val (recArgs, scope, fun) = args.foldRight((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((arg, acc) => { val (recArgs, scope, fun) = args.foldRight((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((arg, acc) => {
val (accArgs, scope, fun) = acc val (accArgs, accScope, accFun) = acc
val (argVal, argScope, argFun) = toFunction(arg) val (argVal, argScope, argFun) = toFunction(arg)
val argInScope = replaceNames(argFun, argVal) val newScope = (body: Expr) => argScope(replaceNames(argFun, accScope(body)))
val newScope = (body: Expr) => argScope(replaceNames(argFun, scope(body))) (argVal +: accArgs, newScope, argFun ++ accFun)
(argInScope +: accArgs, newScope, argFun ++ fun)
}) })
(recons(recArgs), scope, fun) (recons(recArgs), scope, fun)
} }
...@@ -164,11 +189,11 @@ object ImperativeCodeElimination extends Pass { ...@@ -164,11 +189,11 @@ object ImperativeCodeElimination extends Pass {
val lhs = argScope1(replaceNames(argFun1, rhs)) val lhs = argScope1(replaceNames(argFun1, rhs))
lhs lhs
} }
(recons(replaceNames(argFun1, argVal1), replaceNames(argFun2, argVal2)), scope, argFun1 ++ argFun2) (recons(argVal1, argVal2), scope, argFun1 ++ argFun2)
} }
case u @ UnaryOperator(a, recons) => { case u @ UnaryOperator(a, recons) => {
val (argVal, argScope, argFun) = toFunction(a) val (argVal, argScope, argFun) = toFunction(a)
(recons(replaceNames(argFun, argVal)), argScope, argFun) (recons(argVal), argScope, argFun)
} }
case (t: Terminal) => (t, (body: Expr) => body, Map()) case (t: Terminal) => (t, (body: Expr) => body, Map())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment