diff --git a/src/main/scala/leon/purescala/Common.scala b/src/main/scala/leon/purescala/Common.scala index 6ca933b74dffb167e0c8e6ec5ef43db3fb92e960..149530bb3727742ae6127b663697f671aa599c86 100644 --- a/src/main/scala/leon/purescala/Common.scala +++ b/src/main/scala/leon/purescala/Common.scala @@ -7,8 +7,10 @@ object Common { import Trees.Variable import TypeTrees.Typed + abstract class Tree extends Serializable + // the type is left blank (Untyped) for Identifiers that are not variables - class Identifier private[Common](val name: String, private val globalId: Int, val id: Int, alwaysShowUniqueID: Boolean = false) extends Typed { + class Identifier private[Common](val name: String, private val globalId: Int, val id: Int, alwaysShowUniqueID: Boolean = false) extends Tree with Typed { self : Serializable => override def equals(other: Any): Boolean = { diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 0c65ed814e7c2c2f3b5bb2f30dbd4e15059a1391..7bac0be4ca4700ee47e10aab52def1e11a9ed9c6 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -10,7 +10,7 @@ object Definitions { import Extractors._ import TypeTrees._ - sealed abstract class Definition extends Serializable { + sealed abstract class Definition extends Tree { val id: Identifier override def toString: String = PrettyPrinter(this) override def hashCode : Int = id.hashCode @@ -21,7 +21,7 @@ object Definitions { } /** A VarDecl declares a new identifier to be of a certain type. */ - case class VarDecl(id: Identifier, tpe: TypeTree) extends Typed { + case class VarDecl(id: Identifier, tpe: TypeTree) extends Definition with Typed { self: Serializable => override def getType = tpe diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index f8b91d2b9993acfd670f2b401c66b906fd6d185d..acc091cefda34fd2882c11ab48ecb44639391901 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -19,33 +19,37 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { sb.append(str) } - def ind(lvl: Int) { + def ind(implicit lvl: Int) { sb.append(" " * lvl) } + def nl(implicit lvl: Int) { + sb.append("\n") + ind(lvl) + } // EXPRESSIONS // all expressions are printed in-line - def ppUnary(expr: Expr, op1: String, op2: String, lvl: Int) { + def ppUnary(expr: Tree, op1: String, op2: String)(implicit parent: Option[Tree], lvl: Int) { sb.append(op1) - pp(expr, lvl) + pp(expr, parent) sb.append(op2) } - def ppBinary(left: Expr, right: Expr, op: String, lvl: Int) { + def ppBinary(left: Tree, right: Tree, op: String)(implicit parent: Option[Tree], lvl: Int) { sb.append("(") - pp(left, lvl) + pp(left, parent) sb.append(op) - pp(right, lvl) + pp(right, parent) sb.append(")") } - def ppNary(exprs: Seq[Expr], pre: String, op: String, post: String, lvl: Int) { + def ppNary(exprs: Seq[Tree], pre: String, op: String, post: String)(implicit parent: Option[Tree], lvl: Int) { sb.append(pre) val sz = exprs.size var c = 0 exprs.foreach(ex => { - pp(ex, lvl) ; c += 1 ; if(c < sz) sb.append(op) + pp(ex, parent) ; c += 1 ; if(c < sz) sb.append(op) }) sb.append(post) @@ -53,321 +57,316 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { def idToString(id: Identifier): String = id.toString - def pp(tree: Expr, lvl: Int): Unit = tree match { - case Variable(id) => sb.append(idToString(id)) - case DeBruijnIndex(idx) => sb.append("_" + idx) - case LetTuple(bs,d,e) => - sb.append("(let (" + bs.map(idToString _).mkString(",") + " := "); - pp(d, lvl) - sb.append(") in\n") - ind(lvl+1) - pp(e, lvl+1) - sb.append(")") - - case Let(b,d,e) => - sb.append("(let (" + idToString(b) + " := "); - pp(d, lvl) - sb.append(") in\n") - ind(lvl+1) - pp(e, lvl+1) - sb.append(")") - - case LetDef(fd,body) => - sb.append("\n") - pp(fd, lvl+1) - sb.append("\n") - sb.append("\n") - ind(lvl) - pp(body, lvl) - - case And(exprs) => ppNary(exprs, "(", " \u2227 ", ")", lvl) // \land - case Or(exprs) => ppNary(exprs, "(", " \u2228 ", ")", lvl) // \lor - case Not(Equals(l, r)) => ppBinary(l, r, " \u2260 ", lvl) // \neq - case Iff(l,r) => ppBinary(l, r, " <=> ", lvl) - case Implies(l,r) => ppBinary(l, r, " ==> ", lvl) - case UMinus(expr) => ppUnary(expr, "-(", ")", lvl) - case Equals(l,r) => ppBinary(l, r, " == ", lvl) - case IntLiteral(v) => sb.append(v) - case BooleanLiteral(v) => sb.append(v) - case StringLiteral(s) => sb.append("\"" + s + "\"") - case UnitLiteral => sb.append("()") - case t@Tuple(exprs) => ppNary(exprs, "(", ", ", ")", lvl) - case s@TupleSelect(t, i) => - pp(t, lvl) - sb.append("._" + i) - - case c@Choose(vars, pred) => - sb.append("choose("+vars.map(idToString _).mkString(", ")+" => ") - pp(pred, lvl) - sb.append(")") - - case CaseClass(cd, args) => - sb.append(idToString(cd.id)) - if (cd.isCaseObject) { - ppNary(args, "", "", "", lvl) - } else { - ppNary(args, "(", ", ", ")", lvl) - } - - case CaseClassInstanceOf(cd, e) => - pp(e, lvl) - sb.append(".isInstanceOf[" + idToString(cd.id) + "]") - - case CaseClassSelector(_, cc, id) => - pp(cc, lvl) - sb.append("." + idToString(id)) - - case FunctionInvocation(fd, args) => - sb.append(idToString(fd.id)) - ppNary(args, "(", ", ", ")", lvl) - - case Plus(l,r) => ppBinary(l, r, " + ", lvl) - case Minus(l,r) => ppBinary(l, r, " - ", lvl) - case Times(l,r) => ppBinary(l, r, " * ", lvl) - case Division(l,r) => ppBinary(l, r, " / ", lvl) - case Modulo(l,r) => ppBinary(l, r, " % ", lvl) - case LessThan(l,r) => ppBinary(l, r, " < ", lvl) - case GreaterThan(l,r) => ppBinary(l, r, " > ", lvl) - case LessEquals(l,r) => ppBinary(l, r, " \u2264 ", lvl) // \leq - case GreaterEquals(l,r) => ppBinary(l, r, " \u2265 ", lvl) // \geq - case FiniteSet(rs) => if(rs.isEmpty) sb.append("\u2205") /* Ø */ else ppNary(rs, "{", ", ", "}", lvl) - case FiniteMultiset(rs) => ppNary(rs, "{|", ", ", "|}", lvl) - case EmptyMultiset(_) => sb.append("\u2205") // Ø - case Not(ElementOfSet(s,e)) => ppBinary(s, e, " \u2209 ", lvl) // \notin - case ElementOfSet(s,e) => ppBinary(s, e, " \u2208 ", lvl) // \in - case SubsetOf(l,r) => ppBinary(l, r, " \u2286 ", lvl) // \subseteq - case Not(SubsetOf(l,r)) => ppBinary(l, r, " \u2288 ", lvl) // \notsubseteq - case SetMin(s) => pp(s, lvl); sb.append(".min") - case SetMax(s) => pp(s, lvl); sb.append(".max") - case SetUnion(l,r) => ppBinary(l, r, " \u222A ", lvl) // \cup - case MultisetUnion(l,r) => ppBinary(l, r, " \u222A ", lvl) // \cup - case MapUnion(l,r) => ppBinary(l, r, " \u222A ", lvl) // \cup - case SetDifference(l,r) => ppBinary(l, r, " \\ ", lvl) - case MultisetDifference(l,r) => ppBinary(l, r, " \\ ", lvl) - case SetIntersection(l,r) => ppBinary(l, r, " \u2229 ", lvl) // \cap - case MultisetIntersection(l,r) => ppBinary(l, r, " \u2229 ", lvl) // \cap - case SetCardinality(t) => ppUnary(t, "|", "|", lvl) - case MultisetCardinality(t) => ppUnary(t, "|", "|", lvl) - case MultisetPlus(l,r) => ppBinary(l, r, " \u228E ", lvl) // U+ - case MultisetToSet(e) => pp(e, lvl); sb.append(".toSet") - case FiniteMap(rs) => - sb.append("{") - val sz = rs.size - var c = 0 - rs.foreach{case (k, v) => { - pp(k, lvl); sb.append(" -> "); pp(v, lvl); c += 1 ; if(c < sz) sb.append(", ") - }} - sb.append("}") - - case MapGet(m,k) => - pp(m, lvl) - ppNary(Seq(k), "(", ",", ")", lvl) - - case MapIsDefinedAt(m,k) => - pp(m, lvl) - sb.append(".isDefinedAt") - ppNary(Seq(k), "(", ",", ")", lvl) - - case ArrayLength(a) => - pp(a, lvl) - sb.append(".length") - - case ArrayClone(a) => - pp(a, lvl) - sb.append(".clone") - - case fill@ArrayFill(size, v) => - sb.append("Array.fill(") - pp(size, lvl) - sb.append(")(") - pp(v, lvl) - sb.append(")") - - case am@ArrayMake(v) => - sb.append("Array.make(") - pp(v, lvl) - sb.append(")") - - case sel@ArraySelect(ar, i) => - pp(ar, lvl) - sb.append("(") - pp(i, lvl) - sb.append(")") - - case up@ArrayUpdated(ar, i, v) => - pp(ar, lvl) - sb.append(".updated(") - pp(i, lvl) - sb.append(", ") - pp(v, lvl) - sb.append(")") - - case FiniteArray(exprs) => - ppNary(exprs, "Array(", ", ", ")", lvl) - - case Distinct(exprs) => - sb.append("distinct") - ppNary(exprs, "(", ", ", ")", lvl) - - case IfExpr(c, t, e) => - sb.append("if (") - pp(c, lvl) - sb.append(")\n") - ind(lvl+1) - pp(t, lvl+1) - sb.append("\n") - ind(lvl) - sb.append("else\n") - ind(lvl+1) - pp(e, lvl+1) - - case mex @ MatchExpr(s, csc) => { - def ppc(p: Pattern): Unit = p match { - //case InstanceOfPattern(None, ctd) => - //case InstanceOfPattern(Some(id), ctd) => - case CaseClassPattern(bndr, ccd, subps) => { - bndr.foreach(b => sb.append(idToString(b) + " @ ")) - sb.append(idToString(ccd.id)).append("(") - var c = 0 - val sz = subps.size - subps.foreach(sp => { - ppc(sp) - if(c < sz - 1) - sb.append(", ") - c = c + 1 - }) - sb.append(")") - } - case WildcardPattern(None) => sb.append("_") - case WildcardPattern(Some(id)) => sb.append(idToString(id)) - case TuplePattern(bndr, subPatterns) => { - bndr.foreach(b => sb.append(b + " @ ")) - sb.append("(") - subPatterns.init.foreach(p => { - ppc(p) - sb.append(", ") - }) - ppc(subPatterns.last) - sb.append(")") - } - case _ => sb.append("Pattern?") - } + def pp(tree: Tree, parent: Option[Tree])(implicit lvl: Int): Unit = { + implicit val p = Some(tree) + + tree match { + case Variable(id) => sb.append(idToString(id)) + case DeBruijnIndex(idx) => sb.append("_" + idx) + case LetTuple(bs,d,e) => + sb.append("(let (" + bs.map(idToString _).mkString(",") + " := "); + pp(d, p) + sb.append(") in") + nl(lvl+1) + pp(e, p)(lvl+1) + sb.append(")") - pp(s, lvl) - // if(mex.posInfo != "") { - // sb.append(" match@(" + mex.posInfo + ") {\n") - // } else { - sb.append(" match {\n") - // } + case Let(b,d,e) => + sb.append("(let (" + idToString(b) + " := "); + pp(d, p) + sb.append(") in") + nl(lvl+1) + pp(e, p)(lvl+1) + sb.append(")") - csc.foreach(cs => { - ind(lvl+1) - sb.append("case ") - ppc(cs.pattern) - cs.theGuard.foreach(g => { - sb.append(" if ") - pp(g, lvl+1) - }) - sb.append(" => ") - pp(cs.rhs, lvl+1) + case LetDef(fd,body) => sb.append("\n") - }) - ind(lvl) - sb.append("}") - } + pp(fd, p)(lvl+1) + sb.append("\n") + sb.append("\n") + nl + pp(body, p) + + case And(exprs) => ppNary(exprs, "(", " \u2227 ", ")") // \land + case Or(exprs) => ppNary(exprs, "(", " \u2228 ", ")") // \lor + case Not(Equals(l, r)) => ppBinary(l, r, " \u2260 ") // \neq + case Iff(l,r) => ppBinary(l, r, " <=> ") + case Implies(l,r) => ppBinary(l, r, " ==> ") + case UMinus(expr) => ppUnary(expr, "-(", ")") + case Equals(l,r) => ppBinary(l, r, " == ") + case IntLiteral(v) => sb.append(v) + case BooleanLiteral(v) => sb.append(v) + case StringLiteral(s) => sb.append("\"" + s + "\"") + case UnitLiteral => sb.append("()") + case t@Tuple(exprs) => ppNary(exprs, "(", ", ", ")") + case s@TupleSelect(t, i) => + pp(t, p) + sb.append("._" + i) + + case c@Choose(vars, pred) => + sb.append("choose("+vars.map(idToString _).mkString(", ")+" => ") + pp(pred, p) + sb.append(")") - case Not(expr) => ppUnary(expr, "\u00AC(", ")", lvl) // \neg + case CaseClass(cd, args) => + sb.append(idToString(cd.id)) + if (cd.isCaseObject) { + ppNary(args, "", "", "") + } else { + ppNary(args, "(", ", ", ")") + } - case e @ Error(desc) => - sb.append("error(\"" + desc + "\")[") - pp(e.getType, lvl) - sb.append("]") + case CaseClassInstanceOf(cd, e) => + pp(e, p) + sb.append(".isInstanceOf[" + idToString(cd.id) + "]") - case (expr: PrettyPrintable) => expr.printWith(lvl, this) + case CaseClassSelector(_, cc, id) => + pp(cc, p) + sb.append("." + idToString(id)) - case _ => sb.append("Expr? (" + tree.getClass + ")") - } + case FunctionInvocation(fd, args) => + sb.append(idToString(fd.id)) + ppNary(args, "(", ", ", ")") + + case Plus(l,r) => ppBinary(l, r, " + ") + case Minus(l,r) => ppBinary(l, r, " - ") + case Times(l,r) => ppBinary(l, r, " * ") + case Division(l,r) => ppBinary(l, r, " / ") + case Modulo(l,r) => ppBinary(l, r, " % ") + case LessThan(l,r) => ppBinary(l, r, " < ") + case GreaterThan(l,r) => ppBinary(l, r, " > ") + case LessEquals(l,r) => ppBinary(l, r, " \u2264 ") // \leq + case GreaterEquals(l,r) => ppBinary(l, r, " \u2265 ") // \geq + case FiniteSet(rs) => if(rs.isEmpty) sb.append("\u2205") /* Ø */ else ppNary(rs, "{", ", ", "}") + case FiniteMultiset(rs) => ppNary(rs, "{|", ", ", "|}") + case EmptyMultiset(_) => sb.append("\u2205") // Ø + case Not(ElementOfSet(s,e)) => ppBinary(s, e, " \u2209 ") // \notin + case ElementOfSet(s,e) => ppBinary(s, e, " \u2208 ") // \in + case SubsetOf(l,r) => ppBinary(l, r, " \u2286 ") // \subseteq + case Not(SubsetOf(l,r)) => ppBinary(l, r, " \u2288 ") // \notsubseteq + case SetMin(s) => pp(s, p); sb.append(".min") + case SetMax(s) => pp(s, p); sb.append(".max") + case SetUnion(l,r) => ppBinary(l, r, " \u222A ") // \cup + case MultisetUnion(l,r) => ppBinary(l, r, " \u222A ") // \cup + case MapUnion(l,r) => ppBinary(l, r, " \u222A ") // \cup + case SetDifference(l,r) => ppBinary(l, r, " \\ ") + case MultisetDifference(l,r) => ppBinary(l, r, " \\ ") + case SetIntersection(l,r) => ppBinary(l, r, " \u2229 ") // \cap + case MultisetIntersection(l,r) => ppBinary(l, r, " \u2229 ") // \cap + case SetCardinality(t) => ppUnary(t, "|", "|") + case MultisetCardinality(t) => ppUnary(t, "|", "|") + case MultisetPlus(l,r) => ppBinary(l, r, " \u228E ") // U+ + case MultisetToSet(e) => pp(e, p); sb.append(".toSet") + case FiniteMap(rs) => + sb.append("{") + val sz = rs.size + var c = 0 + rs.foreach{case (k, v) => { + pp(k, p); sb.append(" -> "); pp(v, p); c += 1 ; if(c < sz) sb.append(", ") + }} + sb.append("}") + + case MapGet(m,k) => + pp(m, p) + ppNary(Seq(k), "(", ",", ")") + + case MapIsDefinedAt(m,k) => + pp(m, p) + sb.append(".isDefinedAt") + ppNary(Seq(k), "(", ",", ")") + + case ArrayLength(a) => + pp(a, p) + sb.append(".length") + + case ArrayClone(a) => + pp(a, p) + sb.append(".clone") + + case fill@ArrayFill(size, v) => + sb.append("Array.fill(") + pp(size, p) + sb.append(")(") + pp(v, p) + sb.append(")") - // TYPE TREES - // all type trees are printed in-line - def ppNaryType(tpes: Seq[TypeTree], pre: String, op: String, post: String, lvl: Int): Unit = { - sb.append(pre) - val sz = tpes.size - var c = 0 + case am@ArrayMake(v) => + sb.append("Array.make(") + pp(v, p) + sb.append(")") - tpes.foreach(t => { - pp(t, lvl) ; c += 1 ; if(c < sz) sb.append(op) - }) + case sel@ArraySelect(ar, i) => + pp(ar, p) + sb.append("(") + pp(i, p) + sb.append(")") - sb.append(post) - } + case up@ArrayUpdated(ar, i, v) => + pp(ar, p) + sb.append(".updated(") + pp(i, p) + sb.append(", ") + pp(v, p) + sb.append(")") + + case FiniteArray(exprs) => + ppNary(exprs, "Array(", ", ", ")") + + case Distinct(exprs) => + sb.append("distinct") + ppNary(exprs, "(", ", ", ")") + + case IfExpr(c, t, e) => + sb.append("if (") + pp(c, p) + sb.append(")") + nl(lvl+1) + pp(t, p)(lvl+1) + nl + sb.append("else") + nl(lvl+1) + pp(e, p)(lvl+1) + + case mex @ MatchExpr(s, csc) => + pp(s, p) + sb.append(" match {\n") - def pp(tpe: TypeTree, lvl: Int): Unit = tpe match { - case Untyped => sb.append("???") - case UnitType => sb.append("Unit") - case Int32Type => sb.append("Int") - case BooleanType => sb.append("Boolean") - case ArrayType(bt) => sb.append("Array["); pp(bt, lvl); sb.append("]") - case SetType(bt) => sb.append("Set["); pp(bt, lvl); sb.append("]") - case MapType(ft,tt) => sb.append("Map["); pp(ft, lvl); sb.append(","); pp(tt, lvl); sb.append("]") - case MultisetType(bt) => sb.append("Multiset["); pp(bt, lvl); sb.append("]") - case TupleType(tpes) => ppNaryType(tpes, "(", ", ", ")", lvl) - case c: ClassType => sb.append(c.classDef.id) - case FunctionType(fts, tt) => { - if (fts.size > 1) - ppNaryType(fts, "(", ", ", ")", lvl) - else if (fts.size == 1) - pp(fts.head, lvl) - sb.append(" => ") - pp(tt, lvl) - } - case _ => sb.append("Type?") - } + csc.foreach(cs => { + nl(lvl+1) + pp(cs, p) + sb.append("\n") + }) + nl(lvl) + sb.append("}") + + case Not(expr) => ppUnary(expr, "\u00AC(", ")") // \neg + + case e @ Error(desc) => + sb.append("error(\"" + desc + "\")[") + pp(e.getType, p) + sb.append("]") + + case (tree: PrettyPrintable) => tree.printWith(this) + + // Cases + case SimpleCase(pat, rhs) => + sb.append("case ") + pp(pat, p) + sb.append(" =>\n") + ind(lvl+1) + pp(rhs, p)(lvl+2) + case GuardedCase(pat, guard, rhs) => + sb.append("case ") + pp(pat, p) + sb.append(" if ") + pp(guard, p) + sb.append(" =>\n") + ind(lvl+1) + pp(rhs, p)(lvl+2) + + // Patterns + case CaseClassPattern(bndr, ccd, subps) => + bndr.foreach(b => sb.append(b + " @ ")) + sb.append(idToString(ccd.id)).append("(") + var c = 0 + val sz = subps.size + subps.foreach(sp => { + pp(sp, p) + if(c < sz - 1) + sb.append(", ") + c = c + 1 + }) + sb.append(")") + + case WildcardPattern(None) => sb.append("_") + case WildcardPattern(Some(id)) => sb.append(idToString(id)) + case InstanceOfPattern(bndr, ccd) => + bndr.foreach(b => sb.append(b + " : ")) + sb.append(idToString(ccd.id)) + + case TuplePattern(bndr, subPatterns) => + bndr.foreach(b => sb.append(b + " @ ")) + sb.append("(") + subPatterns.init.foreach(pat => { + pp(pat, p) + sb.append(", ") + }) + pp(subPatterns.last, p) + sb.append(")") + + + // Types + case Untyped => sb.append("???") + case UnitType => sb.append("Unit") + case Int32Type => sb.append("Int") + case BooleanType => sb.append("Boolean") + case ArrayType(bt) => + sb.append("Array[") + pp(bt, p) + sb.append("]") + case SetType(bt) => + sb.append("Set[") + pp(bt, p) + sb.append("]") + case MapType(ft,tt) => + sb.append("Map[") + pp(ft, p) + sb.append(",") + pp(tt, p) + sb.append("]") + case MultisetType(bt) => + sb.append("Multiset[") + pp(bt, p) + sb.append("]") + case TupleType(tpes) => ppNary(tpes, "(", ", ", ")") + case FunctionType(fts, tt) => + if (fts.size > 1) { + ppNary(fts, "(", ", ", ")") + } else if (fts.size == 1) { + pp(fts.head, p) + } + sb.append(" => ") + pp(tt, p) + case c: ClassType => sb.append(idToString(c.classDef.id)) - // DEFINITIONS - // all definitions are printed with an end-of-line - def pp(defn: Definition, lvl: Int) { - defn match { - case Program(id, mainObj) => { + + // Definitions + case Program(id, mainObj) => assert(lvl == 0) sb.append("package ") sb.append(idToString(id)) sb.append(" {\n") - pp(mainObj, lvl+1) + pp(mainObj, p)(lvl+1) sb.append("}\n") - } - case ObjectDef(id, defs, invs) => { - ind(lvl) + case ObjectDef(id, defs, invs) => + nl sb.append("object ") sb.append(idToString(id)) - sb.append(" {\n") + sb.append(" {") var c = 0 val sz = defs.size defs.foreach(df => { - pp(df, lvl+1) + pp(df, p)(lvl+1) if(c < sz - 1) { sb.append("\n\n") } c = c + 1 }) - sb.append("\n") - ind(lvl) + nl sb.append("}\n") - } - case AbstractClassDef(id, parent) => { - ind(lvl) + case AbstractClassDef(id, parent) => + nl sb.append("sealed abstract class ") sb.append(idToString(id)) parent.foreach(p => sb.append(" extends " + idToString(p.id))) - } - case CaseClassDef(id, parent, varDecls) => { - ind(lvl) + case CaseClassDef(id, parent, varDecls) => + nl sb.append("case class ") sb.append(idToString(id)) sb.append("(") @@ -377,7 +376,7 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { varDecls.foreach(vd => { sb.append(idToString(vd.id)) sb.append(": ") - pp(vd.tpe, lvl) + pp(vd.tpe, p) if(c < sz - 1) { sb.append(", ") } @@ -385,30 +384,29 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { }) sb.append(")") parent.foreach(p => sb.append(" extends " + idToString(p.id))) - } case fd: FunDef => for(a <- fd.annotations) { - ind(lvl) + ind sb.append("@" + a + "\n") } fd.precondition.foreach(prec => { - ind(lvl) + ind sb.append("@pre : ") - pp(prec, lvl) + pp(prec, p)(lvl) sb.append("\n") }) fd.postcondition.foreach{ case (id, postc) => { - ind(lvl) + ind sb.append("@post: ") sb.append(idToString(id)+" => ") - pp(postc, lvl) + pp(postc, p)(lvl) sb.append("\n") }} - ind(lvl) + ind sb.append("def ") sb.append(idToString(fd.id)) sb.append("(") @@ -419,7 +417,7 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { fd.args.foreach(arg => { sb.append(arg.id) sb.append(" : ") - pp(arg.tpe, lvl) + pp(arg.tpe, p) if(c < sz - 1) { sb.append(", ") @@ -428,46 +426,45 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { }) sb.append(") : ") - pp(fd.returnType, lvl) + pp(fd.returnType, p) sb.append(" = ") fd.body match { case Some(body) => - pp(body, lvl) + pp(body, p)(lvl) case None => sb.append("[unknown function implementation]") } - case _ => sb.append("Defn?") + case _ => sb.append("Tree? (" + tree.getClass + ")") } } } +trait PrettyPrintable { + self: Tree => + + def printWith(printer: PrettyPrinter)(implicit lvl: Int): Unit +} + class EquivalencePrettyPrinter() extends PrettyPrinter() { override def idToString(id: Identifier) = id.name } -object PrettyPrinter { - def apply(tree: Expr): String = { - val printer = new PrettyPrinter() - printer.pp(tree, 0) - printer.toString - } - - def apply(tpe: TypeTree): String = { - val printer = new PrettyPrinter() - printer.pp(tpe, 0) - printer.toString - } +abstract class PrettyPrinterFactory { + def create: PrettyPrinter - def apply(defn: Definition): String = { - val printer = new PrettyPrinter() - printer.pp(defn, 0) + def apply(tree: Tree, ind: Int = 0): String = { + val printer = create + printer.pp(tree, None)(ind) printer.toString } } -trait PrettyPrintable { - def printWith(lvl: Int, printer: PrettyPrinter): Unit +object PrettyPrinter extends PrettyPrinterFactory { + def create = new PrettyPrinter() } +object EquivalencePrettyPrinter extends PrettyPrinterFactory { + def create = new EquivalencePrettyPrinter() +} diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index b3936af8240528244a091bb21d5e33f9c72660f6..348e2ef146423fadb5937ba1780a85800b3e28bb 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -16,305 +16,277 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb import java.lang.StringBuffer + override def ppBinary(left: Tree, right: Tree, op: String)(implicit parent: Option[Tree], lvl: Int) { + pp(left, parent) + sb.append(op) + pp(right, parent) + } + // EXPRESSIONS // all expressions are printed in-line - override def pp(tree: Expr, lvl: Int): Unit = tree match { - case Variable(id) => sb.append(idToString(id)) - case DeBruijnIndex(idx) => sys.error("Not Valid Scala") - case LetTuple(ids,d,e) => - sb.append("locally {\n") - ind(lvl+1) - sb.append("val (" ) - for (((id, tpe), i) <- ids.map(id => (id, id.getType)).zipWithIndex) { - sb.append(idToString(id)+": ") - pp(tpe, lvl) - if (i != ids.size-1) { - sb.append(", ") - } - } - sb.append(") = ") - pp(d, lvl+1) - sb.append("\n") - ind(lvl+1) - pp(e, lvl+1) - sb.append("\n") - ind(lvl) - sb.append("}\n") - ind(lvl) - - case Let(b,d,e) => - sb.append("locally {\n") - ind(lvl+1) - sb.append("val " + b + " = ") - pp(d, lvl+1) - sb.append("\n") - ind(lvl+1) - pp(e, lvl+1) - sb.append("\n") - ind(lvl) - sb.append("}\n") - ind(lvl) - - case LetDef(fd, body) => - sb.append("{\n") - pp(fd, lvl+1) - sb.append("\n") - sb.append("\n") - ind(lvl) - pp(body, lvl) - sb.append("}\n") - - case And(exprs) => ppNary(exprs, "(", " && ", ")", lvl) // \land - case Or(exprs) => ppNary(exprs, "(", " || ", ")", lvl) // \lor - case Not(Equals(l, r)) => ppBinary(l, r, " != ", lvl) // \neq - case UMinus(expr) => ppUnary(expr, "-(", ")", lvl) - case Equals(l,r) => ppBinary(l, r, " == ", lvl) - - case IntLiteral(v) => sb.append(v) - case BooleanLiteral(v) => sb.append(v) - case StringLiteral(s) => sb.append("\"" + s + "\"") - case UnitLiteral => sb.append("()") - - /* These two aren't really supported in Scala, but we know how to encode them. */ - case Implies(l,r) => pp(Or(Not(l), r), lvl) - case Iff(l,r) => ppBinary(l, r, " == ", lvl) - - case Tuple(exprs) => ppNary(exprs, "(", ", ", ")", lvl) - case TupleSelect(t, i) => - pp(t, lvl) - sb.append("._" + i) - - case CaseClass(cd, args) => - sb.append(idToString(cd.id)) - if (cd.isCaseObject) { - ppNary(args, "", "", "", lvl) + override def pp(tree: Tree, parent: Option[Tree])(implicit lvl: Int): Unit = { + implicit val p = Some(tree) + + def optParentheses(body: => Unit) { + val rp = requiresParentheses(tree, parent) + if (rp) sb.append("(") + body + if (rp) sb.append(")") + } + + def optBraces(body: Int => Unit) { + val rp = requiresBraces(tree, parent) + if (rp) { + sb.append("{\n") + ind(lvl+1) + + body(lvl+1) + + sb.append("\n") + ind(lvl) + sb.append("}\n") } else { - ppNary(args, "(", ", ", ")", lvl) + body(lvl) } - - case CaseClassInstanceOf(cd, e) => - pp(e, lvl) - sb.append(".isInstanceOf[" + idToString(cd.id) + "]") - - case CaseClassSelector(_, cc, id) => - pp(cc, lvl) - sb.append("." + idToString(id)) - - case FunctionInvocation(fd, args) => - sb.append(idToString(fd.id)) - ppNary(args, "(", ", ", ")", lvl) - - case Plus(l,r) => ppBinary(l, r, " + ", lvl) - case Minus(l,r) => ppBinary(l, r, " - ", lvl) - case Times(l,r) => ppBinary(l, r, " * ", lvl) - case Division(l,r) => ppBinary(l, r, " / ", lvl) - case Modulo(l,r) => ppBinary(l, r, " % ", lvl) - case LessThan(l,r) => ppBinary(l, r, " < ", lvl) - case GreaterThan(l,r) => ppBinary(l, r, " > ", lvl) - case LessEquals(l,r) => ppBinary(l, r, " <= ", lvl) // \leq - case GreaterEquals(l,r) => ppBinary(l, r, " >= ", lvl) // \geq - case FiniteSet(rs) => ppNary(rs, "Set(", ", ", ")", lvl) - case FiniteMultiset(rs) => ppNary(rs, "{|", ", ", "|}", lvl) - case EmptyMultiset(_) => sys.error("Not Valid Scala") - case ElementOfSet(e, s) => ppBinary(s, e, " contains ", lvl) - //case ElementOfSet(s,e) => ppBinary(s, e, " \u2208 ", lvl) // \in - //case SubsetOf(l,r) => ppBinary(l, r, " \u2286 ", lvl) // \subseteq - //case Not(SubsetOf(l,r)) => ppBinary(l, r, " \u2288 ", lvl) // \notsubseteq - case SetMin(s) => - pp(s, lvl) - sb.append(".min") - case SetMax(s) => - pp(s, lvl) - sb.append(".max") - case SetUnion(l,r) => ppBinary(l, r, " ++ ", lvl) // \cup - // case MultisetUnion(l,r) => ppBinary(l, r, " \u222A ", lvl) // \cup - // case MapUnion(l,r) => ppBinary(l, r, " \u222A ", lvl) // \cup - case SetDifference(l,r) => ppBinary(l, r, " -- ", lvl) - // case MultisetDifference(l,r) => ppBinary(l, r, " \\ ", lvl) - case SetIntersection(l,r) => ppBinary(l, r, " & ", lvl) // \cap - // case MultisetIntersection(l,r) => ppBinary(l, r, " \u2229 ", lvl) // \cap - case SetCardinality(t) => ppUnary(t, "", ".size", lvl) - // case MultisetCardinality(t) => ppUnary(t, "|", "|", lvl) - // case MultisetPlus(l,r) => ppBinary(l, r, " \u228E ", lvl) // U+ - // case MultisetToSet(e) => pp(e, lvl).append(".toSet") - case FiniteMap(rs) => - sb.append("{") - val sz = rs.size - var c = 0 - rs.foreach{case (k, v) => { - pp(k, lvl); sb.append(" -> "); pp(v, lvl); c += 1 ; if(c < sz) sb.append(", ") - }} - sb.append("}") - - case MapGet(m,k) => - pp(m, lvl) - ppNary(Seq(k), "(", ",", ")", lvl) - - case MapIsDefinedAt(m,k) => { - pp(m, lvl) - sb.append(".isDefinedAt") - ppNary(Seq(k), "(", ",", ")", lvl) } - case ArrayLength(a) => - pp(a, lvl) - sb.append(".length") - - case ArrayClone(a) => - pp(a, lvl) - sb.append(".clone") - - case ArrayFill(size, v) => - sb.append("Array.fill(") - pp(size, lvl) - sb.append(")(") - pp(v, lvl) - sb.append(")") - - case ArrayMake(v) => sys.error("Not Scala Code") - case ArraySelect(ar, i) => - pp(ar, lvl) - sb.append("(") - pp(i, lvl) - sb.append(")") - - case ArrayUpdated(ar, i, v) => - pp(ar, lvl) - sb.append(".updated(") - pp(i, lvl) - sb.append(", ") - pp(v, lvl) - sb.append(")") - - case FiniteArray(exprs) => - ppNary(exprs, "Array(", ", ", ")", lvl) - - case Distinct(exprs) => - sb.append("distinct") - ppNary(exprs, "(", ", ", ")", lvl) - - case IfExpr(c, t, e) => - sb.append("if (") - pp(c, lvl) - sb.append(") {\n") - ind(lvl+1) - pp(t, lvl+1) - sb.append("\n") - ind(lvl) - sb.append("} else {\n") - ind(lvl+1) - pp(e, lvl+1) - sb.append("\n") - ind(lvl) - sb.append("}") - - case Choose(ids, pred) => - sb.append("(choose { (") - for (((id, tpe), i) <- ids.map(id => (id, id.getType)).zipWithIndex) { - sb.append(idToString(id)+": ") - pp(tpe, lvl) - if (i != ids.size-1) { - sb.append(", ") + + tree match { + case Variable(id) => sb.append(idToString(id)) + case DeBruijnIndex(idx) => sys.error("Not Valid Scala") + case LetTuple(ids,d,e) => + optBraces { implicit lvl => + sb.append("val (" ) + for (((id, tpe), i) <- ids.map(id => (id, id.getType)).zipWithIndex) { + sb.append(idToString(id)+": ") + pp(tpe, p) + if (i != ids.size-1) { + sb.append(", ") + } + } + sb.append(") = ") + pp(d, p) + sb.append("\n") + ind + pp(e, p) + sb.append("\n") + } + + case Let(b,d,e) => + optBraces { implicit lvl => + sb.append("val " + b + " = ") + pp(d, p) + sb.append("\n") + ind + pp(e, p) + sb.append("\n") + } + + case LetDef(fd, body) => + optBraces { implicit lvl => + pp(fd, p) + sb.append("\n") + sb.append("\n") + ind + pp(body, p) + } + + case And(exprs) => optParentheses { ppNary(exprs, "", " && ", "") } + case Or(exprs) => optParentheses { ppNary(exprs, "", " || ", "") } + case Not(Equals(l, r)) => optParentheses { ppBinary(l, r, " != ") } + case UMinus(expr) => ppUnary(expr, "-(", ")") + case Equals(l,r) => optParentheses { ppBinary(l, r, " == ") } + + case IntLiteral(v) => sb.append(v) + case BooleanLiteral(v) => sb.append(v) + case StringLiteral(s) => sb.append("\"" + s + "\"") + case UnitLiteral => sb.append("()") + + /* These two aren't really supported in Scala, but we know how to encode them. */ + case Implies(l,r) => pp(Or(Not(l), r), p) + case Iff(l,r) => optParentheses { ppBinary(l, r, " == ") } + + case Tuple(exprs) => ppNary(exprs, "(", ", ", ")") + case TupleSelect(t, i) => + pp(t, p) + sb.append("._" + i) + + case CaseClass(cd, args) => + sb.append(idToString(cd.id)) + if (cd.isCaseObject) { + ppNary(args, "", "", "") + } else { + ppNary(args, "(", ", ", ")") + } + + case CaseClassInstanceOf(cd, e) => + pp(e, p) + sb.append(".isInstanceOf[" + idToString(cd.id) + "]") + + case CaseClassSelector(_, cc, id) => + pp(cc, p) + sb.append("." + idToString(id)) + + case FunctionInvocation(fd, args) => + sb.append(idToString(fd.id)) + ppNary(args, "(", ", ", ")") + + case Plus(l,r) => optParentheses { ppBinary(l, r, " + ") } + case Minus(l,r) => optParentheses { ppBinary(l, r, " - ") } + case Times(l,r) => optParentheses { ppBinary(l, r, " * ") } + case Division(l,r) => optParentheses { ppBinary(l, r, " / ") } + case Modulo(l,r) => optParentheses { ppBinary(l, r, " % ") } + case LessThan(l,r) => optParentheses { ppBinary(l, r, " < ") } + case GreaterThan(l,r) => optParentheses { ppBinary(l, r, " > ") } + case LessEquals(l,r) => optParentheses { ppBinary(l, r, " <= ") } + case GreaterEquals(l,r) => optParentheses { ppBinary(l, r, " >= ") } + case fs @ FiniteSet(rs) => + if (rs.isEmpty) { + fs.getType match { + case SetType(b) => + sb.append("Set[") + pp(b, p) + sb.append("]()") + case _ => + sb.append("Set()") } + } else { + ppNary(rs, "Set(", ", ", ")") + } + case FiniteMultiset(rs) => ppNary(rs, "{|", ", ", "|}") + case EmptyMultiset(_) => sys.error("Not Valid Scala") + case ElementOfSet(e, s) => optParentheses { ppBinary(s, e, " contains ") } + case SetUnion(l,r) => optParentheses { ppBinary(l, r, " ++ ") } + case SetDifference(l,r) => optParentheses { ppBinary(l, r, " -- ") } + case SetIntersection(l,r) => optParentheses { ppBinary(l, r, " & ") } + case SetMin(s) => + pp(s, p) + sb.append(".min") + case SetMax(s) => + pp(s, p) + sb.append(".max") + case SetCardinality(t) => ppUnary(t, "", ".size") + case FiniteMap(rs) => + sb.append("{") + val sz = rs.size + var c = 0 + rs.foreach{case (k, v) => { + pp(k, p); sb.append(" -> "); pp(v, p); c += 1 ; if(c < sz) sb.append(", ") + }} + sb.append("}") + + case MapGet(m,k) => + pp(m, p) + ppNary(Seq(k), "(", ",", ")") + + case MapIsDefinedAt(m,k) => { + pp(m, p) + sb.append(".isDefinedAt") + ppNary(Seq(k), "(", ",", ")") } - sb.append(") =>\n") - ind(lvl+1) - pp(pred, lvl+1) - sb.append("\n") - ind(lvl) - sb.append("})") - - case mex @ MatchExpr(s, csc) => { - - sb.append("(") - pp(s, lvl) - // if(mex.posInfo != "") { - // sb.append(" match@(" + mex.posInfo + ") {\n") - // } else { - sb.append(" match {\n") - // } - - csc.foreach(cs => { - ind(lvl+1) - sb.append("case ") - pp(cs.pattern) - cs.theGuard.foreach(g => { - sb.append(" if ") - pp(g, lvl+1) - }) - sb.append(" =>\n") - ind(lvl+2) - pp(cs.rhs, lvl+2) - sb.append("\n") - }) - ind(lvl) - sb.append("}") - sb.append(")") - } + case ArrayLength(a) => + pp(a, p) + sb.append(".length") + + case ArrayClone(a) => + pp(a, p) + sb.append(".clone") + + case ArrayFill(size, v) => + sb.append("Array.fill(") + pp(size, p) + sb.append(")(") + pp(v, p) + sb.append(")") - case Not(expr) => ppUnary(expr, "!(", ")", lvl) // \neg + case ArrayMake(v) => sys.error("Not Scala Code") + case ArraySelect(ar, i) => + pp(ar, p) + sb.append("(") + pp(i, p) + sb.append(")") - case e @ Error(desc) => { - sb.append("leon.Utils.error[") - pp(e.getType, lvl) - sb.append("](\"" + desc + "\")") - } + case ArrayUpdated(ar, i, v) => + pp(ar, p) + sb.append(".updated(") + pp(i, p) + sb.append(", ") + pp(v, p) + sb.append(")") - case (expr: PrettyPrintable) => expr.printWith(lvl, this) + case FiniteArray(exprs) => + ppNary(exprs, "Array(", ", ", ")") + + case Distinct(exprs) => + sb.append("distinct") + ppNary(exprs, "(", ", ", ")") + + case IfExpr(c, t, e) => + optParentheses { + sb.append("if (") + pp(c, p) + sb.append(") {\n") + ind(lvl+1) + pp(t, p)(lvl+1) + sb.append("\n") + ind(lvl) + sb.append("} else {\n") + ind(lvl+1) + pp(e, p)(lvl+1) + sb.append("\n") + ind(lvl) + sb.append("}") + } - case _ => sb.append("Expr?") - } + case Choose(ids, pred) => + optParentheses { + sb.append("choose { (") + for (((id, tpe), i) <- ids.map(id => (id, id.getType)).zipWithIndex) { + sb.append(idToString(id)+": ") + pp(tpe, p) + if (i != ids.size-1) { + sb.append(", ") + } + } + sb.append(") =>\n") + ind(lvl+1) + pp(pred, p)(lvl+1) + sb.append("\n") + ind(lvl) + sb.append("}") + } + + case mex @ MatchExpr(s, csc) => { + optParentheses { + pp(s, p) + sb.append(" match {\n") - // TYPE TREES - // all type trees are printed in-line - - override def pp(tpe: TypeTree, lvl: Int): Unit = tpe match { - case Untyped => sb.append("???") - case UnitType => sb.append("Unit") - case Int32Type => sb.append("Int") - case BooleanType => sb.append("Boolean") - case ArrayType(bt) => - sb.append("Array[") - pp(bt, lvl) - sb.append("]") - case SetType(bt) => - sb.append("Set[") - pp(bt, lvl) - sb.append("]") - case MapType(ft,tt) => - sb.append("Map[") - pp(ft, lvl) - sb.append(",") - pp(tt, lvl) - sb.append("]") - case MultisetType(bt) => - sb.append("Multiset[") - pp(bt, lvl) - sb.append("]") - case TupleType(tpes) => ppNaryType(tpes, "(", ", ", ")", lvl) - case FunctionType(fts, tt) => - if (fts.size > 1) { - ppNaryType(fts, "(", ", ", ")", lvl) - } else if (fts.size == 1) { - pp(fts.head, lvl) + csc.foreach { cs => + ind(lvl+1) + pp(cs, p)(lvl+1) + sb.append("\n") + } + + ind(lvl) + sb.append("}") + } } - sb.append(" => ") - pp(tt, lvl) - case c: ClassType => sb.append(idToString(c.classDef.id)) - case _ => sb.append("Type?") - } - // DEFINITIONS - // all definitions are printed with an end-of-line - override def pp(defn: Definition, lvl: Int): Unit = { + case Not(expr) => sb.append("!"); optParentheses { pp(expr, p) } - defn match { - case Program(id, mainObj) => { - assert(lvl == 0) - pp(mainObj, lvl) + case e @ Error(desc) => { + sb.append("leon.Utils.error[") + pp(e.getType, p) + sb.append("](\"" + desc + "\")") } - case ObjectDef(id, defs, invs) => { - ind(lvl) + case (expr: PrettyPrintable) => expr.printWith(this) + + // Definitions + case Program(id, mainObj) => + assert(lvl == 0) + pp(mainObj, p) + + case ObjectDef(id, defs, invs) => sb.append("object ") sb.append(idToString(id)) sb.append(" {\n") @@ -323,7 +295,8 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb val sz = defs.size defs.foreach(df => { - pp(df, lvl+1) + ind(lvl+1) + pp(df, p)(lvl+1) if(c < sz - 1) { sb.append("\n\n") } @@ -333,16 +306,13 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb sb.append("\n") ind(lvl) sb.append("}\n") - } case AbstractClassDef(id, parent) => - ind(lvl) sb.append("sealed abstract class ") sb.append(idToString(id)) - parent.foreach(p => sb.append(" extends " + p.id)) + parent.foreach(p => sb.append(" extends " + idToString(p.id))) case CaseClassDef(id, parent, varDecls) => - ind(lvl) sb.append("case class ") sb.append(idToString(id)) sb.append("(") @@ -352,7 +322,7 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb varDecls.foreach(vd => { sb.append(idToString(vd.id)) sb.append(": ") - pp(vd.tpe, lvl) + pp(vd.tpe, p) if(c < sz - 1) { sb.append(", ") } @@ -362,19 +332,17 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb parent.foreach(p => sb.append(" extends " + idToString(p.id))) case fd: FunDef => - - ind(lvl) sb.append("def ") sb.append(idToString(fd.id)) sb.append("(") val sz = fd.args.size var c = 0 - + fd.args.foreach(arg => { sb.append(idToString(arg.id)) sb.append(" : ") - pp(arg.tpe, lvl) + pp(arg.tpe, p) if(c < sz - 1) { sb.append(", ") @@ -383,100 +351,74 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb }) sb.append(") : ") - pp(fd.returnType, lvl) - sb.append(" = {") + pp(fd.returnType, p) + sb.append(" = {\n") + ind(lvl+1) fd.precondition match { case None => case Some(prec) => - ind(lvl+1) sb.append("require(") - pp(prec, lvl+1) + pp(prec, p)(lvl+1) sb.append(");\n") + ind(lvl+1) } fd.body match { case Some(body) => - pp(body, lvl) + pp(body, p)(lvl+1) case None => sb.append("???") } + sb.append("\n") + ind + fd.postcondition match { case None => sb.append("}") case Some((id, postc)) => - sb.append("} ensuring(") - pp(Variable(id), lvl) + sb.append("} ensuring { ") + pp(Variable(id), p) sb.append(" => ") - pp(postc, lvl) - sb.append(")") + pp(postc, p) + sb.append(" }") } - case _ => sb.append("Defn?") - } - } - - def pp(p: Pattern): Unit = p match { - //case InstanceOfPattern(None, ctd) => - //case InstanceOfPattern(Some(id), ctd) => - case CaseClassPattern(bndr, ccd, subps) => { - bndr.foreach(b => sb.append(b + " @ ")) - sb.append(idToString(ccd.id)).append("(") - var c = 0 - val sz = subps.size - subps.foreach(sp => { - pp(sp) - if(c < sz - 1) - sb.append(", ") - c = c + 1 - }) - sb.append(")") - } - case WildcardPattern(None) => sb.append("_") - case WildcardPattern(Some(id)) => sb.append(idToString(id)) - case InstanceOfPattern(bndr, ccd) => { - bndr.foreach(b => sb.append(b + " : ")) - sb.append(idToString(ccd.id)) - } - case TuplePattern(bndr, subPatterns) => { - bndr.foreach(b => sb.append(b + " @ ")) - sb.append("(") - subPatterns.init.foreach(p => { - pp(p) - sb.append(", ") - }) - pp(subPatterns.last) - sb.append(")") + case _ => super.pp(tree, parent)(lvl) } - case _ => sb.append("Pattern?") } -} - -object ScalaPrinter { - def apply(tree: Expr, indent: Int): String = { - val printer = new ScalaPrinter() - printer.pp(tree, indent) - printer.toString + private def requiresBraces(ex: Tree, within: Option[Tree]): Boolean = (ex, within) match { + case (_, None) => false + case (_, Some(_: Definition)) => false + case (_, Some(_: MatchExpr | _: Let | _: LetTuple | _: LetDef)) => false + case (_, _) => true } - def apply(tree: Expr): String = { - val printer = new ScalaPrinter() - printer.pp(tree, 0) - printer.toString + private def precedence(ex: Expr): Int = ex match { + case (_: ElementOfSet) => 0 + case (_: Or) => 1 + case (_: And) => 3 + case (_: GreaterThan | _: GreaterEquals | _: LessEquals | _: LessThan) => 4 + case (_: Equals | _: Iff | _: Not) => 5 + case (_: Plus | _: Minus | _: SetUnion| _: SetDifference) => 6 + case (_: Times | _: Division | _: Modulo) => 7 + case _ => 7 } - def apply(tpe: TypeTree): String = { - val printer = new ScalaPrinter() - printer.pp(tpe, 0) - printer.toString + private def requiresParentheses(ex: Tree, within: Option[Tree]): Boolean = (ex, within) match { + case (_, None) => false + case (_, Some(_: Definition)) => false + case (_, Some(_: MatchExpr | _: Let | _: LetTuple | _: LetDef | _: IfExpr)) => false + case (_, Some(_: FunctionInvocation)) => false + case (ie: IfExpr, _) => true + case (e1: Expr, Some(e2: Expr)) if precedence(e1) > precedence(e2) => false + case (_, _) => true } +} - def apply(defn: Definition): String = { - val printer = new ScalaPrinter() - printer.pp(defn, 0) - printer.toString - } +object ScalaPrinter extends PrettyPrinterFactory { + def create = new ScalaPrinter() } diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 68d87237ee13b36861249246002f6c87183a1718..16dda8f063d0d75c96c5b4df726dbee59559a920 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -12,7 +12,7 @@ object Trees { /* EXPRESSIONS */ - abstract class Expr extends Typed with Serializable { + abstract class Expr extends Tree with Typed with Serializable { override def toString: String = PrettyPrinter(this) } @@ -141,7 +141,7 @@ object Trees { override def hashCode: Int = scrutinee.hashCode+cases.hashCode } - sealed abstract class MatchCase extends Serializable { + sealed abstract class MatchCase extends Tree { val pattern: Pattern val rhs: Expr val theGuard: Option[Expr] @@ -158,7 +158,7 @@ object Trees { def expressions = List(guard, rhs) } - sealed abstract class Pattern extends Serializable { + sealed abstract class Pattern extends Tree { val subPatterns: Seq[Pattern] val binder: Option[Identifier] diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index ab7c03cae05bf4aa2b42f5ae095ecd732ca370be..0c37ca43f298ffb478aca7ec86fbaeb87c2736dd 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -56,7 +56,7 @@ object TypeTrees { } - sealed abstract class TypeTree extends Serializable { + sealed abstract class TypeTree extends Tree { override def toString: String = PrettyPrinter(this) } diff --git a/src/main/scala/leon/synthesis/FileInterface.scala b/src/main/scala/leon/synthesis/FileInterface.scala index aee3f4342f64733e8e2cb6efce066659de971080..b97bcc7c7a7fe597e1acdabd37b92e578c7e4352 100644 --- a/src/main/scala/leon/synthesis/FileInterface.scala +++ b/src/main/scala/leon/synthesis/FileInterface.scala @@ -4,12 +4,13 @@ package leon package synthesis import purescala.Trees._ +import purescala.Common.Tree import purescala.ScalaPrinter import java.io.File class FileInterface(reporter: Reporter) { - def updateFile(origFile: File, solutions: Map[ChooseInfo, Expr], ignoreMissing: Boolean = false) { + def updateFile(origFile: File, solutions: Map[ChooseInfo, Expr]) { import java.io.{File, BufferedWriter, FileWriter} val FileExt = """^(.+)\.([^.]+)$""".r @@ -17,6 +18,7 @@ class FileInterface(reporter: Reporter) { case FileExt(path, "scala") => var i = 0 def savePath = path+".scala."+i + while (new File(savePath).isFile()) { i += 1 } @@ -26,8 +28,10 @@ class FileInterface(reporter: Reporter) { val newFile = new File(origFile.getAbsolutePath()) origFile.renameTo(backup) - - val newCode = substitueChooses(origCode, solutions, ignoreMissing) + var newCode = origCode + for ( (ci, e) <- solutions) { + newCode = substitute(newCode, CodePattern.forChoose(ci), e) + } val out = new BufferedWriter(new FileWriter(newFile)) out.write(newCode) @@ -37,7 +41,13 @@ class FileInterface(reporter: Reporter) { } } - def substitueChooses(str: String, solutions: Map[ChooseInfo, Expr], ignoreMissing: Boolean = false): String = { + case class CodePattern(startWith: String, posIntInfo: (Int, Int), blocks: Int) + + object CodePattern { + def forChoose(ci: ChooseInfo) = CodePattern("choose", ci.ch.posIntInfo, 1) + } + + def substitute(str: String, pattern: CodePattern, subst: Tree): String = { var lines = List[Int]() // Compute line positions @@ -87,7 +97,7 @@ class FileInterface(reporter: Reporter) { var newStrOffset = 0 do { - lastFound = str.indexOf("choose", lastFound+1) + lastFound = str.indexOf(pattern.startWith, lastFound+1) if (lastFound > -1) { val (lineno, lineoffset) = lineOf(lastFound) @@ -96,33 +106,31 @@ class FileInterface(reporter: Reporter) { val indent = getLineIndentation(lastFound) - solutions.find(_._1.ch.posIntInfo == (lineno, scalaOffset)) match { - case Some((choose, solution)) => - var lvl = 0; - var i = lastFound + 6; - var continue = true; - do { - val c = str.charAt(i) - if (c == '(' || c == '{') { - lvl += 1 - } else if (c == ')' || c == '}') { - lvl -= 1 - if (lvl == 0) { + if (pattern.posIntInfo == (lineno, scalaOffset)) { + var lvl = 0; + var i = lastFound + 6; + var continue = true; + do { + var blocksRemaining = pattern.blocks + val c = str.charAt(i) + if (c == '(' || c == '{') { + lvl += 1 + } else if (c == ')' || c == '}') { + lvl -= 1 + if (lvl == 0) { + blocksRemaining -= 1 + if (blocksRemaining == 0) { continue = false } } - i += 1 - } while(continue) - - val newCode = ScalaPrinter(solution, indent/2) - newStr = (newStr.substring(0, lastFound+newStrOffset))+newCode+(newStr.substring(i+newStrOffset, newStr.length)) + } + i += 1 + } while(continue) - newStrOffset += -(i-lastFound)+newCode.length + val newCode = ScalaPrinter(subst, indent/2) + newStr = (newStr.substring(0, lastFound+newStrOffset))+newCode+(newStr.substring(i+newStrOffset, newStr.length)) - case _ => - if (!ignoreMissing) { - reporter.warning("Could not find solution corresponding to choose at "+lineno+":"+scalaOffset) - } + newStrOffset += -(i-lastFound)+newCode.length } } } while(lastFound> 0) diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index a4eaaa451a2c38a6160624a21d3a8bd518b1ae26..180b8f554a9c0614ff3722d449cf1e45222e1696 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -7,6 +7,7 @@ import purescala.TreeOps._ import solvers.z3._ import purescala.Trees._ +import purescala.Common._ import purescala.ScalaPrinter import purescala.Definitions.{Program, FunDef} @@ -118,10 +119,19 @@ object SynthesisPhase extends LeonPhase[Program, Program] { var chooses = ChooseInfo.extractFromProgram(ctx, p, options).filter(toProcess) + var functions = Set[FunDef]() + val results = chooses.map { ci => val (sol, isComplete) = ci.synthesizer.synthesize() - ci -> sol.toSimplifiedExpr(ctx, p) + val fd = ci.fd + + val expr = sol.toSimplifiedExpr(ctx, p) + fd.body = fd.body.map(b => replace(Map(ci.ch -> expr), b)) + + functions += fd + + ci -> expr }.toMap if (options.inPlace) { @@ -129,15 +139,12 @@ object SynthesisPhase extends LeonPhase[Program, Program] { new FileInterface(ctx.reporter).updateFile(file, results) } } else { - for ((ci, ex) <- results) { - val middle = " In "+ci.fd.id.toString+", synthesis of: " - + for (fd <- functions) { + val middle = " "+fd.id.name+" " val remSize = (80-middle.length) - ctx.reporter.info("-"*math.floor(remSize/2).toInt+middle+"-"*math.ceil(remSize/2).toInt) - ctx.reporter.info(ci.ch) - ctx.reporter.info("-"*35+" Result: "+"-"*36) - ctx.reporter.info(ScalaPrinter(ex)) + + ctx.reporter.info(ScalaPrinter(fd)) ctx.reporter.info("") } } diff --git a/src/main/scala/leon/xlang/Trees.scala b/src/main/scala/leon/xlang/Trees.scala index 866cfa5d96306a8ee57cfc19d7b1957136e076e8..8595fd3f2805ff822e5869b0a870f59ae8aaa2d1 100644 --- a/src/main/scala/leon/xlang/Trees.scala +++ b/src/main/scala/leon/xlang/Trees.scala @@ -23,11 +23,11 @@ object Trees { Some((args :+ rest, exprs => Block(exprs.init, exprs.last))) } - def printWith(lvl: Int, printer: PrettyPrinter) { + def printWith(printer: PrettyPrinter)(implicit lvl: Int) { printer.append("{\n") (exprs :+ last).foreach(e => { printer.ind(lvl+1) - printer.pp(e, lvl+1) + printer.pp(e, Some(this))(lvl+1) printer.append("\n") }) printer.ind(lvl) @@ -44,11 +44,11 @@ object Trees { Some((expr, Assignment(varId, _))) } - def printWith(lvl: Int, printer: PrettyPrinter) { + def printWith(printer: PrettyPrinter)(implicit lvl: Int) { printer.append("(") printer.append(varId.name) printer.append(" = ") - printer.pp(expr,lvl) + printer.pp(expr, Some(this)) printer.append(")") } } @@ -65,23 +65,23 @@ object Trees { Some((cond, body, (t1, t2) => While(t1, t2).setInvariant(this.invariant).setPosInfo(this))) } - def printWith(lvl: Int, printer: PrettyPrinter) { + def printWith(printer: PrettyPrinter)(implicit lvl: Int) { invariant match { case Some(inv) => { printer.append("\n") - printer.ind(lvl) + printer.ind printer.append("@invariant: ") - printer.pp(inv, lvl) + printer.pp(inv, Some(this)) printer.append("\n") - printer.ind(lvl) + printer.ind } case None => } printer.append("while(") - printer.pp(cond, lvl) + printer.pp(cond, Some(this)) printer.append(")\n") printer.ind(lvl+1) - printer.pp(body, lvl+1) + printer.pp(body, Some(this))(lvl+1) printer.append("\n") } } @@ -91,14 +91,14 @@ object Trees { Some((pred, (expr: Expr) => Epsilon(expr).setType(this.getType).setPosInfo(this))) } - def printWith(lvl: Int, printer: PrettyPrinter) { + def printWith(printer: PrettyPrinter)(implicit lvl: Int) { printer match { case _: ScalaPrinter => sys.error("Not Scala Code") case _ => printer.append("epsilon(x" + this.posIntInfo._1 + "_" + this.posIntInfo._2 + ". ") - printer.pp(pred, lvl) + printer.pp(pred, Some(this)) printer.append(")") } } @@ -106,7 +106,7 @@ object Trees { case class EpsilonVariable(pos: (Int, Int)) extends Expr with Terminal with PrettyPrintable{ - def printWith(lvl: Int, printer: PrettyPrinter) { + def printWith(printer: PrettyPrinter)(implicit lvl: Int) { val (row, col) = pos printer.append("x" + row + "_" + col) } @@ -124,29 +124,28 @@ object Trees { Some((expr, body, (e: Expr, b: Expr) => LetVar(binders, e, b))) } - def printWith(lvl: Int, printer: PrettyPrinter) { + def printWith(printer: PrettyPrinter)(implicit lvl: Int) { printer match { case _: ScalaPrinter => val LetVar(b,d,e) = this printer.append("locally {\n") printer.ind(lvl+1) printer.append("var " + b + " = ") - printer.pp(d, lvl+1) + printer.pp(d, Some(this))(lvl+1) printer.append("\n") printer.ind(lvl+1) - printer.pp(e, lvl+1) - printer.append("\n") - printer.ind(lvl) - printer.append("}\n") - printer.ind(lvl) + printer.pp(e, Some(this))(lvl+1) + printer.nl + printer.append("}") + printer.nl case _ => val LetVar(b,d,e) = this printer.append("(letvar (" + b + " := "); - printer.pp(d, lvl) + printer.pp(d, Some(this)) printer.append(") in\n") printer.ind(lvl+1) - printer.pp(e, lvl+1) + printer.pp(e, Some(this))(lvl+1) printer.append(")") } } @@ -157,14 +156,14 @@ object Trees { Some((expr, (e: Expr) => Waypoint(i, e))) } - def printWith(lvl: Int, printer: PrettyPrinter) { + def printWith(printer: PrettyPrinter)(implicit lvl: Int) { printer match { case _: ScalaPrinter => sys.error("Not Scala Code") case _ => printer.append("waypoint_" + i + "(") - printer.pp(expr, lvl) + printer.pp(expr, Some(this)) printer.append(")") } } @@ -180,12 +179,12 @@ object Trees { Some((Seq(t1,t2,t3), (as: Seq[Expr]) => ArrayUpdate(as(0), as(1), as(2)))) } - def printWith(lvl: Int, printer: PrettyPrinter) { - printer.pp(array, lvl) + def printWith(printer: PrettyPrinter)(implicit lvl: Int) { + printer.pp(array, Some(this)) printer.append("(") - printer.pp(index, lvl) + printer.pp(index, Some(this)) printer.append(") = ") - printer.pp(newValue, lvl) + printer.pp(newValue, Some(this)) } }