diff --git a/src/main/scala/leon/Evaluator.scala b/src/main/scala/leon/Evaluator.scala
index 47337af61e94395275fa76f349958f08856e00b9..8767458c2251b0d8cdeb8e2a87af9c6f47306a8f 100644
--- a/src/main/scala/leon/Evaluator.scala
+++ b/src/main/scala/leon/Evaluator.scala
@@ -247,6 +247,7 @@ object Evaluator {
         case e @ EmptySet(_) => e
         case i @ IntLiteral(_) => i
         case b @ BooleanLiteral(_) => b
+        case u @ UnitLiteral => u
 
         case f @ ArrayMake(default) => {
           val rDefault = rec(ctx, default)
diff --git a/src/main/scala/leon/FairZ3Solver.scala b/src/main/scala/leon/FairZ3Solver.scala
index 3ab173b7d855c76e1d614200cbab8dc3b9078eaf..ea2681cb4dd76c4f487b12fee6b7d94b83df3644 100644
--- a/src/main/scala/leon/FairZ3Solver.scala
+++ b/src/main/scala/leon/FairZ3Solver.scala
@@ -100,6 +100,9 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S
   private var mapSorts: Map[TypeTree, Z3Sort] = Map.empty
   private var arraySorts: Map[TypeTree, Z3Sort] = Map.empty
 
+  private var unitSort: Z3Sort = null
+  private var unitValue: Z3AST = null
+
   protected[leon] var funSorts: Map[TypeTree, Z3Sort] = Map.empty
   protected[leon] var funDomainConstructors: Map[TypeTree, Z3FuncDecl] = Map.empty
   protected[leon] var funDomainSelectors: Map[TypeTree, Seq[Z3FuncDecl]] = Map.empty
@@ -191,6 +194,27 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S
     setSorts = Map.empty
     setCardFuns = Map.empty
 
+    //unitSort = z3.mkUninterpretedSort("unit")
+    //unitValue = z3.mkFreshConst("Unit", unitSort)
+    //val bound = z3.mkBound(0, unitSort)
+    //val eq = z3.mkEq(bound, unitValue)
+    //val decls = Seq((z3.mkFreshStringSymbol("u"), unitSort))
+    //val unitAxiom = z3.mkForAll(0, Seq(), decls, eq)
+    //println(unitAxiom)
+    //println(unitValue)
+    //z3.assertCnstr(unitAxiom)
+    val Seq((us, Seq(unitCons), Seq(unitTester), _)) = z3.mkADTSorts(
+      Seq(
+        (
+          "Unit",
+          Seq("Unit"),
+          Seq(Seq())
+        )
+      )
+    )
+    unitSort = us
+    unitValue = unitCons()
+
     val intSetSort = typeToSort(SetType(Int32Type))
     intSetMinFun = z3.mkFreshFuncDecl("setMin", Seq(intSetSort), intSort)
     intSetMaxFun = z3.mkFreshFuncDecl("setMax", Seq(intSetSort), intSort)
@@ -360,6 +384,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S
   def typeToSort(tt: TypeTree): Z3Sort = tt match {
     case Int32Type => intSort
     case BooleanType => boolSort
+    case UnitType => unitSort
     case AbstractClassType(cd) => adtSorts(cd)
     case CaseClassType(cd) => {
       if (cd.hasParent) {
@@ -967,6 +992,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S
         case Not(e) => z3.mkNot(rec(e))
         case IntLiteral(v) => z3.mkInt(v, intSort)
         case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse()
+        case UnitLiteral => unitValue
         case Equals(l, r) => z3.mkEq(rec(l), rec(r))
         case Plus(l, r) => if(USEBV) z3.mkBVAdd(rec(l), rec(r)) else z3.mkAdd(rec(l), rec(r))
         case Minus(l, r) => if(USEBV) z3.mkBVSub(rec(l), rec(r)) else z3.mkSub(rec(l), rec(r))
@@ -1131,7 +1157,9 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S
         Tuple(rargs)
       }
       case other => 
-        z3.getASTKind(t) match {
+        if(t == unitValue) 
+          UnitLiteral
+        else z3.getASTKind(t) match {
           case Z3AppAST(decl, args) => {
             val argsSize = args.size
             if(argsSize == 0 && z3IdToExpr.isDefinedAt(t)) {
diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala
index b3b3c17b410bc1c0565c1b4fab5964edb914aed7..68161e9e4e93600f2df7be787b9db8f4af796e3c 100644
--- a/src/main/scala/leon/Main.scala
+++ b/src/main/scala/leon/Main.scala
@@ -32,7 +32,7 @@ object Main {
 
   private def defaultAction(program: Program, reporter: Reporter) : Unit = {
     Logger.debug("Default action on program: " + program, 3, "main")
-    val passManager = new PassManager(Seq(ArrayTransformation, EpsilonElimination, ImperativeCodeElimination, UnitElimination, FunctionClosure, FunctionHoisting, Simplificator))
+    val passManager = new PassManager(Seq(ArrayTransformation, EpsilonElimination, ImperativeCodeElimination, /*UnitElimination,*/ FunctionClosure, FunctionHoisting, Simplificator))
     val program2 = passManager.run(program)
     val analysis = new Analysis(program2, reporter)
     analysis.analyse
diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala
index 6ac28303d874fa3cdbd35a928531b93be4c2af8d..6b9ef14e5d419082c8901fd2bef62121ec150a76 100644
--- a/src/main/scala/leon/plugin/CodeExtraction.scala
+++ b/src/main/scala/leon/plugin/CodeExtraction.scala
@@ -935,6 +935,7 @@ trait CodeExtraction extends Extractors {
     def rec(tr: Type): purescala.TypeTrees.TypeTree = tr match {
       case tpe if tpe == IntClass.tpe => Int32Type
       case tpe if tpe == BooleanClass.tpe => BooleanType
+      case tpe if tpe == UnitClass.tpe => UnitType
       case TypeRef(_, sym, btt :: Nil) if isSetTraitSym(sym) => SetType(rec(btt))
       case TypeRef(_, sym, btt :: Nil) if isMultisetTraitSym(sym) => MultisetType(rec(btt))
       case TypeRef(_, sym, btt :: Nil) if isOptionClassSym(sym) => OptionType(rec(btt))
diff --git a/testcases/regression/invalid/Unit1.scala b/testcases/regression/invalid/Unit1.scala
new file mode 100644
index 0000000000000000000000000000000000000000..789a8f058cd8145a0dd3b7bb0d72d44fb92ee94e
--- /dev/null
+++ b/testcases/regression/invalid/Unit1.scala
@@ -0,0 +1,7 @@
+object Unit1 {
+
+  def foo(u: Unit): Unit = ({
+    u
+  }) ensuring(_ != ())
+
+}
diff --git a/testcases/regression/valid/Unit1.scala b/testcases/regression/valid/Unit1.scala
new file mode 100644
index 0000000000000000000000000000000000000000..a7b890d762648cba480c817506c7eee22259860d
--- /dev/null
+++ b/testcases/regression/valid/Unit1.scala
@@ -0,0 +1,7 @@
+object Unit1 {
+
+  def foo(): Unit = ({
+    ()
+  }) ensuring(_ == ())
+
+}
diff --git a/testcases/regression/valid/Unit2.scala b/testcases/regression/valid/Unit2.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ac659589af503a8a79f261f9eb05be095e2e0943
--- /dev/null
+++ b/testcases/regression/valid/Unit2.scala
@@ -0,0 +1,7 @@
+object Unit2 {
+
+  def foo(u: Unit): Unit = {
+    u
+  } ensuring(_ == ())
+
+}