diff --git a/src/main/scala/leon/genc/CAST.scala b/src/main/scala/leon/genc/CAST.scala index b0ba2bf4d58ac619f3bedb81f60e082eb35c8b69..a8b7b835f07997ce76c7019e72e4997dd35004f7 100644 --- a/src/main/scala/leon/genc/CAST.scala +++ b/src/main/scala/leon/genc/CAST.scala @@ -230,11 +230,29 @@ object CAST { // C Abstract Syntax Tree /* ------------------------------------------------------------- DSL ----- */ // Operator ~~ appends and flattens nested compounds implicit class StmtOps(val stmt: Stmt) { - def ~(other: Stmt) = (stmt, other) match { - case (Compound(stmts), Compound(others)) => Compound(stmts ++ others) - case (stmt , Compound(others)) => Compound(stmt +: others) - case (Compound(stmts), other ) => Compound(stmts :+ other ) - case (stmt , other ) => Compound(stmt :: other :: Nil) + // In addition to combining statements together in a compound + // we remove the empty ones and if the resulting compound + // has only one statement we return this one without being + // wrapped into a Compound + def ~(other: Stmt) = { + val stmts = (stmt, other) match { + case (Compound(stmts), Compound(others)) => stmts ++ others + case (stmt , Compound(others)) => stmt +: others + case (Compound(stmts), other ) => stmts :+ other + case (stmt , other ) => stmt :: other :: Nil + } + + def isNoStmt(s: Stmt) = s match { + case NoStmt => true + case _ => false + } + + val compound = Compound(stmts filterNot isNoStmt) + compound match { + case Compound(stmts) if stmts.length == 0 => NoStmt + case Compound(stmts) if stmts.length == 1 => stmts.head + case compound => compound + } } def ~~(others: Seq[Stmt]) = stmt ~ Compound(others) diff --git a/src/main/scala/leon/genc/CPrinter.scala b/src/main/scala/leon/genc/CPrinter.scala index 8b9011781089f2d1f4d02d193f56b184af667268..202814b7cf665d168683fc4772826652164e8a75 100644 --- a/src/main/scala/leon/genc/CPrinter.scala +++ b/src/main/scala/leon/genc/CPrinter.scala @@ -57,25 +57,16 @@ class CPrinter(val sb: StringBuffer = new StringBuffer) { /* --------------------------------------------------------- Stmts ----- */ case NoStmt => c"/* empty */" - - // Try to print new lines and semicolon somewhat correctly - case Compound(stmts) if stmts.isEmpty => // should not happen - - case Compound(stmts) if stmts.length == 1 => - stmts.head match { - case s: Call => c"$s;" // for function calls whose returned value is not saved - case s => c"$s" - } - case Compound(stmts) => - val head = stmts.head - val tail = Compound(stmts.tail) + val lastIdx = stmts.length - 1 + + for ((stmt, idx) <- stmts.zipWithIndex) { + if (stmt.isValue) c"$stmt;" + else c"$stmt" - head match { - case s: Call => c"$s;" // for function calls whose returned value is not saved - case s => c"$s" + if (idx != lastIdx) + c"$NewLine" } - c"$NewLine$tail" case Assert(pred, Some(error)) => c"assert($pred); /* $error */" case Assert(pred, None) => c"assert($pred);"