diff --git a/lib/z3.jar b/lib/z3.jar
index d99a37fc51a48beb3079d05fd603ae91706c4234..6416b0828d4ee9c354c0ad8d0cbf8b0586e36680 100644
Binary files a/lib/z3.jar and b/lib/z3.jar differ
diff --git a/src/purescala/Definitions.scala b/src/purescala/Definitions.scala
index ee59e388ad6ddb8981fa9c6d495f7bbdceb1b3dd..38196e50c1cb5afa8f07bfed087a8ac5237297d0 100644
--- a/src/purescala/Definitions.scala
+++ b/src/purescala/Definitions.scala
@@ -276,6 +276,7 @@ object Definitions {
   }
   @serializable class FunDef(val id: Identifier, val returnType: TypeTree, val args: VarDecls) extends Definition with ScalacPositional {
     var body: Option[Expr] = None
+    def implementation : Option[Expr] = body
     var precondition: Option[Expr] = None
     var postcondition: Option[Expr] = None
 
@@ -300,4 +301,68 @@ object Definitions {
 
     def isPrivate : Boolean = annots.contains("private")
   }
+  
+  object Catamorphism {
+    // If a function is a catamorphism, this deconstructs it into the cases. Eg:
+    // def size(l : List) : Int = ...
+    // should return:
+    // List,
+    // Seq(
+    //   (Nil(), 0)
+    //   (Cons(x, xs), 1 + size(xs)))
+    // ...where x and xs are fresh (and could be unused in the expr)
+    import scala.collection.mutable.{Map=>MutableMap}
+    type CataRepr = (AbstractClassDef,Seq[(CaseClass,Expr)])
+    private val unapplyCache : MutableMap[FunDef,CataRepr] = MutableMap.empty
+
+    def unapply(funDef : FunDef) : Option[CataRepr] = if(
+        funDef == null ||
+        funDef.args.size != 1 ||
+        funDef.hasPrecondition ||
+        !funDef.hasImplementation ||
+        (funDef.hasPostcondition && functionCallsOf(funDef.postcondition.get) != Set.empty)
+      ) {
+      None 
+    } else if(unapplyCache.isDefinedAt(funDef)) {
+      Some(unapplyCache(funDef))
+    } else {
+      var moreConditions = true
+      val argVar = funDef.args(0).toVariable
+      val argVarType = argVar.getType
+      val body = funDef.body.get
+      val iteized = matchToIfThenElse(body)
+      val invocations = functionCallsOf(iteized)
+      moreConditions = moreConditions && invocations.forall(_ match {
+        case FunctionInvocation(fd, Seq(CaseClassSelector(_, e, _))) if fd == funDef && e == argVar => true
+        case _ => false
+      })
+      moreConditions = moreConditions && argVarType.isInstanceOf[AbstractClassType]
+      var spmList : Seq[(CaseClassDef,Identifier,Seq[Identifier],Expr)] = Seq.empty
+      moreConditions = moreConditions && (body match {
+        case SimplePatternMatching(scrut, _, s) if scrut == argVar => spmList = s; true
+        case _ => false
+      })
+
+      val patternSeq : Seq[(CaseClass,Expr)] = if(moreConditions) {
+        spmList.map(tuple => {
+          val (ccd, id, ids, ex) = tuple
+          val ex2 = matchToIfThenElse(ex)
+          if(!(variablesOf(ex2) -- ids).isEmpty) {
+            moreConditions = false
+          }
+          (CaseClass(ccd, ids.map(Variable(_))), ex2)
+        })
+      } else {
+        Seq.empty
+      }
+
+      if(moreConditions) {
+        val finalResult = (argVarType.asInstanceOf[AbstractClassType].classDef, patternSeq)
+        unapplyCache(funDef) = finalResult
+        Some(finalResult)
+      } else {
+        None
+      }
+    }
+  }
 }
diff --git a/src/purescala/FairZ3Solver.scala b/src/purescala/FairZ3Solver.scala
index dd81fa3ea03917a8a308f9b92cfa30f4af06e86a..6371d0a309372832a1270b958ffe5e1e43f263ea 100644
--- a/src/purescala/FairZ3Solver.scala
+++ b/src/purescala/FairZ3Solver.scala
@@ -14,7 +14,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
   // have to comment this to use the solver for constraint solving...
   // assert(Settings.useFairInstantiator)
 
-  private final val UNKNOWNASSAT : Boolean = false
+  private final val UNKNOWNASSAT : Boolean = !Settings.noForallAxioms
 
   val description = "Fair Z3 Solver"
   override val shortDescription = "Z3-f"
@@ -49,7 +49,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
       z3.delete
     }
     z3 = new Z3Context(z3cfg)
-    //z3.traceToStdout
+    // z3.traceToStdout
 
     exprToZ3Id = Map.empty
     z3IdToExpr = Map.empty
@@ -185,6 +185,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
 
   private var functionMap: Map[FunDef, Z3FuncDecl] = Map.empty
   private var reverseFunctionMap: Map[Z3FuncDecl, FunDef] = Map.empty
+  private var axiomatizedFunctions : Set[FunDef] = Set.empty
 
   def prepareFunctions: Unit = {
     functionMap = Map.empty
@@ -197,6 +198,61 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
       functionMap = functionMap + (funDef -> z3Decl)
       reverseFunctionMap = reverseFunctionMap + (z3Decl -> funDef)
     }
+
+    if(!Settings.noForallAxioms) {
+      prepareAxioms
+    }
+  }
+
+  private def prepareAxioms : Unit = {
+    assert(!Settings.noForallAxioms)
+    program.definedFunctions.foreach(_ match {
+      case fd @ Catamorphism(acd, cases) => {
+        assert(!fd.hasPrecondition && fd.hasImplementation)
+        for(cse <- cases) {
+          val (cc @ CaseClass(ccd, args), expr) = cse
+          assert(args.forall(_.isInstanceOf[Variable]))
+          val argsAsIDs = args.map(_.asInstanceOf[Variable].id)
+          assert(variablesOf(expr) -- argsAsIDs.toSet == Set.empty)
+          val axiom : Z3AST = if(args.isEmpty) {
+            val eq = Equals(FunctionInvocation(fd, Seq(cc)), expr)
+            toZ3Formula(z3, eq).get
+          } else {
+            val z3ArgSorts = argsAsIDs.map(i => typeToSort(i.getType))
+            val boundVars = z3ArgSorts.zipWithIndex.map(p => z3.mkBound(p._2, p._1))
+            val map : Map[Identifier,Z3AST] = (argsAsIDs zip boundVars).toMap
+            val eq = Equals(FunctionInvocation(fd, Seq(cc)), expr)
+            val z3IzedEq = toZ3Formula(z3, eq, map).get
+            val z3IzedCC = toZ3Formula(z3, cc, map).get
+            val pattern = z3.mkPattern(functionDefToDecl(fd)(z3IzedCC))
+            val nameTypePairs = z3ArgSorts.map(s => (z3.mkFreshIntSymbol, s))
+            z3.mkForAll(0, List(pattern), nameTypePairs, z3IzedEq)
+          }
+          //println("I'll assert now an axiom: " + axiom)
+          //println("Case axiom:")
+          //println(axiom)
+          z3.assertCnstr(axiom)
+        }
+
+        if(fd.hasPostcondition) {
+          // we know it doesn't contain any function invocation
+          val cleaned = matchToIfThenElse(expandLets(fd.postcondition.get))
+          val argSort = typeToSort(fd.args(0).getType)
+          val bound = z3.mkBound(0, argSort)
+          val subst = replace(Map(ResultVariable() -> FunctionInvocation(fd, Seq(fd.args(0).toVariable))), cleaned)
+          val z3IzedPost = toZ3Formula(z3, subst, Map(fd.args(0).id -> bound)).get
+          val pattern = z3.mkPattern(functionDefToDecl(fd)(bound))
+          val nameTypePairs = Seq((z3.mkFreshIntSymbol, argSort))
+          val postAxiom = z3.mkForAll(0, List(pattern), nameTypePairs, z3IzedPost)
+          //println("Post axiom:")
+          //println(postAxiom)
+          z3.assertCnstr(postAxiom)
+        }
+
+        axiomatizedFunctions += fd
+      }
+      case _ => ;
+    })
   }
 
   // assumes prepareSorts has been called....
@@ -306,13 +362,17 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
     // println("Bunch of blocking bools: " + sib.map(_.toString).mkString(", "))
 
     // println("Basis : " + basis)
-    z3.assertCnstr(toZ3Formula(z3, basis).get)
+    val bb = toZ3Formula(z3, basis).get
+    // println("Base : " + bb)
+    z3.assertCnstr(bb)
     // println(toZ3Formula(z3, basis).get)
     // for(clause <- clauses) {
     //   println("we're getting a new clause " + clause)
     //   z3.assertCnstr(toZ3Formula(z3, clause).get)
     // }
-    z3.assertCnstr(toZ3Formula(z3, And(clauses)).get)
+    val cc = toZ3Formula(z3, And(clauses)).get
+    // println("CC : " + cc)
+    z3.assertCnstr(cc)
 
     blockingSet ++= Set(guards.map(p => if(p._2) Not(Variable(p._1)) else Variable(p._1)) : _*)
 
@@ -594,10 +654,10 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
     // exprToZ3Id = Map.empty
     // z3IdToExpr = Map.empty
 
-    for((k, v) <- initialMap) {
-      exprToZ3Id += (k.toVariable -> v)
-      z3IdToExpr += (v -> k.toVariable)
-    }
+    // for((k, v) <- initialMap) {
+    //   exprToZ3Id += (k.toVariable -> v)
+    //   z3IdToExpr += (v -> k.toVariable)
+    // }
 
     def rec(ex: Expr): Z3AST = { 
       //println("Stacking up call for:")
@@ -630,6 +690,10 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
             newAST
           }
         }
+
+        case ite @ IfExpr(c, t, e) => {
+          z3.mkITE(rec(c), rec(t), rec(e))
+        }
         // case ite @ IfExpr(c, t, e) => {
         //   val switch = z3.mkFreshConst("path", z3.mkBoolSort)
         //   val placeHolder = z3.mkFreshConst("ite", typeToSort(ite.getType))
@@ -821,9 +885,15 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
 
     // Returns whether some invocations were actually blocked in the end.
     private def registerBlocked(blockingAtom : Identifier, polarity : Boolean, invocations : Set[FunctionInvocation]) : Boolean = {
-      // TODO
-      // val filtered = invocations -- "those who are axiomatized"
-      val filtered = invocations
+      val filtered = invocations.filter(i => {
+        val FunctionInvocation(fd, _) = i
+        if(axiomatizedFunctions(fd)) {
+          reporter.info("I'm not registering " + i + " as blocked because it's axiomatized.")
+          false
+        } else {
+          true
+        }
+      })
 
       val pair = (blockingAtom, polarity)
       val alreadyBlocked = blocked.get(pair)
@@ -843,11 +913,6 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
       everythingEverUnrolled += functionInvocation
     }
 
-    def closeUnrollings2(formula : Expr) : (Expr, Seq[Expr], Seq[(Identifier,Boolean)]) = {
-
-      scala.Predef.error("wtf")
-    }
-
     def closeUnrollings(formula : Expr) : (Expr, Seq[Expr], Seq[(Identifier,Boolean)]) = {
       var (basis, clauses, ite2Bools) = clausifyITE(formula)
 
@@ -856,19 +921,17 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
       var treatedClauses : List[Expr] = Nil
       var blockers : List[(Identifier,Boolean)] = Nil
 
-      stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(basis)
+      stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(basis, axiomatizedFunctions)
       do {
         // We go through each clause and figure out what must be enrolled and
         // what must be blocked. We register everything.
         for(clause <- clauses) {
           clause match {
             case Iff(Variable(_), cond) => {
-              stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(cond)
+              stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(cond, axiomatizedFunctions)
             }
-            // TODO : sort out the functions that are not recursive and unroll
-            // them in any case
             case Implies(v @ Variable(id), then) => {
-              val calls = topLevelFunctionCallsOf(then)
+              val calls = topLevelFunctionCallsOf(then, axiomatizedFunctions)
               if(!calls.isEmpty) {
                 assert(!blocked.isDefinedAt((id,true)))
                 if(registerBlocked(id, true, calls)) //blocked((id,true)) = calls
@@ -876,7 +939,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
               }
             }
             case Implies(Not(v @ Variable(id)), elze) => {
-              val calls = topLevelFunctionCallsOf(elze)
+              val calls = topLevelFunctionCallsOf(elze, axiomatizedFunctions)
               if(!calls.isEmpty) {
                 assert(!blocked.isDefinedAt((id,false)))
                 if(registerBlocked(id, false, calls)) //blocked((id,false)) = calls
@@ -902,7 +965,7 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac
             
             for(formula <- unrolled) {
               val (basis2, clauses2, _) = clausifyITE(formula)
-              stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(basis2)
+              stillToUnroll = stillToUnroll ++ topLevelFunctionCallsOf(basis2, axiomatizedFunctions)
               clauses = clauses2 ::: clauses
               treatedClauses = basis2 :: treatedClauses
             }
diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala
index 2ac55013ab184cf9bc025005b06889f3443ac348..ea891f76400fc2ed49e6ad353c153516ab3c33b1 100644
--- a/src/purescala/Trees.scala
+++ b/src/purescala/Trees.scala
@@ -681,14 +681,14 @@ object Trees {
     treeCatamorphism(convert, combine, compute, expr)
   }
 
-  def topLevelFunctionCallsOf(expr: Expr) : Set[FunctionInvocation] = {
+  def topLevelFunctionCallsOf(expr: Expr, barring : Set[FunDef] = Set.empty) : Set[FunctionInvocation] = {
     def convert(t: Expr) : Set[FunctionInvocation] = t match {
-      case f @ FunctionInvocation(_, _) => Set(f)
+      case f @ FunctionInvocation(fd, _) if(!barring(fd)) => Set(f)
       case _ => Set.empty
     }
     def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2
     def compute(t: Expr, s: Set[FunctionInvocation]) = t match {
-      case f @ FunctionInvocation(_, _) => Set(f) // ++ s that's the difference with the one below
+      case f @ FunctionInvocation(fd,  _) if(!barring(fd)) => Set(f) // ++ s that's the difference with the one below
       case _ => s
     }
     treeCatamorphism(convert, combine, compute, expr)
diff --git a/testcases/sas2011-testcases/RedBlackTree.scala b/testcases/sas2011-testcases/RedBlackTree.scala
index 3a1d768b013ef495b69acb95582d61581344bfe1..da5dd22c0c71649e0dd6ec01a7ee30fa7490ef7f 100644
--- a/testcases/sas2011-testcases/RedBlackTree.scala
+++ b/testcases/sas2011-testcases/RedBlackTree.scala
@@ -20,10 +20,10 @@ object RedBlackTree {
     case Node(_, l, v, r) => content(l) ++ Set(v) ++ content(r)
   }
 
-  def size(t: Tree) : Int = t match {
+  def size(t: Tree) : Int = (t match {
     case Empty() => 0
     case Node(_, l, v, r) => size(l) + 1 + size(r)
-  }
+  }) ensuring(_ >= 0)
 
   /* We consider leaves to be black by definition */
   def isBlack(t: Tree) : Boolean = t match {