From 9154d71e9038c56b5d6aa5dca201ebaecd6e3fab Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Thu, 25 Oct 2012 01:47:20 +0200
Subject: [PATCH] Decoupling of trees and tree operations

---
 .../scala/leon/Z3ModelReconstruction.scala    |   1 +
 src/main/scala/leon/isabelle/Main.scala       |   1 +
 src/main/scala/leon/purescala/TreeOps.scala   | 223 ++++++++++++++++++
 src/main/scala/leon/purescala/Trees.scala     | 223 ------------------
 .../scala/leon/testgen/TestGeneration.scala   |   1 +
 5 files changed, 226 insertions(+), 223 deletions(-)

diff --git a/src/main/scala/leon/Z3ModelReconstruction.scala b/src/main/scala/leon/Z3ModelReconstruction.scala
index 913da2fa3..87c75ddd5 100644
--- a/src/main/scala/leon/Z3ModelReconstruction.scala
+++ b/src/main/scala/leon/Z3ModelReconstruction.scala
@@ -4,6 +4,7 @@ import z3.scala._
 import purescala.Common._
 import purescala.Definitions._
 import purescala.Trees._
+import purescala.TreeOps._
 import purescala.TypeTrees._
 import Extensions._
 
diff --git a/src/main/scala/leon/isabelle/Main.scala b/src/main/scala/leon/isabelle/Main.scala
index 75e1c86fb..ac95c8095 100644
--- a/src/main/scala/leon/isabelle/Main.scala
+++ b/src/main/scala/leon/isabelle/Main.scala
@@ -8,6 +8,7 @@ import leon.purescala.Common.Identifier
 import leon.purescala.Definitions._
 import leon.purescala.PrettyPrinter
 import leon.purescala.Trees._
+import leon.purescala.TreeOps._
 import leon.purescala.Extractors._
 import leon.purescala.TypeTrees._
 
diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index 2949fa036..a0494a8e4 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -727,4 +727,227 @@ object TreeOps {
 
     rec(expr, Map.empty)
   }
+
+  private var matchConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]()
+  /** Rewrites all pattern-matching expressions into if-then-else expressions,
+   * with additional error conditions. Does not introduce additional variables.
+   * We use a cache because we can. */
+  def matchToIfThenElse(expr: Expr) : Expr = {
+    val toRet = if(matchConverterCache.isDefinedAt(expr)) {
+      matchConverterCache(expr)
+    } else {
+      val converted = convertMatchToIfThenElse(expr)
+      matchConverterCache(expr) = converted
+      converted
+    }
+
+    toRet
+  }
+
+  def conditionForPattern(in: Expr, pattern: Pattern) : Expr = pattern match {
+    case WildcardPattern(_) => BooleanLiteral(true)
+    case InstanceOfPattern(_,_) => scala.sys.error("InstanceOfPattern not yet supported.")
+    case CaseClassPattern(_, ccd, subps) => {
+      assert(ccd.fields.size == subps.size)
+      val pairs = ccd.fields.map(_.id).toList zip subps.toList
+      val subTests = pairs.map(p => conditionForPattern(CaseClassSelector(ccd, in, p._1), p._2))
+      val together = And(subTests)
+      And(CaseClassInstanceOf(ccd, in), together)
+    }
+    case TuplePattern(_, subps) => {
+      val TupleType(tpes) = in.getType
+      assert(tpes.size == subps.size)
+      val subTests = subps.zipWithIndex.map{case (p, i) => conditionForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)}
+      And(subTests)
+    }
+  }
+
+  private def convertMatchToIfThenElse(expr: Expr) : Expr = {
+    def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = pattern match {
+      case WildcardPattern(None) => Map.empty
+      case WildcardPattern(Some(id)) => Map(id -> in)
+      case InstanceOfPattern(None, _) => Map.empty
+      case InstanceOfPattern(Some(id), _) => Map(id -> in)
+      case CaseClassPattern(b, ccd, subps) => {
+        assert(ccd.fields.size == subps.size)
+        val pairs = ccd.fields.map(_.id).toList zip subps.toList
+        val subMaps = pairs.map(p => mapForPattern(CaseClassSelector(ccd, in, p._1), p._2))
+        val together = subMaps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _)
+        b match {
+          case Some(id) => Map(id -> in) ++ together
+          case None => together
+        }
+      }
+      case TuplePattern(b, subps) => {
+        val TupleType(tpes) = in.getType
+        assert(tpes.size == subps.size)
+
+        val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)}
+        val map = maps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _)
+        b match {
+          case Some(id) => map + (id -> in)
+          case None => map
+        }
+      }
+    }
+
+    def rewritePM(e: Expr) : Option[Expr] = e match {
+      case m @ MatchExpr(scrut, cases) => {
+        // println("Rewriting the following PM: " + e)
+
+        val condsAndRhs = for(cse <- cases) yield {
+          val map = mapForPattern(scrut, cse.pattern)
+          val patCond = conditionForPattern(scrut, cse.pattern)
+          val realCond = cse.theGuard match {
+            case Some(g) => And(patCond, replaceFromIDs(map, g))
+            case None => patCond
+          }
+          val newRhs = replaceFromIDs(map, cse.rhs)
+          (realCond, newRhs)
+        } 
+
+        val optCondsAndRhs = if(SimplePatternMatching.isSimple(m)) {
+          // this is a hackish optimization: because we know all cases are covered, we replace the last condition by true (and that drops the check)
+          val lastExpr = condsAndRhs.last._2
+
+          condsAndRhs.dropRight(1) ++ Seq((BooleanLiteral(true),lastExpr))
+        } else {
+          condsAndRhs
+        }
+
+        val bigIte = optCondsAndRhs.foldRight[Expr](Error("non-exhaustive match").setType(bestRealType(m.getType)).setPosInfo(m))((p1, ex) => {
+          if(p1._1 == BooleanLiteral(true)) {
+            p1._2
+          } else {
+            IfExpr(p1._1, p1._2, ex).setType(m.getType)
+          }
+        })
+
+        Some(bigIte)
+      }
+      case _ => None
+    }
+    
+    searchAndReplaceDFS(rewritePM)(expr)
+  }
+
+  private var mapGetConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]()
+  /** Rewrites all map accesses with additional error conditions. */
+  def mapGetWithChecks(expr: Expr) : Expr = {
+    val toRet = if (mapGetConverterCache.isDefinedAt(expr)) {
+      mapGetConverterCache(expr)
+    } else {
+      val converted = convertMapGet(expr)
+      mapGetConverterCache(expr) = converted
+      converted
+    }
+
+    toRet
+  }
+
+  private def convertMapGet(expr: Expr) : Expr = {
+    def rewriteMapGet(e: Expr) : Option[Expr] = e match {
+      case mg @ MapGet(m,k) => 
+        val ida = MapIsDefinedAt(m, k)
+        Some(IfExpr(ida, mg, Error("key not found for map access").setType(mg.getType).setPosInfo(mg)).setType(mg.getType))
+      case _ => None
+    }
+
+    searchAndReplaceDFS(rewriteMapGet)(expr)
+  }
+
+  // prec: expression does not contain match expressions
+  def measureADTChildrenDepth(expression: Expr) : Int = {
+    import scala.math.max
+
+    def rec(ex: Expr, lm: Map[Identifier,Int]) : Int = ex match {
+      case Let(i,e,b) => rec(b,lm + (i -> rec(e,lm)))
+      case Variable(id) => lm.getOrElse(id, 0)
+      case CaseClassSelector(_, e, _) => rec(e,lm) + 1
+      case NAryOperator(args, _) => if(args.isEmpty) 0 else args.map(rec(_,lm)).max
+      case BinaryOperator(e1,e2,_) => max(rec(e1,lm), rec(e2,lm))
+      case UnaryOperator(e,_) => rec(e,lm)
+      case IfExpr(c,t,e) => max(max(rec(c,lm),rec(t,lm)),rec(e,lm))
+      case t: Terminal => 0
+      case _ => scala.sys.error("Not handled in measureChildrenDepth : " + ex)
+    }
+    
+    rec(expression,Map.empty)
+  }
+
+  private val random = new scala.util.Random()
+
+  def randomValue(v: Variable) : Expr = randomValue(v.getType)
+  def simplestValue(v: Variable) : Expr = simplestValue(v.getType)
+
+  def randomValue(tpe: TypeTree) : Expr = tpe match {
+    case Int32Type => IntLiteral(random.nextInt(42))
+    case BooleanType => BooleanLiteral(random.nextBoolean())
+    case AbstractClassType(acd) =>
+      val children = acd.knownChildren
+      randomValue(classDefToClassType(children(random.nextInt(children.size))))
+    case CaseClassType(cd) =>
+      val fields = cd.fields
+      CaseClass(cd, fields.map(f => randomValue(f.getType)))
+    case _ => throw new Exception("I can't choose random value for type " + tpe)
+  }
+
+  def simplestValue(tpe: TypeTree) : Expr = tpe match {
+    case Int32Type => IntLiteral(0)
+    case BooleanType => BooleanLiteral(false)
+    case AbstractClassType(acd) => {
+      val children = acd.knownChildren
+      val simplerChildren = children.filter{
+        case ccd @ CaseClassDef(id, Some(parent), fields) =>
+          !fields.exists(vd => vd.getType match {
+            case AbstractClassType(fieldAcd) => acd == fieldAcd
+            case CaseClassType(fieldCcd) => ccd == fieldCcd
+            case _ => false
+          })
+        case _ => false
+      }
+      def orderByNumberOfFields(fst: ClassTypeDef, snd: ClassTypeDef) : Boolean = (fst, snd) match {
+        case (CaseClassDef(_, _, flds1), CaseClassDef(_, _, flds2)) => flds1.size <= flds2.size
+        case _ => true
+      }
+      val orderedChildren = simplerChildren.sortWith(orderByNumberOfFields)
+      simplestValue(classDefToClassType(orderedChildren.head))
+    }
+    case CaseClassType(ccd) =>
+      val fields = ccd.fields
+      CaseClass(ccd, fields.map(f => simplestValue(f.getType)))
+    case SetType(baseType) => EmptySet(baseType).setType(tpe)
+    case MapType(fromType, toType) => EmptyMap(fromType, toType).setType(tpe)
+    case FunctionType(fromTypes, toType) => AnonymousFunction(Seq.empty, simplestValue(toType)).setType(tpe)
+    case _ => throw new Exception("I can't choose simplest value for type " + tpe)
+  }
+
+  //guarentee that all IfExpr will be at the top level and as soon as you encounter a non-IfExpr, then no more IfExpr can be find in the sub-expressions
+  //require no-match, no-ets and only pure code
+  def hoistIte(expr: Expr): Expr = {
+    def transform(expr: Expr): Option[Expr] = expr match {
+      case uop@UnaryOperator(IfExpr(c, t, e), op) => Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType))
+      case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType))
+      case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType))
+      case nop@NAryOperator(ts, op) => {
+        val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false }
+        if(iteIndex == -1) None else {
+          val (beforeIte, startIte) = ts.splitAt(iteIndex)
+          val afterIte = startIte.tail
+          val IfExpr(c, t, e) = startIte.head
+          Some(IfExpr(c,
+            op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType),
+            op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType)
+          ).setType(nop.getType))
+        }
+      }
+      case _ => None
+    }
+
+    def fix[A](f: (A) => A, a: A): A = {
+      val na = f(a)
+      if(a == na) a else fix(f, na)
+    }
+    fix(searchAndReplaceDFS(transform), expr)
+  }
 }
diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index b985b2ac3..c8e1dd18f 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -441,227 +441,4 @@ object Trees {
     val fixedType = BooleanType
   }
 
-  private var matchConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]()
-  /** Rewrites all pattern-matching expressions into if-then-else expressions,
-   * with additional error conditions. Does not introduce additional variables.
-   * We use a cache because we can. */
-  def matchToIfThenElse(expr: Expr) : Expr = {
-    val toRet = if(matchConverterCache.isDefinedAt(expr)) {
-      matchConverterCache(expr)
-    } else {
-      val converted = convertMatchToIfThenElse(expr)
-      matchConverterCache(expr) = converted
-      converted
-    }
-
-    toRet
-  }
-
-  def conditionForPattern(in: Expr, pattern: Pattern) : Expr = pattern match {
-    case WildcardPattern(_) => BooleanLiteral(true)
-    case InstanceOfPattern(_,_) => scala.sys.error("InstanceOfPattern not yet supported.")
-    case CaseClassPattern(_, ccd, subps) => {
-      assert(ccd.fields.size == subps.size)
-      val pairs = ccd.fields.map(_.id).toList zip subps.toList
-      val subTests = pairs.map(p => conditionForPattern(CaseClassSelector(ccd, in, p._1), p._2))
-      val together = And(subTests)
-      And(CaseClassInstanceOf(ccd, in), together)
-    }
-    case TuplePattern(_, subps) => {
-      val TupleType(tpes) = in.getType
-      assert(tpes.size == subps.size)
-      val subTests = subps.zipWithIndex.map{case (p, i) => conditionForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)}
-      And(subTests)
-    }
-  }
-
-  private def convertMatchToIfThenElse(expr: Expr) : Expr = {
-    def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = pattern match {
-      case WildcardPattern(None) => Map.empty
-      case WildcardPattern(Some(id)) => Map(id -> in)
-      case InstanceOfPattern(None, _) => Map.empty
-      case InstanceOfPattern(Some(id), _) => Map(id -> in)
-      case CaseClassPattern(b, ccd, subps) => {
-        assert(ccd.fields.size == subps.size)
-        val pairs = ccd.fields.map(_.id).toList zip subps.toList
-        val subMaps = pairs.map(p => mapForPattern(CaseClassSelector(ccd, in, p._1), p._2))
-        val together = subMaps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _)
-        b match {
-          case Some(id) => Map(id -> in) ++ together
-          case None => together
-        }
-      }
-      case TuplePattern(b, subps) => {
-        val TupleType(tpes) = in.getType
-        assert(tpes.size == subps.size)
-
-        val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)}
-        val map = maps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _)
-        b match {
-          case Some(id) => map + (id -> in)
-          case None => map
-        }
-      }
-    }
-
-    def rewritePM(e: Expr) : Option[Expr] = e match {
-      case m @ MatchExpr(scrut, cases) => {
-        // println("Rewriting the following PM: " + e)
-
-        val condsAndRhs = for(cse <- cases) yield {
-          val map = mapForPattern(scrut, cse.pattern)
-          val patCond = conditionForPattern(scrut, cse.pattern)
-          val realCond = cse.theGuard match {
-            case Some(g) => And(patCond, replaceFromIDs(map, g))
-            case None => patCond
-          }
-          val newRhs = replaceFromIDs(map, cse.rhs)
-          (realCond, newRhs)
-        } 
-
-        val optCondsAndRhs = if(SimplePatternMatching.isSimple(m)) {
-          // this is a hackish optimization: because we know all cases are covered, we replace the last condition by true (and that drops the check)
-          val lastExpr = condsAndRhs.last._2
-
-          condsAndRhs.dropRight(1) ++ Seq((BooleanLiteral(true),lastExpr))
-        } else {
-          condsAndRhs
-        }
-
-        val bigIte = optCondsAndRhs.foldRight[Expr](Error("non-exhaustive match").setType(bestRealType(m.getType)).setPosInfo(m))((p1, ex) => {
-          if(p1._1 == BooleanLiteral(true)) {
-            p1._2
-          } else {
-            IfExpr(p1._1, p1._2, ex).setType(m.getType)
-          }
-        })
-
-        Some(bigIte)
-      }
-      case _ => None
-    }
-    
-    searchAndReplaceDFS(rewritePM)(expr)
-  }
-
-  private var mapGetConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]()
-  /** Rewrites all map accesses with additional error conditions. */
-  def mapGetWithChecks(expr: Expr) : Expr = {
-    val toRet = if (mapGetConverterCache.isDefinedAt(expr)) {
-      mapGetConverterCache(expr)
-    } else {
-      val converted = convertMapGet(expr)
-      mapGetConverterCache(expr) = converted
-      converted
-    }
-
-    toRet
-  }
-
-  private def convertMapGet(expr: Expr) : Expr = {
-    def rewriteMapGet(e: Expr) : Option[Expr] = e match {
-      case mg @ MapGet(m,k) => 
-        val ida = MapIsDefinedAt(m, k)
-        Some(IfExpr(ida, mg, Error("key not found for map access").setType(mg.getType).setPosInfo(mg)).setType(mg.getType))
-      case _ => None
-    }
-
-    searchAndReplaceDFS(rewriteMapGet)(expr)
-  }
-
-  // prec: expression does not contain match expressions
-  def measureADTChildrenDepth(expression: Expr) : Int = {
-    import scala.math.max
-
-    def rec(ex: Expr, lm: Map[Identifier,Int]) : Int = ex match {
-      case Let(i,e,b) => rec(b,lm + (i -> rec(e,lm)))
-      case Variable(id) => lm.getOrElse(id, 0)
-      case CaseClassSelector(_, e, _) => rec(e,lm) + 1
-      case NAryOperator(args, _) => if(args.isEmpty) 0 else args.map(rec(_,lm)).max
-      case BinaryOperator(e1,e2,_) => max(rec(e1,lm), rec(e2,lm))
-      case UnaryOperator(e,_) => rec(e,lm)
-      case IfExpr(c,t,e) => max(max(rec(c,lm),rec(t,lm)),rec(e,lm))
-      case t: Terminal => 0
-      case _ => scala.sys.error("Not handled in measureChildrenDepth : " + ex)
-    }
-    
-    rec(expression,Map.empty)
-  }
-
-  private val random = new scala.util.Random()
-
-  def randomValue(v: Variable) : Expr = randomValue(v.getType)
-  def simplestValue(v: Variable) : Expr = simplestValue(v.getType)
-
-  def randomValue(tpe: TypeTree) : Expr = tpe match {
-    case Int32Type => IntLiteral(random.nextInt(42))
-    case BooleanType => BooleanLiteral(random.nextBoolean())
-    case AbstractClassType(acd) =>
-      val children = acd.knownChildren
-      randomValue(classDefToClassType(children(random.nextInt(children.size))))
-    case CaseClassType(cd) =>
-      val fields = cd.fields
-      CaseClass(cd, fields.map(f => randomValue(f.getType)))
-    case _ => throw new Exception("I can't choose random value for type " + tpe)
-  }
-
-  def simplestValue(tpe: TypeTree) : Expr = tpe match {
-    case Int32Type => IntLiteral(0)
-    case BooleanType => BooleanLiteral(false)
-    case AbstractClassType(acd) => {
-      val children = acd.knownChildren
-      val simplerChildren = children.filter{
-        case ccd @ CaseClassDef(id, Some(parent), fields) =>
-          !fields.exists(vd => vd.getType match {
-            case AbstractClassType(fieldAcd) => acd == fieldAcd
-            case CaseClassType(fieldCcd) => ccd == fieldCcd
-            case _ => false
-          })
-        case _ => false
-      }
-      def orderByNumberOfFields(fst: ClassTypeDef, snd: ClassTypeDef) : Boolean = (fst, snd) match {
-        case (CaseClassDef(_, _, flds1), CaseClassDef(_, _, flds2)) => flds1.size <= flds2.size
-        case _ => true
-      }
-      val orderedChildren = simplerChildren.sortWith(orderByNumberOfFields)
-      simplestValue(classDefToClassType(orderedChildren.head))
-    }
-    case CaseClassType(ccd) =>
-      val fields = ccd.fields
-      CaseClass(ccd, fields.map(f => simplestValue(f.getType)))
-    case SetType(baseType) => EmptySet(baseType).setType(tpe)
-    case MapType(fromType, toType) => EmptyMap(fromType, toType).setType(tpe)
-    case FunctionType(fromTypes, toType) => AnonymousFunction(Seq.empty, simplestValue(toType)).setType(tpe)
-    case _ => throw new Exception("I can't choose simplest value for type " + tpe)
-  }
-
-  //guarentee that all IfExpr will be at the top level and as soon as you encounter a non-IfExpr, then no more IfExpr can be find in the sub-expressions
-  //require no-match, no-ets and only pure code
-  def hoistIte(expr: Expr): Expr = {
-    def transform(expr: Expr): Option[Expr] = expr match {
-      case uop@UnaryOperator(IfExpr(c, t, e), op) => Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType))
-      case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType))
-      case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType))
-      case nop@NAryOperator(ts, op) => {
-        val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false }
-        if(iteIndex == -1) None else {
-          val (beforeIte, startIte) = ts.splitAt(iteIndex)
-          val afterIte = startIte.tail
-          val IfExpr(c, t, e) = startIte.head
-          Some(IfExpr(c,
-            op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType),
-            op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType)
-          ).setType(nop.getType))
-        }
-      }
-      case _ => None
-    }
-
-    def fix[A](f: (A) => A, a: A): A = {
-      val na = f(a)
-      if(a == na) a else fix(f, na)
-    }
-    fix(searchAndReplaceDFS(transform), expr)
-  }
-
 }
diff --git a/src/main/scala/leon/testgen/TestGeneration.scala b/src/main/scala/leon/testgen/TestGeneration.scala
index 1b48d7257..8404294ee 100644
--- a/src/main/scala/leon/testgen/TestGeneration.scala
+++ b/src/main/scala/leon/testgen/TestGeneration.scala
@@ -3,6 +3,7 @@ package leon.testgen
 import leon.purescala.Common._
 import leon.purescala.Definitions._
 import leon.purescala.Trees._
+import leon.purescala.TreeOps._
 import leon.purescala.TypeTrees._
 import leon.purescala.ScalaPrinter
 import leon.Extensions._
-- 
GitLab