From 737e6c74a8286c4fd4613bde92829b6c6c7e8c92 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Tue, 4 Mar 2014 16:11:21 +0100
Subject: [PATCH] RecursiveEvaluator only considers calls as steps

---
 .../leon/evaluators/RecursiveEvaluator.scala  | 162 +++++++++---------
 .../leon/evaluators/TracingEvaluator.scala    |  22 ++-
 testcases/case-studies/Compiler.scala         |  74 +++++---
 3 files changed, 140 insertions(+), 118 deletions(-)

diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
index 7432e33b3..8257f082c 100644
--- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
@@ -31,14 +31,16 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu
     def withVars(news: Map[Identifier, Expr]): RC;
   }
 
-  class GlobalContext(var stepsLeft: Int)
+  class GlobalContext(var stepsLeft: Int) {
+    val maxSteps = stepsLeft
+  }
 
   def initRC(mappings: Map[Identifier, Expr]): RC
   def initGC: GC
 
-  def eval(e: Expr, mappings: Map[Identifier, Expr]) = {
+  def eval(ex: Expr, mappings: Map[Identifier, Expr]) = {
     try {
-      EvaluationResults.Successful(se(e)(initRC(mappings), initGC))
+      EvaluationResults.Successful(e(ex)(initRC(mappings), initGC))
     } catch {
       case so: StackOverflowError =>
         EvaluationResults.EvaluatorError("Stack overflow")
@@ -49,15 +51,6 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu
     }
   }
 
-  def se(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = {
-    if (gctx.stepsLeft < 0) {
-      throw RuntimeError("Exceeded number of allocated steps")
-    } else {
-      gctx.stepsLeft -= 1
-      e(expr)
-    }
-  }
-
   def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = expr match {
     case Variable(id) =>
       rctx.mappings.get(id) match {
@@ -72,36 +65,41 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu
       }
 
     case Tuple(ts) =>
-      val tsRec = ts.map(se)
+      val tsRec = ts.map(e)
       Tuple(tsRec)
 
     case TupleSelect(t, i) =>
-      val Tuple(rs) = se(t)
+      val Tuple(rs) = e(t)
       rs(i-1)
 
-    case Let(i,e,b) =>
-      val first = se(e)
-      se(b)(rctx.withNewVar(i, first), gctx)
+    case Let(i,ex,b) =>
+      val first = e(ex)
+      e(b)(rctx.withNewVar(i, first), gctx)
 
     case Error(desc) =>
       throw RuntimeError("Error reached in evaluation: " + desc)
 
     case IfExpr(cond, thenn, elze) =>
-      val first = se(cond)
+      val first = e(cond)
       first match {
-        case BooleanLiteral(true) => se(thenn)
-        case BooleanLiteral(false) => se(elze)
+        case BooleanLiteral(true) => e(thenn)
+        case BooleanLiteral(false) => e(elze)
         case _ => throw EvalError(typeErrorMsg(first, BooleanType))
       }
 
     case FunctionInvocation(tfd, args) =>
-      val evArgs = args.map(a => se(a))
+      if (gctx.stepsLeft < 0) {
+        throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")")
+      }
+      gctx.stepsLeft -= 1
+
+      val evArgs = args.map(a => e(a))
 
       // build a mapping for the function...
       val frame = rctx.withVars((tfd.params.map(_.id) zip evArgs).toMap)
       
       if(tfd.hasPrecondition) {
-        se(matchToIfThenElse(tfd.precondition.get))(frame, gctx) match {
+        e(matchToIfThenElse(tfd.precondition.get))(frame, gctx) match {
           case BooleanLiteral(true) =>
           case BooleanLiteral(false) =>
             throw RuntimeError("Precondition violation for " + tfd.id.name + " reached in evaluation.: " + tfd.precondition.get)
@@ -114,7 +112,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu
       }
 
       val body = tfd.body.getOrElse(rctx.mappings(tfd.id))
-      val callResult = se(matchToIfThenElse(body))(frame, gctx)
+      val callResult = e(matchToIfThenElse(body))(frame, gctx)
 
       if(tfd.hasPostcondition) {
         val (id, post) = tfd.postcondition.get
@@ -122,7 +120,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu
         val freshResID = FreshIdentifier("result").setType(tfd.returnType)
         val postBody = replace(Map(Variable(id) -> Variable(freshResID)), matchToIfThenElse(post))
 
-        se(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match {
+        e(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match {
           case BooleanLiteral(true) =>
           case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.")
           case other => throw EvalError(typeErrorMsg(other, BooleanType))
@@ -135,40 +133,40 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu
       BooleanLiteral(true)
 
     case And(args) =>
-      se(args.head) match {
+      e(args.head) match {
         case BooleanLiteral(false) => BooleanLiteral(false)
-        case BooleanLiteral(true) => se(And(args.tail))
+        case BooleanLiteral(true) => e(And(args.tail))
         case other => throw EvalError(typeErrorMsg(other, BooleanType))
       }
 
     case Or(args) if args.isEmpty => BooleanLiteral(false)
     case Or(args) =>
-      se(args.head) match {
+      e(args.head) match {
         case BooleanLiteral(true) => BooleanLiteral(true)
-        case BooleanLiteral(false) => se(Or(args.tail))
+        case BooleanLiteral(false) => e(Or(args.tail))
         case other => throw EvalError(typeErrorMsg(other, BooleanType))
       }
 
     case Not(arg) =>
-      se(arg) match {
+      e(arg) match {
         case BooleanLiteral(v) => BooleanLiteral(!v)
         case other => throw EvalError(typeErrorMsg(other, BooleanType))
       }
 
     case Implies(l,r) =>
-      (se(l), se(r)) match {
+      (e(l), e(r)) match {
         case (BooleanLiteral(b1),BooleanLiteral(b2)) => BooleanLiteral(!b1 || b2)
         case (le, re) => throw EvalError(typeErrorMsg(le, BooleanType))
       }
 
     case Iff(le,re) =>
-      (se(le), se(re)) match {
+      (e(le), e(re)) match {
         case (BooleanLiteral(b1),BooleanLiteral(b2)) => BooleanLiteral(b1 == b2)
         case _ => throw EvalError(typeErrorMsg(le, BooleanType))
       }
     case Equals(le,re) =>
-      val lv = se(le)
-      val rv = se(re)
+      val lv = e(le)
+      val rv = e(re)
 
       (lv,rv) match {
         case (FiniteSet(el1),FiniteSet(el2)) => BooleanLiteral(el1.toSet == el2.toSet)
@@ -177,91 +175,91 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu
       }
 
     case CaseClass(cd, args) =>
-      CaseClass(cd, args.map(se(_)))
+      CaseClass(cd, args.map(e(_)))
 
     case CaseClassInstanceOf(cct, expr) =>
-      val le = se(expr)
+      val le = e(expr)
       BooleanLiteral(le.getType match {
         case CaseClassType(cd2, _) if cd2 == cct.classDef => true
         case _ => false
       })
 
     case CaseClassSelector(ct1, expr, sel) =>
-      val le = se(expr)
+      val le = e(expr)
       le match {
         case CaseClass(ct2, args) if ct1 == ct2 => args(ct1.classDef.selectorID2Index(sel))
         case _ => throw EvalError(typeErrorMsg(le, ct1))
       }
 
     case Plus(l,r) =>
-      (se(l), se(r)) match {
+      (e(l), e(r)) match {
         case (IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 + i2)
         case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type))
       }
 
     case Minus(l,r) =>
-      (se(l), se(r)) match {
+      (e(l), e(r)) match {
         case (IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 - i2)
         case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type))
       }
 
-    case UMinus(e) =>
-      se(e) match {
+    case UMinus(ex) =>
+      e(ex) match {
         case IntLiteral(i) => IntLiteral(-i)
         case re => throw EvalError(typeErrorMsg(re, Int32Type))
       }
 
     case Times(l,r) =>
-      (se(l), se(r)) match {
+      (e(l), e(r)) match {
         case (IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 * i2)
         case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type))
       }
 
     case Division(l,r) =>
-      (se(l), se(r)) match {
+      (e(l), e(r)) match {
         case (IntLiteral(i1), IntLiteral(i2)) =>
           if(i2 != 0) IntLiteral(i1 / i2) else throw RuntimeError("Division by 0.")
         case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type))
       }
 
     case Modulo(l,r) =>
-      (se(l), se(r)) match {
+      (e(l), e(r)) match {
         case (IntLiteral(i1), IntLiteral(i2)) => 
           if(i2 != 0) IntLiteral(i1 % i2) else throw RuntimeError("Modulo by 0.")
         case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type))
       }
     case LessThan(l,r) =>
-      (se(l), se(r)) match {
+      (e(l), e(r)) match {
         case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 < i2)
         case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type))
       }
 
     case GreaterThan(l,r) =>
-      (se(l), se(r)) match {
+      (e(l), e(r)) match {
         case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 > i2)
         case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type))
       }
 
     case LessEquals(l,r) =>
-      (se(l), se(r)) match {
+      (e(l), e(r)) match {
         case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 <= i2)
         case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type))
       }
 
     case GreaterEquals(l,r) =>
-      (se(l), se(r)) match {
+      (e(l), e(r)) match {
         case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 >= i2)
         case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type))
       }
 
     case SetUnion(s1,s2) =>
-      (se(s1), se(s2)) match {
+      (e(s1), e(s2)) match {
         case (f@FiniteSet(els1),FiniteSet(els2)) => FiniteSet((els1 ++ els2).distinct).setType(f.getType)
         case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType))
       }
 
     case SetIntersection(s1,s2) =>
-      (se(s1), se(s2)) match {
+      (e(s1), e(s2)) match {
         case (f @ FiniteSet(els1), FiniteSet(els2)) => {
           val newElems = (els1.toSet intersect els2.toSet).toSeq
           val baseType = f.getType.asInstanceOf[SetType].base
@@ -271,7 +269,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu
       }
 
     case SetDifference(s1,s2) =>
-      (se(s1), se(s2)) match {
+      (e(s1), e(s2)) match {
         case (f @ FiniteSet(els1),FiniteSet(els2)) => {
           val newElems = (els1.toSet -- els2.toSet).toSeq
           val baseType = f.getType.asInstanceOf[SetType].base
@@ -280,70 +278,68 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu
         case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType))
       }
 
-    case ElementOfSet(el,s) => (se(el), se(s)) match {
+    case ElementOfSet(el,s) => (e(el), e(s)) match {
       case (e, f @ FiniteSet(els)) => BooleanLiteral(els.contains(e))
       case (l,r) => throw EvalError(typeErrorMsg(r, SetType(l.getType)))
     }
-    case SubsetOf(s1,s2) => (se(s1), se(s2)) match {
+    case SubsetOf(s1,s2) => (e(s1), e(s2)) match {
       case (f@FiniteSet(els1),FiniteSet(els2)) => BooleanLiteral(els1.toSet.subsetOf(els2.toSet))
       case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType))
     }
-    case SetCardinality(s) => {
-      val sr = se(s)
+    case SetCardinality(s) =>
+      val sr = e(s)
       sr match {
         case FiniteSet(els) => IntLiteral(els.size)
         case _ => throw EvalError(typeErrorMsg(sr, SetType(AnyType)))
       }
-    }
 
-    case f @ FiniteSet(els) => FiniteSet(els.map(se(_)).distinct).setType(f.getType)
+    case f @ FiniteSet(els) => FiniteSet(els.map(e(_)).distinct).setType(f.getType)
     case i @ IntLiteral(_) => i
     case b @ BooleanLiteral(_) => b
     case u @ UnitLiteral() => u
 
-    case f @ ArrayFill(length, default) => {
-      val rDefault = se(default)
-      val rLength = se(length)
+    case f @ ArrayFill(length, default) =>
+      val rDefault = e(default)
+      val rLength = e(length)
       val IntLiteral(iLength) = rLength
       FiniteArray((1 to iLength).map(_ => rDefault).toSeq)
-    }
-    case ArrayLength(a) => {
-      var ra = se(a)
+
+    case ArrayLength(a) =>
+      var ra = e(a)
       while(!ra.isInstanceOf[FiniteArray])
         ra = ra.asInstanceOf[ArrayUpdated].array
       IntLiteral(ra.asInstanceOf[FiniteArray].exprs.size)
-    }
-    case ArrayUpdated(a, i, v) => {
-      val ra = se(a)
-      val ri = se(i)
-      val rv = se(v)
+
+    case ArrayUpdated(a, i, v) =>
+      val ra = e(a)
+      val ri = e(i)
+      val rv = e(v)
 
       val IntLiteral(index) = ri
       val FiniteArray(exprs) = ra
       FiniteArray(exprs.updated(index, rv))
-    }
-    case ArraySelect(a, i) => {
-      val IntLiteral(index) = se(i)
-      val FiniteArray(exprs) = se(a)
+
+    case ArraySelect(a, i) =>
+      val IntLiteral(index) = e(i)
+      val FiniteArray(exprs) = e(a)
       try {
         exprs(index)
       } catch {
         case e : IndexOutOfBoundsException => throw RuntimeError(e.getMessage)
       }
-    }
-    case FiniteArray(exprs) => {
-      FiniteArray(exprs.map(e => se(e)))
-    }
 
-    case f @ FiniteMap(ss) => FiniteMap(ss.map{ case (k, v) => (se(k), se(v)) }.distinct).setType(f.getType)
-    case g @ MapGet(m,k) => (se(m), se(k)) match {
+    case FiniteArray(exprs) =>
+      FiniteArray(exprs.map(ex => e(ex)))
+
+    case f @ FiniteMap(ss) => FiniteMap(ss.map{ case (k, v) => (e(k), e(v)) }.distinct).setType(f.getType)
+    case g @ MapGet(m,k) => (e(m), e(k)) match {
       case (FiniteMap(ss), e) => ss.find(_._1 == e) match {
         case Some((_, v0)) => v0
         case None => throw RuntimeError("Key not found: " + e)
       }
       case (l,r) => throw EvalError(typeErrorMsg(l, MapType(r.getType, g.getType)))
     }
-    case u @ MapUnion(m1,m2) => (se(m1), se(m2)) match {
+    case u @ MapUnion(m1,m2) => (e(m1), e(m2)) match {
       case (f1@FiniteMap(ss1), FiniteMap(ss2)) => {
         val filtered1 = ss1.filterNot(s1 => ss2.exists(s2 => s2._1 == s1._1))
         val newSs = filtered1 ++ ss2
@@ -351,14 +347,13 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu
       }
       case (l, r) => throw EvalError(typeErrorMsg(l, m1.getType))
     }
-    case i @ MapIsDefinedAt(m,k) => (se(m), se(k)) match {
+    case i @ MapIsDefinedAt(m,k) => (e(m), e(k)) match {
       case (FiniteMap(ss), e) => BooleanLiteral(ss.exists(_._1 == e))
       case (l, r) => throw EvalError(typeErrorMsg(l, m.getType))
     }
-    case Distinct(args) => {
-      val newArgs = args.map(se(_))
+    case Distinct(args) =>
+      val newArgs = args.map(e(_))
       BooleanLiteral(newArgs.distinct.size == newArgs.size)
-    }
 
     case gv: GenericValue =>
       gv
@@ -413,10 +408,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu
         solver.free()
       }
 
-    case other => {
+    case other =>
       context.reporter.error("Error: don't know how to handle " + other + " in Evaluator.")
       throw EvalError("Unhandled case in Evaluator : " + other) 
-    }
   }
 
   def typeErrorMsg(tree : Expr, expected : TypeTree) : String = "Type error : expected %s, found %s.".format(expected, tree)
diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala
index 69a173afb..2548abe07 100644
--- a/src/main/scala/leon/evaluators/TracingEvaluator.scala
+++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala
@@ -9,7 +9,7 @@ import purescala.Definitions._
 import purescala.TreeOps._
 import purescala.TypeTrees._
 
-class TracingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog) {
+class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) extends RecursiveEvaluator(ctx, prog) {
   type RC = TracingRecContext
   type GC = TracingGlobalContext
 
@@ -20,7 +20,7 @@ class TracingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluat
   }
 
   def initGC = {
-    val gc = new TracingGlobalContext(stepsLeft = 50000, Nil)
+    val gc = new TracingGlobalContext(stepsLeft = maxSteps, Nil)
     lastGlobalContext = Some(gc)
     gc
   }
@@ -36,21 +36,25 @@ class TracingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluat
   override def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = {
     try {
       val (res, recordedRes) = expr match {
-        case Let(i,e,b) =>
+        case Let(i,ex,b) =>
           // We record the value of the val at the position of Let, not the value of the body.
-          val first = se(e)
-          val res = se(b)(rctx.withNewVar(i, first), gctx)
+          val first = e(ex)
+          val res = e(b)(rctx.withNewVar(i, first), gctx)
           (res, first)
 
         case fi @ FunctionInvocation(tfd, args) =>
+          if (gctx.stepsLeft < 0) {
+            throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")")
+          }
+          gctx.stepsLeft -= 1
 
-          val evArgs = args.map(a => se(a))
+          val evArgs = args.map(a => e(a))
 
           // build a mapping for the function...
           val frame = new TracingRecContext((tfd.params.map(_.id) zip evArgs).toMap, rctx.tracingFrames-1)
 
           if(tfd.hasPrecondition) {
-            se(matchToIfThenElse(tfd.precondition.get))(frame, gctx) match {
+            e(matchToIfThenElse(tfd.precondition.get))(frame, gctx) match {
               case BooleanLiteral(true) =>
               case BooleanLiteral(false) =>
                 throw RuntimeError("Precondition violation for " + tfd.id.name + " reached in evaluation.: " + tfd.precondition.get)
@@ -63,7 +67,7 @@ class TracingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluat
           }
 
           val body = tfd.body.getOrElse(rctx.mappings(tfd.id))
-          val callResult = se(matchToIfThenElse(body))(frame, gctx)
+          val callResult = e(matchToIfThenElse(body))(frame, gctx)
 
           if(tfd.hasPostcondition) {
             val (id, post) = tfd.postcondition.get
@@ -71,7 +75,7 @@ class TracingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluat
             val freshResID = FreshIdentifier("result").setType(tfd.returnType)
             val postBody = replace(Map(Variable(id) -> Variable(freshResID)), matchToIfThenElse(post))
 
-            se(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match {
+            e(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match {
               case BooleanLiteral(true) =>
               case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.")
               case other => throw EvalError(typeErrorMsg(other, BooleanType))
diff --git a/testcases/case-studies/Compiler.scala b/testcases/case-studies/Compiler.scala
index 6062dff35..f0d70e347 100644
--- a/testcases/case-studies/Compiler.scala
+++ b/testcases/case-studies/Compiler.scala
@@ -10,6 +10,8 @@ object Tokens {
   case object TLT extends Token
   case object TIf extends Token
   case object TElse extends Token
+  case object TLAnd extends Token
+  case object TLOr extends Token
   case object TLeftBrace extends Token
   case object TRightBrace extends Token
   case object TLeftPar extends Token
@@ -22,6 +24,8 @@ object Trees {
   abstract class Expr
   case class Times(lhs: Expr, rhs: Expr) extends Expr
   case class Plus(lhs: Expr, rhs: Expr) extends Expr
+  case class And(lhs: Expr, rhs: Expr) extends Expr
+  case class Or(lhs: Expr, rhs: Expr) extends Expr
   case class Var(id: Int) extends Expr
   case class IntLiteral(v: Int) extends Expr
   case class LessThan(lhs: Expr, rhs: Expr) extends Expr
@@ -46,7 +50,29 @@ object Parser {
   }
 
   def parseGoal(ts: List[Token]): Option[(Expr, List[Token])] = {
-    parseLT(ts)
+    parseOr(ts)
+  }
+
+  def parseOr(ts: List[Token]): Option[(Expr, List[Token])] = {
+    parseAnd(ts) match {
+      case Some((lhs, Cons(TLOr, r))) =>
+        parseAnd(r) match {
+          case Some((rhs, ts2)) => Some((Or(lhs, rhs), ts2))
+          case None() => None()
+        }
+      case r => r
+    }
+  }
+
+  def parseAnd(ts: List[Token]): Option[(Expr, List[Token])] = {
+    parseLT(ts) match {
+      case Some((lhs, Cons(TLAnd, r))) =>
+        parseLT(r) match {
+          case Some((rhs, ts2)) => Some((And(lhs, rhs), ts2))
+          case None() => None()
+        }
+      case r => r
+    }
   }
 
   def parseLT(ts: List[Token]): Option[(Expr, List[Token])] = {
@@ -110,38 +136,36 @@ object Parser {
   }
 }
 
+object TypeChecker {
+  import Trees._
+  import Types._
+
+  def typeChecks(e: Expr, exp: Option[Type]): Boolean = e match {
+    case Times(l, r)    => (exp.getOrElse(IntType) == IntType)   && typeChecks(l, Some(IntType))  && typeChecks(r, Some(IntType))
+    case Plus(l, r)     => (exp.getOrElse(IntType) == IntType)   && typeChecks(l, Some(IntType))  && typeChecks(r, Some(IntType))
+    case And(l, r)      => (exp.getOrElse(BoolType) == BoolType) && typeChecks(l, Some(BoolType)) && typeChecks(r, Some(BoolType))
+    case Or(l, r)       => (exp.getOrElse(BoolType) == BoolType) && typeChecks(l, Some(BoolType)) && typeChecks(r, Some(BoolType))
+    case LessThan(l, r) => (exp.getOrElse(BoolType) == BoolType) && typeChecks(l, Some(IntType))  && typeChecks(r, Some(IntType))
+    case Ite(c, th, el) => typeChecks(c, Some(BoolType)) && typeChecks(th, exp) && typeChecks(el, exp)
+    case IntLiteral(_)  => exp.getOrElse(IntType) == IntType
+    case Var(_)         => exp.getOrElse(IntType) == IntType
+  }
+
+  def typeChecks(e: Expr): Boolean = typeChecks(e, None())
+}
+
 object Compiler {
   import Tokens._
   import Trees._
   import Types._
   import Parser._
+  import TypeChecker._
 
-  @proxy
-  def tokenize(s: String): List[Token] = {
-    Cons(TInt(12), Cons(TLT, Cons(TInt(32), Nil())))
-  }
 
   def parse(ts: List[Token]): Option[Expr] = {
     parsePhrase(ts)
-  } ensuring { res => res match {
-    case Some(res) => typeChecks(res, BoolType)
-    case None() => true
+  } ensuring { _ match {
+    case Some(tree) => typeChecks(tree)
+    case None()     => true
   }}
-
-  def typeChecks(e: Expr, t: Type): Boolean = e match {
-    case Times(l, r) => (t == IntType) && typeChecks(l, IntType) && typeChecks(r, IntType)
-    case Plus(l, r) => (t == IntType) && typeChecks(l, IntType) && typeChecks(r, IntType)
-    case LessThan(l, r) => (t == BoolType) && typeChecks(l, IntType) && typeChecks(r, IntType)
-    case Ite(c, th, el) => typeChecks(c, BoolType) && typeChecks(th, t) && typeChecks(el, t)
-    case IntLiteral(_) => t == IntType
-    case Var(_) => t == IntType
-  }
-
-  @proxy
-  def run(s: String) = {
-    val ts = tokenize(s)
-    val e = parse(ts)
-    e
-  }
-
 }
-- 
GitLab