diff --git a/src/main/scala/leon/UnitElimination.scala b/src/main/scala/leon/UnitElimination.scala
index 9417f7bfb97e69c43da0b9ca940af2f069d21159..005544d9be2e8fff0e86f106464c29358383d772 100644
--- a/src/main/scala/leon/UnitElimination.scala
+++ b/src/main/scala/leon/UnitElimination.scala
@@ -41,6 +41,15 @@ object UnitElimination extends Pass {
     Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants))
   }
 
+  private def simplifyType(tpe: TypeTree): TypeTree = tpe match {
+    case TupleType(tpes) => tpes.map(simplifyType).filterNot{ case UnitType => true case _ => false } match {
+      case Seq() => UnitType
+      case Seq(tpe) => tpe
+      case tpes => TupleType(tpes)
+    }
+    case t => t
+  }
+
   //remove unit value as soon as possible, so expr should never be equal to a unit
   private def removeUnit(expr: Expr): Expr = {
     assert(expr.getType != UnitType)