package leon import leon.purescala.Definitions._ import leon.purescala.Trees._ import leon.purescala.TypeTrees._ import leon.purescala.Common._ class CallGraph(val program: Program) { sealed abstract class ProgramPoint case class FunctionStart(fd: FunDef) extends ProgramPoint case class ExpressionPoint(wp: Expr) extends ProgramPoint //sealed abstract class EdgeLabel //case class ConditionLabel(expr: Expr) extends EdgeLabel { // require(expr.getType == BooleanType) //} //case class FunctionInvocLabel(fd: FunDef, args: List[Expr]) extends EdgeLabel { // require(args.zip(fd.args).forall(p => p._1.getType == p._2.getType)) //} case class TransitionLabel(cond: Expr, assignment: Map[Variable, Expr]) private lazy val graph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = buildGraph private def buildGraph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { var callGraph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = Map() program.definedFunctions.foreach(fd => { val body = fd.body.get //val cleanBody = hoistIte(expandLets(matchToIfThenElse(body))) val cleanBody = expandLets(matchToIfThenElse(body)) //println(cleanBody) val subgraph = collectWithPathCondition(cleanBody, FunctionStart(fd)) //println(subgraph) callGraph ++= subgraph }) //println(callGraph) callGraph } private def collectWithPathCondition(expression: Expr, startingPoint: ProgramPoint): Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = { var callGraph: Map[ProgramPoint, Set[(ProgramPoint, TransitionLabel)]] = Map() def rec(expr: Expr, path: List[Expr], startingPoint: ProgramPoint): Unit = expr match { case FunctionInvocation(fd, args) => { val transitions: Set[(ProgramPoint, TransitionLabel)] = callGraph.get(startingPoint) match { case None => Set() case Some(s) => s } val newPoint = FunctionStart(fd) val newTransition = TransitionLabel(And(path.toSeq), fd.args.zip(args).map{ case (VarDecl(id, _), arg) => (id.toVariable, arg) }.toMap) callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) args.foreach(arg => rec(arg, path, startingPoint)) } case WayPoint(e) => { val transitions: Set[(ProgramPoint, TransitionLabel)] = callGraph.get(startingPoint) match { case None => Set() case Some(s) => s } val newPoint = ExpressionPoint(expr) val newTransition = TransitionLabel(And(path.toSeq), Map()) callGraph += (startingPoint -> (transitions + ((newPoint, newTransition)))) rec(e, List(), newPoint) } case IfExpr(cond, then, elze) => { rec(cond, path, startingPoint) rec(then, cond :: path, startingPoint) rec(elze, Not(cond) :: path, startingPoint) } case NAryOperator(args, _) => args.foreach(rec(_, path, startingPoint)) case BinaryOperator(t1, t2, _) => rec(t1, path, startingPoint); rec(t2, path, startingPoint) case UnaryOperator(t, _) => rec(t, path, startingPoint) case t : Terminal => ; case _ => scala.sys.error("Unhandled tree in collectWithPathCondition : " + expr) } rec(expression, List(), startingPoint) callGraph } //guarentee that all IfExpr will be at the top level and as soon as you encounter a non-IfExpr, then no more IfExpr can be find in the sub-expressions def hoistIte(expr: Expr): Expr = { def transform(expr: Expr): Option[Expr] = expr match { case uop@UnaryOperator(IfExpr(c, t, e), op) => Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType)) case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType)) case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType)) case nop@NAryOperator(ts, op) => { val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } if(iteIndex == -1) None else { val (beforeIte, startIte) = ts.splitAt(iteIndex) val afterIte = startIte.tail val IfExpr(c, t, e) = startIte.head Some(IfExpr(c, op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType), op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType) ).setType(nop.getType)) } } case _ => None } def fix[A](f: (A) => A, a: A): A = { val na = f(a) if(a == na) a else fix(f, na) } fix(searchAndReplaceDFS(transform), expr) } lazy val toDotString: String = { var vertexLabels: Set[(String, String)] = Set() var vertexId = -1 var point2vertex: Map[ProgramPoint, Int] = Map() //return id and label def getVertex(p: ProgramPoint): (String, String) = point2vertex.get(p) match { case Some(id) => ("v_" + id, ppPoint(p)) case None => { vertexId += 1 point2vertex += (p -> vertexId) val pair = ("v_" + vertexId, ppPoint(p)) vertexLabels += pair pair } } def ppPoint(p: ProgramPoint): String = p match { case FunctionStart(fd) => fd.id.name case ExpressionPoint(WayPoint(e)) => "WayPoint" case _ => sys.error("Unexpected programPoint: " + p) } def ppLabel(l: TransitionLabel): String = { val TransitionLabel(cond, assignments) = l cond.toString + ", " + assignments.map(p => p._1 + " -> " + p._2).mkString("\n") } val edges: List[(String, String, String)] = graph.flatMap(pair => { val (startPoint, edges) = pair val (startId, _) = getVertex(startPoint) edges.map(pair => { val (endPoint, label) = pair val (endId, _) = getVertex(endPoint) (startId, endId, ppLabel(label)) }).toList }).toList val res = ( "digraph " + program.id.name + " {\n" + vertexLabels.map(p => p._1 + " [label=\"" + p._2 + "\"];").mkString("\n") + "\n" + edges.map(p => p._1 + " -> " + p._2 + " [label=\"" + p._3 + "\"];").mkString("\n") + "\n" + "}") res } //def analyse(program: Program) { // z3Solver.setProgram(program) // reporter.info("Running test generation") // val allFuns = program.definedFunctions // allFuns.foreach(fd => { // val testcases = generateTestCases(fd) // reporter.info("Running " + fd.id + " with the following testcases:\n") // reporter.info(testcases.mkString("\n")) // }) //} //private def generatePathConditions(funDef: FunDef): Seq[Expr] = if(!funDef.hasImplementation) Seq() else { // val body = funDef.body.get // val cleanBody = expandLets(matchToIfThenElse(body)) // collectWithPathCondition(cleanBody) //} } //def hoistIte(expr: Expr): (Seq[Expr] => Expr, Seq[Expr]) = expr match { // case ite@IfExpr(c, t, e) => { // val (iteThen, valsThen) = hoistIte(t) // val nbValsThen = valsThen.size // val (iteElse, valsElse) = hoistIte(e) // val nbValsElse = valsElse.size // def ite(es: Seq[Expr]): Expr = { // val argsThen = es.take(nbValsThen) // val argsElse = es.drop(nbValsThen) // IfExpr(c, iteThen(argsThen), iteElse(argsElse), e2) // } // (ite, valsThen ++ valsElse) // } // case BinaryOperator(t1, t2, op) => { // val (iteLeft, valsLeft) = hoistIte(t1) // val (iteRight, valsRight) = hoistIte(t2) // def ite(e1: Expr, e2: Expr): Expr = { // } // iteLeft( // iteRight( // op(thenValRight, thenValLeft), // op(thenValRight, elseValLeft) // ), iteRight( // op(elseValRight, thenValLeft), // op(elseValRight, elseValLeft) // ) // ) // } // case NAryOperator(args, op) => { // } // case (t: Terminal) => { // def ite(es: Seq[Expr]): Expr = { // require(es.size == 1) // es.head // } // (ite, Seq(t)) // } // case _ => scala.sys.error("Unhandled tree in hoistIte : " + expr) //}