From d64511f21b2f9c298c7d37005be1134211acee5a Mon Sep 17 00:00:00 2001 From: Robin Steiger <robin.steiger@epfl.ch> Date: Mon, 12 Jul 2010 19:37:30 +0000 Subject: [PATCH] Inferred (dis)equalities on element variables and other changes. --- src/orderedsets/DNF.scala | 48 --------- src/orderedsets/Main.scala | 44 +++++---- src/orderedsets/NormalForms.scala | 2 +- src/orderedsets/TreeOperations.scala | 65 ++++++++---- src/orderedsets/Unifier.scala | 29 +++++- src/orderedsets/UnifierMain.scala | 60 +++++------ src/orderedsets/getAlpha.scala | 47 --------- testcases/BinarySearchTree.scala | 142 +++++++++++++++------------ 8 files changed, 208 insertions(+), 229 deletions(-) delete mode 100644 src/orderedsets/DNF.scala delete mode 100644 src/orderedsets/getAlpha.scala diff --git a/src/orderedsets/DNF.scala b/src/orderedsets/DNF.scala deleted file mode 100644 index 7a87d55ac..000000000 --- a/src/orderedsets/DNF.scala +++ /dev/null @@ -1,48 +0,0 @@ -package orderedsets - -object DNF { - import purescala.Trees._ - import TreeOperations.dnf - - - // Printer (both && and || are printed as ? on windows..) - def pp(expr: Expr): String = expr match { - case And(es) => (es map pp).mkString("( ", " & ", " )") - case Or(es) => (es map pp).mkString("( ", " | ", " )") - case Not(e) => "!(" + pp(e) + ")" - case _ => expr.toString - } - - // Tests - import purescala.Common._ - implicit def str2id(str: String): Identifier = FreshIdentifier(str) - - val a = Variable("a") - val b = Variable("b") - val c = Variable("c") - val x = Variable("x") - val y = Variable("y") - val z = Variable("z") - - val t1 = And(Or(a, b), Or(x, y)) - val t2 = Implies(x, Iff(y, z)) - - def main(args: Array[String]) { - - test(t1) - test(t2) - test(Not(t1)) - test(Not(t2)) - } - - def test(before: Expr) { - val after = Or((dnf(before) map And.apply).toSeq) - println - println("Before dnf : " + pp(before)) - //println("After dnf : " + pp(after)) - println("After dnf : ") - for (and <- dnf(before)) println(" " + pp(And(and))) - } - - -} \ No newline at end of file diff --git a/src/orderedsets/Main.scala b/src/orderedsets/Main.scala index e20188413..40a29f15c 100644 --- a/src/orderedsets/Main.scala +++ b/src/orderedsets/Main.scala @@ -29,7 +29,7 @@ class Main(reporter: Reporter) extends Solver(reporter) { //reporter.info("Sets: " + ExprToASTConverter.getSetTypes(expr)) try { // Negate formula - (Some(!solve(!ExprToASTConverter(expr))), { + (Some(!solve(!ExprToASTConverter(expr, reporter))), { val sets = ExprToASTConverter.getSetTypes(expr) if (sets.size > 1) @@ -38,7 +38,7 @@ class Main(reporter: Reporter) extends Solver(reporter) { })._1 } catch { case ConversionException(badExpr, msg) => - reporter.info(badExpr, msg) + reporter.info(badExpr, msg + " in " + badExpr.getClass) None case IncompleteException(msg) => reporter.info(msg) @@ -121,7 +121,7 @@ object ExprToASTConverter { def makeEq(v: Variable, t: Expr) = v.getType match { case Int32Type => Equals(v, t) case tpe if isSetType(tpe) => SetEquals(v, t) - case _ => throw (new ConversionException(v, "is of type " + v.getType + " and cannot be handled by OrdBapa")) + case _ => throw (new ConversionException(v, "type " + v.getType)) } private def toSetTerm(expr: Expr): AST.Term = expr match { @@ -133,8 +133,8 @@ object ExprToASTConverter { case SetIntersection(set1, set2) => toSetTerm(set1) ** toSetTerm(set2) case SetUnion(set1, set2) => toSetTerm(set1) ++ toSetTerm(set2) case SetDifference(set1, set2) => toSetTerm(set1) -- toSetTerm(set2) - case Variable(_) => throw ConversionException(expr, "is a variable of type " + expr.getType + " and cannot be converted to bapa< set variable") - case _ => throw ConversionException(expr, "Cannot convert to bapa< set term") + case Variable(_) => throw ConversionException(expr, "type " + expr.getType) + case _ => throw ConversionException(expr, "bad set term") } private def toIntTerm(expr: Expr): AST.Term = expr match { @@ -148,7 +148,7 @@ object ExprToASTConverter { case SetCardinality(e) => toSetTerm(e).card case SetMin(set) if set.getType == SetType(Int32Type) => toSetTerm(set).inf case SetMax(set) if set.getType == SetType(Int32Type) => toSetTerm(set).sup - case _ => throw ConversionException(expr, "Cannot convert to bapa< int term") + case _ => throw ConversionException(expr, "bad int term") } private def toFormula(expr: Expr): AST.Formula = expr match { @@ -159,6 +159,10 @@ object ExprToASTConverter { case And(exprs) => AST.And((exprs map toFormula).toList) case Not(expr) => !toFormula(expr) case Implies(expr1, expr2) => !(toFormula(expr1)) || toFormula(expr2) + case Iff(expr1, expr2) => + val f1 = toFormula(expr1) + val f2 = toFormula(expr2) + (f1 && f2) || (!f1 && !f2) // Set Formulas case ElementOfSet(elem, set) => toIntTerm(elem) selem toSetTerm(set) @@ -171,10 +175,13 @@ object ExprToASTConverter { case LessEquals(lhs, rhs) => toIntTerm(lhs) <= toIntTerm(rhs) case GreaterThan(lhs, rhs) => toIntTerm(lhs) > toIntTerm(rhs) case GreaterEquals(lhs, rhs) => toIntTerm(lhs) >= toIntTerm(rhs) - case Equals(lhs, rhs) if lhs.getType == Int32Type && rhs.getType == Int32Type => toIntTerm(lhs) === toIntTerm(rhs) + case Equals(lhs, rhs) => (lhs.getType, rhs.getType) match { + case (Int32Type, Int32Type) => toIntTerm(lhs) === toIntTerm(rhs) + case types => throw ConversionException(expr, "types " + types) + } // Assuming the formula to be True - case _ => throw ConversionException(expr, "Cannot convert to bapa< formula") + case _ => throw ConversionException(expr, "bad formula") } def getSetTypes(expr: Expr): Set[TypeTree] = expr match { @@ -204,7 +211,16 @@ object ExprToASTConverter { case _ => Set.empty[TypeTree] } - def apply(expr: Expr) = { + def apply(expr: Expr, reporter: Reporter) = { + def toRelaxedFormula(expr: Expr): AST.Formula = + try { + toFormula(expr) + } catch { + case ConversionException(badExpr, msg) => + //reporter.warning("BAPA was relaxed : " + msg + " in " + badExpr.getClass + "\n" + rpp(badExpr)) + formulaRelaxed = true + AST.True // Assuming the formula to be True + } formulaRelaxed = false; expr match { case And(exprs) => AST.And((exprs map toRelaxedFormula).toList) @@ -212,13 +228,5 @@ object ExprToASTConverter { } } - private def toRelaxedFormula(expr: Expr): AST.Formula = - try { - toFormula(expr) - } catch { - case ConversionException(_, _) => - formulaRelaxed = true - // Assuming the formula to be True - AST.True - } + } diff --git a/src/orderedsets/NormalForms.scala b/src/orderedsets/NormalForms.scala index e4bcb5275..b8918434e 100644 --- a/src/orderedsets/NormalForms.scala +++ b/src/orderedsets/NormalForms.scala @@ -186,7 +186,7 @@ object NormalForms { case Not(Predicate(comp, terms)) => rewriteNonPure_*(terms, ts => Not(Predicate(comp, ts)) :: Nil) - case And(_) | Or(_) if isAtom(f) => + case And(_) | Or(_) if !isAtom(f) => error("A simplified conjunction cannot contain " + f) case _ => List(f) diff --git a/src/orderedsets/TreeOperations.scala b/src/orderedsets/TreeOperations.scala index 5d8005963..64d8a01f5 100644 --- a/src/orderedsets/TreeOperations.scala +++ b/src/orderedsets/TreeOperations.scala @@ -8,6 +8,8 @@ import Common._ import TypeTrees._ import Definitions._ +import RPrettyPrinter.rpp + object TreeOperations { def dnf(expr: Expr): Stream[Seq[Expr]] = expr match { case And(Nil) => Stream(Nil) @@ -35,7 +37,6 @@ object TreeOperations { searchAndReplace({case FunctionInvocation(_, Seq(Variable(id))) => varSet += id; None; case _ => None})(l._4) varSet.subsetOf(l._3.toSet) } - val c = program.callees(f) if (f.hasImplementation && f.args.size == 1 && c.size == 1 && c.head == f) f.body.get match { case SimplePatternMatching(scrut, _, lstMatches) @@ -45,7 +46,7 @@ object TreeOperations { None } } - + // 'Lazy' rewriter // // Hoists if expressions to the top level and @@ -53,15 +54,18 @@ object TreeOperations { // // The implementation is totally brain-teasing def rewrite(expr: Expr): Expr = - rewrite(expr, ex => ex) - + Simplifier(rewrite(expr, ex => ex)) + private def rewrite(expr: Expr, context: Expr => Expr): Expr = expr match { + // Convert to nnf + case Not(e@(And(_) | Or(_) | Iff(_, _) | Implies(_, _) | IfExpr(_, _, _))) => + rewrite(negate(e), context) case IfExpr(_c, _t, _e) => rewrite(_c, c => rewrite(_t, t => rewrite(_e, e => Or(And(c, context(t)), And(negate(c), context(e))) - ))) + ))) case And(_exs) => rewrite_*(_exs, exs => context(And(exs))) @@ -69,7 +73,7 @@ object TreeOperations { rewrite_*(_exs, exs => context(Or(exs))) case Not(_ex) => - rewrite(_ex, ex => + rewrite(_ex, ex => context(Not(ex))) case f@FunctionInvocation(fd, _args) => rewrite_*(_args, args => @@ -78,8 +82,8 @@ object TreeOperations { rewrite(_t, t => context(recons(t) setType u.getType)) case b@BinaryOperator(_t1, _t2, recons) => - rewrite(_t1, t1 => - rewrite(_t2, t2 => + rewrite(_t1, t1 => + rewrite(_t2, t2 => context(recons(t1, t2) setType b.getType))) case c@CaseClass(cd, _args) => rewrite_*(_args, args => @@ -90,29 +94,48 @@ object TreeOperations { case f@FiniteSet(_elems) => rewrite_*(_elems, elems => context(FiniteSet(elems) setType f.getType)) - case Terminal() => + case _: Terminal => context(expr) case _ => // Missed case error("Unsupported case in rewrite : " + expr.getClass) } - + private def rewrite_*(exprs: Seq[Expr], context: Seq[Expr] => Expr): Expr = exprs match { case Nil => context(Nil) case _t :: _ts => rewrite(_t, t => rewrite_*(_ts, ts => context(t +: ts))) } - - // This should rather be a marker interface, but I don't want - // to change Trees.scala without Philippe's permission. - object Terminal { - def unapply(expr: Expr): Boolean = expr match { - case Variable(_) | ResultVariable() | OptionNone(_) | EmptySet(_) | EmptyMultiset(_) | EmptyMap(_, _) | NilList(_) => true - case _: Literal[_] => true - case _ => false + + + object Simplifier { + private val True = BooleanLiteral(true) + private val False = BooleanLiteral(false) + + def apply(expr: Expr) = simplify(expr) + + def simplify(expr: Expr): Expr = expr match { + case Not(ex) => negate(ex) + case And(exs) => And(simplify(exs, True, False) flatMap flatAnd) + case Or(exs) => Or(simplify(exs, False, True) flatMap flatOr) + case _ => expr + } + + private def simplify(exprs: Seq[Expr], neutral: Expr, absorbing: Expr): Seq[Expr] = { + val exs = (exprs map simplify) filterNot {_ == neutral} + if (exs contains absorbing) Seq(absorbing) + else exs + } + + private def flatAnd(f: Expr) = f match { + case And(fs) => fs + case _ => Seq(f) + } + + private def flatOr(f: Expr) = f match { + case Or(fs) => fs + case _ => Seq(f) } } - - - + } diff --git a/src/orderedsets/Unifier.scala b/src/orderedsets/Unifier.scala index 6f3b608a2..f78ede598 100644 --- a/src/orderedsets/Unifier.scala +++ b/src/orderedsets/Unifier.scala @@ -86,7 +86,7 @@ object ADTUnifier extends Unifier[Variable, CaseClassDef] { def freshVar(prefix: String)(typed: Typed) = Var(Variable(FreshIdentifier(prefix, true) setType typed.getType)) - def unify(conjunction: Seq[Expr]): Map[Variable, Expr] = { + def unify(conjunction: Seq[Expr]): (Seq[Expr], Map[Variable, Expr]) = { val equalities = new ArrayBuffer[(Term, Term)]() val inequalities = new ArrayBuffer[(Var, Var)]() @@ -168,13 +168,38 @@ object ADTUnifier extends Unifier[Variable, CaseClassDef] { if (map1.isEmpty) println(" (empty table)") println */ - table mapValues term2expr + + // Extract element equalities and disequalities + val elementFormula = new ArrayBuffer[Expr]() + for ((e1, term) <- table; if isElementType(e1)) { + term2expr(term) match { + case e2@Variable(_) => + //println(" " + e1 + ": " + e1.getType + + // " -> " + e2 + ": " + e2.getType) + elementFormula += Equals(e1, e2) + case expr => + //println(" " + e1 + ": " + e1.getType + + // " -> " + expr + ": " + expr.getType) + //println("UNEXPECTED: " + term) + error("Unexpected " + expr) + } + } + for ((Var(e1), Var(e2)) <- inequalities; if isElementType(e1)) + elementFormula += Not(Equals(e1, e2)) + + (elementFormula.toSeq, table mapValues term2expr) } def term2expr(term: Term): Expr = term match { case Var(v) => v case Fun(cd, args) => CaseClass(cd, args map term2expr) } + + def isElementType(typed: Typed) = typed.getType match { + case AbstractClassType(_) | CaseClassType(_) => false + case _ => true + } + } diff --git a/src/orderedsets/UnifierMain.scala b/src/orderedsets/UnifierMain.scala index 099d306ac..91ac3df33 100644 --- a/src/orderedsets/UnifierMain.scala +++ b/src/orderedsets/UnifierMain.scala @@ -52,7 +52,7 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) { val noAlphas = And(restFormula flatMap expandAlphas(varMap)) reporter.info("The resulting formula is\n" + rpp(noAlphas)) tryAllSolvers(noAlphas) - } catch { + } catch { case ex@ConversionException(badExpr, msg) => reporter.info("Conjunction " + counter + " is UNKNOWN, could not be parsed") throw ex @@ -88,36 +88,38 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) { } } - def tryAllSolvers(f : Expr): Unit = { - for(solver <- Extensions.loadedSolverExtensions; if solver != this) { - reporter.info("Trying solver: " + solver.shortDescription + " from inside the unifier.") - solver.isUnsat(f) match { - case Some(true) => - reporter.info("Solver: " + solver.shortDescription + " proved the formula unsatisfiable") - return - case Some(false) => - reporter.warning("Solver: " + solver.shortDescription + " proved the formula satisfiable") - throw (new SatException(null)) - case None => - reporter.info("Solver: " + solver.shortDescription + " was unable to conclusively determine the correctness of the formula") - } - }; throw IncompleteException("All the solvers were unable to prove the formula unsatisfiable, giving up.") } + def tryAllSolvers(f: Expr): Unit = { + for (solver <- Extensions.loadedSolverExtensions; if solver != this) { + reporter.info("Trying solver: " + solver.shortDescription + " from inside the unifier.") + solver.isUnsat(f) match { + case Some(true) => + reporter.info("Solver: " + solver.shortDescription + " proved the formula unsatisfiable") + return + case Some(false) => + reporter.warning("Solver: " + solver.shortDescription + " proved the formula satisfiable") + throw (new SatException(null)) + case None => + reporter.info("Solver: " + solver.shortDescription + " was unable to conclusively determine the correctness of the formula") + } + }; + throw IncompleteException("All the solvers were unable to prove the formula unsatisfiable, giving up.") + } def checkIsSupported(expr: Expr) { def check(ex: Expr): Option[Expr] = ex match { case Let(_, _, _) | MatchExpr(_, _) => throw ConversionException(ex, "Unifier does not support this expression") case IfExpr(_, _, _) => - - println - println("--- BEFORE ---") - println(rpp(expr)) - println - println("--- AFTER ---") - println(rpp(rewrite(expr))) - println - - + + println + println("--- BEFORE ---") + println(rpp(expr)) + println + println("--- AFTER ---") + println(rpp(rewrite(expr))) + println + + throw ConversionException(ex, "Unifier does not support this expression") case _ => None } @@ -139,12 +141,12 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) { */ // The substitution table - val substTable = ADTUnifier.unify(treeEquations) + val (elementFormula, substTable) = ADTUnifier.unify(treeEquations) // The substitution function (returns identity if unmapped) def subst(v: Variable): Expr = substTable getOrElse (v, v) - (subst, rest) + (subst, elementFormula ++ rest) } @@ -239,7 +241,9 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) { //reporter.warning("Result:\n" + rpp(res)) Some(res) } - case _ => error("Bad argument/substitution to catamorphism: " + substArg(arg)) + case badArg => + println(rpp(badArg)) + error("Bad argument/substitution to catamorphism") } case None => // Not a catamorphism warning("Function " + fd.id + " is not a catamorphism.") diff --git a/src/orderedsets/getAlpha.scala b/src/orderedsets/getAlpha.scala deleted file mode 100644 index cad144df1..000000000 --- a/src/orderedsets/getAlpha.scala +++ /dev/null @@ -1,47 +0,0 @@ -package orderedsets - -import TreeOperations._ - -import purescala._ -import Trees._ -import Common._ -import TypeTrees._ -import Definitions._ - -object getAlpha { - var program: Program = null - - def setProgram(p: Program) = program = p - - def isAlpha(varMap: Variable => Expr)(t: Expr): Option[Expr] = t match { - case FunctionInvocation(fd, Seq(v@Variable(_))) => asCatamorphism(program, fd) match { - case None => None - case Some(lstMatch) => varMap(v) match { - case CaseClass(cd, args) => { - val (_, _, ids, rhs) = lstMatch.find(_._1 == cd).get - val repMap = Map(ids.map(id => Variable(id): Expr).zip(args): _*) - Some(searchAndReplace(repMap.get)(rhs)) - } - case u@Variable(_) => { - val c = Variable(FreshIdentifier("Coll", true)).setType(t.getType) - // TODO: Keep track of these variables for M1(t, c) - Some(c) - } - case _ => error("Bad substitution") - } - case _ => None - } - case _ => None - } - - def apply(t: Expr, varMap: Variable => Expr): Expr = { - searchAndReplace(isAlpha(varMap))(t) - } - - // def solve(e : Expr): Option[Boolean] = { - // searchAndReplace(isAlpha(x => x))(e) - // None - // } -} - - diff --git a/testcases/BinarySearchTree.scala b/testcases/BinarySearchTree.scala index 859a5eb29..1db8dcd10 100644 --- a/testcases/BinarySearchTree.scala +++ b/testcases/BinarySearchTree.scala @@ -1,5 +1,5 @@ import scala.collection.immutable.Set -// import scala.collection.immutable.Multiset +//import scala.collection.immutable.Multiset object BinarySearchTree { sealed abstract class Tree @@ -15,56 +15,56 @@ object BinarySearchTree { sealed abstract class Triple case class SortedTriple(min: Option, max: Option, sorted: Boolean) extends Triple - def isSorted(tree: Tree): SortedTriple = tree match { + def isSorted(tree: Tree): SortedTriple = tree match { case Leaf() => SortedTriple(None(), None(), true) - case Node(l,v,r) => isSorted(l) match { - case SortedTriple(minl,maxl,sortl) => if (!sortl) SortedTriple(None(), None(), false) - else minl match { - case None() => maxl match { - case None() => isSorted(r) match { - case SortedTriple(minr,maxr,sortr) => if (!sortr) SortedTriple(None(), None(), false) - else minr match { - case None() => maxr match { - case None() => SortedTriple(Some(v),Some(v),true) - case Some(maxrv) => SortedTriple(None(),None(),false) - } - case Some(minrv) => maxr match { - case Some(maxrv) => if (minrv > v) SortedTriple(Some(v),Some(maxrv),true) else SortedTriple(None(),None(),false) - case None() => SortedTriple(None(),None(),false) - } - } - } - case Some(maxlv) => SortedTriple(None(),None(),false) - } - case Some(minlv) => maxl match { - case Some(maxlv) => isSorted(r) match { - case SortedTriple(minr,maxr,sortr) => if (!sortr) SortedTriple(None(), None(), false) - else minr match { - case None() => maxr match { - case None() => if (maxlv <= v) SortedTriple(Some(minlv),Some(v),true) else SortedTriple(None(),None(),false) - case Some(maxrv) => SortedTriple(None(),None(),false) - } - case Some(minrv) => maxr match { - case Some(maxrv) => if (maxlv <= v && minrv > v) SortedTriple(Some(minlv),Some(maxrv),true) else SortedTriple(None(),None(),false) - case None() => SortedTriple(None(),None(),false) - } - } - } - case None() => SortedTriple(None(),None(),false) - } - } + case Node(l, v, r) => isSorted(l) match { + case SortedTriple(minl, maxl, sortl) => if (!sortl) SortedTriple(None(), None(), false) + else minl match { + case None() => maxl match { + case None() => isSorted(r) match { + case SortedTriple(minr, maxr, sortr) => if (!sortr) SortedTriple(None(), None(), false) + else minr match { + case None() => maxr match { + case None() => SortedTriple(Some(v), Some(v), true) + case Some(maxrv) => SortedTriple(None(), None(), false) + } + case Some(minrv) => maxr match { + case Some(maxrv) => if (minrv > v) SortedTriple(Some(v), Some(maxrv), true) else SortedTriple(None(), None(), false) + case None() => SortedTriple(None(), None(), false) + } + } + } + case Some(maxlv) => SortedTriple(None(), None(), false) + } + case Some(minlv) => maxl match { + case Some(maxlv) => isSorted(r) match { + case SortedTriple(minr, maxr, sortr) => if (!sortr) SortedTriple(None(), None(), false) + else minr match { + case None() => maxr match { + case None() => if (maxlv <= v) SortedTriple(Some(minlv), Some(v), true) else SortedTriple(None(), None(), false) + case Some(maxrv) => SortedTriple(None(), None(), false) + } + case Some(minrv) => maxr match { + case Some(maxrv) => if (maxlv <= v && minrv > v) SortedTriple(Some(minlv), Some(maxrv), true) else SortedTriple(None(), None(), false) + case None() => SortedTriple(None(), None(), false) + } + } + } + case None() => SortedTriple(None(), None(), false) } + } } - + } def treeMin(tree: Node): Int = { require(isSorted(tree).sorted) tree match { - case Node(left, v, _) => left match { - case Leaf() => v - case n@Node(_, _, _) => treeMin(n) + case Node(left, v, _) => left match { + case Leaf() => v + case n@Node(_, _, _) => treeMin(n) + } } - }} ensuring (_ == contents(tree).min) + } ensuring (_ == contents(tree).min) def treeMax(tree: Node): Int = { require(isSorted(tree).sorted) @@ -77,6 +77,19 @@ object BinarySearchTree { } ensuring (_ == contents(tree).max) def insert(tree: Tree, value: Int): Node = { + tree match { + case Leaf() => Node(Leaf(), value, Leaf()) + case n@Node(l, v, r) => if (v < value) { + Node(l, v, insert(r, value)) + } else if (v > value) { + Node(insert(l, value), v, r) + } else { + n + } + } + } ensuring (contents(_) == contents(tree) ++ Set(value)) + + def insertSorted(tree: Tree, value: Int): Node = { require(isSorted(tree).sorted) tree match { case Leaf() => Node(Leaf(), value, Leaf()) @@ -97,18 +110,18 @@ object BinarySearchTree { } } ensuring (contents(_) == contents(tree) ++ Set(0)) -/* - def remove(tree: Tree, value: Int) : Node = (tree match { - case Leaf() => Node(Leaf(), value, Leaf()) - case n @ Node(l, v, r) => if(v < value) { - Node(l, v, insert(r, value)) - } else if(v > value) { - Node(insert(l, value), v, r) - } else { - n - } - }) ensuring (contents(_) == contents(tree) -- Set(value)) -*/ + /* + def remove(tree: Tree, value: Int) : Node = (tree match { + case Leaf() => Node(Leaf(), value, Leaf()) + case n @ Node(l, v, r) => if(v < value) { + Node(l, v, insert(r, value)) + } else if(v > value) { + Node(insert(l, value), v, r) + } else { + n + } + }) ensuring (contents(_) == contents(tree) -- Set(value)) + */ def dumbInsertWithOrder(tree: Tree): Node = { tree match { @@ -139,20 +152,21 @@ object BinarySearchTree { Node(mkInfiniteTree(x), x, mkInfiniteTree(x)) } ensuring (res => res.left != Leaf() && res.right != Leaf() - ) + ) def contains(tree: Tree, value: Int): Boolean = { require(isSorted(tree).sorted) tree match { - case Leaf() => false - case n@Node(l, v, r) => if (v < value) { - contains(r, value) - } else if (v > value) { - contains(l, value) - } else { - true + case Leaf() => false + case n@Node(l, v, r) => if (v < value) { + contains(r, value) + } else if (v > value) { + contains(l, value) + } else { + true + } } - } } ensuring( _ || !(contents(tree) == contents(tree) ++ Set(value))) + } ensuring (_ || !(contents(tree) == contents(tree) ++ Set(value))) def contents(tree: Tree): Set[Int] = tree match { case Leaf() => Set.empty[Int] -- GitLab