From ba41c6b20ab2f0015b7f70b45e1aef75a4d3816f Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Fri, 4 Jan 2013 19:28:21 +0100
Subject: [PATCH] Simplification of 1-uples and 0-uples.

This commit also fixes a serious bug (that apparently affected no one in
purescala/Extractors, namely the reconstruction function for Let and
LetTuple was broken).
---
 .../scala/leon/purescala/Extractors.scala     |  4 +-
 src/main/scala/leon/purescala/TreeOps.scala   | 96 +++++++++++++++++++
 src/main/scala/leon/purescala/Trees.scala     |  4 +-
 src/main/scala/leon/purescala/TypeTrees.scala | 11 ++-
 .../scala/leon/synthesis/SynthesisPhase.scala |  3 +-
 .../synthesis/heuristics/IntInduction.scala   |  2 +-
 6 files changed, 113 insertions(+), 7 deletions(-)

diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index 904e6e3d3..00ce83c66 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -69,8 +69,8 @@ object Extractors {
       case ArraySelect(t1, t2) => Some((t1, t2, ArraySelect))
       case Concat(t1,t2) => Some((t1,t2,Concat))
       case ListAt(t1,t2) => Some((t1,t2,ListAt))
-      case Let(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => Let(binders, e, body))) //TODO: shouldn't be "b" instead of "body" ?
-      case LetTuple(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => LetTuple(binders, e, body))) //TODO: shouldn't be "b" instead of "body" ?
+      case Let(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => Let(binders, e, b)))
+      case LetTuple(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => LetTuple(binders, e, b)))
       case (ex: BinaryExtractable) => ex.extract
       case _ => None
     }
diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index c3cc7809d..4a5ca356c 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -1254,6 +1254,102 @@ object TreeOps {
     rec(expr, Nil)
   }
 
+  // Eliminates tuples of arity 0 and 1. This function also affects types!
+  // Only rewrites local fundefs (i.e. LetDef's).
+  def rewriteTuples(expr: Expr) : Expr = {
+    def mapType(tt : TypeTree) : Option[TypeTree] = tt match {
+      case TupleType(ts) => ts.size match {
+        case 0 => Some(UnitType)
+        case 1 => Some(ts(0))
+        case _ =>
+          val tss = ts.map(mapType)
+          if(tss.exists(_.isDefined)) {
+            Some(TupleType((tss zip ts).map(p => p._1.getOrElse(p._2))))
+          } else {
+            None
+          }
+      }
+      case ListType(t)           => mapType(t).map(ListType(_))
+      case SetType(t)            => mapType(t).map(SetType(_))
+      case MultisetType(t)       => mapType(t).map(MultisetType(_))
+      case ArrayType(t)          => mapType(t).map(ArrayType(_))
+      case MapType(f,t)          => 
+        val (f2,t2) = (mapType(f),mapType(t))
+        if(f2.isDefined || t2.isDefined) {
+          Some(MapType(f2.getOrElse(f), t2.getOrElse(t)))
+        } else {
+          None
+        }
+      case a : AbstractClassType => None
+      case c : CaseClassType     =>
+        // This is really just one big assertion. We don't rewrite class defs.
+        val ccd = c.classDef
+        val fieldTypes = ccd.fields.map(_.tpe)
+        if(fieldTypes.exists(t => t match {
+          case TupleType(ts) if ts.size <= 1 => true
+          case _ => false
+        })) {
+          scala.sys.error("Cannot rewrite case class def that contains degenerate tuple types.")
+        } else {
+          None
+        }
+      case Untyped | AnyType | BottomType | BooleanType | Int32Type | UnitType => None  
+    }
+
+    var funDefMap = Map.empty[FunDef,FunDef]
+
+    def fd2fd(funDef : FunDef) : FunDef = funDefMap.get(funDef) match {
+      case Some(fd) => fd
+      case None =>
+        if(funDef.args.map(vd => mapType(vd.tpe)).exists(_.isDefined)) {
+          scala.sys.error("Cannot rewrite function def that takes degenerate tuple arguments,")
+        }
+        val newFD = mapType(funDef.returnType) match {
+          case None => funDef
+          case Some(rt) =>
+            val fd = new FunDef(FreshIdentifier(funDef.id.name, true), rt, funDef.args)
+            // These will be taken care of in the recursive traversal.
+            fd.body = funDef.body
+            fd.precondition = funDef.precondition
+            fd.postcondition = funDef.postcondition
+            fd
+        }
+        funDefMap = funDefMap.updated(funDef, newFD)
+        newFD
+    }
+
+    def pre(e : Expr) : Expr = e match {
+      case Tuple(Seq()) => UnitLiteral
+
+      case Tuple(Seq(s)) => pre(s)
+
+      case ts @ TupleSelect(t, 1) => t.getType match {
+        case TupleOneType(_) => pre(t)
+        case _ => ts
+      }
+
+      case LetTuple(bs, v, bdy) if bs.size == 1 =>
+        Let(bs(0), v, bdy)
+
+      case l @ LetDef(fd, bdy) =>
+        LetDef(fd2fd(fd), bdy)
+
+      case r @ ResultVariable() =>
+        mapType(r.getType).map { newType =>
+          ResultVariable().setType(newType)
+        } getOrElse {
+          r
+        }
+
+      case FunctionInvocation(fd, args) =>
+        FunctionInvocation(fd2fd(fd), args)
+
+      case _ => e
+    }
+
+    simplePreTransform(pre)(expr)
+  }
+
   def formulaSize(e: Expr): Int = e match {
     case t: Terminal =>
       1
diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index 4f88d3eb8..4e5e0fea6 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -47,7 +47,7 @@ object Trees {
   case class FunctionInvocation(funDef: FunDef, args: Seq[Expr]) extends Expr with FixedType with ScalacPositional {
     val fixedType = funDef.returnType
 
-    funDef.args.zip(args).foreach{ case (a, c) => typeCheck(c, a.tpe) }
+    funDef.args.zip(args).foreach { case (a, c) => typeCheck(c, a.tpe) }
   }
   case class IfExpr(cond: Expr, then: Expr, elze: Expr) extends Expr 
 
@@ -67,7 +67,7 @@ object Trees {
         assert(index <= ts.size)
         ts(index - 1)
 
-      case _ => assert(false); Untyped
+      case _ => scala.sys.error("Applying TupleSelect on a non-tuple tree [%s] of type [%s].".format(tuple, tuple.getType))
     }
   }
 
diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala
index 6dd49f749..7d2897ad5 100644
--- a/src/main/scala/leon/purescala/TypeTrees.scala
+++ b/src/main/scala/leon/purescala/TypeTrees.scala
@@ -176,7 +176,16 @@ object TypeTrees {
     //  case (tt: TupleType) => Some(tt.bases)
     //  case _ => None
     //}
-    def unapply(tt: TupleType): Option[Seq[TypeTree]] = Some(tt.bases)
+    def unapply(tt: TupleType): Option[Seq[TypeTree]] = if(tt == null) None else Some(tt.bases)
+  }
+  object TupleOneType {
+    def unapply(tt : TupleType) : Option[TypeTree] = if(tt == null) None else {
+      if(tt.bases.size == 1) {
+        Some(tt.bases.head)
+      } else {
+        None
+      }
+    }
   }
 
   case class ListType(base: TypeTree) extends TypeTree
diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala
index 0d2701dbe..7adb83366 100644
--- a/src/main/scala/leon/synthesis/SynthesisPhase.scala
+++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala
@@ -122,7 +122,8 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
       decomposeIfs _,
       patternMatchReconstruction _,
       simplifyTautologies(uninterpretedZ3)(_),
-      simplifyLets _
+      simplifyLets _,
+      rewriteTuples _
     )
 
     def simplify(e: Expr): Expr = simplifiers.foldLeft(e){ (x, sim) => sim(x) }
diff --git a/src/main/scala/leon/synthesis/heuristics/IntInduction.scala b/src/main/scala/leon/synthesis/heuristics/IntInduction.scala
index a882fe32b..8151341a3 100644
--- a/src/main/scala/leon/synthesis/heuristics/IntInduction.scala
+++ b/src/main/scala/leon/synthesis/heuristics/IntInduction.scala
@@ -38,7 +38,7 @@ case object IntInduction extends Rule("Int Induction") with Heuristic {
 
             val newFun = new FunDef(FreshIdentifier("rec", true), tpe, Seq(VarDecl(inductOn, inductOn.getType)))
             newFun.precondition = Some(preIn)
-            newFun.postcondition = Some(LetTuple(p.xs.toSeq, ResultVariable(), p.phi))
+            newFun.postcondition = Some(LetTuple(p.xs.toSeq, ResultVariable().setType(tpe), p.phi))
 
             newFun.body = Some(
               IfExpr(Equals(Variable(inductOn), IntLiteral(0)),
-- 
GitLab