diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index 29a23510173f017df5d81ae4bb388bf4618f1a87..aee9044bbe68572b9c7fb74e26cd08e60bd6c05e 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -16,17 +16,21 @@ trait CodeExtraction extends Extractors { import StructuralExtractors._ import ExpressionExtractors._ + private val varSubsts: scala.collection.mutable.Map[Identifier,Function0[Expr]] = scala.collection.mutable.Map.empty[Identifier,Function0[Expr]] + def extractCode(unit: CompilationUnit): Program = { import scala.collection.mutable.HashMap - // register where the symbols where extracted from - // val symbolDefMap = new HashMap[purescala.Symbols.Symbol,Tree] - def s2ps(tree: Tree): Expr = toPureScala(unit)(tree) match { case Some(ex) => ex case None => stopIfErrors; scala.Predef.error("unreachable error.") } + def st2ps(tree: Tree): funcheck.purescala.TypeTrees.TypeTree = toPureScalaType(unit)(tree) match { + case Some(tt) => tt + case None => stopIfErrors; scala.Predef.error("unreachable error.") + } + def extractTopLevelDef: ObjectDef = { val top = unit.body match { case p @ PackageDef(name, lst) if lst.size == 0 => { @@ -63,7 +67,7 @@ trait CodeExtraction extends Extractors { var funDefs: List[FunDef] = Nil tmpl.body.foreach(tree => { - println("[[[ " + tree + "]]]\n"); + //println("[[[ " + tree + "]]]\n"); tree match { case ExObjectDef(o2, t2) => { objectDefs = extractObjectDef(o2, t2) :: objectDefs } case ExAbstractClass(o2) => ; @@ -82,7 +86,30 @@ trait CodeExtraction extends Extractors { } def extractFunDef(name: Identifier, params: Seq[ValDef], tpt: Tree, body: Tree) = { - FunDef(name, scalaType2PureScala(unit, false)(tpt), Nil, null, None, None) + var realBody = body + var reqCont: Option[Expr] = None + var ensCont: Option[Expr] = None + + realBody match { + case ExEnsuredExpression(body2, resId, contract) => { + varSubsts(resId) = (() => ResultVariable()) + val c1 = s2ps(contract) + varSubsts.remove(resId) + realBody = body2 + ensCont = Some(c1) + } + case _ => ; + } + + realBody match { + case ExRequiredExpression(body3, contract) => { + realBody = body3 + reqCont = Some(s2ps(contract)) + } + case _ => ; + } + + FunDef(name, st2ps(tpt), Nil, s2ps(realBody), reqCont, ensCont) } // THE EXTRACTION CODE STARTS HERE @@ -95,10 +122,6 @@ trait CodeExtraction extends Extractors { case _ => "<program>" } - println("Top level sym:") - println(topLevelObjDef) - - //Program(programName, ObjectDef("Object", Nil, Nil)) Program(programName, topLevelObjDef) } @@ -115,22 +138,44 @@ trait CodeExtraction extends Extractors { } } + def toPureScalaType(unit: CompilationUnit)(typeTree: Tree): Option[funcheck.purescala.TypeTrees.TypeTree] = { + try { + Some(scalaType2PureScala(unit, false)(typeTree)) + } catch { + case ImpureCodeEncounteredException(_) => None + } + } + /** Forces conversion from scalac AST to purescala AST, throws an Exception * if impossible. If not in 'silent mode', non-pure AST nodes are reported as * errors. */ private def scala2PureScala(unit: CompilationUnit, silent: Boolean)(tree: Tree): Expr = { - tree match { + def rec(tr: Tree): Expr = tr match { case ExInt32Literal(v) => IntLiteral(v) case ExBooleanLiteral(v) => BooleanLiteral(v) - + case ExIntIdentifier(id) => varSubsts.get(id) match { + case Some(fun) => fun() + case None => Variable(id) + } + case ExAnd(l, r) => And(rec(l), rec(r)) + case ExPlus(l, r) => Plus(rec(l), rec(r)) + case ExEquals(l, r) => Equals(rec(l), rec(r)) + case ExGreaterThan(l, r) => GreaterThan(rec(l), rec(r)) + case ExGreaterEqThan(l, r) => GreaterEquals(rec(l), rec(r)) + case ExLessThan(l, r) => LessThan(rec(l), rec(r)) + case ExLessEqThan(l, r) => LessEquals(rec(l), rec(r)) + // default behaviour is to complain :) case _ => { if(!silent) { + println(tr) unit.error(tree.pos, "Could not extract as PureScala.") } throw ImpureCodeEncounteredException(tree) } } + + rec(tree) } private def scalaType2PureScala(unit: CompilationUnit, silent: Boolean)(tree: Tree): funcheck.purescala.TypeTrees.TypeTree = { diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala index 9ce171cdb6bfe46f1ca04111bb44327d7a34ff04..0e0d200dcb80fb7c069f5a25d5b2cabbff7ad863 100644 --- a/src/funcheck/Extractors.scala +++ b/src/funcheck/Extractors.scala @@ -21,9 +21,9 @@ trait Extractors { } } - object EnsuredExpression { + object ExEnsuredExpression { /** Extracts the 'ensuring' contract from an expression. */ - def unapply(tree: Tree): Option[(Tree,Function)] = tree match { + def unapply(tree: Tree): Option[(Tree,String,Tree)] = tree match { case Apply( Select( Apply( @@ -34,12 +34,12 @@ trait Extractors { ensuringName), (anonymousFun @ Function(ValDef(_, resultName, resultType, EmptyTree) :: Nil, contractBody)) :: Nil) - if("ensuring".equals(ensuringName.toString)) => Some((body,anonymousFun)) + if("ensuring".equals(ensuringName.toString)) => Some((body, resultName.toString, contractBody)) case _ => None } } - object RequiredExpression { + object ExRequiredExpression { /** Extracts the 'require' contract from an expression (only if it's the * first call in the block). */ def unapply(tree: Tree): Option[(Tree,Tree)] = tree match { @@ -88,7 +88,6 @@ trait Extractors { object ExMainFunctionDef { def unapply(dd: DefDef): Boolean = dd match { case DefDef(_, name, tparams, vparamss, tpt, rhs) if(name.toString == "main" && tparams.isEmpty && vparamss.size == 1 && vparamss(0).size == 1) => { - println("Looks like main " + vparamss(0)(0).symbol.tpe); true } case _ => false @@ -120,6 +119,85 @@ trait Extractors { case _ => None } } + + object ExIntIdentifier { + def unapply(tree: Tree): Option[String] = tree match { + case i: Ident if i.symbol.tpe == IntClass.tpe => Some(i.symbol.name.toString) + case _ => None + } + } + + object ExAnd { + def unapply(tree: Tree): Option[(Tree,Tree)] = tree match { + case Apply(s @ Select(lhs, _), List(rhs)) if (s.symbol == Boolean_and) => + Some((lhs,rhs)) + case _ => None + } + } + + object ExOr { + def unapply(tree: Tree): Option[(Tree,Tree)] = tree match { + case Apply(s @ Select(lhs, _), List(rhs)) if (s.symbol == Boolean_or) => + Some((lhs,rhs)) + case _ => None + } + } + + object ExNot { + def unapply(tree: Tree): Option[Tree] = tree match { + case Select(t, n) if (n == nme.UNARY_!) => Some(t) + case _ => None + } + } + + object ExEquals { + def unapply(tree: Tree): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.EQ) => Some((lhs,rhs)) + case _ => None + } + } + + object ExNotEquals { + def unapply(tree: Tree): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.NE) => Some((lhs,rhs)) + case _ => None + } + } + + object ExLessThan { + def unapply(tree: Tree): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.LT) => Some((lhs,rhs)) + case _ => None + } + } + + object ExLessEqThan { + def unapply(tree: Tree): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.LE) => Some((lhs,rhs)) + case _ => None + } + } + + object ExGreaterThan { + def unapply(tree: Tree): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.GT) => Some((lhs,rhs)) + case _ => None + } + } + + object ExGreaterEqThan { + def unapply(tree: Tree): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.GE) => Some((lhs,rhs)) + case _ => None + } + } + + object ExPlus { + def unapply(tree: Tree): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n == nme.ADD) => Some((lhs,rhs)) + case _ => None + } + } } object TypeExtractors { diff --git a/src/funcheck/purescala/PrettyPrinter.scala b/src/funcheck/purescala/PrettyPrinter.scala index bf1f9f3036d450517292efba75ad3b527612c455..35391e4edf679c61949b30d363f567bfddee30c7 100644 --- a/src/funcheck/purescala/PrettyPrinter.scala +++ b/src/funcheck/purescala/PrettyPrinter.scala @@ -62,6 +62,7 @@ object PrettyPrinter { } private def pp(tree: Expr, sb: StringBuffer): StringBuffer = tree match { + case Variable(id) => sb.append(id) case And(exprs) => ppNary(sb, exprs, " \u2227 ") // \land case Or(exprs) => ppNary(sb, exprs, " \u2228 ") // \lor case Not(Equals(l, r)) => ppBinary(sb, l, r, " \u2260 ") // \neq @@ -69,6 +70,15 @@ object PrettyPrinter { case Equals(l,r) => ppBinary(sb, l, r, " = ") case IntLiteral(v) => sb.append(v) case BooleanLiteral(v) => sb.append(v) + case Plus(l,r) => ppBinary(sb, l, r, " + ") + case Minus(l,r) => ppBinary(sb, l, r, " - ") + case Times(l,r) => ppBinary(sb, l, r, " * ") + case Division(l,r) => ppBinary(sb, l, r, " / ") + case LessThan(l,r) => ppBinary(sb, l, r, " < ") + case GreaterThan(l,r) => ppBinary(sb, l, r, " > ") + case LessEquals(l,r) => ppBinary(sb, l, r, " \u2264 ") // \leq + case GreaterEquals(l,r) => ppBinary(sb, l, r, " \u2265 ") // \geq + case IfExpr(c, t, e) => { var nsb = sb nsb.append("if (") @@ -81,6 +91,8 @@ object PrettyPrinter { nsb } + case ResultVariable() => sb.append("<res>") + case _ => sb.append("Expr?") } @@ -94,7 +106,9 @@ object PrettyPrinter { // DEFINITIONS // all definitions are printed with an end-of-line private def pp(defn: Definition, sb: StringBuffer, lvl: Int): StringBuffer = { - def ind(sb: StringBuffer): Unit = { sb.append(" " * lvl) } + def ind(sb: StringBuffer, customLevel: Int = lvl) : Unit = { + sb.append(" " * customLevel) + } defn match { case Program(id, mainObj) => { @@ -124,6 +138,21 @@ object PrettyPrinter { case FunDef(id, rt, args, body, pre, post) => { var nsb = sb + + pre.foreach(prec => { + ind(nsb) + nsb.append("@pre : ") + nsb = pp(prec, nsb) + nsb.append("\n") + }) + + post.foreach(postc => { + ind(nsb) + nsb.append("@post: ") + nsb = pp(postc, nsb) + nsb.append("\n") + }) + ind(nsb) nsb.append("def ") nsb.append(id) @@ -137,7 +166,7 @@ object PrettyPrinter { nsb = pp(rt, nsb) nsb.append(" = {\n") - ind(nsb) + ind(nsb, lvl+1) nsb = pp(body, nsb) nsb.append("\n") diff --git a/src/funcheck/purescala/Trees.scala b/src/funcheck/purescala/Trees.scala index f995ded93f11b3fe7bded2c3c2ffa409f1088004..1a14d5d4b60adad9fe2b05373805b11c26e4d4c4 100644 --- a/src/funcheck/purescala/Trees.scala +++ b/src/funcheck/purescala/Trees.scala @@ -11,22 +11,6 @@ object Trees { sealed abstract class Expr extends Typed { override def toString: String = PrettyPrinter(this) - - // private var _scope: Option[Scope] = None - // - // def scope: Scope = - // if(_scope.isEmpty) - // throw new Exception("Undefined scope.") - // else - // _scope.get - - // def scope_=(s: Scope): Unit = { - // if(_scope.isEmpty) { - // _scope = Some(s) - // } else { - // throw new Exception("Redefining scope.") - // } - // } } /* Control flow */ @@ -51,6 +35,15 @@ object Trees { // We don't handle Seq stars for now. /* Propositional logic */ + case object And { + def apply(l: Expr, r: Expr): Expr = (l,r) match { + case (And(exs1), And(exs2)) => And(exs1 ++ exs2) + case (And(exs1), ex2) => And(exs1 :+ ex2) + case (ex1, And(exs2)) => And(exs2 :+ ex1) + case (ex1, ex2) => And(List(ex1, ex2)) + } + } + case class And(exprs: Seq[Expr]) extends Expr case class Or(exprs: Seq[Expr]) extends Expr case class Not(expr: Expr) extends Expr @@ -63,6 +56,9 @@ object Trees { // variable, which would also give us its type. case class Variable(id: Identifier) extends Expr + // represents the result in post-conditions + case class ResultVariable() extends Expr + sealed abstract class Literal[T] extends Expr { val value: T }