Skip to content
Snippets Groups Projects
Commit e1ea834a authored by Régis Blanc's avatar Régis Blanc
Browse files

simplify the function closure pass

parent 43a89884
No related branches found
No related tags found
No related merge requests found
......@@ -9,16 +9,14 @@ object FunctionClosure extends Pass {
val description = "Closing function with its scoping variables"
private var enclosingPreconditions: List[Expr] = Nil
private var pathConstraints: List[Expr] = Nil
private var enclosingLets: List[(Identifier, Expr)] = Nil
private var newFunDefs: Map[FunDef, FunDef] = Map()
def apply(program: Program): Program = {
newFunDefs = Map()
val funDefs = program.definedFunctions
funDefs.foreach(fd => {
enclosingPreconditions = fd.precondition.toList
pathConstraints = fd.precondition.toList
fd.body = fd.body.map(b => functionClosure(b, fd.args.map(_.id).toSet))
fd.postcondition = fd.postcondition.map(b => functionClosure(b, fd.args.map(_.id).toSet))
......@@ -28,58 +26,66 @@ object FunctionClosure extends Pass {
private def functionClosure(expr: Expr, bindedVars: Set[Identifier]): Expr = expr match {
case l @ LetDef(fd, rest) => {
val id = fd.id
val rt = fd.returnType
val varDecl = fd.args
val precondition = fd.precondition
val postcondition = fd.postcondition
val bodyVars: Set[Identifier] = variablesOf(fd.body.getOrElse(BooleanLiteral(true))) ++
variablesOf(precondition.getOrElse(BooleanLiteral(true))) ++
variablesOf(postcondition.getOrElse(BooleanLiteral(true)))
val capturedVars = bodyVars.intersect(bindedVars)// this should be the variable used that are in the scope
val (constraints, allCapturedVars) = filterConstraints(capturedVars) //all relevant path constraints
val capturedVarsWithConstraints = allCapturedVars.toSeq
val freshVars: Map[Identifier, Identifier] = capturedVarsWithConstraints.map(v => (v, FreshIdentifier(v.name).setType(v.getType))).toMap
val freshVarsExpr: Map[Expr, Expr] = freshVars.map(p => (p._1.toVariable, p._2.toVariable))
val extraVarDecls = freshVars.map{ case (_, v2) => VarDecl(v2, v2.getType) }
val newVarDecls = varDecl ++ extraVarDecls
val newFunId = FreshIdentifier(id.name)
val newFunDef = new FunDef(newFunId, rt, newVarDecls).setPosInfo(fd)
//val bodyVars: Set[Identifier] = variablesOf(fd.body.getOrElse(BooleanLiteral(true))) ++
// variablesOf(precondition.getOrElse(BooleanLiteral(true))) ++
// variablesOf(postcondition.getOrElse(BooleanLiteral(true)))
//val capturedVars = bodyVars.intersect(bindedVars)// this should be the variable used that are in the scope
//val capturedLets = enclosingLets.filter(let => capturedVars.contains(let._1))
//val (constraints, allCapturedVars) = filterConstraints(capturedVars) //all relevant path constraints
//val capturedVarsWithConstraints = allCapturedVars.toSeq
/* let's just take everything for now */
val capturedVars = bindedVars.toSeq
val capturedLets = enclosingLets
val capturedConstraints = pathConstraints
val freshIds: Map[Identifier, Identifier] = capturedVars.map(v => (v, FreshIdentifier(v.name).setType(v.getType))).toMap
val freshVars: Map[Expr, Expr] = freshIds.map(p => (p._1.toVariable, p._2.toVariable))
val extraVarDeclIds = capturedVars.toSet.diff(capturedLets.map(p => p._1).toSet).toSeq
val extraVarDecls = extraVarDeclIds.map(id => VarDecl(freshIds(id), id.getType))
val newVarDecls = fd.args ++ extraVarDecls
val newFunId = FreshIdentifier(fd.id.name)
val newFunDef = new FunDef(newFunId, fd.returnType, newVarDecls).setPosInfo(fd)
newFunDef.fromLoop = fd.fromLoop
newFunDef.parent = fd.parent
newFunDef.addAnnotation(fd.annotations.toSeq:_*)
val freshPrecondition = precondition.map(expr => replace(freshVarsExpr, expr))
val freshPostcondition = postcondition.map(expr => replace(freshVarsExpr, expr))
val freshBody = fd.body.map(b => replace(freshVarsExpr, b))
val freshConstraints = constraints.map(expr => replace(freshVarsExpr, expr))
val freshPrecondition = fd.precondition.map(expr => replace(freshVars, expr))
val freshPostcondition = fd.postcondition.map(expr => replace(freshVars, expr))
val freshBody = fd.body.map(expr => replace(freshVars, expr))
val freshConstraints = capturedConstraints.map(expr => replace(freshVars, expr))
val freshLets = capturedLets.map{ case (i, v) => (freshIds(i), replace(freshVars, v)) }
def substFunInvocInDef(expr: Expr): Option[Expr] = expr match {
case fi@FunctionInvocation(fd, args) if fd.id == id => Some(FunctionInvocation(newFunDef, args ++ extraVarDecls.map(_.id.toVariable)).setPosInfo(fi))
case fi@FunctionInvocation(fd2, args) if fd2.id == fd.id => Some(FunctionInvocation(newFunDef, args ++ extraVarDeclIds.map(_.toVariable)).setPosInfo(fi))
case _ => None
}
val oldPathConstraints = pathConstraints
pathConstraints = (precondition.getOrElse(BooleanLiteral(true)) :: pathConstraints).map(e => replace(freshVarsExpr, e))
val recPrecondition = freshConstraints match { //Actually, we do not allow nested fundef in precondition
val oldEnclosingLets = enclosingLets
pathConstraints = (fd.precondition.getOrElse(BooleanLiteral(true)) :: pathConstraints).map(e => replace(freshVars, e))
enclosingLets = enclosingLets.map{ case (i, v) => (freshIds(i), replace(freshVars, v)) }
val recPrecondition = freshConstraints match {
case List() => freshPrecondition
case precs => Some(And(freshPrecondition.getOrElse(BooleanLiteral(true)) +: precs))
}
val recBody = freshBody.map(b =>
functionClosure(b, bindedVars ++ newVarDecls.map(_.id))
).map(b => searchAndReplaceDFS(substFunInvocInDef)(b))
val finalBody = recBody.map(b => freshLets.foldLeft(b){ case (bacc, (i, v)) => Let(i, v, bacc) })
pathConstraints = oldPathConstraints
enclosingLets = oldEnclosingLets
newFunDef.precondition = recPrecondition
newFunDef.body = recBody
newFunDef.body = finalBody
newFunDef.postcondition = freshPostcondition
def substFunInvocInRest(expr: Expr): Option[Expr] = expr match {
case fi@FunctionInvocation(fd, args) if fd.id == id => Some(FunctionInvocation(newFunDef, args ++ capturedVarsWithConstraints.map(_.toVariable)).setPosInfo(fi))
case fi@FunctionInvocation(fd2, args) if fd2.id == fd.id => Some(FunctionInvocation(newFunDef, args ++ extraVarDeclIds.map(_.toVariable)).setPosInfo(fi))
case _ => None
}
val recRest = searchAndReplaceDFS(substFunInvocInRest)(functionClosure(rest, bindedVars))
......@@ -87,13 +93,22 @@ object FunctionClosure extends Pass {
}
case l @ Let(i,e,b) => {
val re = functionClosure(e, bindedVars)
pathConstraints ::= Equals(Variable(i), re)
enclosingLets ::= (i, re)
val rb = functionClosure(b, bindedVars + i)
pathConstraints = pathConstraints.tail
enclosingLets = enclosingLets.tail
Let(i, re, rb).setType(l.getType)
}
case i @ IfExpr(cond,then,elze) => {
val rCond = functionClosure(cond, bindedVars)
pathConstraints ::= rCond
val rThen = functionClosure(then, bindedVars)
pathConstraints = pathConstraints.tail
pathConstraints ::= Not(rCond)
val rElze = functionClosure(elze, bindedVars)
pathConstraints = pathConstraints.tail
IfExpr(rCond, rThen, rElze).setType(i.getType)
}
case n @ NAryOperator(args, recons) => {
var change = false
val rargs = args.map(a => functionClosure(a, bindedVars))
recons(rargs).setType(n.getType)
}
......@@ -106,16 +121,6 @@ object FunctionClosure extends Pass {
val r = functionClosure(t, bindedVars)
recons(r).setType(u.getType)
}
case i @ IfExpr(cond,then,elze) => {
val rCond = functionClosure(cond, bindedVars)
pathConstraints ::= rCond
val rThen = functionClosure(then, bindedVars)
pathConstraints = pathConstraints.tail
pathConstraints ::= Not(rCond)
val rElze = functionClosure(elze, bindedVars)
pathConstraints = pathConstraints.tail
IfExpr(rCond, rThen, rElze).setType(i.getType)
}
case m @ MatchExpr(scrut,cses) => { //TODO: will not work if there are actual nested function in cases
//val rScrut = functionClosure(scrut, bindedVars)
m
......
......@@ -2,44 +2,42 @@ import leon.Utils._
/* The calculus of Computation textbook */
object Bubble {
object BubbleSort {
def sort(a: Map[Int, Int], size: Int): Map[Int, Int] = ({
require(size < 5 && isArray(a, size))
var i = size - 1
def sort(a: Array[Int]): Array[Int] = ({
require(a.length >= 5)
var i = a.length - 1
var j = 0
var sortedArray = a
val sa = a
(while(i > 0) {
j = 0
(while(j < i) {
if(sortedArray(j) > sortedArray(j+1)) {
val tmp = sortedArray(j)
sortedArray = sortedArray.updated(j, sortedArray(j+1))
sortedArray = sortedArray.updated(j+1, tmp)
}
if(sa(j) > sa(j+1)) {
val tmp = sa(j)
sa(j) = sa(j+1)
sa(j+1) = tmp
} else 0
j = j + 1
}) invariant(
j >= 0 &&
j <= i &&
i < size &&
isArray(sortedArray, size) &&
partitioned(sortedArray, size, 0, i, i+1, size-1) &&
sorted(sortedArray, size, i, size-1) &&
partitioned(sortedArray, size, 0, j-1, j, j)
i < sa.length &&
partitioned(sa, 0, i, i+1, sa.length-1) &&
sorted(sa, i, sa.length-1) &&
partitioned(sa, 0, j-1, j, j)
)
i = i - 1
}) invariant(
i >= 0 &&
i < size &&
isArray(sortedArray, size) &&
partitioned(sortedArray, size, 0, i, i+1, size-1) &&
sorted(sortedArray, size, i, size-1)
i < sa.length &&
partitioned(sa, 0, i, i+1, sa.length-1) &&
sorted(sa, i, sa.length-1)
)
sortedArray
}) ensuring(res => sorted(res, size, 0, size-1))
sa
}) ensuring(res => sorted(res, 0, a.length-1))
def sorted(a: Map[Int, Int], size: Int, l: Int, u: Int): Boolean = {
require(isArray(a, size) && size < 5 && l >= 0 && u < size && l <= u)
def sorted(a: Array[Int], l: Int, u: Int): Boolean = {
require(a.length >= 0 && l >= 0 && u < a.length && l <= u)
var k = l
var isSorted = true
(while(k < u) {
......@@ -49,57 +47,9 @@ object Bubble {
}) invariant(k <= u && k >= l)
isSorted
}
/*
// --------------------- sorted --------------------
def sorted(a: Map[Int,Int], size: Int, l: Int, u: Int) : Boolean = {
require(isArray(a, size) && size < 5 && l >= 0 && l <= u && u < size)
val t = sortedWhile(true, l, l, u, a, size)
t._1
}
def sortedWhile(isSorted: Boolean, k: Int, l: Int, u: Int, a: Map[Int,Int], size: Int) : (Boolean, Int) = {
require(isArray(a, size) && size < 5 && l >= 0 && l <= u && u < size && k >= l && k <= u)
if(k < u) {
sortedWhile(if(a(k) > a(k + 1)) false else isSorted, k + 1, l, u, a, size)
} else (isSorted, k)
}
*/
/*
// ------------- partitioned ------------------
def partitioned(a: Map[Int,Int], size: Int, l1: Int, u1: Int, l2: Int, u2: Int) : Boolean = {
require(isArray(a, size) && size < 5 && l1 >= 0 && u1 < l2 && u2 < size)
if(l2 > u2 || l1 > u1)
true
else {
val t = partitionedWhile(l2, true, l1, l1, size, u2, l2, u1, a)
t._2
}
}
def partitionedWhile(j: Int, isPartitionned: Boolean, i: Int, l1: Int, size: Int, u2: Int, l2: Int, u1: Int, a: Map[Int,Int]) : (Int, Boolean, Int) = {
require(isArray(a, size) && size < 5 && l1 >= 0 && l1 <= u1 && u1 < l2 && l2 <= u2 && u2 < size && i >= l1)
if(i <= u1) {
val t = partitionedNestedWhile(isPartitionned, l2, i, l1, u1, size, u2, a, l2)
partitionedWhile(t._2, t._1, i + 1, l1, size, u2, l2, u1, a)
} else (j, isPartitionned, i)
}
def partitionedNestedWhile(isPartitionned: Boolean, j: Int, i: Int, l1: Int, u1: Int, size: Int, u2: Int, a: Map[Int,Int], l2: Int): (Boolean, Int) = {
require(isArray(a, size) && size < 5 && l1 >= 0 && l1 <= u1 && u1 < l2 && l2 <= u2 && u2 < size && j >= l2 && i >= l1 && i <= u1)
if (j <= u2) {
partitionedNestedWhile(
(if (a(i) > a(j))
false
else
isPartitionned),
j + 1, i, l1, u1, size, u2, a, l2)
} else (isPartitionned, j)
}
*/
def partitioned(a: Map[Int, Int], size: Int, l1: Int, u1: Int, l2: Int, u2: Int): Boolean = {
require(l1 >= 0 && u1 < l2 && u2 < size && isArray(a, size) && size < 5)
def partitioned(a: Array[Int], l1: Int, u1: Int, l2: Int, u2: Int): Boolean = {
require(a.length >= 0 && l1 >= 0 && u1 < l2 && u2 < a.length)
if(l2 > u2 || l1 > u1)
true
else {
......@@ -119,14 +69,4 @@ object Bubble {
}
}
def isArray(a: Map[Int, Int], size: Int): Boolean = {
def rec(i: Int): Boolean = if(i >= size) true else {
if(a.isDefinedAt(i)) rec(i+1) else false
}
if(size <= 0)
false
else
rec(0)
}
}
object Nested2 {
def foo(a: Int): Int = {
require(a >= 0)
val b = a + 2
def rec1(c: Int): Int = {
require(c >= 0)
b + c
}
rec1(2)
} ensuring(_ > 0)
}
object Nested3 {
def foo(a: Int): Int = {
require(a >= 0 && a <= 50)
val b = a + 2
val c = a + b
def rec1(d: Int): Int = {
require(d >= 0 && d <= 50)
val e = d + b + c
e
}
rec1(2)
} ensuring(_ > 0)
}
object Nested4 {
def foo(a: Int, a2: Int): Int = {
require(a >= 0 && a <= 50)
val b = a + 2
val c = a + b
if(a2 > a) {
def rec1(d: Int): Int = {
require(d >= 0 && d <= 50)
val e = d + b + c + a2
e
} ensuring(_ > 0)
rec1(2)
} else {
5
}
} ensuring(_ > 0)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment