From 1c831e128d836996c32ac4fd004ea94b4357bfc4 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <colder@php.net> Date: Sat, 13 Jul 2013 19:17:49 +0200 Subject: [PATCH] Promote LetDef to purescala --- src/main/scala/leon/Main.scala | 2 +- .../scala/leon/purescala/Extractors.scala | 54 +++++++++ .../FunctionClosure.scala | 19 ++- .../scala/leon/purescala/PrettyPrinter.scala | 10 +- .../scala/leon/purescala/ScalaPrinter.scala | 10 ++ src/main/scala/leon/purescala/TreeOps.scala | 35 +++++- src/main/scala/leon/purescala/Trees.scala | 8 ++ src/main/scala/leon/synthesis/CostModel.scala | 1 - src/main/scala/leon/synthesis/Solution.scala | 4 +- src/main/scala/leon/xlang/TreeOps.scala | 42 ------- src/main/scala/leon/xlang/Trees.scala | 112 +----------------- .../scala/leon/xlang/XlangAnalysisPhase.scala | 2 +- 12 files changed, 126 insertions(+), 173 deletions(-) rename src/main/scala/leon/{xlang => purescala}/FunctionClosure.scala (96%) diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index bd9ee762e..6bec43bb3 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -11,7 +11,7 @@ object Main { xlang.ArrayTransformation, xlang.EpsilonElimination, xlang.ImperativeCodeElimination, - xlang.FunctionClosure, + purescala.FunctionClosure, xlang.XlangAnalysisPhase, synthesis.SynthesisPhase, termination.TerminationPhase, diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 5c2571719..02e0e7c29 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -128,6 +128,60 @@ object Extractors { MatchExpr(es(0), newcases) })) + case LetDef(fd, body) => + fd.body match { + case Some(b) => + (fd.precondition, fd.postcondition) match { + case (None, None) => + Some((Seq(b, body), (as: Seq[Expr]) => { + fd.body = Some(as(0)) + LetDef(fd, as(1)) + })) + case (Some(pre), None) => + Some((Seq(b, body, pre), (as: Seq[Expr]) => { + fd.body = Some(as(0)) + fd.precondition = Some(as(2)) + LetDef(fd, as(1)) + })) + case (None, Some(post)) => + Some((Seq(b, body, post), (as: Seq[Expr]) => { + fd.body = Some(as(0)) + fd.postcondition = Some(as(2)) + LetDef(fd, as(1)) + })) + case (Some(pre), Some(post)) => + Some((Seq(b, body, pre, post), (as: Seq[Expr]) => { + fd.body = Some(as(0)) + fd.precondition = Some(as(2)) + fd.postcondition = Some(as(3)) + LetDef(fd, as(1)) + })) + } + + case None => //case no body, we still need to handle remaining cases + (fd.precondition, fd.postcondition) match { + case (None, None) => + Some((Seq(body), (as: Seq[Expr]) => { + LetDef(fd, as(0)) + })) + case (Some(pre), None) => + Some((Seq(body, pre), (as: Seq[Expr]) => { + fd.precondition = Some(as(1)) + LetDef(fd, as(0)) + })) + case (None, Some(post)) => + Some((Seq(body, post), (as: Seq[Expr]) => { + fd.postcondition = Some(as(1)) + LetDef(fd, as(0)) + })) + case (Some(pre), Some(post)) => + Some((Seq(body, pre, post), (as: Seq[Expr]) => { + fd.precondition = Some(as(1)) + fd.postcondition = Some(as(2)) + LetDef(fd, as(0)) + })) + } + } case (ex: NAryExtractable) => ex.extract case _ => None } diff --git a/src/main/scala/leon/xlang/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala similarity index 96% rename from src/main/scala/leon/xlang/FunctionClosure.scala rename to src/main/scala/leon/purescala/FunctionClosure.scala index 3b6a74394..e4d006e2e 100644 --- a/src/main/scala/leon/xlang/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -1,17 +1,14 @@ /* Copyright 2009-2013 EPFL, Lausanne */ package leon -package xlang - -import leon.TransformationPhase -import leon.LeonContext -import leon.purescala.Common._ -import leon.purescala.Definitions._ -import leon.purescala.Trees._ -import leon.purescala.Extractors._ -import leon.purescala.TreeOps._ -import leon.purescala.TypeTrees._ -import leon.xlang.Trees._ +package purescala + +import Common._ +import Definitions._ +import Trees._ +import Extractors._ +import TreeOps._ +import TypeTrees._ object FunctionClosure extends TransformationPhase { diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index d13fd83e5..147c4635b 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -57,7 +57,6 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { case Variable(id) => sb.append(idToString(id)) case DeBruijnIndex(idx) => sb.append("_" + idx) case LetTuple(bs,d,e) => - //pp(e, pp(d, sb.append("(let (" + b + " := "), lvl).append(") in "), lvl).append(")") sb.append("(let (" + bs.map(idToString _).mkString(",") + " := "); pp(d, lvl) sb.append(") in\n") @@ -66,7 +65,6 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { sb.append(")") case Let(b,d,e) => - //pp(e, pp(d, sb.append("(let (" + b + " := "), lvl).append(") in "), lvl).append(")") sb.append("(let (" + idToString(b) + " := "); pp(d, lvl) sb.append(") in\n") @@ -74,6 +72,14 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { pp(e, lvl+1) sb.append(")") + case LetDef(fd,body) => + sb.append("\n") + pp(fd, lvl+1) + sb.append("\n") + sb.append("\n") + ind(lvl) + pp(body, lvl) + case And(exprs) => ppNary(exprs, "(", " \u2227 ", ")", lvl) // \land case Or(exprs) => ppNary(exprs, "(", " \u2228 ", ")", lvl) // \lor case Not(Equals(l, r)) => ppBinary(l, r, " \u2260 ", lvl) // \neq diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index c0dbb34cf..b1d82854d 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -55,11 +55,21 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb sb.append("}\n") ind(lvl) + case LetDef(fd, body) => + sb.append("{\n") + pp(fd, lvl+1) + sb.append("\n") + sb.append("\n") + ind(lvl) + pp(body, lvl) + sb.append("}\n") + case And(exprs) => ppNary(exprs, "(", " && ", ")", lvl) // \land case Or(exprs) => ppNary(exprs, "(", " || ", ")", lvl) // \lor case Not(Equals(l, r)) => ppBinary(l, r, " != ", lvl) // \neq case UMinus(expr) => ppUnary(expr, "-(", ")", lvl) case Equals(l,r) => ppBinary(l, r, " == ", lvl) + case IntLiteral(v) => sb.append(v) case BooleanLiteral(v) => sb.append(v) case StringLiteral(s) => sb.append("\"" + s + "\"") diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 77ded68f1..45a5ba008 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -11,7 +11,6 @@ object TreeOps { import Common._ import TypeTrees._ import Definitions._ - import xlang.Trees.LetDef import Trees._ import Extractors._ @@ -1441,6 +1440,27 @@ object TreeOps { val sb = rec(b, scope.register(i -> si)) Let(si, se, sb) + case LetDef(fd: FunDef, body: Expr) => + val newId = genId(fd.id, scope) + var newScope = scope.register(fd.id -> newId) + + val newArgs = for(VarDecl(id, tpe) <- fd.args) yield { + val newArg = genId(id, newScope) + newScope = newScope.register(id -> newArg) + VarDecl(newArg, tpe) + } + + val newFd = new FunDef(newId, fd.returnType, newArgs) + + newScope = newScope.registerFunDef(fd -> newFd) + + newFd.body = fd.body.map(b => rec(b, newScope)) + newFd.precondition = fd.precondition.map(pre => rec(pre, newScope)) + newFd.postcondition = fd.postcondition.map(post => rec(post, newScope)) + + + LetDef(newFd, rec(body, newScope)) + case LetTuple(is, e, b) => var newScope = scope val sis = for (i <- is) yield { @@ -1802,4 +1822,17 @@ object TreeOps { false } + def containsLetDef(expr: Expr): Boolean = { + def convert(t : Expr) : Boolean = t match { + case (l : LetDef) => true + case _ => false + } + def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 + def compute(t : Expr, c : Boolean) = t match { + case (l : LetDef) => true + case _ => c + } + treeCatamorphism(convert, combine, compute, expr) + } + } diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index b0a7d3220..f4f18b1b7 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -51,6 +51,14 @@ object Trees { val fixedType = body.getType } + case class LetDef(fd: FunDef, body: Expr) extends Expr { + val et = body.getType + if(et != Untyped) + setType(et) + + } + + /* Control flow */ case class FunctionInvocation(funDef: FunDef, args: Seq[Expr]) extends Expr with FixedType with ScalacPositional { val fixedType = funDef.returnType diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index 128529c25..b4e42a0b1 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -5,7 +5,6 @@ package synthesis import purescala.Trees._ import purescala.TreeOps._ -import leon.xlang.Trees.LetDef import synthesis.search.Cost diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 61a60a2d8..e6a5620cf 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -7,8 +7,6 @@ import purescala.Trees._ import purescala.TypeTrees.{TypeTree,TupleType} import purescala.Definitions._ import purescala.TreeOps._ -import xlang.Trees.LetDef -import xlang.TreeOps.{ScopeSimplifier => XLangScopeSimplifier} import solvers.z3.UninterpretedZ3Solver // Defines a synthesis solution of the form: @@ -45,7 +43,7 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) { simplifyTautologies(uninterpretedZ3)(_), simplifyLets _, rewriteTuples _, - (new ScopeSimplifier with XLangScopeSimplifier).transform _ + (new ScopeSimplifier).transform _ ) simplifiers.foldLeft(toExpr){ (x, sim) => sim(x) } diff --git a/src/main/scala/leon/xlang/TreeOps.scala b/src/main/scala/leon/xlang/TreeOps.scala index 0e8c362aa..4c01aa639 100644 --- a/src/main/scala/leon/xlang/TreeOps.scala +++ b/src/main/scala/leon/xlang/TreeOps.scala @@ -76,46 +76,4 @@ object TreeOps { } searchAndReplaceDFS(applyToTree)(expr) } - - def containsLetDef(expr: Expr): Boolean = { - def convert(t : Expr) : Boolean = t match { - case (l : LetDef) => true - case _ => false - } - def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 - def compute(t : Expr, c : Boolean) = t match { - case (l : LetDef) => true - case _ => c - } - treeCatamorphism(convert, combine, compute, expr) - } - - trait ScopeSimplifier extends purescala.TreeOps.ScopeSimplifier { - override def rec(e: Expr, scope: Scope) = e match { - case LetDef(fd: FunDef, body: Expr) => - val newId = genId(fd.id, scope) - var newScope = scope.register(fd.id -> newId) - - val newArgs = for(VarDecl(id, tpe) <- fd.args) yield { - val newArg = genId(id, newScope) - newScope = newScope.register(id -> newArg) - VarDecl(newArg, tpe) - } - - val newFd = new FunDef(newId, fd.returnType, newArgs) - - newScope = newScope.registerFunDef(fd -> newFd) - - newFd.body = fd.body.map(b => rec(b, newScope)) - newFd.precondition = fd.precondition.map(pre => rec(pre, newScope)) - newFd.postcondition = fd.postcondition.map(post => rec(post, newScope)) - - - LetDef(newFd, rec(body, newScope)) - - case _ => - super.rec(e, scope) - } - } - } diff --git a/src/main/scala/leon/xlang/Trees.scala b/src/main/scala/leon/xlang/Trees.scala index 22cb9e639..866cfa5d9 100644 --- a/src/main/scala/leon/xlang/Trees.scala +++ b/src/main/scala/leon/xlang/Trees.scala @@ -152,117 +152,7 @@ object Trees { } } - case class LetDef(fd: FunDef, body: Expr) extends Expr with NAryExtractable with PrettyPrintable { - val et = body.getType - if(et != Untyped) - setType(et) - - def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = { - fd.body match { - case Some(b) => - (fd.precondition, fd.postcondition) match { - case (None, None) => - Some((Seq(b, body), (as: Seq[Expr]) => { - //val nfd = new FunDef(fd.id, fd.returnType, fd.args) - //nfd.body = Some(as(0)) - //LetDef(nfd, as(1)) - fd.body = Some(as(0)) - LetDef(fd, as(1)) - })) - case (Some(pre), None) => - Some((Seq(b, body, pre), (as: Seq[Expr]) => { - //val nfd = new FunDef(fd.id, fd.returnType, fd.args) - //nfd.body = Some(as(0)) - //nfd.precondition = Some(as(2)) - //LetDef(nfd, as(1)) - fd.body = Some(as(0)) - fd.precondition = Some(as(2)) - LetDef(fd, as(1)) - })) - case (None, Some(post)) => - Some((Seq(b, body, post), (as: Seq[Expr]) => { - //val nfd = new FunDef(fd.id, fd.returnType, fd.args) - //nfd.body = Some(as(0)) - //nfd.postcondition = Some(as(2)) - //LetDef(nfd, as(1)) - fd.body = Some(as(0)) - fd.postcondition = Some(as(2)) - LetDef(fd, as(1)) - })) - case (Some(pre), Some(post)) => - Some((Seq(b, body, pre, post), (as: Seq[Expr]) => { - //val nfd = new FunDef(fd.id, fd.returnType, fd.args) - //nfd.body = Some(as(0)) - //nfd.precondition = Some(as(2)) - //nfd.postcondition = Some(as(3)) - //LetDef(nfd, as(1)) - fd.body = Some(as(0)) - fd.precondition = Some(as(2)) - fd.postcondition = Some(as(3)) - LetDef(fd, as(1)) - })) - } - - case None => //case no body, we still need to handle remaining cases - (fd.precondition, fd.postcondition) match { - case (None, None) => - Some((Seq(body), (as: Seq[Expr]) => { - //val nfd = new FunDef(fd.id, fd.returnType, fd.args) - //LetDef(nfd, as(0)) - LetDef(fd, as(0)) - })) - case (Some(pre), None) => - Some((Seq(body, pre), (as: Seq[Expr]) => { - //val nfd = new FunDef(fd.id, fd.returnType, fd.args) - //nfd.precondition = Some(as(1)) - //LetDef(nfd, as(0)) - fd.precondition = Some(as(1)) - LetDef(fd, as(0)) - })) - case (None, Some(post)) => - Some((Seq(body, post), (as: Seq[Expr]) => { - //val nfd = new FunDef(fd.id, fd.returnType, fd.args) - //nfd.postcondition = Some(as(1)) - //LetDef(nfd, as(0)) - fd.postcondition = Some(as(1)) - LetDef(fd, as(0)) - })) - case (Some(pre), Some(post)) => - Some((Seq(body, pre, post), (as: Seq[Expr]) => { - //val nfd = new FunDef(fd.id, fd.returnType, fd.args) - //nfd.precondition = Some(as(1)) - //nfd.postcondition = Some(as(2)) - //LetDef(nfd, as(0)) - fd.precondition = Some(as(1)) - fd.postcondition = Some(as(2)) - LetDef(fd, as(0)) - })) - } - } - } - - def printWith(lvl: Int, printer: PrettyPrinter) { - printer match { - case _: ScalaPrinter => - printer.append("{\n") - printer.pp(fd, lvl+1) - printer.append("\n") - printer.append("\n") - printer.ind(lvl) - printer.pp(body, lvl) - printer.append("}\n") - case _ => - printer.append("\n") - printer.pp(fd, lvl+1) - printer.append("\n") - printer.append("\n") - printer.ind(lvl) - printer.pp(body, lvl) - } - } - } - - case class Waypoint(i: Int, expr: Expr) extends Expr with UnaryExtractable with PrettyPrintable{ + case class Waypoint(i: Int, expr: Expr) extends Expr with UnaryExtractable with PrettyPrintable { def extract: Option[(Expr, (Expr)=>Expr)] = { Some((expr, (e: Expr) => Waypoint(i, e))) } diff --git a/src/main/scala/leon/xlang/XlangAnalysisPhase.scala b/src/main/scala/leon/xlang/XlangAnalysisPhase.scala index 4fb3c98bc..3ec72ee09 100644 --- a/src/main/scala/leon/xlang/XlangAnalysisPhase.scala +++ b/src/main/scala/leon/xlang/XlangAnalysisPhase.scala @@ -16,7 +16,7 @@ object XlangAnalysisPhase extends LeonPhase[Program, VerificationReport] { val pgm1 = ArrayTransformation(ctx, pgm) val pgm2 = EpsilonElimination(ctx, pgm1) val (pgm3, wasLoop) = ImperativeCodeElimination.run(ctx)(pgm2) - val pgm4 = FunctionClosure.run(ctx)(pgm3) + val pgm4 = purescala.FunctionClosure.run(ctx)(pgm3) def functionWasLoop(fd: FunDef): Boolean = fd.orig match { case None => false //meaning, this was a top level function -- GitLab