diff --git a/src/main/scala/leon/CallGraph.scala b/src/main/scala/leon/CallGraph.scala index 6a086805f203d71cec7fc629ef4b64610830bafe..f0c37c3c9056b9de27d4f977e024c06073acc89a 100644 --- a/src/main/scala/leon/CallGraph.scala +++ b/src/main/scala/leon/CallGraph.scala @@ -53,7 +53,7 @@ class CallGraph(val program: Program) { graph.find{ case (point, edges) => { edges.exists{ - case edge@(p2, TransitionLabel(BooleanLiteral(true), assign)) if assign.isEmpty => { + case edge@(p2@ExpressionPoint(e, _), TransitionLabel(BooleanLiteral(true), assign)) if assign.isEmpty && !e.isInstanceOf[Waypoint] => { val edgesOfPoint: Set[(ProgramPoint, TransitionLabel)] = graph.get(p2).getOrElse(Set()) //should be unique entry point and cannot be a FunctionStart newGraph += (point -> ((edges - edge) ++ edgesOfPoint)) newGraph -= p2 @@ -160,7 +160,7 @@ class CallGraph(val program: Program) { fd.annotations.exists(_ == "main") } - def findAllPaths: Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + def findAllPaths(z3Solver: FairZ3Solver): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { val waypoints: Set[ProgramPoint] = programPoints.filter{ case ExpressionPoint(Waypoint(_, _), _) => true case _ => false } val sortedWaypoints: Seq[ProgramPoint] = waypoints.toSeq.sortWith((p1, p2) => { val (ExpressionPoint(Waypoint(i1, _), _), ExpressionPoint(Waypoint(i2, _), _)) = (p1, p2) @@ -170,25 +170,72 @@ class CallGraph(val program: Program) { val functionPoints: Set[ProgramPoint] = programPoints.flatMap{ case f@FunctionStart(fd) => Set[ProgramPoint](f) case _ => Set[ProgramPoint]() } val mainPoint: Option[ProgramPoint] = functionPoints.find{ case FunctionStart(fd) => isMain(fd) case p => sys.error("unexpected: " + p) } - assert(mainPoint != None || sortedWaypoints.size > 1) + assert(mainPoint != None) - if(mainPoint != None) { + if(sortedWaypoints.size == 0) { findSimplePaths(mainPoint.get) } else { - Set( - sortedWaypoints.zip(sortedWaypoints.tail).foldLeft(Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]())((path, waypoint) => - path ++ findPath(waypoint._1, waypoint._2)) - ) + visitAllWaypoints(mainPoint.get :: sortedWaypoints.toList, z3Solver) match { + case None => Set() + case Some(p) => Set(p) + } + //Set( + // sortedWaypoints.zip(sortedWaypoints.tail).foldLeft(Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]())((path, waypoint) => + // path ++ findPath(waypoint._1, waypoint._2)) + //) } } - def findSimplePaths(from: ProgramPoint): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + def visitAllWaypoints(waypoints: List[ProgramPoint], z3Solver: FairZ3Solver): Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + def rec(head: ProgramPoint, tail: List[ProgramPoint], path: Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]): + Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { + tail match { + case Nil => Some(path) + case x::xs => { + val allPaths = findSimplePaths(head, Some(x)) + var completePath: Option[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = None + allPaths.find(intermediatePath => { + val pc = pathConstraint(path ++ intermediatePath) + z3Solver.init() + z3Solver.restartZ3 + + var testcase: Option[Map[Identifier, Expr]] = None + + val (solverResult, model) = z3Solver.decideWithModel(pc, false) + solverResult match { + case None => { + false + } + case Some(true) => { + false + } + case Some(false) => { + val recPath = rec(x, xs, path ++ intermediatePath) + recPath match { + case None => false + case Some(path) => { + completePath = Some(path) + true + } + } + } + } + }) + completePath + } + } + } + rec(waypoints.head, waypoints.tail, Seq()) + } + def findSimplePaths(from: ProgramPoint, to: Option[ProgramPoint] = None): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { def dfs(point: ProgramPoint, path: List[(ProgramPoint, ProgramPoint, TransitionLabel)], visitedPoints: Set[ProgramPoint]): Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = graph.get(point) match { case None => Set(path.reverse) case Some(edges) => { - if(edges.forall((edge: (ProgramPoint, TransitionLabel)) => visitedPoints.contains(edge._1) || point == edge._1)) + if(to != None && to.get == point) + Set(path.reverse) + else if(to == None && edges.forall((edge: (ProgramPoint, TransitionLabel)) => visitedPoints.contains(edge._1) || point == edge._1)) Set(path.reverse) else { edges.flatMap((edge: (ProgramPoint, TransitionLabel)) => { diff --git a/src/main/scala/leon/TestGeneration.scala b/src/main/scala/leon/TestGeneration.scala index cc800918df3c720680075e727fde6bece9210071..2935816eb9d80bd76bcf0371d3ceeba5cf62595d 100644 --- a/src/main/scala/leon/TestGeneration.scala +++ b/src/main/scala/leon/TestGeneration.scala @@ -42,7 +42,7 @@ class TestGeneration(reporter: Reporter) extends Analyser(reporter) { val Program(id, ObjectDef(objId, defs, invariants)) = program val testProgram = Program(id, ObjectDef(objId, testFun +: defs , invariants)) - testProgram.writeScalaFile("TestGen.scala") + testProgram.writeScalaFile("TestGen.scalax") reporter.info("Running from waypoint with the following testcases:\n") reporter.info(testcases.mkString("\n")) @@ -56,7 +56,7 @@ class TestGeneration(reporter: Reporter) extends Analyser(reporter) { val callGraph = new CallGraph(program) callGraph.writeDotFile("testgen.dot") - val constraints = callGraph.findAllPaths.map(path => { + val constraints = callGraph.findAllPaths(z3Solver).map(path => { println("Path is: " + path) val cnstr = callGraph.pathConstraint(path) println("constraint is: " + cnstr) diff --git a/testcases/testgen/Sum.scala b/testcases/testgen/Sum.scala index 786b009808f86d6e6db3d9ef91f0820ee4d2a2c3..5f1c6c481f36d195f1fe6d58c591a5d91918871b 100644 --- a/testcases/testgen/Sum.scala +++ b/testcases/testgen/Sum.scala @@ -1,9 +1,11 @@ import leon.Utils._ +import leon.Annotations._ object Sum { + @main def sum(n: Int): Int = { - waypoint(1, if(n <= 0) 0 else n + sum(n-1)) + if(n <= 0) waypoint(4, 0) else waypoint(3, waypoint(2, n + sum(n-1))) } ensuring(_ >= 0) }