From 463524185e3646f4d3e2f145c7b9377fcb493762 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Tue, 30 Jun 2015 16:16:16 +0200 Subject: [PATCH] Make sure we don't inline recursive functions, test inlinings --- src/main/scala/leon/utils/InliningPhase.scala | 16 +++-- src/test/scala/leon/test/LeonTests.scala | 2 + .../leon/test/purescala/InliningSuite.scala | 63 +++++++++++++++++++ 3 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 src/test/scala/leon/test/purescala/InliningSuite.scala diff --git a/src/main/scala/leon/utils/InliningPhase.scala b/src/main/scala/leon/utils/InliningPhase.scala index e11fec3f6..62e62a9e2 100644 --- a/src/main/scala/leon/utils/InliningPhase.scala +++ b/src/main/scala/leon/utils/InliningPhase.scala @@ -11,12 +11,20 @@ import purescala.ExprOps._ import purescala.DefOps._ object InliningPhase extends TransformationPhase { - + val name = "Inline @inline functions" val description = "Inline functions marked as @inline and remove their definitions" - + def apply(ctx: LeonContext, p: Program): Program = { + // Detect inlined functions that are recursive + val doNotInline = (for (fd <- p.definedFunctions.filter(_.flags(IsInlined)) if p.callGraph.isRecursive(fd)) yield { + ctx.reporter.warning("Refusing to inline recursive function '"+fd.id.asString(ctx)+"'!") + fd + }).toSet + + def doInline(fd: FunDef) = fd.flags(IsInlined) && !doNotInline(fd) + def simplifyImplicitClass(e: Expr) = e match { case CaseClassSelector(cct, cc: CaseClass, id) => Some(CaseClassSelector(cct, cc, id)) @@ -31,7 +39,7 @@ object InliningPhase extends TransformationPhase { for (fd <- p.definedFunctions) { fd.fullBody = simplify(preMap { - case FunctionInvocation(TypedFunDef(fd, tps), args) if fd.flags(IsInlined) => + case FunctionInvocation(TypedFunDef(fd, tps), args) if doInline(fd) => val newBody = replaceFromIDs(fd.params.map(_.id).zip(args).toMap, fd.fullBody) Some(instantiateType(newBody, (fd.tparams zip tps).toMap, Map())) case _ => @@ -39,7 +47,7 @@ object InliningPhase extends TransformationPhase { }(fd.fullBody)) } - filterFunDefs(p, fd => !fd.flags(IsInlined)) + filterFunDefs(p, fd => !doInline(fd)) } } diff --git a/src/test/scala/leon/test/LeonTests.scala b/src/test/scala/leon/test/LeonTests.scala index 6e943463e..2b64a137e 100644 --- a/src/test/scala/leon/test/LeonTests.scala +++ b/src/test/scala/leon/test/LeonTests.scala @@ -32,6 +32,8 @@ class LeonFunTests extends Suites( new SynthesisSuite, new SynthesisRegressionSuite, + new InliningSuite, + new LibraryVerificationSuite, new PureScalaVerificationSuite, new XLangVerificationSuite diff --git a/src/test/scala/leon/test/purescala/InliningSuite.scala b/src/test/scala/leon/test/purescala/InliningSuite.scala new file mode 100644 index 000000000..36dded6ff --- /dev/null +++ b/src/test/scala/leon/test/purescala/InliningSuite.scala @@ -0,0 +1,63 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.test.purescala + +import leon._ +import purescala.Definitions._ +import purescala.DefOps._ +import purescala.Expressions._ +import frontends.scalac._ +import utils._ +import leon.test.LeonTestSuite + +class InliningSuite extends LeonTestSuite { + private def parseProgram(str: String): (Program, LeonContext) = { + val context = createLeonContext() + + val pipeline = + TemporaryInputPhase andThen + ExtractionPhase andThen + PreprocessingPhase + + val program = pipeline.run(context)((str, Nil)) + + (program, context) + } + + test("Simple Inlining") { + val (pgm, ctx) = parseProgram( + """| + |import leon.lang._ + |import leon.annotation._ + | + |object InlineGood { + | + | @inline + | def foo(a: BigInt) = true + | + | def bar(a: BigInt) = foo(a) + | + |} """.stripMargin) + + val bar = pgm.lookup("InlineGood.bar").collect { case fd: FunDef => fd }.get + + assert(bar.fullBody == BooleanLiteral(true), "Function not inlined?") + } + + test("Recursive Inlining") { + val (pgm, ctx) = parseProgram( + """ |import leon.lang._ + |import leon.annotation._ + | + |object InlineBad { + | + | @inline + | def foo(a: BigInt): BigInt = if (a > 42) foo(a-1) else 0 + | + | def bar(a: BigInt) = foo(a) + | + |}""".stripMargin) + + assert(ctx.reporter.warningCount > 0, "Warning received for the invalid inline") + } +} -- GitLab