From ba235ac16d188bb171e4139ffa1d61b75a7cdfc1 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Tue, 23 Oct 2012 19:59:29 +0200
Subject: [PATCH] Simplify Lets, fix types

---
 src/main/scala/leon/purescala/Trees.scala     | 59 ++++++++++++++++++-
 src/main/scala/leon/synthesis/Rules.scala     |  2 +-
 .../scala/leon/synthesis/Synthesizer.scala    |  2 +-
 3 files changed, 58 insertions(+), 5 deletions(-)

diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index 58e4f8162..808ab52e0 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -580,6 +580,16 @@ object Trees {
           fd.postcondition = fd.postcondition.map(rec(_))
           LetDef(fd, rec(b)).setType(l.getType)
         }
+
+        case lt @ LetTuple(ids, expr, body) => {
+          val re = rec(expr)
+          val rb = rec(body)
+          if (re != expr || rb != body) {
+            LetTuple(ids, re, rb).setType(lt.getType)
+          } else {
+            lt
+          }
+        }
         case n @ NAryOperator(args, recons) => {
           var change = false
           val rargs = args.map(a => {
@@ -673,6 +683,15 @@ object Trees {
           l
         })
       }
+      case l @ LetTuple(ids,e,b) => {
+        val re = rec(e)
+        val rb = rec(b)
+        applySubst(if(re != e || rb != b) {
+          LetTuple(ids,re,rb).setType(l.getType)
+        } else {
+          l
+        })
+      }
       case l @ LetVar(i,e,b) => {
         val re = rec(e)
         val rb = rec(b)
@@ -1059,11 +1078,45 @@ object Trees {
           None
         }
       }
-      case letTuple @ LetTuple(ids, e, body) =>
-      None
+      case letTuple @ LetTuple(ids, expr, body) if ids.size == 1 =>
+        simplerLet(Let(ids.head, TupleSelect(expr, 0).setType(ids.head.getType), body))
+
+      case letTuple @ LetTuple(ids, Tuple(exprs), body) =>
+
+        var newBody = body
+
+        val (remIds, remExprs) = (ids zip exprs).filter { 
+          case (id, value: Terminal) =>
+            newBody = replace(Map((Variable(id) -> value)), newBody)
+            //we replace, so we drop old
+            false
+          case (id, value) =>
+            val occurences = treeCatamorphism[Int]((e:Expr) => e match {
+              case Variable(x) if x == id => 1
+              case _ => 0
+            }, (x:Int,y:Int)=>x+y, body)
+
+            if(occurences == 0) {
+              false
+            } else if(occurences == 1) {
+              newBody = replace(Map((Variable(id) -> value)), newBody)
+              false
+            } else {
+              true
+            }
+        }.unzip
+
+
+        if (remIds.isEmpty) {
+          Some(newBody)
+        } else if (remIds.tail.isEmpty) {
+          Some(Let(remIds.head, remExprs.head, newBody))
+        } else {
+          Some(LetTuple(remIds, Tuple(remExprs), newBody))
+        }
       case _ => None 
     }
-    searchAndReplace(simplerLet)(expr)
+    searchAndReplaceDFS(simplerLet)(expr)
   }
 
   // Pulls out all let constructs to the top level, and makes sure they're
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index e1dab45ee..1765c44f3 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -58,7 +58,7 @@ class OnePoint(synth: Synthesizer) extends Rule("One-point", synth) {
           val onSuccess: List[Solution] => Solution = { 
             case List(Solution(pre, term, sc)) =>
               if (oxs.isEmpty) {
-                Solution(pre, e, sc) 
+                Solution(pre, Tuple(e :: Nil), sc) 
               } else {
                 Solution(pre, LetTuple(oxs, term, subst(x -> e, Tuple(p.xs.map(Variable(_))))), sc) 
               }
diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala
index 8a2426f3b..daebf0f41 100644
--- a/src/main/scala/leon/synthesis/Synthesizer.scala
+++ b/src/main/scala/leon/synthesis/Synthesizer.scala
@@ -96,7 +96,7 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver]) {
         val sol = synthesize(Problem(as, phi, xs), rules)
 
         info("Scala code:")
-        info(ScalaPrinter(sol.toExpr))
+        info(ScalaPrinter(simplifyLets(sol.toExpr)))
 
         a
       case _ =>
-- 
GitLab