diff --git a/CHANGES.md b/CHANGES.md index 9e00084aefc0fa88d3b085b37aa98eff324f253b..b71a20ae617b283b2021c8b7a59e2543de53cd74 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,8 @@ # Change List +## 2024-04-12 +Addition of the Congruence tactic, solving sequents by congruence closure using egraphs. + ## 2024-04-12 Addition of simply typed lambda calculus with top level polymorphism and inductive poylmorphic algebraic data types. Addition of tactics for typechecking and structural induction over ADTs. diff --git a/lisa-examples/src/main/scala/Example.scala b/lisa-examples/src/main/scala/Example.scala index c942aff6549e5dad4a610134951fcc7893771fbd..87ec3821c18d03483ef0a5012345966c85eb1771 100644 --- a/lisa-examples/src/main/scala/Example.scala +++ b/lisa-examples/src/main/scala/Example.scala @@ -1,8 +1,10 @@ import lisa.automation.Substitution.{ApplyRules as Substitute} import lisa.automation.Tableau import lisa.automation.atp.Goeland +import lisa.automation.Congruence object Example extends lisa.Main { + draft() val x = variable val y = variable @@ -61,7 +63,31 @@ object Example extends lisa.Main { } val buveurs2 = Theorem(exists(x, P(x) ==> forall(y, P(y)))) { - have(thesis) by Goeland("goeland/Example.buveurs2_sol") + have(thesis) by Goeland//("goeland/Example.buveurs2_sol") + } + + + val a = variable + val one = variable + val two = variable + val * = SchematicFunctionLabel("*", 2) + val << = SchematicFunctionLabel("<<", 2) + val / = SchematicFunctionLabel("/", 2) + private val star: SchematicFunctionLabel[2] = * + private val shift: SchematicFunctionLabel[2] = << + private val divide: SchematicFunctionLabel[2] = / + + while (true) { + () + } + + extension (t:Term) { + def *(u:Term) = star(t, u) + def <<(u:Term) = shift(t, u) + def /(u:Term) = divide(t, u) + } + val congruence = Theorem(((a*two) === (a<<one), a*(two/two) === (a*two)/two, (two/two) === one, (a*one) === a) |- ((a<<one)/two) === a) { + have(thesis) by Congruence } /* @@ -78,40 +104,43 @@ object Example extends lisa.Main { */ - /* + // Simple tactic definition for LISA DSL +/* import lisa.automation.kernel.OLPropositionalSolver.* -// object SimpleTautology extends ProofTactic { -// def solveFormula(using proof: library.Proof)(f: Formula, decisionsPos: List[Formula], decisionsNeg: List[Formula]): proof.ProofTacticJudgement = { -// val redF = reducedForm(f) -// if (redF == ⊤) { -// Restate(decisionsPos |- f :: decisionsNeg) -// } else if (redF == ⊥) { -// proof.InvalidProofTactic("Sequent is not a propositional tautology") -// } else { -// val atom = findBestAtom(redF).get -// def substInRedF(f: Formula) = redF.substituted(atom -> f) -// TacticSubproof { -// have(solveFormula(substInRedF(⊤), atom :: decisionsPos, decisionsNeg)) -// val step2 = thenHave(atom :: decisionsPos |- redF :: decisionsNeg) by Substitution2(⊤ <=> atom) -// have(solveFormula(substInRedF(⊥), decisionsPos, atom :: decisionsNeg)) -// val step4 = thenHave(decisionsPos |- redF :: atom :: decisionsNeg) by Substitution2(⊥ <=> atom) -// have(decisionsPos |- redF :: decisionsNeg) by Cut(step4, step2) -// thenHave(decisionsPos |- f :: decisionsNeg) by Restate -// } -// } -// } -// def solveSequent(using proof: library.Proof)(bot: Sequent) = -// TacticSubproof { // Since the tactic above works on formulas, we need an extra step to convert an arbitrary sequent to an equivalent formula -// have(solveFormula(sequentToFormula(bot), Nil, Nil)) -// thenHave(bot) by Restate.from -// } -// } - */ +object SimpleTautology extends ProofTactic { + def solveFormula(using proof: library.Proof)(f: Formula, decisionsPos: List[Formula], decisionsNeg: List[Formula]): proof.ProofTacticJudgement = { + val redF = reducedForm(f) + if (redF == ⊤) { + Restate(decisionsPos |- f :: decisionsNeg) + } else if (redF == ⊥) { + proof.InvalidProofTactic("Sequent is not a propositional tautology") + } else { + val atom = findBestAtom(redF).get + def substInRedF(f: Formula) = redF.substituted(atom -> f) + TacticSubproof { + have(solveFormula(substInRedF(⊤), atom :: decisionsPos, decisionsNeg)) + val step2 = thenHave(atom :: decisionsPos |- redF :: decisionsNeg) by Substitution2(⊤ <=> atom) + have(solveFormula(substInRedF(⊥), decisionsPos, atom :: decisionsNeg)) + val step4 = thenHave(decisionsPos |- redF :: atom :: decisionsNeg) by Substitution2(⊥ <=> atom) + have(decisionsPos |- redF :: decisionsNeg) by Cut(step4, step2) + thenHave(decisionsPos |- f :: decisionsNeg) by Restate + } + } + } + + def solveSequent(using proof: library.Proof)(bot: Sequent) = + TacticSubproof { // Since the tactic above works on formulas, we need an extra step to convert an arbitrary sequent to an equivalent formula + have(solveFormula(sequentToFormula(bot), Nil, Nil)) + thenHave(bot) by Restate.from + } +} +*/ + // val a = formulaVariable() // val b = formulaVariable() // val c = formulaVariable() diff --git a/lisa-sets/src/main/scala/lisa/automation/Congruence.scala b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala new file mode 100644 index 0000000000000000000000000000000000000000..ccf25295f6803d41a06a86e59177f6f596a79a03 --- /dev/null +++ b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala @@ -0,0 +1,624 @@ +package lisa.automation +import lisa.fol.FOL.{*, given} +import lisa.prooflib.BasicStepTactic.* +import lisa.prooflib.ProofTacticLib.* +import lisa.prooflib.SimpleDeducedSteps.* +import lisa.prooflib.* +import lisa.utils.parsing.UnreachableException +import leo.datastructures.TPTP.CNF.AtomicFormula + +/** + * This tactic tries to prove a sequent by congruence. + * Consider the congruence closure of all terms and formulas in the sequent, with respect to all === and <=> left of the sequent. + * The sequent is provable by congruence if one of the following conditions is met: + * - The right side contains an equality s === t or equivalence a <=> b provable in the congruence closure. + * - The left side contains an negated equality !(s === t) or equivalence !(a <=> b) provable in the congruence closure. + * - There is a formula a on the left and b on the right such that a and b are congruent. + * - There are two formulas a and !b on the left such that a and b are congruent. + * - There are two formulas a and !b on the right such that a and b are congruent. + * - The sequent is Ol-valid without equality reasoning + * Note that complete congruence closure modulo OL is an open problem. + * + * The tactic uses an egraph datastructure to compute the congruence closure. + * The egraph itselfs relies on two underlying union-find datastructure, one for terms and one for formulas. + * The union-finds are equiped with an `explain` method that produces a path between any two elements in the same equivalence class. + * Each edge of the path can come from an external equality, or be the consequence of congruence. + * The tactic uses uses this path to produce needed proofs. + * + */ +object Congruence extends ProofTactic with ProofSequentTactic { + def apply(using lib: Library, proof: lib.Proof)(bot: Sequent): proof.ProofTacticJudgement = TacticSubproof { + import lib.* + + val egraph = new EGraphTerms() + egraph.addAll(bot.left) + egraph.addAll(bot.right) + + bot.left.foreach{ + case (left === right) => egraph.merge(left, right) + case (left <=> right) => egraph.merge(left, right) + case _ => () + } + + if isSameSequent(bot, ⊤) then + have(bot) by Restate + else if bot.left.exists { lf => + bot.right.exists { rf => + if egraph.idEq(lf, rf) then + val base = have(bot.left |- (bot.right + lf) ) by Restate + val eq = have(egraph.proveFormula(lf, rf, bot)) + val a = formulaVariable + have((bot.left + (lf <=> rf)) |- (bot.right) ) by RightSubstIff.withParametersSimple(List((lf, rf)), lambda(a, a))(base) + have(bot) by Cut(eq, lastStep) + true + else false + } || + bot.left.exists{ + case rf2 @ Neg(rf) if egraph.idEq(lf, rf)=> + val base = have((bot.left + !lf) |- bot.right ) by Restate + val eq = have(egraph.proveFormula(lf, rf, bot)) + val a = formulaVariable + have((bot.left + (lf <=> rf)) |- (bot.right) ) by LeftSubstIff.withParametersSimple(List((lf, rf)), lambda(a, !a))(base) + have(bot) by Cut(eq, lastStep) + true + case _ => false + } || { + lf match + case !(a === b) if egraph.idEq(a, b) => + have(egraph.proveTerm(a, b, bot)) + true + case !(a <=> b) if egraph.idEq(a, b) => + have(egraph.proveFormula(a, b, bot)) + true + case _ => false + } + + } then () + else if bot.right.exists { rf => + bot.right.exists{ + case lf2 @ Neg(lf) if egraph.idEq(lf, rf)=> + val base = have((bot.left) |- (bot.right + !rf) ) by Restate + val eq = have(egraph.proveFormula(lf, rf, bot)) + val a = formulaVariable + have((bot.left + (lf <=> rf)) |- (bot.right) ) by RightSubstIff.withParametersSimple(List((lf, rf)), lambda(a, !a))(base) + have(bot) by Cut(eq, lastStep) + true + case _ => false + } || { + rf match + case (a === b) if egraph.idEq(a, b) => + have(egraph.proveTerm(a, b, bot)) + true + case (a <=> b) if egraph.idEq(a, b) => + have(egraph.proveFormula(a, b, bot)) + true + case _ => false + } + } then () + else + return proof.InvalidProofTactic(s"No congruence found to show sequent\n $bot") + } + + +} + + +class UnionFind[T] { + // parent of each element, leading to its root. Uses path compression + val parent = scala.collection.mutable.Map[T, T]() + // original parent of each element, leading to its root. Does not use path compression. Used for explain. + val realParent = scala.collection.mutable.Map[T, (T, ((T, T), Boolean, Int))]() + //keep track of the rank (i.e. number of elements bellow it) of each element. Necessary to optimize union. + val rank = scala.collection.mutable.Map[T, Int]() + //tracks order of ancientness of unions. + var unionCounter = 0 + + /** + * add a new element to the union-find. + */ + def add(x: T): Unit = { + parent(x) = x + realParent(x) = (x, ((x, x), true, 0)) + rank(x) = 0 + } + + /** + * + * @param x the element whose parent we want to find + * @return the root of x + */ + def find(x: T): T = { + if parent(x) == x then + x + else + var root = x + while parent(root) != root do + root = parent(root) + var y = x + while parent(y) != root do + parent(y) = root + y = parent(y) + root + } + + /** + * Merges the classes of x and y + */ + def union(x: T, y: T): Unit = { + unionCounter += 1 + val xRoot = find(x) + val yRoot = find(y) + if (xRoot == yRoot) return + if (rank(xRoot) < rank(yRoot)) { + parent(xRoot) = yRoot + realParent(xRoot) = (yRoot, ((x, y), true, unionCounter)) + } else if (rank(xRoot) > rank(yRoot)) { + parent(yRoot) = xRoot + realParent(yRoot) = (xRoot, ((x, y), false, unionCounter)) + } else { + parent(yRoot) = xRoot + realParent(yRoot) = (xRoot, ((x, y), false, unionCounter)) + rank(xRoot) = rank(xRoot) + 1 + } + } + + private def getPathToRoot(x: T): List[T] = { + if x == find(x) then + List(x) + else + val next = realParent(x) + x :: getPathToRoot(next._1) + + } + + private def getExplanationFromTo(x:T, c: T): List[(T, ((T, T), Boolean, Int))] = { + if x == c then + List() + else + val next = realParent(x) + next :: getExplanationFromTo(next._1, c)} + + private def lowestCommonAncestor(x: T, y: T): Option[T] = { + val pathX = getPathToRoot(x) + val pathY = getPathToRoot(y) + pathX.find(pathY.contains) + } + + /** + * Returns a path from x to y made of pairs of elements (u, v) + * such that union(u, v) was called. + */ + def explain(x:T, y:T): Option[List[(T, T)]]= { + + if (x == y) then return Some(List()) + val lca = lowestCommonAncestor(x, y) + lca match + case None => None + case Some(lca) => + var max :((T, T), Boolean, Int) = ((x, x), true, 0) + var itX = x + while itX != lca do + val (next, ((u1, u2), b, c)) = realParent(itX) + if c > max._3 then + max = ((u1, u2), b, c) + itX = next + + var itY = y + while itY != lca do + val (next, ((u1, u2), b, c)) = realParent(itY) + if c > max._3 then + max = ((u1, u2), !b, c) + itY = next + + val u1 = max._1._1 + val u2 = max._1._2 + if max._2 then + Some(explain(x, u1).get ++ List((u1, u2)) ++ explain(u2, y).get) + else + Some(explain(x, u2).get ++ List((u1, u2)) ++ explain(u1, y).get) + } + + + /** + * Returns the set of all roots of all classes + */ + def getClasses: Set[T] = parent.keys.map(find).toSet + + /** + * Add all elements in the collection to the union-find + */ + def addAll(xs: Iterable[T]): Unit = xs.foreach(add) + +} + + +/////////////////////////////// +///////// E-graph ///////////// +/////////////////////////////// + +import scala.collection.mutable + +class EGraphTerms() { + + type ENode = Term | Formula + + + + val termMap = mutable.Map[Term, Set[Term]]() + val termParents = mutable.Map[Term, mutable.Set[AppliedFunctional | AppliedPredicate]]() + var termWorklist = List[Term]() + val termUF = new UnionFind[Term]() + + + + + val formulaMap = mutable.Map[Formula, Set[Formula]]() + val formulaParents = mutable.Map[Formula, mutable.Set[AppliedConnector]]() + var formulaWorklist = List[Formula]() + val formulaUF = new UnionFind[Formula]() + + + + + trait TermStep + case class TermExternal(between: (Term, Term)) extends TermStep + case class TermCongruence(between: (Term, Term)) extends TermStep + + trait FormulaStep + case class FormulaExternal(between: (Formula, Formula)) extends FormulaStep + case class FormulaCongruence(between: (Formula, Formula)) extends FormulaStep + + val termProofMap = mutable.Map[(Term, Term), TermStep]() + val formulaProofMap = mutable.Map[(Formula, Formula), FormulaStep]() + + def explain(id1: Term, id2: Term): Option[List[TermStep]] = { + val steps = termUF.explain(id1, id2) + steps.map(_.foldLeft((id1, List[TermStep]())) { + case ((prev, acc), step) => + termProofMap(step) match + case s @ TermExternal((l, r)) => + if l == prev then + (r, s :: acc) + else if r == prev then + (l, TermExternal(r, l) :: acc) + else throw new Exception("Invalid proof recovered: It is not a chain") + case s @ TermCongruence((l, r)) => + if l == prev then + (r, s :: acc) + else if r == prev then + (l, TermCongruence(r, l) :: acc) + else throw new Exception("Invalid proof recovered: It is not a chain") + + }._2.reverse) + } + + def explain(id1: Formula, id2: Formula): Option[List[FormulaStep]] = { + val steps = formulaUF.explain(id1, id2) + steps.map(_.foldLeft((id1, List[FormulaStep]())) { + case ((prev, acc), step) => + formulaProofMap(step) match + case s @ FormulaExternal((l, r)) => + if l == prev then + (r, s :: acc) + else if r == prev then + (l, FormulaExternal(r, l) :: acc) + else throw new Exception("Invalid proof recovered: It is not a chain") + case s @ FormulaCongruence((l, r)) => + if l == prev then + (r, s :: acc) + else if r == prev then + (l, FormulaCongruence(r, l) :: acc) + else throw new Exception("Invalid proof recovered: It is not a chain") + + }._2.reverse) + } + + + def makeSingletonEClass(node:Term): Term = { + termUF.add(node) + termMap(node) = Set(node) + termParents(node) = mutable.Set() + node + } + def makeSingletonEClass(node:Formula): Formula = { + formulaUF.add(node) + formulaMap(node) = Set(node) + formulaParents(node) = mutable.Set() + node + } + + def classOf(id: Term): Set[Term] = termMap(id) + def classOf(id: Formula): Set[Formula] = formulaMap(id) + + def idEq(id1: Term, id2: Term): Boolean = termUF.find(id1) == termUF.find(id2) + def idEq(id1: Formula, id2: Formula): Boolean = formulaUF.find(id1) == formulaUF.find(id2) + + def canonicalize(node: Term): Term = node match + case AppliedFunctional(label, args) => + AppliedFunctional(label, args.map(termUF.find.asInstanceOf)) + case _ => node + + + def canonicalize(node: Formula): Formula = { + node match + case AppliedPredicate(label, args) => AppliedPredicate(label, args.map(termUF.find)) + case AppliedConnector(label, args) => AppliedConnector(label, args.map(formulaUF.find)) + case node => node + } + + def add(node: Term): Term = + if termMap.contains(node) then return node + makeSingletonEClass(node) + node match + case node @ AppliedFunctional(_, args) => + args.foreach(child => + add(child) + termParents(child).add(node) + ) + node + case _ => node + + def add(node: Formula): Formula = + if formulaMap.contains(node) then return node + makeSingletonEClass(node) + node match + case node @ AppliedPredicate(_, args) => + args.foreach(child => + add(child) + termParents(child).add(node) + ) + node + case node @ AppliedConnector(_, args) => + args.foreach(child => + add(child) + formulaParents(child).add(node) + ) + node + case _ => node + + def addAll(nodes: Iterable[Term|Formula]): Unit = + nodes.foreach{ + case node: Term => add(node) + case node: Formula => add(node) + } + + + + + def merge(id1: Term, id2: Term): Unit = { + mergeWithStep(id1, id2, TermExternal((id1, id2))) + } + def merge(id1: Formula, id2: Formula): Unit = { + mergeWithStep(id1, id2, FormulaExternal((id1, id2))) + } + + protected def mergeWithStep(id1: Term, id2: Term, step: TermStep): Unit = { + if termUF.find(id1) == termUF.find(id2) then () + else + termProofMap((id1, id2)) = step + val newSet = termMap(termUF.find(id1)) ++ termMap(termUF.find(id2)) + val newparents = termParents(termUF.find(id1)) ++ termParents(termUF.find(id2)) + termUF.union(id1, id2) + val newId1 = termUF.find(id1) + val newId2 = termUF.find(id2) + termMap(newId1) = newSet + termMap(newId2) = newSet + termParents(newId1) = newparents + termParents(newId2) = newparents + + val id = termUF.find(id2) + termWorklist = id :: termWorklist + val cause = (id1, id2) + val termSeen = mutable.Map[Term, AppliedFunctional]() + val formulaSeen = mutable.Map[Formula, AppliedPredicate]() + newparents.foreach { + case pTerm: AppliedFunctional => + val canonicalPTerm = canonicalize(pTerm) + if termSeen.contains(canonicalPTerm) then + val qTerm = termSeen(canonicalPTerm) + Some((pTerm, qTerm, cause)) + mergeWithStep(pTerm, qTerm, TermCongruence((pTerm, qTerm))) + else + termSeen(canonicalPTerm) = pTerm + case pFormula: AppliedPredicate => + val canonicalPFormula = canonicalize(pFormula) + if formulaSeen.contains(canonicalPFormula) then + val qFormula = formulaSeen(canonicalPFormula) + + Some((pFormula, qFormula, cause)) + mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) + else + formulaSeen(canonicalPFormula) = pFormula + } + termParents(id) = (termSeen.values.to(mutable.Set): mutable.Set[AppliedFunctional | AppliedPredicate]) ++ formulaSeen.values.to(mutable.Set) + } + + protected def mergeWithStep(id1: Formula, id2: Formula, step: FormulaStep): Unit = + if formulaUF.find(id1) == formulaUF.find(id2) then () + else + formulaProofMap((id1, id2)) = step + val newSet = formulaMap(formulaUF.find(id1)) ++ formulaMap(formulaUF.find(id2)) + val newparents = formulaParents(formulaUF.find(id1)) ++ formulaParents(formulaUF.find(id2)) + formulaUF.union(id1, id2) + val newId1 = formulaUF.find(id1) + val newId2 = formulaUF.find(id2) + formulaMap(newId1) = newSet + formulaMap(newId2) = newSet + formulaParents(newId1) = newparents + formulaParents(newId2) = newparents + val id = formulaUF.find(id2) + formulaWorklist = id :: formulaWorklist + val cause = (id1, id2) + val formulaSeen = mutable.Map[Formula, AppliedConnector]() + newparents.foreach { + case pFormula: AppliedConnector => + val canonicalPFormula = canonicalize(pFormula) + if formulaSeen.contains(canonicalPFormula) then + val qFormula = formulaSeen(canonicalPFormula) + Some((pFormula, qFormula, cause)) + mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) + else + formulaSeen(canonicalPFormula) = pFormula + } + formulaParents(id) = formulaSeen.values.to(mutable.Set) + + + def proveTerm(using lib: Library, proof: lib.Proof)(id1: Term, id2:Term, base: Sequent): proof.ProofTacticJudgement = + TacticSubproof { proveInnerTerm(id1, id2, base) } + + def proveInnerTerm(using lib: Library, proof: lib.Proof)(id1: Term, id2:Term, base: Sequent): Unit = { + import lib.* + val steps = explain(id1, id2) + steps match { + case None => throw new Exception("No proof found in the egraph") + case Some(steps) => + if steps.isEmpty then have(base.left |- (base.right + (id1 === id2))) by Restate + steps.foreach { + case TermExternal((l, r)) => + val goalSequent = base.left |- (base.right + (id1 === r)) + if l == id1 then + have(goalSequent) by Restate + else + val x = freshVariable(id1) + have(goalSequent) by RightSubstEq.withParametersSimple(List((l, r)), lambda(x, id1 === x))(lastStep) + case TermCongruence((l, r)) => + val prev = if id1 != l then lastStep else null + val leqr = have(base.left |- (base.right + (l === r))) subproof { sp ?=> + (l, r) match + case (AppliedFunctional(labell, argsl), AppliedFunctional(labelr, argsr)) if labell == labelr && argsl.size == argsr.size => + var freshn = freshId((l.freeVariables ++ r.freeVariables).map(_.id), "n").no + val ziped = (argsl zip argsr) + var zip = List[(Term, Term)]() + var children = List[Term]() + var vars = List[Variable]() + var steps = List[(Formula, sp.ProofStep)]() + ziped.reverse.foreach { (al, ar) => + if al == ar then children = al :: children + else { + val x = Variable(Identifier("n", freshn)) + freshn = freshn + 1 + children = x :: children + vars = x :: vars + steps = (al === ar, have(proveTerm(al, ar, base))) :: steps + zip = (al, ar) :: zip + } + } + have(base.left |- (base.right + (l === l))) by Restate + val eqs = zip.map((l, r) => l === r) + val goal = have((base.left ++ eqs) |- (base.right + (l === r))).by.bot + have((base.left ++ eqs) |- (base.right + (l === r))) by RightSubstEq.withParametersSimple(zip, lambda(vars, l === labelr.applyUnsafe(children)))(lastStep) + steps.foreach { s => + have( + if s._2.bot.left.contains(s._1) then lastStep.bot else lastStep.bot -<< s._1 + ) by Cut(s._2, lastStep) + } + case _ => + println(s"l: $l") + println(s"r: $r") + throw UnreachableException + + } + if id1 != l then + val goalSequent = base.left |- (base.right + (id1 === r)) + val x = freshVariable(id1) + have(goalSequent +<< (l === r)) by RightSubstEq.withParametersSimple(List((l, r)), lambda(x, id1 === x))(prev) + have(goalSequent) by Cut(leqr, lastStep) + } + } + } + + def proveFormula(using lib: Library, proof: lib.Proof)(id1: Formula, id2:Formula, base: Sequent): proof.ProofTacticJudgement = + TacticSubproof { proveInnerFormula(id1, id2, base) } + + def proveInnerFormula(using lib: Library, proof: lib.Proof)(id1: Formula, id2:Formula, base: Sequent): Unit = { + import lib.* + val steps = explain(id1, id2) + steps match { + case None => throw new Exception("No proof found in the egraph") + case Some(steps) => + if steps.isEmpty then have(base.left |- (base.right + (id1 <=> id2))) by Restate + steps.foreach { + case FormulaExternal((l, r)) => + val goalSequent = base.left |- (base.right + (id1 <=> r)) + if l == id1 then + have(goalSequent) by Restate + else + val x = freshVariableFormula(id1) + have(goalSequent) by RightSubstIff.withParametersSimple(List((l, r)), lambda(x, id1 <=> x))(lastStep) + case FormulaCongruence((l, r)) => + val prev = if id1 != l then lastStep else null + val leqr = have(base.left |- (base.right + (l <=> r))) subproof { sp ?=> + (l, r) match + case (AppliedConnector(labell, argsl), AppliedConnector(labelr, argsr)) if labell == labelr && argsl.size == argsr.size => + var freshn = freshId((l.freeVariableFormulas ++ r.freeVariableFormulas).map(_.id), "n").no + val ziped = (argsl zip argsr) + var zip = List[(Formula, Formula)]() + var children = List[Formula]() + var vars = List[VariableFormula]() + var steps = List[(Formula, sp.ProofStep)]() + ziped.reverse.foreach { (al, ar) => + if al == ar then children = al :: children + else { + val x = VariableFormula(Identifier("n", freshn)) + freshn = freshn + 1 + children = x :: children + vars = x :: vars + steps = (al <=> ar, have(proveFormula(al, ar, base))) :: steps + zip = (al, ar) :: zip + } + } + have(base.left |- (base.right + (l <=> l))) by Restate + val eqs = zip.map((l, r) => l <=> r) + val goal = have((base.left ++ eqs) |- (base.right + (l <=> r))).by.bot + have((base.left ++ eqs) |- (base.right + (l <=> r))) by RightSubstIff.withParametersSimple(zip, lambda(vars, l <=> labelr.applyUnsafe(children)))(lastStep) + steps.foreach { s => + have( + if s._2.bot.left.contains(s._1) then lastStep.bot else lastStep.bot -<< s._1 + ) by Cut(s._2, lastStep) + } + + case (AppliedPredicate(labell, argsl), AppliedPredicate(labelr, argsr)) if labell == labelr && argsl.size == argsr.size => + var freshn = freshId((l.freeVariableFormulas ++ r.freeVariableFormulas).map(_.id), "n").no + val ziped = (argsl zip argsr) + var zip = List[(Term, Term)]() + var children = List[Term]() + var vars = List[Variable]() + var steps = List[(Formula, sp.ProofStep)]() + ziped.reverse.foreach { (al, ar) => + if al == ar then children = al :: children + else { + val x = Variable(Identifier("n", freshn)) + freshn = freshn + 1 + children = x :: children + vars = x :: vars + steps = (al === ar, have(proveTerm(al, ar, base))) :: steps + zip = (al, ar) :: zip + } + } + have(base.left |- (base.right + (l <=> l))) by Restate + val eqs = zip.map((l, r) => l === r) + val goal = have((base.left ++ eqs) |- (base.right + (l <=> r))).by.bot + have((base.left ++ eqs) |- (base.right + (l <=> r))) by RightSubstEq.withParametersSimple(zip, lambda(vars, l <=> labelr.applyUnsafe(children)))(lastStep) + steps.foreach { s => + have( + if s._2.bot.left.contains(s._1) then lastStep.bot else lastStep.bot -<< s._1 + ) by Cut(s._2, lastStep) + } + case _ => + println(s"l: $l") + println(s"r: $r") + throw UnreachableException + + } + if id1 != l then + val goalSequent = base.left |- (base.right + (id1 <=> r)) + val x = freshVariableFormula(id1) + have(goalSequent +<< (l <=> r)) by RightSubstIff.withParametersSimple(List((l, r)), lambda(x, id1 <=> x))(prev) + have(goalSequent) by Cut(leqr, lastStep) + + } + } + } + + +} \ No newline at end of file diff --git a/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala b/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala index b3d69dcc607dc99de300992e3a9df47510825929..f9599302ef52de66df30855cd9f7d20f515997f9 100644 --- a/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala +++ b/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala @@ -36,7 +36,7 @@ object SetTheory2 extends lisa.Main { have(thesis) by Tableau } - val functionalIsFunctional = Lemma( + val functionalIsFunctional = Theorem( ∀(x, in(x, A) ==> ∀(y, ∀(z, (P(x, y) /\ P(x, z)) ==> (y === z)))).substitute(P := lambda((A, B), Filter(A) /\ (B === Map(A)))) <=> top ) { @@ -45,7 +45,7 @@ object SetTheory2 extends lisa.Main { thenHave(in(x, A) |- ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))) by Weakening thenHave(in(x, A) |- ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z)))) by RightForall thenHave(in(x, A) |- ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by RightForall - thenHave(in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by Restate + //thenHave(in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by Restate thenHave(∀(x, in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z)))))) by RightForall thenHave(thesis) by Restate diff --git a/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala b/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..5e50502d583a20cbc0d84d61752801a1e197f72d --- /dev/null +++ b/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala @@ -0,0 +1,914 @@ +package lisa.automation +import lisa.fol.FOL.{*, given} +import lisa.automation.Congruence.* +import lisa.automation.Congruence +import org.scalatest.funsuite.AnyFunSuite + + +class CongruenceTest extends AnyFunSuite with lisa.TestMain { + + given lib: lisa.SetTheoryLibrary.type = lisa.SetTheoryLibrary + + val a = variable + val b = variable + val c = variable + val d = variable + val e = variable + val f = variable + val g = variable + val h = variable + val i = variable + val j = variable + val k = variable + val l = variable + val m = variable + val n = variable + val o = variable + + val x = variable + + val F = function[1] + + val one = variable + val two = variable + val * = SchematicFunctionLabel("*", 2) + val << = SchematicFunctionLabel("<<", 2) + val / = SchematicFunctionLabel("/", 2) + + + val af = formulaVariable + val bf = formulaVariable + val cf = formulaVariable + val df = formulaVariable + val ef = formulaVariable + val ff = formulaVariable + val gf = formulaVariable + val hf = formulaVariable + val if_ = formulaVariable + val jf = formulaVariable + val kf = formulaVariable + val lf = formulaVariable + val mf = formulaVariable + val nf = formulaVariable + val of = formulaVariable + + val xf = formulaVariable + + val Ff = SchematicConnectorLabel("Ff", 1) + val Fp = SchematicPredicateLabel("Fp", 1) + + val onef = formulaVariable + val twof = formulaVariable + val `*f` = SchematicConnectorLabel("*f", 2) + val `<<f` = SchematicConnectorLabel("<<f", 2) + val `/f` = SchematicConnectorLabel("/f", 2) + + + test("3 terms no congruence egraph test") { + val egraph = new EGraphTerms() + egraph.add(a) + egraph.add(b) + egraph.add(c) + egraph.merge(a, b) + assert(egraph.idEq(a, b)) + assert(!egraph.idEq(a, c)) + + } + + test("8 terms no congruence egraph test") { + val egraph = new EGraphTerms() + egraph.add(a) + egraph.add(b) + egraph.add(c) + egraph.add(d) + egraph.add(e) + egraph.add(f) + egraph.add(g) + egraph.add(h) + egraph.merge(a, b) + egraph.merge(c, d) + egraph.merge(e, f) + egraph.merge(g, h) + egraph.merge(a, c) + egraph.merge(f, h) + egraph.merge(a, f) + assert(egraph.idEq(a, h)) + + } + + test("15 terms no congruence egraph test") { + val egraph = new EGraphTerms() + egraph.add(a) + egraph.add(b) + egraph.add(c) + egraph.add(d) + egraph.add(e) + egraph.add(f) + egraph.add(g) + egraph.add(h) + egraph.add(i) + egraph.add(j) + egraph.add(k) + egraph.add(l) + egraph.add(m) + egraph.add(n) + egraph.add(o) + egraph.merge(a, c) + egraph.merge(e, f) + egraph.merge(i, k) + egraph.merge(m, n) + egraph.merge(a, b) + egraph.merge(o, m) + egraph.merge(i, m) + egraph.merge(g, h) + egraph.merge(l, k) + egraph.merge(c, d) + egraph.merge(a, e) + egraph.merge(a, i) + egraph.merge(g, e) + egraph.merge(i, j) + assert(egraph.idEq(a, o)) + + } + + test("15 terms no congruence egraph test with redundant merges") { + val egraph = new EGraphTerms() + egraph.add(a) + egraph.add(b) + egraph.add(c) + egraph.add(d) + egraph.add(e) + egraph.add(f) + egraph.add(g) + egraph.add(h) + egraph.add(i) + egraph.add(j) + egraph.add(k) + egraph.add(l) + egraph.add(m) + egraph.add(n) + egraph.add(o) + egraph.merge(a, c) + egraph.merge(e, f) + egraph.merge(i, k) + egraph.merge(m, n) + egraph.merge(a, b) + egraph.merge(o, m) + egraph.merge(i, m) + egraph.merge(g, h) + egraph.merge(l, k) + egraph.merge(b, c) + egraph.merge(f, e) + egraph.merge(o, i) + egraph.merge(g, e) + egraph.merge(i, j) + + assert(egraph.idEq(b, c)) + assert(egraph.idEq(f, h)) + assert(egraph.idEq(i, o)) + assert(!egraph.idEq(a, d)) + assert(!egraph.idEq(b, g)) + assert(!egraph.idEq(f, i)) + assert(!egraph.idEq(n, c)) + assert(egraph.termUF.getClasses.size == 4) + + } + + test("4 terms withcongruence egraph test") { + val egraph = new EGraphTerms() + egraph.add(F(a)) + egraph.add(F(b)) + egraph.merge(a, b) + assert(egraph.idEq(a, b)) + assert(egraph.idEq(F(a), F(b))) + assert(!egraph.idEq(a, F(a))) + assert(!egraph.idEq(a, F(b))) + assert(!egraph.idEq(b, F(a))) + assert(!egraph.idEq(b, F(b))) + + assert(egraph.explain(F(a), F(b)) == Some(List(egraph.TermCongruence((F(a), F(b))))) ) + + } + + + + test("divide-mult-shift in terms by 2 egraph test") { + + val egraph = new EGraphTerms() + egraph.add(one) + egraph.add(two) + egraph.add(a) + val ax2 = egraph.add(*(a, two)) + val ax2_d2 = egraph.add(/(*(a, two), two)) + val `2d2` = egraph.add(/(two, two)) + val ax_2d2 = egraph.add(*(a, /(two, two))) + val ax1 = egraph.add(*(a, one)) + val as1 = egraph.add(<<(a, one)) + + egraph.merge(ax2, as1) + egraph.merge(ax2_d2, ax_2d2) + egraph.merge(`2d2`, one) + egraph.merge(ax1, a) + + + assert(egraph.idEq(one, `2d2`)) + assert(egraph.idEq(ax2, as1)) + assert(egraph.idEq(ax2_d2, ax_2d2)) + assert(egraph.idEq(ax_2d2, ax1)) + assert(egraph.idEq(ax_2d2, a)) + + assert(!egraph.idEq(ax2, ax2_d2)) + assert(!egraph.idEq(ax2, `2d2`)) + assert(!egraph.idEq(ax2, ax_2d2)) + assert(!egraph.idEq(ax2, ax1)) + assert(!egraph.idEq(ax2, a)) + assert(!egraph.idEq(ax2_d2, `2d2`)) + + assert(egraph.explain(one, `2d2`) == Some(List(egraph.TermExternal((one, `2d2`)))) ) + assert(egraph.explain(ax2, as1) == Some(List(egraph.TermExternal((ax2, as1)))) ) + assert(egraph.explain(ax2_d2, ax_2d2) == Some(List(egraph.TermExternal((ax2_d2, ax_2d2)))) ) + + assert(egraph.explain(ax_2d2, ax1) == Some(List(egraph.TermCongruence((ax_2d2, ax1)))) ) + assert(egraph.explain(ax_2d2, a) == Some(List(egraph.TermCongruence((ax_2d2, ax1)), egraph.TermExternal((ax1, a))) )) + + + } + + test("long chain of terms congruence eGraph") { + val egraph = new EGraphTerms() + egraph.add(x) + val fx = egraph.add(F(x)) + val ffx = egraph.add(F(fx)) + val fffx = egraph.add(F(ffx)) + val ffffx = egraph.add(F(fffx)) + val fffffx = egraph.add(F(ffffx)) + val ffffffx = egraph.add(F(fffffx)) + val fffffffx = egraph.add(F(ffffffx)) + val ffffffffx = egraph.add(F(fffffffx)) + + + egraph.merge(ffffffffx, x) + egraph.merge(fffffx, x) + assert(egraph.idEq(fffx, x)) + assert(egraph.idEq(ffx, x)) + assert(egraph.idEq(fx, x)) + assert(egraph.idEq(x, fx)) + + assert(egraph.explain(fx, x) == Some(List(egraph.TermCongruence(fx, fffx), egraph.TermCongruence(fffx, ffffffffx), egraph.TermExternal(ffffffffx, x)))) + + } + + + test("3 formulas no congruence egraph test") { + val egraph = new EGraphTerms() + egraph.add(af) + egraph.add(bf) + egraph.add(cf) + egraph.merge(af, bf) + assert(egraph.idEq(af, bf)) + assert(!egraph.idEq(af, cf)) + + } + + test("8 formulas no congruence egraph test") { + val egraph = new EGraphTerms() + egraph.add(af) + egraph.add(bf) + egraph.add(cf) + egraph.add(df) + egraph.add(ef) + egraph.add(ff) + egraph.add(gf) + egraph.add(hf) + egraph.merge(af, bf) + egraph.merge(cf, df) + egraph.merge(ef, ff) + egraph.merge(gf, hf) + egraph.merge(af, cf) + egraph.merge(ff, hf) + egraph.merge(af, ff) + assert(egraph.idEq(af, hf)) + + } + + test("15 formulas no congruence egraph test") { + val egraph = new EGraphTerms() + egraph.add(af) + egraph.add(bf) + egraph.add(cf) + egraph.add(df) + egraph.add(ef) + egraph.add(ff) + egraph.add(gf) + egraph.add(hf) + egraph.add(if_) + egraph.add(jf) + egraph.add(kf) + egraph.add(lf) + egraph.add(mf) + egraph.add(nf) + egraph.add(of) + egraph.merge(af, cf) + egraph.merge(ef, ff) + egraph.merge(if_, kf) + egraph.merge(mf, nf) + egraph.merge(af, bf) + egraph.merge(of, mf) + egraph.merge(if_, mf) + egraph.merge(gf, hf) + egraph.merge(lf, kf) + egraph.merge(cf, df) + egraph.merge(af, ef) + egraph.merge(af, if_) + egraph.merge(gf, ef) + egraph.merge(if_, jf) + assert(egraph.idEq(af, of)) + + } + + test("15 formulas no congruence egraph test with redundant merges") { + val egraph = new EGraphTerms() + egraph.add(af) + egraph.add(bf) + egraph.add(cf) + egraph.add(df) + egraph.add(ef) + egraph.add(ff) + egraph.add(gf) + egraph.add(hf) + egraph.add(if_) + egraph.add(jf) + egraph.add(kf) + egraph.add(lf) + egraph.add(mf) + egraph.add(nf) + egraph.add(of) + egraph.merge(af, cf) + egraph.merge(ef, ff) + egraph.merge(if_, kf) + egraph.merge(mf, nf) + egraph.merge(af, bf) + egraph.merge(of, mf) + egraph.merge(if_, mf) + egraph.merge(gf, hf) + egraph.merge(lf, kf) + egraph.merge(bf, cf) + egraph.merge(ff, ef) + egraph.merge(of, if_) + egraph.merge(gf, ef) + egraph.merge(if_, jf) + + assert(egraph.idEq(bf, cf)) + assert(egraph.idEq(ff, hf)) + assert(egraph.idEq(if_, of)) + assert(!egraph.idEq(af, df)) + assert(!egraph.idEq(bf, gf)) + assert(!egraph.idEq(ff, if_)) + assert(!egraph.idEq(nf, cf)) + assert(egraph.formulaUF.getClasses.size == 4) + + } + + test("4 formulas withcongruence egraph test") { + val egraph = new EGraphTerms() + egraph.add(Ff(af)) + egraph.add(Ff(bf)) + egraph.merge(af, bf) + assert(egraph.idEq(af, bf)) + assert(egraph.idEq(Ff(af), Ff(bf))) + assert(!egraph.idEq(af, Ff(af))) + assert(!egraph.idEq(af, Ff(bf))) + assert(!egraph.idEq(bf, Ff(af))) + assert(!egraph.idEq(bf, Ff(bf))) + + assert(egraph.explain(Ff(af), Ff(bf)) == Some(List(egraph.FormulaCongruence((Ff(af), Ff(bf))))) ) + + } + + test("divide-mult-shift in formulas by 2 egraph test") { + + val egraph = new EGraphTerms() + egraph.add(onef) + egraph.add(twof) + egraph.add(af) + val ax2 = egraph.add(`*f`(af, twof)) + val ax2_d2 = egraph.add(`/f`(`*f`(af, twof), twof)) + val `2d2` = egraph.add(`/f`(twof, twof)) + val ax_2d2 = egraph.add(`*f`(af, `/f`(twof, twof))) + val ax1 = egraph.add(`*f`(af, onef)) + val as1 = egraph.add(`<<f`(af, onef)) + + egraph.merge(ax2, as1) + egraph.merge(ax2_d2, ax_2d2) + egraph.merge(`2d2`, onef) + egraph.merge(ax1, af) + + + assert(egraph.idEq(onef, `2d2`)) + assert(egraph.idEq(ax2, as1)) + assert(egraph.idEq(ax2_d2, ax_2d2)) + assert(egraph.idEq(ax_2d2, ax1)) + assert(egraph.idEq(ax_2d2, af)) + + assert(!egraph.idEq(ax2, ax2_d2)) + assert(!egraph.idEq(ax2, `2d2`)) + assert(!egraph.idEq(ax2, ax_2d2)) + assert(!egraph.idEq(ax2, ax1)) + assert(!egraph.idEq(ax2, af)) + assert(!egraph.idEq(ax2_d2, `2d2`)) + + assert(egraph.explain(onef, `2d2`) == Some(List(egraph.FormulaExternal((onef, `2d2`)))) ) + assert(egraph.explain(ax2, as1) == Some(List(egraph.FormulaExternal((ax2, as1)))) ) + assert(egraph.explain(ax2_d2, ax_2d2) == Some(List(egraph.FormulaExternal((ax2_d2, ax_2d2)))) ) + + assert(egraph.explain(ax_2d2, ax1) == Some(List(egraph.FormulaCongruence((ax_2d2, ax1)))) ) + assert(egraph.explain(ax_2d2, af) == Some(List(egraph.FormulaCongruence((ax_2d2, ax1)), egraph.FormulaExternal((ax1, af))) )) + + + + + } + + test("long chain of formulas congruence eGraph") { + val egraph = new EGraphTerms() + egraph.add(xf) + val fx = egraph.add(Ff(xf)) + val ffx = egraph.add(Ff(fx)) + val fffx = egraph.add(Ff(ffx)) + val ffffx = egraph.add(Ff(fffx)) + val fffffx = egraph.add(Ff(ffffx)) + val ffffffx = egraph.add(Ff(fffffx)) + val fffffffx = egraph.add(Ff(ffffffx)) + val ffffffffx = egraph.add(Ff(fffffffx)) + + + egraph.merge(ffffffffx, xf) + egraph.merge(fffffx, xf) + assert(egraph.idEq(fffx, xf)) + assert(egraph.idEq(ffx, xf)) + assert(egraph.idEq(fx, xf)) + assert(egraph.idEq(xf, fx)) + assert(egraph.explain(fx, xf) == Some(List(egraph.FormulaCongruence(fx, ffffffffx), egraph.FormulaExternal(ffffffffx, xf)))) + + } + + ////////////////////////////////////// + //// With both terms and formulas //// + ////////////////////////////////////// + + test("2 terms 6 predicates with congruence egraph test") { + val egraph = new EGraphTerms() + egraph.add(Ff(Ff(Fp(a)))) + egraph.add(Ff(Ff(Fp(b)))) + egraph.merge(a, b) + assert(egraph.idEq(a, b)) + assert(egraph.idEq(Fp(a), Fp(b))) + assert(egraph.idEq(Ff(Fp(a)), Ff(Fp(b)))) + assert(egraph.idEq(Ff(Ff(Fp(a))), Ff(Ff(Fp(b))))) + assert(!egraph.idEq(Fp(a), Ff(Fp(a)))) + assert(!egraph.idEq(Fp(a), Ff(Fp(b)))) + assert(!egraph.idEq(Fp(b), Ff(Fp(a)))) + assert(!egraph.idEq(Fp(b), Ff(Ff(Fp(b))))) + assert(!egraph.idEq(Ff(Fp(a)), Ff(Ff(Fp(b))))) + assert(egraph.formulaUF.getClasses.size == 3) + + egraph.merge(Fp(a), Ff(Fp(a))) + assert(egraph.idEq(Fp(a), Ff(Fp(b)))) + assert(egraph.idEq(Fp(b), Ff(Fp(a)))) + assert(egraph.idEq(Ff(Fp(a)), Ff(Ff(Fp(a))))) + assert(egraph.idEq(Fp(b), Ff(Ff(Fp(a))))) + assert(egraph.formulaUF.getClasses.size == 1) + + } + + test("6 terms 6 predicates with congruence egraph test") { + val egraph = new EGraphTerms() + egraph.add(Ff(Ff(Fp(F(F(a)))))) + egraph.add(Ff(Ff(Fp(F(F(b)))))) + egraph.merge(a, b) + assert(egraph.idEq(a, b)) + assert(egraph.idEq(F(a), F(b))) + assert(egraph.idEq(Fp(F(F(a))), Fp(F(F(b))))) + assert(egraph.idEq(Ff(Ff(Fp(F(F(a))))), Ff(Ff(Fp(F(F(b))))))) + assert(egraph.formulaUF.getClasses.size == 3) + assert(egraph.termUF.getClasses.size == 3) + + egraph.merge(Fp(F(F(b))), Ff(Fp(F(F(a))))) + assert(egraph.formulaUF.getClasses.size == 1) + + } + + + test("15 terms no congruence with redundant merges test with proofs") { + val egraph = new EGraphTerms() + egraph.add(a) + egraph.add(b) + egraph.add(c) + egraph.add(d) + egraph.add(e) + egraph.add(f) + egraph.add(g) + egraph.add(h) + egraph.add(i) + egraph.add(j) + egraph.add(k) + egraph.add(l) + egraph.add(m) + egraph.add(n) + egraph.add(o) + egraph.merge(a, c) + egraph.merge(e, f) + egraph.merge(i, k) + egraph.merge(m, n) + egraph.merge(a, b) + egraph.merge(o, m) + egraph.merge(i, m) + egraph.merge(g, h) + egraph.merge(l, k) + egraph.merge(b, c) + egraph.merge(f, e) + egraph.merge(o, i) + egraph.merge(g, e) + egraph.merge(i, j) + val base = List(a === c, e === f, i === k, m === n, a === b, o === m, i === m, g === h, l === k, b === c, f === e, o === i, g === e, i === j) + + val test1 = Theorem(base |- (b === c)) { + egraph.proveInnerTerm(b, c, base |- ()) + } + + val test2 = Theorem(base |- (f === h)) { + egraph.proveInnerTerm(f, h, base |- ()) + } + + val test3 = Theorem(base |- (i === o)) { + egraph.proveInnerTerm(i, o, base |- ()) + } + + val test4 = Theorem(base |- (o === i)) { + egraph.proveInnerTerm(o, i, base |- ()) + } + + } + + + test("4 elements with congruence test with proofs") { + val egraph = new EGraphTerms() + egraph.add(F(a)) + egraph.add(F(b)) + egraph.merge(a, b) + val test5 = Theorem(a===b |- F(a) === F(b)) { + egraph.proveInnerTerm(F(a), F(b), (a === b) |- ()) + } + } + + + test("divide-mult-shift by 2 in terms egraph test with proofs") { + val egraph = new EGraphTerms() + egraph.add(one) + egraph.add(two) + egraph.add(a) + val ax2 = egraph.add(`*`(a, two)) + val ax2_d2 = egraph.add(`/`(`*`(a, two), two)) + val `2d2` = egraph.add(`/`(two, two)) + val ax_2d2 = egraph.add(`*`(a, `/`(two, two))) + val ax1 = egraph.add(`*`(a, one)) + val as1 = egraph.add(`<<`(a, one)) + + egraph.merge(ax2, as1) + egraph.merge(ax2_d2, ax_2d2) + egraph.merge(`2d2`, one) + egraph.merge(ax1, a) + + val base = List[Formula](ax2 === as1, ax2_d2 === ax_2d2, `2d2` === one, ax1 === a) + + val one_2d2 = Theorem(base |- (one === `2d2`)) { + egraph.proveInnerTerm(one, `2d2`, base |- ()) + } + + val ax2_as1 = Theorem(base |- (ax2 === as1)) { + egraph.proveInnerTerm(ax2, as1, base |- ()) + } + + val ax2_d2_ax_2d2 = Theorem(base |- (ax2_d2 === ax_2d2)) { + egraph.proveInnerTerm(ax2_d2, ax_2d2, base |- ()) + } + + val ax_2d2_ax1 = Theorem(base |- (ax_2d2 === ax1)) { + egraph.proveInnerTerm(ax_2d2, ax1, base |- ()) + } + + val ax_2d2_a = Theorem(base |- (ax_2d2 === a)) { + egraph.proveInnerTerm(ax_2d2, a, base |- ()) + } + + } + + test("long chain of termscongruence eGraph with proofs") { + val egraph = new EGraphTerms() + egraph.add(x) + val fx = egraph.add(F(x)) + val ffx = egraph.add(F(fx)) + val fffx = egraph.add(F(ffx)) + val ffffx = egraph.add(F(fffx)) + val fffffx = egraph.add(F(ffffx)) + val ffffffx = egraph.add(F(fffffx)) + val fffffffx = egraph.add(F(ffffffx)) + val ffffffffx = egraph.add(F(fffffffx)) + + egraph.merge(ffffffffx, x) + egraph.merge(fffffx, x) + + + val base = List(ffffffffx === x, fffffx === x) + + + val test2 = Theorem(base |- fffx === x) { + egraph.proveInnerTerm(fffx, x, base |- ()) + } + val test3 = Theorem(base |- ffx === x) { + egraph.proveInnerTerm(ffx, x, base |- ()) + } + val test4 = Theorem(base |- fx === x) { + egraph.proveInnerTerm(fx, x, base |- ()) + } + + } + + + test("15 formulas no congruence proofs with redundant merges test with proofs") { + val egraph = new EGraphTerms() + egraph.add(af) + egraph.add(bf) + egraph.add(cf) + egraph.add(df) + egraph.add(ef) + egraph.add(ff) + egraph.add(gf) + egraph.add(hf) + egraph.add(if_) + egraph.add(jf) + egraph.add(kf) + egraph.add(lf) + egraph.add(mf) + egraph.add(nf) + egraph.add(of) + egraph.merge(af, cf) + egraph.merge(ef, ff) + egraph.merge(if_, kf) + egraph.merge(mf, nf) + egraph.merge(af, bf) + egraph.merge(of, mf) + egraph.merge(if_, mf) + egraph.merge(gf, hf) + egraph.merge(lf, kf) + egraph.merge(bf, cf) + egraph.merge(ff, ef) + egraph.merge(of, if_) + egraph.merge(gf, ef) + egraph.merge(if_, jf) + + val base = List(af <=> cf, ef <=> ff, if_ <=> kf, mf <=> nf, af <=> bf, + of <=> mf, if_ <=> mf, gf <=> hf, lf <=> kf, bf <=> cf, ff <=> ef, of <=> if_, gf <=> ef, if_ <=> jf) + + val test1 = Theorem(base |- bf <=> cf) { + egraph.proveInnerFormula(bf, cf, base |- ()) + } + + val test2 = Theorem(base |- ff <=> hf) { + egraph.proveInnerFormula(ff, hf, base |- ()) + } + + val test3 = Theorem(base |- if_ <=> of) { + egraph.proveInnerFormula(if_, of, base |- ()) + } + + val test4 = Theorem(base |- of <=> if_) { + egraph.proveInnerFormula(of, if_, base |- ()) + } + + } + + test("4 formulas with congruence test with proofs") { + val egraph = new EGraphTerms() + egraph.add(Ff(af)) + egraph.add(Ff(bf)) + egraph.merge(af, bf) + val test5 = Theorem(af <=> bf |- Ff(af) <=> Ff(bf)) { + egraph.proveInnerFormula(Ff(af), Ff(bf), (af <=> bf) |- ()) + } + } + + test("divide-mult-shift by 2 in formulas egraph test with proofs") { + val egraph = new EGraphTerms() + egraph.add(onef) + egraph.add(twof) + egraph.add(af) + val ax2 = egraph.add(`*f`(af, twof)) + val ax2_d2 = egraph.add(`/f`(`*f`(af, twof), twof)) + val `2d2` = egraph.add(`/f`(twof, twof)) + val ax_2d2 = egraph.add(`*f`(af, `/f`(twof, twof))) + val ax1 = egraph.add(`*f`(af, onef)) + val as1 = egraph.add(`<<f`(af, onef)) + + egraph.merge(ax2, as1) + egraph.merge(ax2_d2, ax_2d2) + egraph.merge(`2d2`, onef) + egraph.merge(ax1, af) + + val base = List[Formula](ax2 <=> as1, ax2_d2 <=> ax_2d2, `2d2` <=> onef, ax1 <=> af) + + val one_2d2 = Theorem(base |- onef <=> `2d2`) { + egraph.proveInnerFormula(onef, `2d2`, base |- ()) + } + + val ax2_as1 = Theorem(base |- ax2 <=> as1) { + egraph.proveInnerFormula(ax2, as1, base |- ()) + } + + val ax2_d2_ax_2d2 = Theorem(base |- ax2_d2 <=> ax_2d2) { + egraph.proveInnerFormula(ax2_d2, ax_2d2, base |- ()) + } + + val ax_2d2_ax1 = Theorem(base |- ax_2d2 <=> ax1) { + egraph.proveInnerFormula(ax_2d2, ax1, base |- ()) + } + + val ax_2d2_a = Theorem(base |- ax_2d2 <=> af) { + egraph.proveInnerFormula(ax_2d2, af, base |- ()) + } + + } + + test("long chain of formulas congruence eGraph with proofs") { + val egraph = new EGraphTerms() + egraph.add(xf) + val fx = egraph.add(Ff(xf)) + val ffx = egraph.add(Ff(fx)) + val fffx = egraph.add(Ff(ffx)) + val ffffx = egraph.add(Ff(fffx)) + val fffffx = egraph.add(Ff(ffffx)) + val ffffffx = egraph.add(Ff(fffffx)) + val fffffffx = egraph.add(Ff(ffffffx)) + val ffffffffx = egraph.add(Ff(fffffffx)) + + egraph.merge(ffffffffx, xf) + egraph.merge(fffffx, xf) + + val base = List(ffffffffx <=> xf, fffffx <=> xf) + + val test2 = Theorem(base |- fffx <=> xf) { + egraph.proveInnerFormula(fffx, xf, base |- ()) + } + val test3 = Theorem(base |- ffx <=> xf) { + egraph.proveInnerFormula(ffx, xf, base |- ()) + } + val test4 = Theorem(base |- fx <=> xf) { + egraph.proveInnerFormula(fx, xf, base |- ()) + } + } + + + test("2 terms 6 predicates with congruence egraph test with proofs") { + val egraph = new EGraphTerms() + egraph.add(Ff(Ff(Fp(a)))) + egraph.add(Ff(Ff(Fp(b)))) + egraph.merge(a, b) + + val test5 = Theorem((a === b) |- Fp(a) <=> Fp(b)) { + egraph.proveInnerFormula(Fp(a), Fp(b), (a === b) |- ()) + } + + val test6 = Theorem((a === b) |- Ff(Fp(a)) <=> Ff(Fp(b))) { + egraph.proveInnerFormula(Ff(Fp(a)), Ff(Fp(b)), (a === b) |- ()) + } + + val test7 = Theorem((a === b) |- Ff(Ff(Fp(a))) <=> Ff(Ff(Fp(b))) ) { + egraph.proveInnerFormula(Ff(Ff(Fp(a))), Ff(Ff(Fp(b))), (a === b) |- ()) + } + + } + + test("6 terms 6 predicates with congruence egraph test with proofs") { + val egraph = new EGraphTerms() + egraph.add(Ff(Ff(Fp(F(F(a)))))) + egraph.add(Ff(Ff(Fp(F(F(b)))))) + egraph.merge(a, b) + + val test5 = Theorem((a === b) |- (F(a) === F(b))) { + egraph.proveInnerTerm(F(a), F(b), (a === b) |- ()) + } + + val test6 = Theorem((a === b) |- Fp(F(F(a))) <=> Fp(F(F(b))) ) { + egraph.proveInnerFormula(Fp(F(F(a))), Fp(F(F(b))), (a === b) |- ()) + } + + val test7 = Theorem((a === b) |- Ff(Ff(Fp(F(F(a))))) <=> Ff(Ff(Fp(F(F(b))))) ) { + egraph.proveInnerFormula(Ff(Ff(Fp(F(F(a))))), Ff(Ff(Fp(F(F(b))))), (a === b) |- ()) + } + + egraph.merge(Fp(F(F(b))), Ff(Fp(F(F(a))))) + + val test8 = Theorem(((a === b), Fp(F(F(b))) <=> Ff(Fp(F(F(a)))) ) |- Ff(Ff(Fp(F(F(a))))) <=> Ff(Ff(Fp(F(F(b))))) ) { + egraph.proveInnerFormula(Ff(Ff(Fp(F(F(a))))), Ff(Ff(Fp(F(F(b))))), (a === b, Fp(F(F(b))) <=> Ff(Fp(F(F(a)))) ) |- ()) + } + + } + + test("Full congruence tactic tests") { + println("Full congruence tactic tests\n") + + val base1 = List(a === c, e === f, i === k, m === n, a === b, o === m, i === m, g === h, l === k, b === c, f === e, o === i, g === e, i === j) + + val test1 = Theorem(base1 |- (b === c)) { + have(thesis) by Congruence + } + + val test2 = Theorem(base1 |- (f === h)) { + have(thesis) by Congruence + } + + val test3 = Theorem(base1 |- (i === o)) { + have(thesis) by Congruence + } + + + val ax2 = `*`(a, two) + val ax2_d2 = `/`(`*`(a, two), two) + val `2d2` = `/`(two, two) + val ax_2d2 = `*`(a, `/`(two, two)) + val ax1 = `*`(a, one) + val as1 = `<<`(a, one) + + val base2 = List[Formula](ax2 === as1, ax2_d2 === ax_2d2, `2d2` === one, ax1 === a) + + + val one_2d2 = Theorem(base2 |- (one === `2d2`)) { + have(thesis) by Congruence + } + + val ax2_as1 = Theorem(base2 |- (ax2 === as1)) { + have(thesis) by Congruence + } + + val ax2_d2_ax_2d2 = Theorem(base2 |- (ax2_d2 === ax_2d2)) { + have(thesis) by Congruence + } + + val ax_2d2_ax1 = Theorem(base2 |- (ax_2d2 === ax1)) { + have(thesis) by Congruence + } + + val ax_2d2_a = Theorem(base2 |- (ax_2d2 === a)) { + have(thesis) by Congruence + } + + val ax_2d2_a_2 = Theorem(base2 |- (Fp(ax_2d2) <=> Fp(a))) { + have(thesis) by Congruence + } + + val ax_2d2_a_1 = Theorem((Fp(a) :: base2) |- Fp(ax_2d2)) { + have(thesis) by Congruence + } + + val ax_2d2_a_3 = Theorem((base2 :+ Fp(ax_2d2) :+ !Fp(a)) |- () ) { + have(thesis) by Congruence + } + + val test5 = Theorem(a===b |- F(a) === F(b)) { + have(thesis) by Congruence + } + + val test6 = Theorem(a === b |- F(a) === F(b)) { + have(thesis) by Congruence + } + + val test7 = Theorem((Ff(Ff(Ff(Ff(Ff(Ff(Ff(xf))))))) <=> xf, Ff(Ff(Ff(Ff(Ff(xf))))) <=> xf) |- Ff(Ff(Ff(xf))) <=> xf) { + have(thesis) by Congruence + } + + val test8 = Theorem((Ff(Ff(Ff(Ff(Ff(Ff(Ff(xf))))))) <=> xf, Ff(Ff(Ff(Ff(Ff(xf))))) <=> xf) |- Ff(xf) <=> xf) { + have(thesis) by Congruence + } + + val test9 = Theorem((a === b) |- (Fp(F(F(a))), !Fp(F(F(b)))) ) { + have(thesis) by Congruence + } + + val test10 = Theorem((a === b) |- Fp(F(F(a))) <=> Fp(F(F(b))) ) { + have(thesis) by Congruence + } + + + val test11 = Theorem((a === b) |- Ff(Ff(Fp(F(F(a))))) <=> Ff(Ff(Fp(F(F(b))))) ) { + have(thesis) by Congruence + } + + val test12 = Theorem(((a === b), Fp(F(F(b))) <=> Ff(Fp(F(F(a)))), Ff(Ff(Fp(F(F(a))))) ) |- Ff(Ff(Fp(F(F(b))))) ) { + have(thesis) by Congruence + } + + + } + + +} \ No newline at end of file diff --git a/lisa-sets/src/test/scala/lisa/utilities/TestMain.scala b/lisa-sets/src/test/scala/lisa/utilities/TestMain.scala new file mode 100644 index 0000000000000000000000000000000000000000..6d93e5d056ae495e3b7c257f4c45d96ecfd0e6e6 --- /dev/null +++ b/lisa-sets/src/test/scala/lisa/utilities/TestMain.scala @@ -0,0 +1,16 @@ +package lisa + +import lisa.prooflib.* + +trait TestMain extends lisa.Main { + + override val om: OutputManager = new OutputManager { + def finishOutput(exception: Exception): Nothing = { + log(exception) + main(Array[String]()) + throw exception + } + val stringWriter: java.io.StringWriter = new java.io.StringWriter() + } + +} diff --git a/lisa-utils/src/main/scala/lisa/fol/FOLHelpers.scala b/lisa-utils/src/main/scala/lisa/fol/FOLHelpers.scala index 402ad81ba096ca8f90d4f2a850645534c4a567b5..d9232d3bafc552283a712d6d10db9a1d87069e8c 100644 --- a/lisa-utils/src/main/scala/lisa/fol/FOLHelpers.scala +++ b/lisa-utils/src/main/scala/lisa/fol/FOLHelpers.scala @@ -50,6 +50,9 @@ object FOLHelpers { def predicate[N <: Arity: ValueOf](using name: sourcecode.Name): SchematicPredicateLabel[N] = SchematicPredicateLabel[N](name.value, valueOf[N]) def connector[N <: Arity: ValueOf](using name: sourcecode.Name): SchematicConnectorLabel[N] = SchematicConnectorLabel[N](name.value, valueOf[N]) + def freshVariable(using name: sourcecode.Name)(elems: LisaObject[?]*): Variable = Variable(freshId(elems.flatMap(_.freeVariables).map(_.id), name.value)) + def freshVariableFormula(using name: sourcecode.Name)(elems: LisaObject[?]*): VariableFormula = VariableFormula(freshId(elems.flatMap(_.freeVariables).map(_.id), name.value)) + //////////////////////////////////////// // Kernel to Front transformers // //////////////////////////////////////// diff --git a/lisa-utils/src/main/scala/lisa/prooflib/BasicMain.scala b/lisa-utils/src/main/scala/lisa/prooflib/BasicMain.scala index 34f8878fd37e4cbb4b616acd77bf819abcc4bac1..748f58d5151de5b0fd66d226342e6f9c421bb018 100644 --- a/lisa-utils/src/main/scala/lisa/prooflib/BasicMain.scala +++ b/lisa-utils/src/main/scala/lisa/prooflib/BasicMain.scala @@ -7,7 +7,7 @@ trait BasicMain { private val realOutput: String => Unit = println - given om: OutputManager = new OutputManager { + val om: OutputManager = new OutputManager { def finishOutput(exception: Exception): Nothing = { log(exception) main(Array[String]()) @@ -24,4 +24,6 @@ trait BasicMain { realOutput(om.stringWriter.toString) } + given om.type = om + } diff --git a/lisa-utils/src/main/scala/lisa/prooflib/ProofsHelpers.scala b/lisa-utils/src/main/scala/lisa/prooflib/ProofsHelpers.scala index adda8986c435984a2a7e9d8b70411f827e8abe9b..946a16f2e1177992731d81c08adf5197c240ac1d 100644 --- a/lisa-utils/src/main/scala/lisa/prooflib/ProofsHelpers.scala +++ b/lisa-utils/src/main/scala/lisa/prooflib/ProofsHelpers.scala @@ -21,8 +21,8 @@ trait ProofsHelpers { given Library = library - class HaveSequent /*private[ProofsHelpers]*/ (val bot: Sequent) { - // val x: lisa.fol.FOL.Sequent = bot + class HaveSequent(val bot: Sequent) { + inline infix def by(using proof: library.Proof, line: sourcecode.Line, file: sourcecode.File): By { val _proof: proof.type } = By(proof, line, file).asInstanceOf class By(val _proof: library.Proof, line: sourcecode.Line, file: sourcecode.File) { diff --git a/refman/lisa.pdf b/refman/lisa.pdf index 80036345874145be30b7dc8a7edc9bd947f413f0..d22d7da5a928e5881571fad1be52a1ad76fd4a2c 100644 Binary files a/refman/lisa.pdf and b/refman/lisa.pdf differ diff --git a/refman/tactics.tex b/refman/tactics.tex index d3c92691744ae5590d8e1518f379be04c8247ae7..c7030feb64140adf56603f65b8252fa62427b108 100644 --- a/refman/tactics.tex +++ b/refman/tactics.tex @@ -1,6 +1,59 @@ \chapter{Tactics: Specifications and Use} \label{chapt:tactics} +\subsection*{Congruence} +The \lstinline|Congruence| tactic is used to prove sequents whose validity directly follow from the congruence closure of all equalities and formula equivalences given left of the sequent. +Specifically, it works in the following cases: +\begin{itemize} + \item The right side contains an equality s === t or equivalence a <=> b provable in the congruence closure. + \item The left side contains an negated equality !(s === t) or equivalence !(a <=> b) provable in the congruence closure. + \item There is a formula a on the left and b on the right such that a and b are congruent. + \item There are two formulas a and !b on the left such that a and b are congruent. + \item There are two formulas a and !b on the right such that a and b are congruent. + \item The sequent is Ol-valid without equality reasoning +\end{itemize} +Note that congruence closure modulo OL is an open problem. + +\begin{example} + The following statements are provable by \lstinline|Congruence|: +\newline\begin{lstlisting}[language=lisa, frame=single] +val congruence1 = Theorem ((a === b, b === c) |- f(a) === f(c)) { + have(thesis) by Congruence +} +\end{lstlisting} + +\begin{lstlisting}[language=lisa, frame=single] +val congruence2 = Theorem ( + (F(F(F(F(F(F(F(x))))))) === x, F(F(F(F(F(x))))) === x) + |- (F(x) === x) +) { + have(thesis) by Congruence +} +\end{lstlisting} + +\begin{lstlisting}[language=lisa, frame=single] +val congruence3 = Theorem ( + (a === b, b === c, P(f(c)) <=> Q, P(f(a))) + |- Q +) { + have(thesis) by Congruence +} +\end{lstlisting} + +\end{example} + +The tactic computes the congruence closure of all terms and formulas, with respect to the given equalities and equivalences, using an egraph datastructure \cite{willseyEggFastExtensible2021, nelsonFastDecisionProcedures1980}. The egraph contains two union-find datastructure which maintain equivalence classes of formulas and terms, respectively. The union-finds are equiped with an explain method, which can output a path of equalities between any two points in the same equivalence class, as in \cite{nelsonFastDecisionProcedures1980}. Each such equality can come from the left hand-side of the sequent being proven (we call those \textit{external equalities}), or be consequences of congruence. For an equality labelled by a congruence, the equalities between all children terms can recursively be explained. + +\begin{example} + Consider again the sequent + $$ + a = b, b = c \vdash f(a) = f(c) + $$ + the domain of our egraph is$\lbrace a, b, c, f(a), f(c) \rbrace$. When $a$ and $b$ are merged and then $b$ and $c$ are emrged, the egraph detects that $f(a)$ and $f(c)$ are congruent and should also be merged. The explanation of $f(a) = f(c)$ is then \lstinline|Congruence($f(a)$, $f(c)$)|, and the explanation of $a = c$ is \lstinline|External($a$, $b$), External($b$, $c$)|. +\end{example} + +Once the congruence closure is computed, the tactic checks if the sequent is satisfies any of the above conditions and returns a proof if it does (and otherwise fails). + \subsection*{Goeland} Goeland\cite{DBLP:conf/cade/CaillerRDRB22} is an Automated Theorem prover for first order logic. The Goeland tactic exports a statement in SC-TPTP format, and call Goeland to prove it. Goeland produce a proof file in the SC-TPTP format, from which Lisa rebuilds a kernel proof. \paragraph*{Usage}.