diff --git a/src/main/scala/leon/CallGraph.scala b/src/main/scala/leon/CallGraph.scala index d37cf16827641642dcc110ab39ed39d943843202..6a086805f203d71cec7fc629ef4b64610830bafe 100644 --- a/src/main/scala/leon/CallGraph.scala +++ b/src/main/scala/leon/CallGraph.scala @@ -9,15 +9,10 @@ class CallGraph(val program: Program) { sealed abstract class ProgramPoint case class FunctionStart(fd: FunDef) extends ProgramPoint - case class ExpressionPoint(wp: Expr) extends ProgramPoint + case class ExpressionPoint(wp: Expr, id: Int) extends ProgramPoint + private var epid = -1 + private def freshExpressionPoint(wp: Expr) = {epid += 1; ExpressionPoint(wp, epid)} - //sealed abstract class EdgeLabel - //case class ConditionLabel(expr: Expr) extends EdgeLabel { - // require(expr.getType == BooleanType) - //} - //case class FunctionInvocLabel(fd: FunDef, args: List[Expr]) extends EdgeLabel { - // require(args.zip(fd.args).forall(p => p._1.getType == p._2.getType)) - //} case class TransitionLabel(cond: Expr, assignment: Map[Variable, Expr]) private lazy val graph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = buildGraph @@ -32,17 +27,62 @@ class CallGraph(val program: Program) { val body = fd.body.get //val cleanBody = hoistIte(expandLets(matchToIfThenElse(body))) val cleanBody = expandLets(matchToIfThenElse(body)) - //println("Clean Body: " + cleanBody) val subgraph = collectWithPathCondition(cleanBody, FunctionStart(fd)) - //println("Subgraph: " + subgraph) callGraph ++= subgraph }) - //println(callGraph) + callGraph = addFunctionInvocationsEdges(callGraph) + + callGraph = simplifyGraph(callGraph) callGraph } + private def simplifyGraph(graph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]]): Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { + def fix[A](f: (A) => A, a: A): A = { + val na = f(a) + if(a == na) a else fix(f, na) + } + fix(compressGraph, graph) + } + + //does a one level compression of the graph + private def compressGraph(graph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]]): Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { + var newGraph = graph + + graph.find{ + case (point, edges) => { + edges.exists{ + case edge@(p2, TransitionLabel(BooleanLiteral(true), assign)) if assign.isEmpty => { + val edgesOfPoint: Set[(ProgramPoint, TransitionLabel)] = graph.get(p2).getOrElse(Set()) //should be unique entry point and cannot be a FunctionStart + newGraph += (point -> ((edges - edge) ++ edgesOfPoint)) + newGraph -= p2 + true + } + case _ => false + } + } + } + + newGraph + } + + + private def addFunctionInvocationsEdges(graph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]]): Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { + var augmentedGraph = graph + + graph.foreach{ + case (point@ExpressionPoint(FunctionInvocation(fd, args), _), edges) => { + val newPoint = FunctionStart(fd) + val newTransition = TransitionLabel(BooleanLiteral(true), fd.args.zip(args).map{ case (VarDecl(id, _), arg) => (id.toVariable, arg) }.toMap) + augmentedGraph += (point -> (edges + ((newPoint, newTransition)))) + } + case _ => ; + } + + augmentedGraph + } + private def collectWithPathCondition(expression: Expr, startingPoint: ProgramPoint): Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { var callGraph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = Map() @@ -51,14 +91,13 @@ class CallGraph(val program: Program) { case None => Set() case Some(s) => s } - expr match { - case FunctionInvocation(fd, args) => { - val newPoint = FunctionStart(fd) - val newTransition = TransitionLabel(And(path.toSeq), fd.args.zip(args).map{ case (VarDecl(id, _), arg) => (id.toVariable, arg) }.toMap) - callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) - args.foreach(arg => rec(arg, path, startingPoint)) - } + //case FunctionInvocation(fd, args) => { + // val newPoint = FunctionStart(fd) + // val newTransition = TransitionLabel(And(path.toSeq), fd.args.zip(args).map{ case (VarDecl(id, _), arg) => (id.toVariable, arg) }.toMap) + // callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) + // args.foreach(arg => rec(arg, path, startingPoint)) + //} //this case is actually now handled in the unaryOp case //case way@Waypoint(i, e) => { // val newPoint = ExpressionPoint(way) @@ -67,44 +106,39 @@ class CallGraph(val program: Program) { // rec(e, List(), newPoint) //} case IfExpr(cond, then, elze) => { - //rec(cond, path, startingPoint) + rec(cond, path, startingPoint) rec(then, cond :: path, startingPoint) rec(elze, Not(cond) :: path, startingPoint) } case n@NAryOperator(args, _) => { - val newPoint = ExpressionPoint(n) + val newPoint = freshExpressionPoint(n) val newTransition = TransitionLabel(And(path.toSeq), Map()) callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) args.foreach(rec(_, List(), newPoint)) } case b@BinaryOperator(t1, t2, _) => { - val newPoint = ExpressionPoint(b) + val newPoint = freshExpressionPoint(b) val newTransition = TransitionLabel(And(path.toSeq), Map()) callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) rec(t1, List(), newPoint) rec(t2, List(), newPoint) } case u@UnaryOperator(t, _) => { - val newPoint = ExpressionPoint(u) + val newPoint = freshExpressionPoint(u) val newTransition = TransitionLabel(And(path.toSeq), Map()) callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) rec(t, List(), newPoint) } case t : Terminal => { - val newPoint = ExpressionPoint(t) + val newPoint = freshExpressionPoint(t) val newTransition = TransitionLabel(And(path.toSeq), Map()) callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) } - //case NAryOperator(args, _) => args.foreach(rec(_, path, startingPoint)) - //case BinaryOperator(t1, t2, _) => rec(t1, path, startingPoint); rec(t2, path, startingPoint) - //case UnaryOperator(t, _) => rec(t, path, startingPoint) - //case t : Terminal => ; case _ => scala.sys.error("Unhandled tree in collectWithPathCondition : " + expr) } } rec(expression, List(), startingPoint) - callGraph } @@ -122,15 +156,24 @@ class CallGraph(val program: Program) { } } + private def isMain(fd: FunDef): Boolean = { + fd.annotations.exists(_ == "main") + } + def findAllPaths: Set[Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]] = { - val waypoints: Set[ProgramPoint] = programPoints.filter{ case ExpressionPoint(Waypoint(_, _)) => true case _ => false } + val waypoints: Set[ProgramPoint] = programPoints.filter{ case ExpressionPoint(Waypoint(_, _), _) => true case _ => false } val sortedWaypoints: Seq[ProgramPoint] = waypoints.toSeq.sortWith((p1, p2) => { - val (ExpressionPoint(Waypoint(i1, _)), ExpressionPoint(Waypoint(i2, _))) = (p1, p2) + val (ExpressionPoint(Waypoint(i1, _), _), ExpressionPoint(Waypoint(i2, _), _)) = (p1, p2) i1 <= i2 }) - assert(!sortedWaypoints.isEmpty) - if(sortedWaypoints.size == 1) { //if only one waypoint then we want to cover all static statements starting from the waypoint - findSimplePaths(sortedWaypoints.head) + + val functionPoints: Set[ProgramPoint] = programPoints.flatMap{ case f@FunctionStart(fd) => Set[ProgramPoint](f) case _ => Set[ProgramPoint]() } + val mainPoint: Option[ProgramPoint] = functionPoints.find{ case FunctionStart(fd) => isMain(fd) case p => sys.error("unexpected: " + p) } + + assert(mainPoint != None || sortedWaypoints.size > 1) + + if(mainPoint != None) { + findSimplePaths(mainPoint.get) } else { Set( sortedWaypoints.zip(sortedWaypoints.tail).foldLeft(Seq[(ProgramPoint, ProgramPoint, TransitionLabel)]())((path, waypoint) => @@ -215,8 +258,8 @@ class CallGraph(val program: Program) { def ppPoint(p: ProgramPoint): String = p match { case FunctionStart(fd) => fd.id.name - case ExpressionPoint(Waypoint(i, e)) => "WayPoint " + i - case ExpressionPoint(e) => e.toString + case ExpressionPoint(Waypoint(i, e), _) => "WayPoint " + i + case ExpressionPoint(e, _) => e.toString } def ppLabel(l: TransitionLabel): String = { val TransitionLabel(cond, assignments) = l diff --git a/src/main/scala/leon/TestGeneration.scala b/src/main/scala/leon/TestGeneration.scala index 5c268c84da85a74f3fb5a98f9d32acb30d024317..cc800918df3c720680075e727fde6bece9210071 100644 --- a/src/main/scala/leon/TestGeneration.scala +++ b/src/main/scala/leon/TestGeneration.scala @@ -22,11 +22,11 @@ class TestGeneration(reporter: Reporter) extends Analyser(reporter) { val testcases = generateTestCases(program) - val topFunDef = program.definedFunctions.find(fd => fd.body.exists(body => body match { - case Waypoint(1, _) => true - case _ => false - })).get - + val topFunDef = program.definedFunctions.find(fd => isMain(fd)).get +//fd.body.exists(body => body match { +// case Waypoint(1, _) => true +// case _ => false +// }) val testFun = new FunDef(FreshIdentifier("test"), UnitType, Seq()) val funInvocs = testcases.map(testcase => { val params = topFunDef.args diff --git a/testcases/testgen/Abs2.scala b/testcases/testgen/Abs2.scala new file mode 100644 index 0000000000000000000000000000000000000000..d640f941bf299ce06e0e84e40e2a2f98c39a4c4c --- /dev/null +++ b/testcases/testgen/Abs2.scala @@ -0,0 +1,11 @@ +import leon.Utils._ +import leon.Annotations._ + +object Abs2 { + + @main + def f(x: Int): Int = if(x < 0) g(-x) else g(x) + + def g(y: Int): Int = if(y < 0) -y else y + +} diff --git a/testcases/testgen/Imp.scala b/testcases/testgen/Imp.scala new file mode 100644 index 0000000000000000000000000000000000000000..8257e6619ca299a8e56fd3d78076c54f8b731882 --- /dev/null +++ b/testcases/testgen/Imp.scala @@ -0,0 +1,17 @@ +import leon.Utils._ +import leon.Annotations._ + +object Imp { + + @main + def foo(i: Int): Int = { + var a = 0 + a = a + 3 + if(i < a) + a = a + 1 + else + a = a - 1 + a + } ensuring(_ >= 0) + +} diff --git a/testcases/testgen/MultiCall.scala b/testcases/testgen/MultiCall.scala index e536d275b86ef1c2f5989a3399c8d242d4d9ccfb..6742cc4ca224b8bef044745dc6cd12d60683ac80 100644 --- a/testcases/testgen/MultiCall.scala +++ b/testcases/testgen/MultiCall.scala @@ -1,9 +1,10 @@ import leon.Utils._ +import leon.Annotations._ object MultiCall { - - def a(i: Int): Int = waypoint(1, if(i < 0) b(i) else c(i)) + @main + def a(i: Int): Int = if(i < 0) b(i) else c(i) def b(j: Int): Int = if(j == -5) d(j) else e(j) def c(k: Int): Int = if(k == 5) d(k) else e(k)