Skip to content
Snippets Groups Projects
Commit 4a2709e4 authored by Marco Antognini's avatar Marco Antognini Committed by Etienne Kneuss
Browse files

Add C pretty printer

parent c7f64da8
No related branches found
No related tags found
No related merge requests found
/* Copyright 2009-2015 EPFL, Lausanne */
package leon
package genc
import CAST._
import CPrinterHelpers._
class CPrinter(val sb: StringBuffer = new StringBuffer) {
override def toString = sb.toString
def print(tree: Tree) = pp(tree)(PrinterContext(0, this))
def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = tree match {
/* ---------------------------------------------------------- Types ----- */
case typ: Type => c"${typ.toString}"
/* ------------------------------------------------------- Literals ----- */
case IntLiteral(v) => c"$v"
case BoolLiteral(b) => c"$b"
/* --------------------------------------------------- Definitions ----- */
case Prog(structs, functions) =>
c"""|/* ------------------------------------ includes ----- */
|
|${nary(includeStmts, sep = "\n")}
|
|/* ---------------------- data type declarations ----- */
|
|${nary(structs map StructDecl, sep = "\n")}
|
|/* ----------------------- data type definitions ----- */
|
|${nary(structs map StructDef, sep = "\n")}
|
|/* ----------------------- function declarations ----- */
|
|${nary(functions map FunDecl, sep = "\n")}
|
|/* ------------------------ function definitions ----- */
|
|${nary(functions, sep = "\n")}
|"""
case f @ Fun(_, _, _, body) =>
c"""|${FunSign(f)}
|{
| $body
|}
|"""
case Id(name) => c"$name"
/* --------------------------------------------------------- Stmts ----- */
case NoStmt => c"/* empty */"
// Try to print new lines and semicolon somewhat correctly
case Compound(stmts) if stmts.isEmpty => // should not happen
case Compound(stmts) if stmts.length == 1 =>
stmts.head match {
case s: Call => c"$s;" // for function calls whose returned value is not saved
case s => c"$s"
}
case Compound(stmts) =>
val head = stmts.head
val tail = Compound(stmts.tail)
head match {
case s: Call => c"$s;" // for function calls whose returned value is not saved
case s => c"$s"
}
c"$NewLine$tail"
case Assert(pred, Some(error)) => c"assert($pred); /* $error */"
case Assert(pred, None) => c"assert($pred);"
case Var(id, _) => c"$id"
case DeclVar(Var(id, typ)) => c"$typ $id;"
// TODO depending on the type of array (e.g. `char`) or the value (e.g. `0`), we could use `memset`.
case DeclInitVar(Var(id, typ), ai: ArrayInit) => // Note that `typ` is a struct here
val buffer = FreshId("vla_buffer")
val i = FreshId("i")
c"""|${ai.valueType} $buffer[${ai.length}];
|for (${Int32} $i = 0; $i < ${ai.length}; ++$i) {
| $buffer[$i] = ${ai.defaultValue};
|}
|$typ $id = { .length = ${ai.length}, .data = $buffer };
|"""
case DeclInitVar(Var(id, typ), ai: ArrayInitWithValues) => // Note that `typ` is a struct here
val buffer = FreshId("vla_buffer")
c"$NewLine${ai.valueType} $buffer[${ai.length}];$NewLine"
for ((v, i) <- ai.values.zipWithIndex) {
c"$buffer[$i] = $v;$NewLine"
}
c"$typ $id = { .length = ${ai.length}, .data = $buffer };$NewLine"
case DeclInitVar(Var(id, typ), value) =>
c"$typ $id = $value;"
case Assign(lhs, rhs) =>
c"$lhs = $rhs;"
case UnOp(op, rhs) => c"($op $rhs)"
case MultiOp(op, stmts) => c"""${nary(stmts, sep = s" ${op.fixMargin} ")}"""
case SubscriptOp(ptr, idx) => c"$ptr[$idx]"
case Break => c"break;"
case Return(stmt) => c"return $stmt;"
case IfElse(cond, thn, elze) =>
c"""|if ($cond)
|{
| $thn
|}
|else
|{
| $elze
|}
|"""
case While(cond, body) =>
c"""|while ($cond)
|{
| $body
|}
|"""
case AccessVar(id) => c"$id"
case AccessRef(id) => c"(*$id)"
case AccessAddr(id) => c"(&$id)"
case AccessField(struct, field) => c"$struct.$field"
case Call(id, args) => c"$id($args)"
case StructInit(args, struct) =>
c"(${struct.id}) { "
for ((id, stmt) <- args.init) {
c".$id = $stmt, "
}
if (!args.isEmpty) {
val (id, stmt) = args.last
c".$id = $stmt "
}
c"}"
/* --------------------------------------------------------- Error ----- */
case tree => sys.error(s"CPrinter: <<$tree>> was not handled properly")
}
def pp(wt: WrapperTree)(implicit ctx: PrinterContext): Unit = wt match {
case FunDecl(f) =>
c"${FunSign(f)};$NewLine"
case FunSign(Fun(id, retType, Nil, _)) =>
c"""|$retType
|$id($Void)"""
case FunSign(Fun(id, retType, params, _)) =>
c"""|$retType
|$id(${nary(params map DeclParam)})"""
case DeclParam(Var(id, typ)) =>
c"$typ $id"
case StructDecl(s) =>
c"struct $s;"
case StructDef(Struct(name, fields)) =>
c"""|typedef struct $name {
| ${nary(fields map DeclParam, sep = ";\n", closing = ";")}
|} $name;
|"""
case NewLine =>
c"""|
|"""
}
/** Hardcoded list of required include files from C standard library **/
lazy val includes = "assert.h" :: "stdbool.h" :: "stdint.h" :: Nil
lazy val includeStmts = includes map { i => s"#include <$i>" }
/** Wrappers to distinguish how the data should be printed **/
sealed abstract class WrapperTree
case class FunDecl(f: Fun) extends WrapperTree
case class FunSign(f: Fun) extends WrapperTree
case class DeclParam(x: Var) extends WrapperTree
case class StructDecl(s: Struct) extends WrapperTree
case class StructDef(s: Struct) extends WrapperTree
case object NewLine extends WrapperTree
}
/* Copyright 2009-2015 EPFL, Lausanne */
package leon
package genc
import CAST.Tree
/* Printer helpers adapted to C code generation */
case class PrinterContext(
indent: Int,
printer: CPrinter
)
object CPrinterHelpers {
implicit class Printable(val f: PrinterContext => Any) extends AnyVal {
def print(ctx: PrinterContext) = f(ctx)
}
implicit class PrinterHelper(val sc: StringContext) extends AnyVal {
def c(args: Any*)(implicit ctx: PrinterContext): Unit = {
val printer = ctx.printer
import printer.WrapperTree
val sb = printer.sb
val strings = sc.parts.iterator
val expressions = args.iterator
var extraInd = 0
var firstElem = true
while(strings.hasNext) {
val s = strings.next.stripMargin
// Compute indentation
val start = s.lastIndexOf('\n')
if(start >= 0 || firstElem) {
var i = start + 1
while(i < s.length && s(i) == ' ') {
i += 1
}
extraInd = (i - start - 1) / 2
}
firstElem = false
// Make sure new lines are also indented
sb.append(s.replaceAll("\n", "\n" + (" " * ctx.indent)))
val nctx = ctx.copy(indent = ctx.indent + extraInd)
if (expressions.hasNext) {
val e = expressions.next
e match {
case ts: Seq[Any] =>
nary(ts).print(nctx)
case t: Tree =>
printer.pp(t)(nctx)
case wt: WrapperTree =>
printer.pp(wt)(nctx)
case p: Printable =>
p.print(nctx)
case e =>
sb.append(e.toString)
}
}
}
}
}
def nary(ls: Seq[Any], sep: String = ", ", opening: String = "", closing: String = ""): Printable = {
val (o, c) = if(ls.isEmpty) ("", "") else (opening, closing)
val strs = o +: List.fill(ls.size-1)(sep) :+ c
implicit pctx: PrinterContext =>
new StringContext(strs: _*).c(ls: _*)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment