diff --git a/src/main/scala/leon/utils/InliningPhase.scala b/src/main/scala/leon/utils/InliningPhase.scala index e11fec3f653302f783507c3d3034c435743e005f..62e62a9e2d03744c9d84b6ae872075ec885e407a 100644 --- a/src/main/scala/leon/utils/InliningPhase.scala +++ b/src/main/scala/leon/utils/InliningPhase.scala @@ -11,12 +11,20 @@ import purescala.ExprOps._ import purescala.DefOps._ object InliningPhase extends TransformationPhase { - + val name = "Inline @inline functions" val description = "Inline functions marked as @inline and remove their definitions" - + def apply(ctx: LeonContext, p: Program): Program = { + // Detect inlined functions that are recursive + val doNotInline = (for (fd <- p.definedFunctions.filter(_.flags(IsInlined)) if p.callGraph.isRecursive(fd)) yield { + ctx.reporter.warning("Refusing to inline recursive function '"+fd.id.asString(ctx)+"'!") + fd + }).toSet + + def doInline(fd: FunDef) = fd.flags(IsInlined) && !doNotInline(fd) + def simplifyImplicitClass(e: Expr) = e match { case CaseClassSelector(cct, cc: CaseClass, id) => Some(CaseClassSelector(cct, cc, id)) @@ -31,7 +39,7 @@ object InliningPhase extends TransformationPhase { for (fd <- p.definedFunctions) { fd.fullBody = simplify(preMap { - case FunctionInvocation(TypedFunDef(fd, tps), args) if fd.flags(IsInlined) => + case FunctionInvocation(TypedFunDef(fd, tps), args) if doInline(fd) => val newBody = replaceFromIDs(fd.params.map(_.id).zip(args).toMap, fd.fullBody) Some(instantiateType(newBody, (fd.tparams zip tps).toMap, Map())) case _ => @@ -39,7 +47,7 @@ object InliningPhase extends TransformationPhase { }(fd.fullBody)) } - filterFunDefs(p, fd => !fd.flags(IsInlined)) + filterFunDefs(p, fd => !doInline(fd)) } } diff --git a/src/test/scala/leon/test/LeonTests.scala b/src/test/scala/leon/test/LeonTests.scala index 6e943463e02066c115ae6bf9980c6f66682fc4e3..2b64a137e38517731834cd44e0e5c994660e7358 100644 --- a/src/test/scala/leon/test/LeonTests.scala +++ b/src/test/scala/leon/test/LeonTests.scala @@ -32,6 +32,8 @@ class LeonFunTests extends Suites( new SynthesisSuite, new SynthesisRegressionSuite, + new InliningSuite, + new LibraryVerificationSuite, new PureScalaVerificationSuite, new XLangVerificationSuite diff --git a/src/test/scala/leon/test/purescala/InliningSuite.scala b/src/test/scala/leon/test/purescala/InliningSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..36dded6ff3c8a700df4e5fbc2afbe851d721d25b --- /dev/null +++ b/src/test/scala/leon/test/purescala/InliningSuite.scala @@ -0,0 +1,63 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.test.purescala + +import leon._ +import purescala.Definitions._ +import purescala.DefOps._ +import purescala.Expressions._ +import frontends.scalac._ +import utils._ +import leon.test.LeonTestSuite + +class InliningSuite extends LeonTestSuite { + private def parseProgram(str: String): (Program, LeonContext) = { + val context = createLeonContext() + + val pipeline = + TemporaryInputPhase andThen + ExtractionPhase andThen + PreprocessingPhase + + val program = pipeline.run(context)((str, Nil)) + + (program, context) + } + + test("Simple Inlining") { + val (pgm, ctx) = parseProgram( + """| + |import leon.lang._ + |import leon.annotation._ + | + |object InlineGood { + | + | @inline + | def foo(a: BigInt) = true + | + | def bar(a: BigInt) = foo(a) + | + |} """.stripMargin) + + val bar = pgm.lookup("InlineGood.bar").collect { case fd: FunDef => fd }.get + + assert(bar.fullBody == BooleanLiteral(true), "Function not inlined?") + } + + test("Recursive Inlining") { + val (pgm, ctx) = parseProgram( + """ |import leon.lang._ + |import leon.annotation._ + | + |object InlineBad { + | + | @inline + | def foo(a: BigInt): BigInt = if (a > 42) foo(a-1) else 0 + | + | def bar(a: BigInt) = foo(a) + | + |}""".stripMargin) + + assert(ctx.reporter.warningCount > 0, "Warning received for the invalid inline") + } +}