diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index 829eaa6070f262ed6a01dab4d7e4f5f540ea5587..74925bce8cb7e9a8f12eee10c19be3591a3a7546 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -181,8 +181,11 @@ object Trees {
       case _ => new And(Seq(l,r))
     }
     def apply(exprs: Seq[Expr]) : Expr = {
-      val simpler = exprs.filter(_ != BooleanLiteral(true))
-      if(simpler.isEmpty) BooleanLiteral(true) else simpler.reduceRight(And(_,_))
+      val simpler = exprs.flatMap(_ match {
+        case And(es) => es
+        case o => Seq(o)
+      }).takeWhile(_ != BooleanLiteral(false)).filter(_ != BooleanLiteral(true))
+      if(simpler.isEmpty) BooleanLiteral(true) else new And(simpler)
     }
 
     def unapply(and: And) : Option[Seq[Expr]] = 
@@ -202,13 +205,18 @@ object Trees {
 
   object Or {
     def apply(l: Expr, r: Expr) : Expr = (l,r) match {
+      case (BooleanLiteral(true),_)  => BooleanLiteral(true)
       case (BooleanLiteral(false),_) => r
       case (_,BooleanLiteral(false)) => l
       case _ => new Or(Seq(l,r))
     }
     def apply(exprs: Seq[Expr]) : Expr = {
-      val simpler = exprs.filter(_ != BooleanLiteral(false))
-      if(simpler.isEmpty) BooleanLiteral(false) else simpler.reduceRight(Or(_,_))
+      val simpler = exprs.flatMap(_ match {
+        case Or(es) => es
+        case o => Seq(o)
+      }).takeWhile(_ != BooleanLiteral(true)).filter(_ != BooleanLiteral(false))
+
+      if(simpler.isEmpty) BooleanLiteral(false) else new Or(simpler)
     }
 
     def unapply(or: Or) : Option[Seq[Expr]] = 
diff --git a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala b/src/main/scala/leon/synthesis/ArithmeticNormalization.scala
index 7e916afedee3d7bbdca77a18124f4dec23acda93..f43522762faf18c77cf3cd41956e617970c90ec5 100644
--- a/src/main/scala/leon/synthesis/ArithmeticNormalization.scala
+++ b/src/main/scala/leon/synthesis/ArithmeticNormalization.scala
@@ -107,21 +107,17 @@ object ArithmeticNormalization {
     simplePostTransform(simplify0)(expr)
   }
 
-  //assume the formula consist only of top level AND, find a top level
-  //Equals and extract it, return the remaining formula as well
-  //warning: also assume that And are always binary !
+  // Assume the formula consist only of top level AND, find a top level
+  // Equals and extract it, return the remaining formula as well
   def extractEquals(expr: Expr): (Option[Equals], Expr) = expr match {
-    case And(Seq(eq@Equals(_, _), f)) => (Some(eq), f)
-    case And(Seq(f, eq@Equals(_, _))) => (Some(eq), f)
-    case And(Seq(f1, f2)) => extractEquals(f1) match {
-      case (Some(eq), r) => (Some(eq), And(r, f2))
-      case (None, r1) => {
-        val (eq, r2) = extractEquals(f2)
-        (eq, And(r1, r2))
+    case And(es) =>
+      // OK now I'm just messing with you.
+      val (r, nes) = es.foldLeft[(Option[Equals],Seq[Expr])]((None, Seq())) {
+        case ((None, nes), eq @ Equals(_,_)) => (Some(eq), nes)
+        case ((o, nes), e) => (o, e +: nes)
       }
-    }
+      (r, And(nes.reverse))
+
     case e => (None, e)
   }
-
-
 }