diff --git a/src/main/scala/leon/genc/CPrinter.scala b/src/main/scala/leon/genc/CPrinter.scala new file mode 100644 index 0000000000000000000000000000000000000000..8b9011781089f2d1f4d02d193f56b184af667268 --- /dev/null +++ b/src/main/scala/leon/genc/CPrinter.scala @@ -0,0 +1,200 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package genc + +import CAST._ +import CPrinterHelpers._ + +class CPrinter(val sb: StringBuffer = new StringBuffer) { + override def toString = sb.toString + + def print(tree: Tree) = pp(tree)(PrinterContext(0, this)) + + def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = tree match { + /* ---------------------------------------------------------- Types ----- */ + case typ: Type => c"${typ.toString}" + + + /* ------------------------------------------------------- Literals ----- */ + case IntLiteral(v) => c"$v" + case BoolLiteral(b) => c"$b" + + + /* --------------------------------------------------- Definitions ----- */ + case Prog(structs, functions) => + c"""|/* ------------------------------------ includes ----- */ + | + |${nary(includeStmts, sep = "\n")} + | + |/* ---------------------- data type declarations ----- */ + | + |${nary(structs map StructDecl, sep = "\n")} + | + |/* ----------------------- data type definitions ----- */ + | + |${nary(structs map StructDef, sep = "\n")} + | + |/* ----------------------- function declarations ----- */ + | + |${nary(functions map FunDecl, sep = "\n")} + | + |/* ------------------------ function definitions ----- */ + | + |${nary(functions, sep = "\n")} + |""" + + case f @ Fun(_, _, _, body) => + c"""|${FunSign(f)} + |{ + | $body + |} + |""" + + case Id(name) => c"$name" + + + /* --------------------------------------------------------- Stmts ----- */ + case NoStmt => c"/* empty */" + + + // Try to print new lines and semicolon somewhat correctly + case Compound(stmts) if stmts.isEmpty => // should not happen + + case Compound(stmts) if stmts.length == 1 => + stmts.head match { + case s: Call => c"$s;" // for function calls whose returned value is not saved + case s => c"$s" + } + + case Compound(stmts) => + val head = stmts.head + val tail = Compound(stmts.tail) + + head match { + case s: Call => c"$s;" // for function calls whose returned value is not saved + case s => c"$s" + } + c"$NewLine$tail" + + case Assert(pred, Some(error)) => c"assert($pred); /* $error */" + case Assert(pred, None) => c"assert($pred);" + + case Var(id, _) => c"$id" + case DeclVar(Var(id, typ)) => c"$typ $id;" + + // TODO depending on the type of array (e.g. `char`) or the value (e.g. `0`), we could use `memset`. + case DeclInitVar(Var(id, typ), ai: ArrayInit) => // Note that `typ` is a struct here + val buffer = FreshId("vla_buffer") + val i = FreshId("i") + c"""|${ai.valueType} $buffer[${ai.length}]; + |for (${Int32} $i = 0; $i < ${ai.length}; ++$i) { + | $buffer[$i] = ${ai.defaultValue}; + |} + |$typ $id = { .length = ${ai.length}, .data = $buffer }; + |""" + + case DeclInitVar(Var(id, typ), ai: ArrayInitWithValues) => // Note that `typ` is a struct here + val buffer = FreshId("vla_buffer") + c"$NewLine${ai.valueType} $buffer[${ai.length}];$NewLine" + for ((v, i) <- ai.values.zipWithIndex) { + c"$buffer[$i] = $v;$NewLine" + } + c"$typ $id = { .length = ${ai.length}, .data = $buffer };$NewLine" + + case DeclInitVar(Var(id, typ), value) => + c"$typ $id = $value;" + + case Assign(lhs, rhs) => + c"$lhs = $rhs;" + + case UnOp(op, rhs) => c"($op $rhs)" + case MultiOp(op, stmts) => c"""${nary(stmts, sep = s" ${op.fixMargin} ")}""" + case SubscriptOp(ptr, idx) => c"$ptr[$idx]" + + case Break => c"break;" + case Return(stmt) => c"return $stmt;" + + case IfElse(cond, thn, elze) => + c"""|if ($cond) + |{ + | $thn + |} + |else + |{ + | $elze + |} + |""" + + case While(cond, body) => + c"""|while ($cond) + |{ + | $body + |} + |""" + + case AccessVar(id) => c"$id" + case AccessRef(id) => c"(*$id)" + case AccessAddr(id) => c"(&$id)" + case AccessField(struct, field) => c"$struct.$field" + case Call(id, args) => c"$id($args)" + + case StructInit(args, struct) => + c"(${struct.id}) { " + for ((id, stmt) <- args.init) { + c".$id = $stmt, " + } + if (!args.isEmpty) { + val (id, stmt) = args.last + c".$id = $stmt " + } + c"}" + + /* --------------------------------------------------------- Error ----- */ + case tree => sys.error(s"CPrinter: <<$tree>> was not handled properly") + } + + + def pp(wt: WrapperTree)(implicit ctx: PrinterContext): Unit = wt match { + case FunDecl(f) => + c"${FunSign(f)};$NewLine" + + case FunSign(Fun(id, retType, Nil, _)) => + c"""|$retType + |$id($Void)""" + + case FunSign(Fun(id, retType, params, _)) => + c"""|$retType + |$id(${nary(params map DeclParam)})""" + + case DeclParam(Var(id, typ)) => + c"$typ $id" + + case StructDecl(s) => + c"struct $s;" + + case StructDef(Struct(name, fields)) => + c"""|typedef struct $name { + | ${nary(fields map DeclParam, sep = ";\n", closing = ";")} + |} $name; + |""" + + case NewLine => + c"""| + |""" + } + + /** Hardcoded list of required include files from C standard library **/ + lazy val includes = "assert.h" :: "stdbool.h" :: "stdint.h" :: Nil + lazy val includeStmts = includes map { i => s"#include <$i>" } + + /** Wrappers to distinguish how the data should be printed **/ + sealed abstract class WrapperTree + case class FunDecl(f: Fun) extends WrapperTree + case class FunSign(f: Fun) extends WrapperTree + case class DeclParam(x: Var) extends WrapperTree + case class StructDecl(s: Struct) extends WrapperTree + case class StructDef(s: Struct) extends WrapperTree + case object NewLine extends WrapperTree +} + diff --git a/src/main/scala/leon/genc/CPrinterHelper.scala b/src/main/scala/leon/genc/CPrinterHelper.scala new file mode 100644 index 0000000000000000000000000000000000000000..5c32df8b7289b0ca5510d04680ab449e915a12dd --- /dev/null +++ b/src/main/scala/leon/genc/CPrinterHelper.scala @@ -0,0 +1,86 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package genc + +import CAST.Tree + +/* Printer helpers adapted to C code generation */ + +case class PrinterContext( + indent: Int, + printer: CPrinter +) + +object CPrinterHelpers { + implicit class Printable(val f: PrinterContext => Any) extends AnyVal { + def print(ctx: PrinterContext) = f(ctx) + } + + implicit class PrinterHelper(val sc: StringContext) extends AnyVal { + def c(args: Any*)(implicit ctx: PrinterContext): Unit = { + val printer = ctx.printer + import printer.WrapperTree + val sb = printer.sb + + val strings = sc.parts.iterator + val expressions = args.iterator + + var extraInd = 0 + var firstElem = true + + while(strings.hasNext) { + val s = strings.next.stripMargin + + // Compute indentation + val start = s.lastIndexOf('\n') + if(start >= 0 || firstElem) { + var i = start + 1 + while(i < s.length && s(i) == ' ') { + i += 1 + } + extraInd = (i - start - 1) / 2 + } + + firstElem = false + + // Make sure new lines are also indented + sb.append(s.replaceAll("\n", "\n" + (" " * ctx.indent))) + + val nctx = ctx.copy(indent = ctx.indent + extraInd) + + if (expressions.hasNext) { + val e = expressions.next + + e match { + case ts: Seq[Any] => + nary(ts).print(nctx) + + case t: Tree => + printer.pp(t)(nctx) + + case wt: WrapperTree => + printer.pp(wt)(nctx) + + case p: Printable => + p.print(nctx) + + case e => + sb.append(e.toString) + } + } + } + } + } + + def nary(ls: Seq[Any], sep: String = ", ", opening: String = "", closing: String = ""): Printable = { + val (o, c) = if(ls.isEmpty) ("", "") else (opening, closing) + val strs = o +: List.fill(ls.size-1)(sep) :+ c + + implicit pctx: PrinterContext => + new StringContext(strs: _*).c(ls: _*) + } + +} + +