From c17a3ee180f4ad6f9c10571cd0d5500d8224baff Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Sun, 11 Jul 2010 17:54:29 +0000
Subject: [PATCH] now with NAryOperator extractor

---
 src/purescala/Analysis.scala |  16 +++---
 src/purescala/Trees.scala    | 100 +++++++++++++++++++----------------
 2 files changed, 61 insertions(+), 55 deletions(-)

diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala
index 497a2de06..bd8dc3252 100644
--- a/src/purescala/Analysis.scala
+++ b/src/purescala/Analysis.scala
@@ -117,17 +117,17 @@ class Analysis(val program: Program) {
       }
 
       import Analysis._
-      //reporter.info("Before unrolling:")
-      //reporter.info(expandLets(withPrec))
+      // reporter.info("Before unrolling:")
+      // reporter.info(expandLets(withPrec))
       val expr0 = unrollRecursiveFunctions(program, withPrec, Settings.unrollingLevel)
-      //reporter.info("Before inlining:")
-      //reporter.info(expandLets(expr0))
+      // reporter.info("Before inlining:")
+      // reporter.info(expandLets(expr0))
       val expr1 = inlineFunctionsAndContracts(program, expr0)
-      //reporter.info("Before PM-rewriting:")
-      //reporter.info(expandLets(expr1))
+      // reporter.info("Before PM-rewriting:")
+      // reporter.info(expandLets(expr1))
       val expr2 = rewriteSimplePatternMatching(expr1)
-      //reporter.info("After PM-rewriting:")
-      //reporter.info(expandLets(expr2))
+      // reporter.info("After PM-rewriting:")
+      // reporter.info(expandLets(expr2))
       expr2
     }
   }
diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala
index 54d072efa..91d0c15de 100644
--- a/src/purescala/Trees.scala
+++ b/src/purescala/Trees.scala
@@ -238,6 +238,7 @@ object Trees {
 
   object UnaryOperator {
     def unapply(expr: Expr) : Option[(Expr,(Expr)=>Expr)] = expr match {
+      case Not(t) => Some((t,Not(_)))
       case IsEmptySet(t) => Some((t,IsEmptySet))
       case IsEmptyMultiset(t) => Some((t,IsEmptyMultiset))
       case SetCardinality(t) => Some((t,SetCardinality))
@@ -247,6 +248,7 @@ object Trees {
       case Cdr(t) => Some((t,Cdr))
       case SetMin(s) => Some((s,SetMin))
       case SetMax(s) => Some((s,SetMax))
+      case CaseClassSelector(e, sel) => Some((e, CaseClassSelector(_, sel)))
       case _ => None
     }
   }
@@ -287,6 +289,18 @@ object Trees {
     }
   }
 
+  object NAryOperator {
+    def unapply(expr: Expr) : Option[(Seq[Expr],(Seq[Expr])=>Expr)] = expr match {
+      case FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _)))
+      case CaseClass(cd, args) => Some((args, CaseClass(cd, _)))
+      case And(args) => Some((args, And(_)))
+      case Or(args) => Some((args, Or(_)))
+      case FiniteSet(args) => Some((args, FiniteSet))
+      case FiniteMultiset(args) => Some((args, FiniteMultiset))
+      case _ => None
+    }
+  }
+
   def negate(expr: Expr) : Expr = expr match {
     case Let(i,b,e) => Let(i,b,negate(e))
     case Not(e) => e
@@ -333,7 +347,7 @@ object Trees {
         else
           l
       }
-      case f @ FunctionInvocation(fd, args) => {
+      case n @ NAryOperator(args, recons) => {
         var change = false
         val rargs = args.map(a => {
           val ra = rec(a)
@@ -345,21 +359,9 @@ object Trees {
           }            
         })
         if(change)
-          FunctionInvocation(fd, rargs).setType(f.getType)
-        else
-          f
-      }
-      case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType)
-      case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType)
-      case And(exs) => And(exs.map(rec(_)))
-      case Or(exs) => Or(exs.map(rec(_)))
-      case Not(e) => Not(rec(e))
-      case u @ UnaryOperator(t,recons) => {
-        val r = rec(t)
-        if(r != t)
-          recons(r).setType(u.getType)
+          recons(rargs).setType(n.getType)
         else
-          u
+          n
       }
       case b @ BinaryOperator(t1,t2,recons) => {
         val r1 = rec(t1)
@@ -369,20 +371,24 @@ object Trees {
         else
           b
       }
-      case c @ CaseClass(cd, args) => {
-        CaseClass(cd, args.map(rec(_))).setType(c.getType)
-      }
-      case c @ CaseClassSelector(cc, sel) => {
-        val rc = rec(cc)
-        if(rc != cc)
-          CaseClassSelector(rc, sel).setType(c.getType)
+      case u @ UnaryOperator(t,recons) => {
+        val r = rec(t)
+        if(r != t)
+          recons(r).setType(u.getType)
         else
-          c
+          u
       }
-      case f @ FiniteSet(elems) => {
-        FiniteSet(elems.map(rec(_))).setType(f.getType)
+      case i @ IfExpr(t1,t2,t3) => {
+        val r1 = rec(t1)
+        val r2 = rec(t2)
+        val r3 = rec(t3)
+        if(r1 != t1 || r2 != t2 || r3 != t3)
+          IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType)
+        else
+          i
       }
-      case t: Terminal => t
+      case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType)
+      case t if t.isInstanceOf[Terminal] => t
       case unhandled => scala.Predef.error("Non-terminal case should be handled in searchAndApply: " + unhandled)
     }
 
@@ -456,18 +462,23 @@ object Trees {
     def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match {
       case v @ Variable(id) if s.isDefinedAt(id) => rec(s(id), s)
       case l @ Let(i,e,b) => rec(b, s + (i -> rec(e, s)))
-      case f @ FunctionInvocation(fd, args) => FunctionInvocation(fd, args.map(rec(_, s))).setType(f.getType)
       case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)).setType(i.getType)
       case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut, s), cses.map(inCase(_, s))).setType(m.getType)
-      case And(exs) => And(exs.map(rec(_, s)))
-      case Or(exs) => Or(exs.map(rec(_, s)))
-      case Not(e) => Not(rec(e, s))
-      case u @ UnaryOperator(t,recons) => {
-        val r = rec(t, s)
-        if(r != t)
-          recons(r).setType(u.getType)
+      case n @ NAryOperator(args, recons) => {
+        var change = false
+        val rargs = args.map(a => {
+          val ra = rec(a, s)
+          if(ra != a) {
+            change = true  
+            ra
+          } else {
+            a
+          }            
+        })
+        if(change)
+          recons(rargs).setType(n.getType)
         else
-          u
+          n
       }
       case b @ BinaryOperator(t1,t2,recons) => {
         val r1 = rec(t1, s)
@@ -477,20 +488,15 @@ object Trees {
         else
           b
       }
-      case c @ CaseClass(cd, args) => {
-        CaseClass(cd, args.map(rec(_, s))).setType(c.getType)
-      }
-      case c @ CaseClassSelector(cc, sel) => {
-        val rc = rec(cc, s)
-        if(rc != cc)
-          CaseClassSelector(rc, sel).setType(c.getType)
+      case u @ UnaryOperator(t,recons) => {
+        val r = rec(t, s)
+        if(r != t)
+          recons(r).setType(u.getType)
         else
-          c
-      }
-      case f @ FiniteSet(elems) => {
-        FiniteSet(elems.map(rec(_, s))).setType(f.getType)
+          u
       }
-      case _ => ex
+      case t if t.isInstanceOf[Terminal] => t
+      case unhandled => scala.Predef.error("Unhandled case in expandLets: " + unhandled)
     }
 
     def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = cse match {
-- 
GitLab