diff --git a/src/main/scala/leon/FunctionClosure.scala b/src/main/scala/leon/FunctionClosure.scala index 4d76cd8b5e714b9799cc90b9ca0b6802de116ceb..33170247429c169a5023fff833e2b96ff61c37bf 100644 --- a/src/main/scala/leon/FunctionClosure.scala +++ b/src/main/scala/leon/FunctionClosure.scala @@ -13,7 +13,6 @@ object FunctionClosure extends Pass { private var pathConstraints: List[Expr] = Nil private var newFunDefs: Map[FunDef, FunDef] = Map() - //private var def apply(program: Program): Program = { newFunDefs = Map() @@ -22,6 +21,7 @@ object FunctionClosure extends Pass { enclosingPreconditions = fd.precondition.toList pathConstraints = fd.precondition.toList fd.body = fd.body.map(b => functionClosure(b, fd.args.map(_.id).toSet)) + fd.postcondition = fd.postcondition.map(b => functionClosure(b, fd.args.map(_.id).toSet)) }) program } @@ -35,10 +35,9 @@ object FunctionClosure extends Pass { val precondition = fd.precondition val postcondition = fd.postcondition - val bodyVars: Set[Identifier] = (fd.body match { - case Some(body) => variablesOf(body) - case None => Set() - }) ++ variablesOf(precondition.getOrElse(BooleanLiteral(true))) ++ variablesOf(postcondition.getOrElse(BooleanLiteral(true))) + val bodyVars: Set[Identifier] = variablesOf(fd.body.getOrElse(BooleanLiteral(true))) ++ + variablesOf(precondition.getOrElse(BooleanLiteral(true))) ++ + variablesOf(postcondition.getOrElse(BooleanLiteral(true))) val capturedVars = bodyVars.intersect(bindedVars)// this should be the variable used that are in the scope val (constraints, allCapturedVars) = filterConstraints(capturedVars) //all relevant path constraints @@ -55,32 +54,38 @@ object FunctionClosure extends Pass { newFunDef.parent = fd.parent val freshPrecondition = precondition.map(expr => replace(freshVarsExpr, expr)) + val freshPostcondition = postcondition.map(expr => replace(freshVarsExpr, expr)) + val freshBody = fd.body.map(b => replace(freshVarsExpr, b)) val freshConstraints = constraints.map(expr => replace(freshVarsExpr, expr)) - newFunDef.precondition = freshConstraints match { - case List() => freshPrecondition - case precs => Some(And(freshPrecondition.getOrElse(BooleanLiteral(true)) +: precs)) - } - newFunDef.postcondition = postcondition.map(expr => replace(freshVarsExpr, expr)) def substFunInvocInDef(expr: Expr): Option[Expr] = expr match { case fi@FunctionInvocation(fd, args) if fd.id == id => Some(FunctionInvocation(newFunDef, args ++ extraVarDecls.map(_.id.toVariable)).setPosInfo(fi)) case _ => None } - val freshBody = fd.body.map(b => replace(freshVarsExpr, b)) val oldPathConstraints = pathConstraints pathConstraints = (precondition.getOrElse(BooleanLiteral(true)) :: pathConstraints).map(e => replace(freshVarsExpr, e)) - val recBody = freshBody.map(b => functionClosure(b, bindedVars ++ newVarDecls.map(_.id))) + val recPrecondition = freshConstraints match { //Actually, we do not allow nested fundef in precondition + case List() => freshPrecondition + case precs => Some(And(freshPrecondition.getOrElse(BooleanLiteral(true)) +: precs)) + } + val recBody = freshBody.map(b => + functionClosure(b, bindedVars ++ newVarDecls.map(_.id)) + ).map(b => searchAndReplaceDFS(substFunInvocInDef)(b)) + val recPostcondition = freshPostcondition.map(expr => + functionClosure(expr, bindedVars ++ newVarDecls.map(_.id)) + ).map(expr => searchAndReplaceDFS(substFunInvocInDef)(expr)) pathConstraints = oldPathConstraints - val recBody2 = recBody.map(b => searchAndReplaceDFS(substFunInvocInDef)(b)) - newFunDef.body = recBody2 + + newFunDef.precondition = recPrecondition + newFunDef.body = recBody + newFunDef.postcondition = recPostcondition def substFunInvocInRest(expr: Expr): Option[Expr] = expr match { case fi@FunctionInvocation(fd, args) if fd.id == id => Some(FunctionInvocation(newFunDef, args ++ capturedVarsWithConstraints.map(_.toVariable)).setPosInfo(fi)) case _ => None } - val recRest = functionClosure(rest, bindedVars) - val recRest2 = searchAndReplaceDFS(substFunInvocInRest)(recRest) - LetDef(newFunDef, recRest2).setType(l.getType) + val recRest = searchAndReplaceDFS(substFunInvocInRest)(functionClosure(rest, bindedVars)) + LetDef(newFunDef, recRest).setType(l.getType) } case l @ Let(i,e,b) => { val re = functionClosure(e, bindedVars) diff --git a/src/main/scala/leon/FunctionHoisting.scala b/src/main/scala/leon/FunctionHoisting.scala index 2e92ef23568783deb28971ea9d1a5c0bdbe2e34e..483d53d7aeed2a7dfaafb3296b2a0f5ef3a4b7a4 100644 --- a/src/main/scala/leon/FunctionHoisting.scala +++ b/src/main/scala/leon/FunctionHoisting.scala @@ -12,11 +12,25 @@ object FunctionHoisting extends Pass { def apply(program: Program): Program = { val funDefs = program.definedFunctions var topLevelFuns: Set[FunDef] = Set() - funDefs.foreach(fd => fd.body.map(body => { - val (newBody, additionalTopLevelFun) = hoist(body) - fd.body = Some(newBody) - topLevelFuns ++= additionalTopLevelFun - })) + funDefs.foreach(fd => { + val s2 = fd.body match { + case Some(body) => { + val (e2, s2) = hoist(body) + fd.body = Some(e2) + s2 + } + case None => Set() + } + val s4 = fd.postcondition match { + case Some(expr) => { + val (e4, s4) = hoist(expr) + fd.postcondition = Some(e4) + s4 + } + case None => Set() + } + topLevelFuns ++= (s2 ++ s4) + }) val Program(id, ObjectDef(objId, defs, invariants)) = program Program(id, ObjectDef(objId, defs ++ topLevelFuns, invariants)) } @@ -32,7 +46,15 @@ object FunctionHoisting extends Pass { } case None => Set() } - (e, (s ++ s2) + fd) + val s4 = fd.postcondition match { + case Some(expr) => { + val (e4, s4) = hoist(expr) + fd.postcondition = Some(e4) + s4 + } + case None => Set() + } + (e, (s ++ s2 ++ s4) + fd) } case l @ Let(i,e,b) => { val (re, s1) = hoist(e) diff --git a/src/main/scala/leon/PassManager.scala b/src/main/scala/leon/PassManager.scala index f2830df79bfa129ce93d58157110a2107453595a..5381b7bee6a389ac120477c8316fcf1e97ccdc62 100644 --- a/src/main/scala/leon/PassManager.scala +++ b/src/main/scala/leon/PassManager.scala @@ -6,9 +6,9 @@ class PassManager(passes: Seq[Pass]) { def run(program: Program): Program = { passes.foldLeft(program)((pgm, pass) => { - //println("Running Pass: " + pass.description) + println("Running Pass: " + pass.description) val newPgm = pass(pgm) - //println("Resulting program: " + newPgm) + println("Resulting program: " + newPgm) newPgm }) } diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index c6757f2cb87a6e2ee0d12403b788a3589ef8da71..ff6e84b38af47fe7db035c3be2fd7b95b96e9e81 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -285,6 +285,11 @@ trait CodeExtraction extends Extractors { case e: ImpureCodeEncounteredException => None } + reqCont.map(e => + if(containsLetDef(e)) { + unit.error(realBody.pos, "Function precondtion should not contain nested function definition") + throw ImpureCodeEncounteredException(realBody) + }) funDef.body = bodyAttempt funDef.precondition = reqCont funDef.postcondition = ensCont @@ -414,6 +419,11 @@ trait CodeExtraction extends Extractors { case e: ImpureCodeEncounteredException => None } + reqCont.map(e => + if(containsLetDef(e)) { + unit.error(realBody.pos, "Function precondtion should not contain nested function definition") + throw ImpureCodeEncounteredException(realBody) + }) funDef.body = bodyAttempt funDef.precondition = reqCont funDef.postcondition = ensCont diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 6155d5da99d76f66a97dc6425690dcc26b8f8f02..814a9ea33896df2e0723706980f860808586319f 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -753,7 +753,10 @@ object Trees { def rec(expr: Expr) : A = expr match { case l @ Let(_, e, b) => compute(l, combine(rec(e), rec(b))) case l @ LetVar(_, e, b) => compute(l, combine(rec(e), rec(b))) - case l @ LetDef(fd, b) => compute(l, combine(rec(fd.getBody), rec(b))) //TODO, still not sure about the semantic + case l @ LetDef(fd, b) => {//TODO, still not sure about the semantic + val exprs: Seq[Expr] = fd.precondition.toSeq ++ fd.body.toSeq ++ fd.postcondition.toSeq ++ Seq(b) + compute(l, exprs.map(rec(_)).reduceLeft(combine)) + } case n @ NAryOperator(args, _) => { if(args.size == 0) compute(n, convert(n)) @@ -812,6 +815,19 @@ object Trees { treeCatamorphism(convert, combine, compute, expr) } + def containsLetDef(expr: Expr): Boolean = { + def convert(t : Expr) : Boolean = t match { + case (l : LetDef) => true + case _ => false + } + def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 + def compute(t : Expr, c : Boolean) = t match { + case (l : LetDef) => true + case _ => c + } + treeCatamorphism(convert, combine, compute, expr) + } + def variablesOf(expr: Expr) : Set[Identifier] = { def convert(t: Expr) : Set[Identifier] = t match { case Variable(i) => Set(i)