diff --git a/src/main/scala/leon/purescala/CallGraph.scala b/src/main/scala/leon/purescala/CallGraph.scala index 63886dfd27c2eaf8af28bd993ab26de98cc9e873..a07b807ddaad26817322878c46d575f0a4ec1306 100644 --- a/src/main/scala/leon/purescala/CallGraph.scala +++ b/src/main/scala/leon/purescala/CallGraph.scala @@ -69,13 +69,20 @@ class CallGraph(p: Program) { transitiveClosure() } + private def collectCallsInPats(fd: FunDef)(p: Pattern): Set[(FunDef, FunDef)] = + (p match { + case u: UnapplyPattern => Set((fd, u.unapplyFun.fd)) + case _ => Set() + }) ++ p.subPatterns.flatMap(collectCallsInPats(fd)) + private def collectCalls(fd: FunDef)(e: Expr): Set[(FunDef, FunDef)] = e match { case f @ FunctionInvocation(f2, _) => Set((fd, f2.fd)) + case MatchExpr(_, cases) => cases.toSet.flatMap((mc: MatchCase) => collectCallsInPats(fd)(mc.pattern)) case _ => Set() } private def scanForCalls(fd: FunDef) { - for( (from, to) <- collect(collectCalls(fd)(_))(fd.fullBody) ) { + for( (from, to) <- collect(collectCalls(fd))(fd.fullBody) ) { _calls += (from -> to) _callees += (from -> (_callees.getOrElse(from, Set()) + to)) _callers += (to -> (_callers.getOrElse(to, Set()) + from)) diff --git a/src/test/scala/leon/integration/purescala/CallGraphSuite.scala b/src/test/scala/leon/integration/purescala/CallGraphSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..7c47baa4b1abf2ec0c6493503fd7c944983c8c50 --- /dev/null +++ b/src/test/scala/leon/integration/purescala/CallGraphSuite.scala @@ -0,0 +1,29 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.integration.purescala + +import leon.test._ + +import leon._ +import leon.purescala.Definitions._ +import leon.utils._ + +class CallGraphSuite extends LeonTestSuiteWithProgram with helpers.ExpressionsDSL { + + val sources = List( + """object Matches { + | import leon.collection._ + | def aMatch(a: List[Int]) = a match { + | case _ :: _ => 0 + | } + |}""".stripMargin + ) + + test("CallGraph tracks dependency to unapply pattern") { implicit fix => + val fd1 = funDef("Matches.aMatch") + val fd2 = funDef("leon.collection.::.unapply") + + assert(implicitly[Program].callGraph.calls(fd1, fd2)) + } + +}