From d64511f21b2f9c298c7d37005be1134211acee5a Mon Sep 17 00:00:00 2001
From: Robin Steiger <robin.steiger@epfl.ch>
Date: Mon, 12 Jul 2010 19:37:30 +0000
Subject: [PATCH] Inferred (dis)equalities on element variables and other
 changes.

---
 src/orderedsets/DNF.scala            |  48 ---------
 src/orderedsets/Main.scala           |  44 +++++----
 src/orderedsets/NormalForms.scala    |   2 +-
 src/orderedsets/TreeOperations.scala |  65 ++++++++----
 src/orderedsets/Unifier.scala        |  29 +++++-
 src/orderedsets/UnifierMain.scala    |  60 +++++------
 src/orderedsets/getAlpha.scala       |  47 ---------
 testcases/BinarySearchTree.scala     | 142 +++++++++++++++------------
 8 files changed, 208 insertions(+), 229 deletions(-)
 delete mode 100644 src/orderedsets/DNF.scala
 delete mode 100644 src/orderedsets/getAlpha.scala

diff --git a/src/orderedsets/DNF.scala b/src/orderedsets/DNF.scala
deleted file mode 100644
index 7a87d55ac..000000000
--- a/src/orderedsets/DNF.scala
+++ /dev/null
@@ -1,48 +0,0 @@
-package orderedsets
-
-object DNF {
-  import purescala.Trees._
-  import TreeOperations.dnf
-
-
-  // Printer (both && and || are printed as ? on windows..)
-  def pp(expr: Expr): String = expr match {
-    case And(es) => (es map pp).mkString("( ", " & ", " )")
-    case Or(es) => (es map pp).mkString("( ", " | ", " )")
-    case Not(e) => "!(" + pp(e) + ")"
-    case _ => expr.toString
-  }
-
-  // Tests
-  import purescala.Common._
-  implicit def str2id(str: String): Identifier = FreshIdentifier(str)
-
-  val a = Variable("a")
-  val b = Variable("b")
-  val c = Variable("c")
-  val x = Variable("x")
-  val y = Variable("y")
-  val z = Variable("z")
-
-  val t1 = And(Or(a, b), Or(x, y))
-  val t2 = Implies(x, Iff(y, z))
-
-  def main(args: Array[String]) {
-
-    test(t1)
-    test(t2)
-    test(Not(t1))
-    test(Not(t2))
-  }
-
-  def test(before: Expr) {
-    val after = Or((dnf(before) map And.apply).toSeq)
-    println
-    println("Before dnf : " + pp(before))
-    //println("After dnf  : " + pp(after))
-    println("After dnf  : ")
-    for (and <- dnf(before)) println("  " + pp(And(and)))
-  }
-
-
-}
\ No newline at end of file
diff --git a/src/orderedsets/Main.scala b/src/orderedsets/Main.scala
index e20188413..40a29f15c 100644
--- a/src/orderedsets/Main.scala
+++ b/src/orderedsets/Main.scala
@@ -29,7 +29,7 @@ class Main(reporter: Reporter) extends Solver(reporter) {
     //reporter.info("Sets: " + ExprToASTConverter.getSetTypes(expr))
     try {
       // Negate formula
-      (Some(!solve(!ExprToASTConverter(expr))), {
+      (Some(!solve(!ExprToASTConverter(expr, reporter))), {
 
         val sets = ExprToASTConverter.getSetTypes(expr)
         if (sets.size > 1)
@@ -38,7 +38,7 @@ class Main(reporter: Reporter) extends Solver(reporter) {
       })._1
     } catch {
       case ConversionException(badExpr, msg) =>
-        reporter.info(badExpr, msg)
+        reporter.info(badExpr, msg + " in " + badExpr.getClass)
         None
       case IncompleteException(msg) =>
         reporter.info(msg)
@@ -121,7 +121,7 @@ object ExprToASTConverter {
   def makeEq(v: Variable, t: Expr) = v.getType match {
     case Int32Type => Equals(v, t)
     case tpe if isSetType(tpe) => SetEquals(v, t)
-    case _ => throw (new ConversionException(v, "is of type " + v.getType + " and cannot be handled by OrdBapa"))
+    case _ => throw (new ConversionException(v, "type " + v.getType))
   }
 
   private def toSetTerm(expr: Expr): AST.Term = expr match {
@@ -133,8 +133,8 @@ object ExprToASTConverter {
     case SetIntersection(set1, set2) => toSetTerm(set1) ** toSetTerm(set2)
     case SetUnion(set1, set2) => toSetTerm(set1) ++ toSetTerm(set2)
     case SetDifference(set1, set2) => toSetTerm(set1) -- toSetTerm(set2)
-    case Variable(_) => throw ConversionException(expr, "is a variable of type " + expr.getType + " and cannot be converted to bapa< set variable")
-    case _ => throw ConversionException(expr, "Cannot convert to bapa< set term")
+    case Variable(_) => throw ConversionException(expr, "type " + expr.getType)
+    case _ => throw ConversionException(expr, "bad set term")
   }
 
   private def toIntTerm(expr: Expr): AST.Term = expr match {
@@ -148,7 +148,7 @@ object ExprToASTConverter {
     case SetCardinality(e) => toSetTerm(e).card
     case SetMin(set) if set.getType == SetType(Int32Type) => toSetTerm(set).inf
     case SetMax(set) if set.getType == SetType(Int32Type) => toSetTerm(set).sup
-    case _ => throw ConversionException(expr, "Cannot convert to bapa< int term")
+    case _ => throw ConversionException(expr, "bad int term")
   }
 
   private def toFormula(expr: Expr): AST.Formula = expr match {
@@ -159,6 +159,10 @@ object ExprToASTConverter {
     case And(exprs) => AST.And((exprs map toFormula).toList)
     case Not(expr) => !toFormula(expr)
     case Implies(expr1, expr2) => !(toFormula(expr1)) || toFormula(expr2)
+    case Iff(expr1, expr2) =>
+      val f1 = toFormula(expr1)
+      val f2 = toFormula(expr2)
+      (f1 && f2) || (!f1 && !f2)
 
     // Set Formulas
     case ElementOfSet(elem, set) => toIntTerm(elem) selem toSetTerm(set)
@@ -171,10 +175,13 @@ object ExprToASTConverter {
     case LessEquals(lhs, rhs) => toIntTerm(lhs) <= toIntTerm(rhs)
     case GreaterThan(lhs, rhs) => toIntTerm(lhs) > toIntTerm(rhs)
     case GreaterEquals(lhs, rhs) => toIntTerm(lhs) >= toIntTerm(rhs)
-    case Equals(lhs, rhs) if lhs.getType == Int32Type && rhs.getType == Int32Type => toIntTerm(lhs) === toIntTerm(rhs)
+    case Equals(lhs, rhs) => (lhs.getType, rhs.getType) match {
+      case (Int32Type, Int32Type) => toIntTerm(lhs) === toIntTerm(rhs)
+      case types => throw ConversionException(expr, "types " + types)
+    }
 
     // Assuming the formula to be True
-    case _ => throw ConversionException(expr, "Cannot convert to bapa< formula")
+    case _ => throw ConversionException(expr, "bad formula")
   }
 
   def getSetTypes(expr: Expr): Set[TypeTree] = expr match {
@@ -204,7 +211,16 @@ object ExprToASTConverter {
     case _ => Set.empty[TypeTree]
   }
 
-  def apply(expr: Expr) = {
+  def apply(expr: Expr, reporter: Reporter) = {
+    def toRelaxedFormula(expr: Expr): AST.Formula =
+      try {
+        toFormula(expr)
+      } catch {
+        case ConversionException(badExpr, msg) =>
+          //reporter.warning("BAPA was relaxed : " + msg + " in " + badExpr.getClass + "\n" + rpp(badExpr))
+          formulaRelaxed = true
+          AST.True // Assuming the formula to be True
+      }
     formulaRelaxed = false;
     expr match {
       case And(exprs) => AST.And((exprs map toRelaxedFormula).toList)
@@ -212,13 +228,5 @@ object ExprToASTConverter {
     }
   }
 
-  private def toRelaxedFormula(expr: Expr): AST.Formula =
-    try {
-      toFormula(expr)
-    } catch {
-      case ConversionException(_, _) =>
-        formulaRelaxed = true
-        // Assuming the formula to be True
-        AST.True
-    }
+
 }
diff --git a/src/orderedsets/NormalForms.scala b/src/orderedsets/NormalForms.scala
index e4bcb5275..b8918434e 100644
--- a/src/orderedsets/NormalForms.scala
+++ b/src/orderedsets/NormalForms.scala
@@ -186,7 +186,7 @@ object NormalForms {
     case Not(Predicate(comp, terms)) =>
       rewriteNonPure_*(terms, ts => Not(Predicate(comp, ts)) :: Nil)
 
-    case And(_) | Or(_) if isAtom(f) =>
+    case And(_) | Or(_) if !isAtom(f) =>
       error("A simplified conjunction cannot contain " + f)
 
     case _ => List(f)
diff --git a/src/orderedsets/TreeOperations.scala b/src/orderedsets/TreeOperations.scala
index 5d8005963..64d8a01f5 100644
--- a/src/orderedsets/TreeOperations.scala
+++ b/src/orderedsets/TreeOperations.scala
@@ -8,6 +8,8 @@ import Common._
 import TypeTrees._
 import Definitions._
 
+import RPrettyPrinter.rpp
+
 object TreeOperations {
   def dnf(expr: Expr): Stream[Seq[Expr]] = expr match {
     case And(Nil) => Stream(Nil)
@@ -35,7 +37,6 @@ object TreeOperations {
       searchAndReplace({case FunctionInvocation(_, Seq(Variable(id))) => varSet += id; None; case _ => None})(l._4)
       varSet.subsetOf(l._3.toSet)
     }
-
     val c = program.callees(f)
     if (f.hasImplementation && f.args.size == 1 && c.size == 1 && c.head == f) f.body.get match {
       case SimplePatternMatching(scrut, _, lstMatches)
@@ -45,7 +46,7 @@ object TreeOperations {
       None
     }
   }
-  
+
   // 'Lazy' rewriter
   // 
   // Hoists if expressions to the top level and
@@ -53,15 +54,18 @@ object TreeOperations {
   //
   // The implementation is totally brain-teasing
   def rewrite(expr: Expr): Expr =
-    rewrite(expr, ex => ex)
-  
+    Simplifier(rewrite(expr, ex => ex))
+
   private def rewrite(expr: Expr, context: Expr => Expr): Expr = expr match {
+  // Convert to nnf
+    case Not(e@(And(_) | Or(_) | Iff(_, _) | Implies(_, _) | IfExpr(_, _, _))) =>
+      rewrite(negate(e), context)
     case IfExpr(_c, _t, _e) =>
       rewrite(_c, c =>
         rewrite(_t, t =>
           rewrite(_e, e =>
             Or(And(c, context(t)), And(negate(c), context(e)))
-          )))
+            )))
     case And(_exs) =>
       rewrite_*(_exs, exs =>
         context(And(exs)))
@@ -69,7 +73,7 @@ object TreeOperations {
       rewrite_*(_exs, exs =>
         context(Or(exs)))
     case Not(_ex) =>
-      rewrite(_ex, ex => 
+      rewrite(_ex, ex =>
         context(Not(ex)))
     case f@FunctionInvocation(fd, _args) =>
       rewrite_*(_args, args =>
@@ -78,8 +82,8 @@ object TreeOperations {
       rewrite(_t, t =>
         context(recons(t) setType u.getType))
     case b@BinaryOperator(_t1, _t2, recons) =>
-      rewrite(_t1, t1 => 
-        rewrite(_t2, t2 => 
+      rewrite(_t1, t1 =>
+        rewrite(_t2, t2 =>
           context(recons(t1, t2) setType b.getType)))
     case c@CaseClass(cd, _args) =>
       rewrite_*(_args, args =>
@@ -90,29 +94,48 @@ object TreeOperations {
     case f@FiniteSet(_elems) =>
       rewrite_*(_elems, elems =>
         context(FiniteSet(elems) setType f.getType))
-    case Terminal() =>
+    case _: Terminal =>
       context(expr)
     case _ => // Missed case
       error("Unsupported case in rewrite : " + expr.getClass)
   }
-  
+
   private def rewrite_*(exprs: Seq[Expr], context: Seq[Expr] => Expr): Expr =
     exprs match {
       case Nil => context(Nil)
       case _t :: _ts =>
         rewrite(_t, t => rewrite_*(_ts, ts => context(t +: ts)))
     }
-  
-  // This should rather be a marker interface, but I don't want
-  // to change Trees.scala without Philippe's permission.
-  object Terminal {
-    def unapply(expr: Expr): Boolean = expr match {
-      case Variable(_) | ResultVariable() | OptionNone(_) | EmptySet(_) | EmptyMultiset(_) | EmptyMap(_, _) | NilList(_) => true
-      case _: Literal[_] => true
-      case _ => false
+
+
+  object Simplifier {
+    private val True = BooleanLiteral(true)
+    private val False = BooleanLiteral(false)
+
+    def apply(expr: Expr) = simplify(expr)
+
+    def simplify(expr: Expr): Expr = expr match {
+      case Not(ex) => negate(ex)
+      case And(exs) => And(simplify(exs, True, False) flatMap flatAnd)
+      case Or(exs) => Or(simplify(exs, False, True) flatMap flatOr)
+      case _ => expr
+    }
+
+    private def simplify(exprs: Seq[Expr], neutral: Expr, absorbing: Expr): Seq[Expr] = {
+      val exs = (exprs map simplify) filterNot {_ == neutral}
+      if (exs contains absorbing) Seq(absorbing)
+      else exs
+    }
+
+    private def flatAnd(f: Expr) = f match {
+      case And(fs) => fs
+      case _ => Seq(f)
+    }
+
+    private def flatOr(f: Expr) = f match {
+      case Or(fs) => fs
+      case _ => Seq(f)
     }
   }
-  
-  
-  
+
 }
diff --git a/src/orderedsets/Unifier.scala b/src/orderedsets/Unifier.scala
index 6f3b608a2..f78ede598 100644
--- a/src/orderedsets/Unifier.scala
+++ b/src/orderedsets/Unifier.scala
@@ -86,7 +86,7 @@ object ADTUnifier extends Unifier[Variable, CaseClassDef] {
 
   def freshVar(prefix: String)(typed: Typed) = Var(Variable(FreshIdentifier(prefix, true) setType typed.getType))
 
-  def unify(conjunction: Seq[Expr]): Map[Variable, Expr] = {
+  def unify(conjunction: Seq[Expr]): (Seq[Expr], Map[Variable, Expr]) = {
     val equalities = new ArrayBuffer[(Term, Term)]()
     val inequalities = new ArrayBuffer[(Var, Var)]()
 
@@ -168,13 +168,38 @@ object ADTUnifier extends Unifier[Variable, CaseClassDef] {
     if (map1.isEmpty) println("  (empty table)")
     println 
     */
-    table mapValues term2expr
+
+    // Extract element equalities and disequalities
+    val elementFormula = new ArrayBuffer[Expr]()
+    for ((e1, term) <- table; if isElementType(e1)) {
+      term2expr(term) match {
+        case e2@Variable(_) =>
+          //println("  " + e1 + ": " + e1.getType +
+          //    "  ->  " + e2 + ": " + e2.getType) 
+          elementFormula += Equals(e1, e2)
+        case expr =>
+          //println("  " + e1 + ": " + e1.getType +
+          //    "  ->  " + expr + ": " + expr.getType)
+          //println("UNEXPECTED: " + term)
+          error("Unexpected " + expr)
+      }
+    }
+    for ((Var(e1), Var(e2)) <- inequalities; if isElementType(e1))
+      elementFormula += Not(Equals(e1, e2))
+
+    (elementFormula.toSeq, table mapValues term2expr)
   }
 
   def term2expr(term: Term): Expr = term match {
     case Var(v) => v
     case Fun(cd, args) => CaseClass(cd, args map term2expr)
   }
+
+  def isElementType(typed: Typed) = typed.getType match {
+    case AbstractClassType(_) | CaseClassType(_) => false
+    case _ => true
+  }
+
 }
 
 
diff --git a/src/orderedsets/UnifierMain.scala b/src/orderedsets/UnifierMain.scala
index 099d306ac..91ac3df33 100644
--- a/src/orderedsets/UnifierMain.scala
+++ b/src/orderedsets/UnifierMain.scala
@@ -52,7 +52,7 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) {
           val noAlphas = And(restFormula flatMap expandAlphas(varMap))
           reporter.info("The resulting formula is\n" + rpp(noAlphas))
           tryAllSolvers(noAlphas)
-          } catch {
+        } catch {
           case ex@ConversionException(badExpr, msg) =>
             reporter.info("Conjunction " + counter + " is UNKNOWN, could not be parsed")
             throw ex
@@ -88,36 +88,38 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) {
     }
   }
 
-  def tryAllSolvers(f : Expr): Unit = {
-    for(solver <- Extensions.loadedSolverExtensions; if solver != this) {
-          reporter.info("Trying solver: " + solver.shortDescription + " from inside the unifier.")
-          solver.isUnsat(f) match {
-            case Some(true) => 
-                reporter.info("Solver: " + solver.shortDescription + " proved the formula unsatisfiable") 
-                return
-            case Some(false) => 
-                reporter.warning("Solver: " + solver.shortDescription + " proved the formula satisfiable")
-                throw (new SatException(null))
-            case None =>
-                reporter.info("Solver: " + solver.shortDescription + " was unable to conclusively determine the correctness of the formula") 
-          }
-    }; throw IncompleteException("All the solvers were unable to prove the formula unsatisfiable, giving up.") }
+  def tryAllSolvers(f: Expr): Unit = {
+    for (solver <- Extensions.loadedSolverExtensions; if solver != this) {
+      reporter.info("Trying solver: " + solver.shortDescription + " from inside the unifier.")
+      solver.isUnsat(f) match {
+        case Some(true) =>
+          reporter.info("Solver: " + solver.shortDescription + " proved the formula unsatisfiable")
+          return
+        case Some(false) =>
+          reporter.warning("Solver: " + solver.shortDescription + " proved the formula satisfiable")
+          throw (new SatException(null))
+        case None =>
+          reporter.info("Solver: " + solver.shortDescription + " was unable to conclusively determine the correctness of the formula")
+      }
+    };
+    throw IncompleteException("All the solvers were unable to prove the formula unsatisfiable, giving up.")
+  }
 
   def checkIsSupported(expr: Expr) {
     def check(ex: Expr): Option[Expr] = ex match {
       case Let(_, _, _) | MatchExpr(_, _) =>
         throw ConversionException(ex, "Unifier does not support this expression")
       case IfExpr(_, _, _) =>
-       
-       println
-       println("--- BEFORE ---")
-       println(rpp(expr))
-       println
-       println("--- AFTER ---")
-       println(rpp(rewrite(expr)))
-       println
-         
-      
+
+        println
+        println("--- BEFORE ---")
+        println(rpp(expr))
+        println
+        println("--- AFTER ---")
+        println(rpp(rewrite(expr)))
+        println
+
+
         throw ConversionException(ex, "Unifier does not support this expression")
       case _ => None
     }
@@ -139,12 +141,12 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) {
     */
 
     // The substitution table
-    val substTable = ADTUnifier.unify(treeEquations)
+    val (elementFormula, substTable) = ADTUnifier.unify(treeEquations)
 
     // The substitution function (returns identity if unmapped)
     def subst(v: Variable): Expr = substTable getOrElse (v, v)
 
-    (subst, rest)
+    (subst, elementFormula ++ rest)
 
   }
 
@@ -239,7 +241,9 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) {
           //reporter.warning("Result:\n" + rpp(res))
           Some(res)
         }
-        case _ => error("Bad argument/substitution to catamorphism: " + substArg(arg))
+        case badArg =>
+          println(rpp(badArg))
+          error("Bad argument/substitution to catamorphism")
       }
       case None => // Not a catamorphism
         warning("Function " + fd.id + " is not a catamorphism.")
diff --git a/src/orderedsets/getAlpha.scala b/src/orderedsets/getAlpha.scala
deleted file mode 100644
index cad144df1..000000000
--- a/src/orderedsets/getAlpha.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-package orderedsets
-
-import TreeOperations._
-
-import purescala._
-import Trees._
-import Common._
-import TypeTrees._
-import Definitions._
-
-object getAlpha {
-  var program: Program = null
-
-  def setProgram(p: Program) = program = p
-
-  def isAlpha(varMap: Variable => Expr)(t: Expr): Option[Expr] = t match {
-    case FunctionInvocation(fd, Seq(v@Variable(_))) => asCatamorphism(program, fd) match {
-      case None => None
-      case Some(lstMatch) => varMap(v) match {
-        case CaseClass(cd, args) => {
-          val (_, _, ids, rhs) = lstMatch.find(_._1 == cd).get
-          val repMap = Map(ids.map(id => Variable(id): Expr).zip(args): _*)
-          Some(searchAndReplace(repMap.get)(rhs))
-        }
-        case u@Variable(_) => {
-          val c = Variable(FreshIdentifier("Coll", true)).setType(t.getType)
-          // TODO: Keep track of these variables for M1(t, c)
-          Some(c)
-        }
-        case _ => error("Bad substitution")
-      }
-      case _ => None
-    }
-    case _ => None
-  }
-
-  def apply(t: Expr, varMap: Variable => Expr): Expr = {
-    searchAndReplace(isAlpha(varMap))(t)
-  }
-
-  //  def solve(e : Expr): Option[Boolean] = {
-  //    searchAndReplace(isAlpha(x => x))(e)
-  //    None
-  //  }
-}
-
-
diff --git a/testcases/BinarySearchTree.scala b/testcases/BinarySearchTree.scala
index 859a5eb29..1db8dcd10 100644
--- a/testcases/BinarySearchTree.scala
+++ b/testcases/BinarySearchTree.scala
@@ -1,5 +1,5 @@
 import scala.collection.immutable.Set
-// import scala.collection.immutable.Multiset
+//import scala.collection.immutable.Multiset
 
 object BinarySearchTree {
   sealed abstract class Tree
@@ -15,56 +15,56 @@ object BinarySearchTree {
   sealed abstract class Triple
   case class SortedTriple(min: Option, max: Option, sorted: Boolean) extends Triple
 
-   def isSorted(tree: Tree): SortedTriple = tree match {
+  def isSorted(tree: Tree): SortedTriple = tree match {
     case Leaf() => SortedTriple(None(), None(), true)
-    case Node(l,v,r) => isSorted(l) match {
-      case SortedTriple(minl,maxl,sortl) => if (!sortl) SortedTriple(None(), None(), false)
-        else minl match {
-    	  case None() => maxl match {
-      	    case None() =>  isSorted(r) match {
-              case SortedTriple(minr,maxr,sortr) => if (!sortr) SortedTriple(None(), None(), false)
-        	else minr match {
-          	  case None() => maxr match {
-      		    case None() => SortedTriple(Some(v),Some(v),true)
-      		    case Some(maxrv) => SortedTriple(None(),None(),false)
-    		    }
-   		  case Some(minrv) => maxr match {
-      		    case Some(maxrv) => if (minrv > v) SortedTriple(Some(v),Some(maxrv),true) else SortedTriple(None(),None(),false)
-      		    case None() => SortedTriple(None(),None(),false)
-    		    }
-        	  }
-       	      }
-      	    case Some(maxlv) => SortedTriple(None(),None(),false)
-    	    }
-    	  case Some(minlv) => maxl match {
-      	    case Some(maxlv) => isSorted(r) match {
-              case SortedTriple(minr,maxr,sortr) => if (!sortr) SortedTriple(None(), None(), false)
-        	else minr match {
-          	  case None() => maxr match {
-      		    case None() => if (maxlv <= v) SortedTriple(Some(minlv),Some(v),true) else SortedTriple(None(),None(),false)
-      		    case Some(maxrv) => SortedTriple(None(),None(),false)
-    		    }
-    		  case Some(minrv) => maxr match {
-      		    case Some(maxrv) => if (maxlv <= v && minrv > v) SortedTriple(Some(minlv),Some(maxrv),true) else SortedTriple(None(),None(),false)
-      		    case None() => SortedTriple(None(),None(),false)
-    		    }
-        	  }
-     	      }
-      	    case None() => SortedTriple(None(),None(),false)
-    	    }
-  	  }
+    case Node(l, v, r) => isSorted(l) match {
+      case SortedTriple(minl, maxl, sortl) => if (!sortl) SortedTriple(None(), None(), false)
+      else minl match {
+        case None() => maxl match {
+          case None() => isSorted(r) match {
+            case SortedTriple(minr, maxr, sortr) => if (!sortr) SortedTriple(None(), None(), false)
+            else minr match {
+              case None() => maxr match {
+                case None() => SortedTriple(Some(v), Some(v), true)
+                case Some(maxrv) => SortedTriple(None(), None(), false)
+              }
+              case Some(minrv) => maxr match {
+                case Some(maxrv) => if (minrv > v) SortedTriple(Some(v), Some(maxrv), true) else SortedTriple(None(), None(), false)
+                case None() => SortedTriple(None(), None(), false)
+              }
+            }
+          }
+          case Some(maxlv) => SortedTriple(None(), None(), false)
+        }
+        case Some(minlv) => maxl match {
+          case Some(maxlv) => isSorted(r) match {
+            case SortedTriple(minr, maxr, sortr) => if (!sortr) SortedTriple(None(), None(), false)
+            else minr match {
+              case None() => maxr match {
+                case None() => if (maxlv <= v) SortedTriple(Some(minlv), Some(v), true) else SortedTriple(None(), None(), false)
+                case Some(maxrv) => SortedTriple(None(), None(), false)
+              }
+              case Some(minrv) => maxr match {
+                case Some(maxrv) => if (maxlv <= v && minrv > v) SortedTriple(Some(minlv), Some(maxrv), true) else SortedTriple(None(), None(), false)
+                case None() => SortedTriple(None(), None(), false)
+              }
+            }
+          }
+          case None() => SortedTriple(None(), None(), false)
         }
+      }
     }
-  
+  }
 
   def treeMin(tree: Node): Int = {
     require(isSorted(tree).sorted)
     tree match {
-    case Node(left, v, _) => left match {
-      case Leaf() => v
-      case n@Node(_, _, _) => treeMin(n)
+      case Node(left, v, _) => left match {
+        case Leaf() => v
+        case n@Node(_, _, _) => treeMin(n)
+      }
     }
-  }} ensuring (_ == contents(tree).min)
+  } ensuring (_ == contents(tree).min)
 
   def treeMax(tree: Node): Int = {
     require(isSorted(tree).sorted)
@@ -77,6 +77,19 @@ object BinarySearchTree {
   } ensuring (_ == contents(tree).max)
 
   def insert(tree: Tree, value: Int): Node = {
+    tree match {
+      case Leaf() => Node(Leaf(), value, Leaf())
+      case n@Node(l, v, r) => if (v < value) {
+        Node(l, v, insert(r, value))
+      } else if (v > value) {
+        Node(insert(l, value), v, r)
+      } else {
+        n
+      }
+    }
+  } ensuring (contents(_) == contents(tree) ++ Set(value))
+
+  def insertSorted(tree: Tree, value: Int): Node = {
     require(isSorted(tree).sorted)
     tree match {
       case Leaf() => Node(Leaf(), value, Leaf())
@@ -97,18 +110,18 @@ object BinarySearchTree {
     }
   } ensuring (contents(_) == contents(tree) ++ Set(0))
 
-/*
-    def remove(tree: Tree, value: Int) : Node = (tree match {
-        case Leaf() => Node(Leaf(), value, Leaf())
-        case n @ Node(l, v, r) => if(v < value) {
-          Node(l, v, insert(r, value))
-        } else if(v > value) {
-          Node(insert(l, value), v, r)
-        } else {
-          n
-        }
-    }) ensuring (contents(_) == contents(tree) -- Set(value))
-*/
+  /*
+      def remove(tree: Tree, value: Int) : Node = (tree match {
+          case Leaf() => Node(Leaf(), value, Leaf())
+          case n @ Node(l, v, r) => if(v < value) {
+            Node(l, v, insert(r, value))
+          } else if(v > value) {
+            Node(insert(l, value), v, r)
+          } else {
+            n
+          }
+      }) ensuring (contents(_) == contents(tree) -- Set(value))
+  */
 
   def dumbInsertWithOrder(tree: Tree): Node = {
     tree match {
@@ -139,20 +152,21 @@ object BinarySearchTree {
     Node(mkInfiniteTree(x), x, mkInfiniteTree(x))
   } ensuring (res =>
     res.left != Leaf() && res.right != Leaf()
-  )
+          )
 
   def contains(tree: Tree, value: Int): Boolean = {
     require(isSorted(tree).sorted)
     tree match {
-    case Leaf() => false
-    case n@Node(l, v, r) => if (v < value) {
-      contains(r, value)
-    } else if (v > value) {
-      contains(l, value)
-    } else {
-      true
+      case Leaf() => false
+      case n@Node(l, v, r) => if (v < value) {
+        contains(r, value)
+      } else if (v > value) {
+        contains(l, value)
+      } else {
+        true
+      }
     }
-  } } ensuring( _ || !(contents(tree) == contents(tree) ++ Set(value)))
+  } ensuring (_ || !(contents(tree) == contents(tree) ++ Set(value)))
 
   def contents(tree: Tree): Set[Int] = tree match {
     case Leaf() => Set.empty[Int]
-- 
GitLab