diff --git a/src/main/scala/leon/CallGraph.scala b/src/main/scala/leon/CallGraph.scala index 2da4663af10893a833b72f2d8a2a463b66bfe1f8..ad78bfa11bdf81688efc62e53d5f8cab2f93f4bc 100644 --- a/src/main/scala/leon/CallGraph.scala +++ b/src/main/scala/leon/CallGraph.scala @@ -21,6 +21,9 @@ class CallGraph(val program: Program) { case class TransitionLabel(cond: Expr, assignment: Map[Variable, Expr]) private lazy val graph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = buildGraph + private lazy val programPoints: Set[ProgramPoint] = { + graph.flatMap(pair => pair._2.map(edge => edge._1).toSet + pair._1).toSet + } private def buildGraph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { var callGraph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = Map() @@ -81,9 +84,51 @@ class CallGraph(val program: Program) { callGraph } + def findAllPathes: Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + val waypoints: Set[ProgramPoint] = programPoints.filter{ case ExpressionPoint(_) => true case _ => false } + val sortedWaypoints: Seq[ProgramPoint] = waypoints.toSeq.sortWith((p1, p2) => { + val (ExpressionPoint(Waypoint(i1, _)), ExpressionPoint(Waypoint(i2, _))) = (p1, p2) + i1 <= i2 + }) + Set(findPath(sortedWaypoints(0), sortedWaypoints(1))) + } + //find a path that goes through all waypoint in order - //def findPath + def findPath(from: ProgramPoint, to: ProgramPoint): Seq[(ProgramPoint, ProgramPoint, TransitionLabel)] = { + var visitedPoints: Set[ProgramPoint] = Set() + var history: Map[ProgramPoint, (ProgramPoint, TransitionLabel)] = Map() + var toVisit: List[ProgramPoint] = List(from) + var currentPoint: ProgramPoint = null + while(!toVisit.isEmpty && currentPoint != to) { + currentPoint = toVisit.head + if(currentPoint != to) { + visitedPoints += currentPoint + toVisit = toVisit.tail + graph.get(currentPoint).foreach(edges => edges.foreach{ + case (neighbour, transition) => + if(!visitedPoints.contains(neighbour) && !toVisit.contains(neighbour)) { + toVisit ::= neighbour + history += (neighbour -> ((currentPoint, transition))) + } + }) + } + } + + def rebuildPath(point: ProgramPoint, path: List[(ProgramPoint, ProgramPoint, TransitionLabel)]): Seq[(ProgramPoint, ProgramPoint, TransitionLabel)] = { + if(point == from) path else { + val (previousPoint, transition) = history(point) + val newPath = (previousPoint, point, transition) :: path + rebuildPath(previousPoint, newPath) + } + } + + //TODO: handle case where the target node is not found + println(history) + println(from) + println(to) + rebuildPath(to, List()) + } //guarentee that all IfExpr will be at the top level and as soon as you encounter a non-IfExpr, then no more IfExpr can be find in the sub-expressions def hoistIte(expr: Expr): Expr = { diff --git a/src/main/scala/leon/TestGeneration.scala b/src/main/scala/leon/TestGeneration.scala index 421d4dec348235c4b8a456344927861d50ea433d..469c91fdd29ac65d1da286a962b4ef8caa067ca0 100644 --- a/src/main/scala/leon/TestGeneration.scala +++ b/src/main/scala/leon/TestGeneration.scala @@ -18,6 +18,7 @@ class TestGeneration(reporter: Reporter) extends Analyser(reporter) { def analyse(program: Program) { val callGraph = new CallGraph(program) println(callGraph.toDotString) + println(callGraph.findAllPathes) //z3Solver.setProgram(program) //reporter.info("Running test generation") //val allFuns = program.definedFunctions diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 21a9393b142b177d49ba985988512c5dbe885c5f..b7b5bd343c2adc3bf973440776951a124a0ddd08 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -154,6 +154,12 @@ object PrettyPrinter { nsb } + case Waypoint(i, expr) => { + sb.append("waypoint_" + i + "(") + pp(expr, sb, lvl) + sb.append(")") + } + case OptionSome(a) => { var nsb = sb nsb.append("Some(")