diff --git a/src/main/scala/leon/CallGraph.scala b/src/main/scala/leon/CallGraph.scala index 212600925c4dee506ea24a20628df4e05a34f028..d37cf16827641642dcc110ab39ed39d943843202 100644 --- a/src/main/scala/leon/CallGraph.scala +++ b/src/main/scala/leon/CallGraph.scala @@ -32,9 +32,9 @@ class CallGraph(val program: Program) { val body = fd.body.get //val cleanBody = hoistIte(expandLets(matchToIfThenElse(body))) val cleanBody = expandLets(matchToIfThenElse(body)) - //println(cleanBody) + //println("Clean Body: " + cleanBody) val subgraph = collectWithPathCondition(cleanBody, FunctionStart(fd)) - //println(subgraph) + //println("Subgraph: " + subgraph) callGraph ++= subgraph }) @@ -46,37 +46,61 @@ class CallGraph(val program: Program) { private def collectWithPathCondition(expression: Expr, startingPoint: ProgramPoint): Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { var callGraph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = Map() - def rec(expr: Expr, path: List[Expr], startingPoint: ProgramPoint): Unit = expr match { - case FunctionInvocation(fd, args) => { - val transitions: Set[(ProgramPoint, TransitionLabel)] = callGraph.get(startingPoint) match { - case None => Set() - case Some(s) => s - } - val newPoint = FunctionStart(fd) - val newTransition = TransitionLabel(And(path.toSeq), fd.args.zip(args).map{ case (VarDecl(id, _), arg) => (id.toVariable, arg) }.toMap) - callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) - args.foreach(arg => rec(arg, path, startingPoint)) + def rec(expr: Expr, path: List[Expr], startingPoint: ProgramPoint): Unit = { + val transitions: Set[(ProgramPoint, TransitionLabel)] = callGraph.get(startingPoint) match { + case None => Set() + case Some(s) => s } - case way@Waypoint(i, e) => { - val transitions: Set[(ProgramPoint, TransitionLabel)] = callGraph.get(startingPoint) match { - case None => Set() - case Some(s) => s + + expr match { + case FunctionInvocation(fd, args) => { + val newPoint = FunctionStart(fd) + val newTransition = TransitionLabel(And(path.toSeq), fd.args.zip(args).map{ case (VarDecl(id, _), arg) => (id.toVariable, arg) }.toMap) + callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + args.foreach(arg => rec(arg, path, startingPoint)) } - val newPoint = ExpressionPoint(way) - val newTransition = TransitionLabel(And(path.toSeq), Map()) - callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) - rec(e, List(), newPoint) - } - case IfExpr(cond, then, elze) => { - rec(cond, path, startingPoint) - rec(then, cond :: path, startingPoint) - rec(elze, Not(cond) :: path, startingPoint) + //this case is actually now handled in the unaryOp case + //case way@Waypoint(i, e) => { + // val newPoint = ExpressionPoint(way) + // val newTransition = TransitionLabel(And(path.toSeq), Map()) + // callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + // rec(e, List(), newPoint) + //} + case IfExpr(cond, then, elze) => { + //rec(cond, path, startingPoint) + rec(then, cond :: path, startingPoint) + rec(elze, Not(cond) :: path, startingPoint) + } + case n@NAryOperator(args, _) => { + val newPoint = ExpressionPoint(n) + val newTransition = TransitionLabel(And(path.toSeq), Map()) + callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + args.foreach(rec(_, List(), newPoint)) + } + case b@BinaryOperator(t1, t2, _) => { + val newPoint = ExpressionPoint(b) + val newTransition = TransitionLabel(And(path.toSeq), Map()) + callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + rec(t1, List(), newPoint) + rec(t2, List(), newPoint) + } + case u@UnaryOperator(t, _) => { + val newPoint = ExpressionPoint(u) + val newTransition = TransitionLabel(And(path.toSeq), Map()) + callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + rec(t, List(), newPoint) + } + case t : Terminal => { + val newPoint = ExpressionPoint(t) + val newTransition = TransitionLabel(And(path.toSeq), Map()) + callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + } + //case NAryOperator(args, _) => args.foreach(rec(_, path, startingPoint)) + //case BinaryOperator(t1, t2, _) => rec(t1, path, startingPoint); rec(t2, path, startingPoint) + //case UnaryOperator(t, _) => rec(t, path, startingPoint) + //case t : Terminal => ; + case _ => scala.sys.error("Unhandled tree in collectWithPathCondition : " + expr) } - case NAryOperator(args, _) => args.foreach(rec(_, path, startingPoint)) - case BinaryOperator(t1, t2, _) => rec(t1, path, startingPoint); rec(t2, path, startingPoint) - case UnaryOperator(t, _) => rec(t, path, startingPoint) - case t : Terminal => ; - case _ => scala.sys.error("Unhandled tree in collectWithPathCondition : " + expr) } rec(expression, List(), startingPoint) @@ -89,20 +113,53 @@ class CallGraph(val program: Program) { if(path.isEmpty) BooleanLiteral(true) else { val (_, _, TransitionLabel(cond, assign)) = path.head val finalCond = assigns.foldRight(cond)((map, acc) => replace(map, acc)) - And(finalCond, pathConstraint(path.tail, assign.asInstanceOf[Map[Expr, Expr]] :: assigns)) + And(finalCond, + pathConstraint( + path.tail, + if(assign.isEmpty) assigns else assign.asInstanceOf[Map[Expr, Expr]] :: assigns + ) + ) } } - def findAllPathes: Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { - val waypoints: Set[ProgramPoint] = programPoints.filter{ case ExpressionPoint(_) => true case _ => false } + def findAllPaths: Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + val waypoints: Set[ProgramPoint] = programPoints.filter{ case ExpressionPoint(Waypoint(_, _)) => true case _ => false } val sortedWaypoints: Seq[ProgramPoint] = waypoints.toSeq.sortWith((p1, p2) => { val (ExpressionPoint(Waypoint(i1, _)), ExpressionPoint(Waypoint(i2, _))) = (p1, p2) i1 <= i2 }) - Set( - sortedWaypoints.zip(sortedWaypoints.tail).foldLeft(Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]())((path, waypoint) => - path ++ findPath(waypoint._1, waypoint._2)) - ) + assert(!sortedWaypoints.isEmpty) + if(sortedWaypoints.size == 1) { //if only one waypoint then we want to cover all static statements starting from the waypoint + findSimplePaths(sortedWaypoints.head) + } else { + Set( + sortedWaypoints.zip(sortedWaypoints.tail).foldLeft(Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]())((path, waypoint) => + path ++ findPath(waypoint._1, waypoint._2)) + ) + } + } + + def findSimplePaths(from: ProgramPoint): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + + def dfs(point: ProgramPoint, path: List[(ProgramPoint, ProgramPoint, TransitionLabel)], visitedPoints: Set[ProgramPoint]): + Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = graph.get(point) match { + case None => Set(path.reverse) + case Some(edges) => { + if(edges.forall((edge: (ProgramPoint, TransitionLabel)) => visitedPoints.contains(edge._1) || point == edge._1)) + Set(path.reverse) + else { + edges.flatMap((edge: (ProgramPoint, TransitionLabel)) => { + val (neighbour, transition) = edge + if(visitedPoints.contains(neighbour) || point == neighbour) + Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]]() + else + dfs(neighbour, (point, neighbour, transition) :: path, visitedPoints + point) + }) + } + } + } + + dfs(from, List(), Set()) } //find a path that goes through all waypoint in order @@ -159,7 +216,7 @@ class CallGraph(val program: Program) { def ppPoint(p: ProgramPoint): String = p match { case FunctionStart(fd) => fd.id.name case ExpressionPoint(Waypoint(i, e)) => "WayPoint " + i - case _ => sys.error("Unexpected programPoint: " + p) + case ExpressionPoint(e) => e.toString } def ppLabel(l: TransitionLabel): String = { val TransitionLabel(cond, assignments) = l