diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala
index 7550b0304dd8b03d665b64e9e11ec6b9f2dd59bd..04d91cb80b0c0e2b67a21482a2d18c9df003530e 100644
--- a/src/main/scala/leon/purescala/Expressions.scala
+++ b/src/main/scala/leon/purescala/Expressions.scala
@@ -791,7 +791,7 @@ object Expressions {
     *
     * @param exprs The expressions in the tuple
     */
-  case class Tuple (exprs: Seq[Expr]) extends Expr {
+  case class Tuple(exprs: Seq[Expr]) extends Expr {
     require(exprs.size >= 2)
     val getType = TupleType(exprs.map(_.getType)).unveilUntyped
   }
diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala
index 0b02ca9e3daea630aafb4000a7b2d768c1b6dc5a..1edf989ae8d85f2fd70d874ee09c5bb74a87f4e9 100644
--- a/src/main/scala/leon/purescala/Types.scala
+++ b/src/main/scala/leon/purescala/Types.scala
@@ -72,7 +72,7 @@ object Types {
    * If you are not sure about the requirement, 
    * you should use tupleTypeWrap in purescala.Constructors
    */
-  case class TupleType (bases: Seq[TypeTree]) extends TypeTree {
+  case class TupleType(bases: Seq[TypeTree]) extends TypeTree {
     val dimension: Int = bases.length
     require(dimension >= 2)
   }
diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/leon/utils/PreprocessingPhase.scala
index 76f54169a7ab1548ded751abe6ef1b583f8928eb..558f4a32c1c19441126bdaa28e63823da8a49650 100644
--- a/src/main/scala/leon/utils/PreprocessingPhase.scala
+++ b/src/main/scala/leon/utils/PreprocessingPhase.scala
@@ -7,7 +7,7 @@ import leon.purescala._
 import leon.purescala.Definitions.Program
 import leon.solvers.isabelle.AdaptationPhase
 import leon.verification.InjectAsserts
-import leon.xlang.{NoXLangFeaturesChecking, XLangDesugaringPhase}
+import leon.xlang.{NoXLangFeaturesChecking, XLangDesugaringPhase, XLangCleanupPhase}
 
 class PreprocessingPhase(genc: Boolean = false) extends LeonPhase[Program, Program] {
 
@@ -35,6 +35,7 @@ class PreprocessingPhase(genc: Boolean = false) extends LeonPhase[Program, Progr
     def pipeEnd = (
       InjectAsserts  andThen
       FunctionClosure andThen
+      XLangCleanupPhase andThen
       AdaptationPhase
     ) when (!genc)
 
diff --git a/src/main/scala/leon/xlang/XLangCleanupPhase.scala b/src/main/scala/leon/xlang/XLangCleanupPhase.scala
new file mode 100644
index 0000000000000000000000000000000000000000..10feb89fe68ad880300c0b0f88387d280ee2dd8a
--- /dev/null
+++ b/src/main/scala/leon/xlang/XLangCleanupPhase.scala
@@ -0,0 +1,204 @@
+/* Copyright 2009-2016 EPFL, Lausanne */
+
+package leon
+package xlang
+
+import purescala.Common._
+import purescala.Definitions._
+import purescala.DefinitionTransformer
+import purescala.DefOps._
+import purescala.Expressions._
+import purescala.Extractors._
+import purescala.Constructors._
+import purescala.Types._
+
+/** Cleanup the program after running XLang desugaring.
+  *
+  * This functions simplifies away typical pattern of expressions
+  * that can be generated during xlang desugaring phase. The most
+  * common case is the generation of function returning tuple with
+  * Unit in it, which can be safely eliminated.
+  */
+object XLangCleanupPhase extends TransformationPhase {
+
+  val name = "xlang cleanup"
+  val description = "Cleanup program after running xlang desugaring"
+
+  //private var fun2FreshFun: Map[FunDef, FunDef] = Map()
+  //private var id2FreshId: Map[Identifier, Identifier] = Map()
+
+  override def apply(ctx: LeonContext, program: Program): Program = {
+
+    val transformer = new DefinitionTransformer {
+      override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match {
+        case (tt: TupleType) if tt.bases.exists(_ == UnitType) => 
+          Some(tupleTypeWrap(tt.bases.filterNot(_ == UnitType)))
+        case _ => None
+      }
+
+      override def transformExpr(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Option[Expr] = expr match {
+        case sel@TupleSelect(IsTyped(t, TupleType(bases)), index) =>
+          if(bases(index-1) == UnitType) 
+            Some(UnitLiteral())
+          else {
+            val nbUnitsUntilIndex = bases.take(index).count(_ == UnitType)
+            if(nbUnitsUntilIndex == 0)
+              None
+            else if(bases.count(_ != UnitType) == 1)
+              Some(t)
+            else
+              Some(TupleSelect(t, index - nbUnitsUntilIndex).copiedFrom(sel))
+          }
+        case tu@Tuple(es) if es.exists(_.getType == UnitType) => 
+          Some(tupleWrap(es.filterNot(_.getType == UnitType)).copiedFrom(tu))
+        case let@Let(id, IsTyped(t, tt@TupleType(bases)), rest) if bases.exists(_.getType == UnitType) =>
+          val ntt = tupleTypeFilterUnits(tt)
+          val nid = id.duplicate(tpe=ntt)
+          Some(Let(nid, t, transform(rest)(bindings + (id -> nid))).copiedFrom(let))
+
+        case _ => None
+      }
+    }
+
+    val cdsMap = program.definedClasses.map(cd => cd -> transformer.transform(cd)).toMap
+    val fdsMap = program.definedFunctions.map(fd => fd -> transformer.transform(fd)).toMap
+    val pgm = replaceDefsInProgram(program)(fdsMap, cdsMap)
+    pgm
+  }
+
+  private def tupleTypeFilterUnits(tt: TupleType): TypeTree = tupleTypeWrap(tt.bases.filterNot(_ == UnitType))
+}
+
+//    val newUnits = pgm.units map { u => u.copy(defs = u.defs.map { 
+//      case m: ModuleDef =>
+//        fun2FreshFun = Map()
+//        val allFuns = m.definedFunctions
+//        //first introduce new signatures without Unit parameters
+//        allFuns.foreach(fd => {
+//          if(fd.returnType != UnitType && fd.params.exists(vd => vd.getType == UnitType)) {
+//            val freshFunDef = fd.duplicate(params = fd.params.filterNot(vd => vd.getType == UnitType))
+//            fun2FreshFun += (fd -> freshFunDef)
+//          } else {
+//            fun2FreshFun += (fd -> fd) //this will make the next step simpler
+//          }
+//        })
+//
+//        //then apply recursively to the bodies
+//        val newFuns = allFuns.collect{ case fd if fd.returnType != UnitType =>
+//          val newFd = fun2FreshFun(fd)
+//          newFd.fullBody = removeUnit(fd.fullBody)
+//          newFd
+//        }
+//
+//        ModuleDef(m.id, m.definedClasses ++ newFuns, m.isPackageObject )
+//      case d =>
+//        d
+//    })}
+//
+//
+//    Program(newUnits)
+//  }
+//
+//  private def simplifyType(tpe: TypeTree): TypeTree = tpe match {
+//    case TupleType(tpes) => tupleTypeWrap(tpes.map(simplifyType).filterNot{ _ == UnitType })
+//    case t => t
+//  }
+//
+//  //remove unit value as soon as possible, so expr should never be equal to a unit
+//  private def removeUnit(expr: Expr): Expr = {
+//    assert(expr.getType != UnitType)
+//    expr match {
+//      case fi@FunctionInvocation(tfd, args) =>
+//        val newArgs = args.filterNot(arg => arg.getType == UnitType)
+//        FunctionInvocation(fun2FreshFun(tfd.fd).typed(tfd.tps), newArgs).setPos(fi)
+//
+//      case IsTyped(Tuple(args), TupleType(tpes)) =>
+//        val newArgs = tpes.zip(args).collect {
+//          case (tp, arg) if tp != UnitType => arg
+//        }
+//        tupleWrap(newArgs.map(removeUnit)) // @mk: FIXME this may actually return a Unit, is that cool?
+//
+//      case ts@TupleSelect(t, index) =>
+//        val TupleType(tpes) = t.getType
+//        val simpleTypes = tpes map simplifyType
+//        val newArity = tpes.count(_ != UnitType)
+//        val newIndex = simpleTypes.take(index).count(_ != UnitType)
+//        tupleSelect(removeUnit(t), newIndex, newArity)
+//
+//      case Let(id, e, b) =>
+//        if(id.getType == UnitType)
+//          removeUnit(b)
+//        else {
+//          id.getType match {
+//            case TupleType(tpes) if tpes.contains(UnitType) => {
+//              val newTupleType = tupleTypeWrap(tpes.filterNot(_ == UnitType))
+//              val freshId = FreshIdentifier(id.name, newTupleType)
+//              id2FreshId += (id -> freshId)
+//              val newBody = removeUnit(b)
+//              id2FreshId -= id
+//              Let(freshId, removeUnit(e), newBody)
+//            }
+//            case _ => Let(id, removeUnit(e), removeUnit(b))
+//          }
+//        }
+//
+//      case LetDef(fds, b) =>
+//        val nonUnits = fds.filter(fd => fd.returnType != UnitType)
+//        if(nonUnits.isEmpty) {
+//          removeUnit(b)
+//        } else {
+//          val fdtoFreshFd = for(fd <- nonUnits) yield {
+//            val m = if(fd.params.exists(vd => vd.getType == UnitType)) {
+//              val freshFunDef = fd.duplicate(params = fd.params.filterNot(vd => vd.getType == UnitType))
+//              fd -> freshFunDef
+//            } else {
+//              fd -> fd
+//            }
+//            fun2FreshFun += m
+//            m
+//          }
+//          for((fd, freshFunDef) <- fdtoFreshFd) {
+//            if(fd.params.exists(vd => vd.getType == UnitType)) {
+//              freshFunDef.fullBody = removeUnit(fd.fullBody)
+//            } else {
+//              fd.body = fd.body.map(b => removeUnit(b))
+//            }
+//          }
+//          val rest = removeUnit(b)
+//          val newFds = for((fd, freshFunDef) <- fdtoFreshFd) yield {
+//            fun2FreshFun -= fd
+//            if(fd.params.exists(vd => vd.getType == UnitType)) {
+//              freshFunDef
+//            } else {
+//              fd
+//            }
+//          }
+//          
+//          letDef(newFds, rest)
+//        }
+//
+//      case ite@IfExpr(cond, tExpr, eExpr) =>
+//        val thenRec = removeUnit(tExpr)
+//        val elseRec = removeUnit(eExpr)
+//        IfExpr(removeUnit(cond), thenRec, elseRec)
+//
+//      case v @ Variable(id) =>
+//        if(id2FreshId.isDefinedAt(id))
+//          Variable(id2FreshId(id))
+//        else v
+//
+//      case m @ MatchExpr(scrut, cses) =>
+//        val scrutRec = removeUnit(scrut)
+//        val csesRec = cses.map{ cse =>
+//          MatchCase(cse.pattern, cse.optGuard map removeUnit, removeUnit(cse.rhs))
+//        }
+//        matchExpr(scrutRec, csesRec).setPos(m)
+//
+//      case Operator(args, recons) =>
+//        recons(args.map(removeUnit))
+//
+//      case _ => sys.error("not supported: " + expr)
+//    }
+//  }
+//
+//}