From 22798b6d9854faccc36248e46f355df3c7c77c37 Mon Sep 17 00:00:00 2001 From: Robin Steiger <robin.steiger@epfl.ch> Date: Sat, 10 Jul 2010 10:43:46 +0000 Subject: [PATCH] orderedsets.UnifierMain now is a decision procedure. It performs: - DNF transformation - Separating ADT formulas from non-ADT formulas (purifying terms as needed) - Unification on ADT part - returns either VALID or UNKNOWN (sound, but incomplete) --- src/orderedsets/DNF.scala | 30 +--- .../{Unifier2.scala => Unifier.scala} | 95 ++++++----- src/orderedsets/UnifierMain.scala | 147 +++++++++++++++--- src/orderedsets/getAlpha.scala | 108 +------------ testcases/UnificationTest.scala | 21 +-- 5 files changed, 205 insertions(+), 196 deletions(-) rename src/orderedsets/{Unifier2.scala => Unifier.scala} (80%) diff --git a/src/orderedsets/DNF.scala b/src/orderedsets/DNF.scala index 0d47fae1e..28c124c51 100644 --- a/src/orderedsets/DNF.scala +++ b/src/orderedsets/DNF.scala @@ -3,28 +3,8 @@ package orderedsets object DNF { import purescala.Trees._ - - def dnf(expr: Expr): Stream[Expr] = _dnf(expr) map And.apply - - private def _dnf(expr: Expr): Stream[Seq[Expr]] = expr match { - case And(Nil) => Stream(Nil) - case And(c :: Nil) => _dnf(c) - case And(c :: cs) => - for (conj1 <- _dnf(c); conj2 <- _dnf(And(cs))) - yield conj1 ++ conj2 - case Or(Nil) => Stream(Seq(BooleanLiteral(false))) - case Or(d :: Nil) => _dnf(d) - case Or(d :: ds) => _dnf(d) append _dnf(Or(ds)) - // Rewrite Iff and Implies - case Iff(p, q) => - _dnf(Or(And(p, q), And(negate(p), negate(q)))) - case Implies(p, q) => - _dnf(Or(negate(p), q)) - // Convert to nnf - case Not(e@(And(_) | Or(_) | Iff(_, _) | Implies(_, _))) => - _dnf(negate(e)) - case _ => Stream(expr :: Nil) - } + import TreeOperations.dnf + // Printer (both && and || are printed as ? on windows..) def pp(expr: Expr): String = expr match { @@ -57,11 +37,13 @@ object DNF { } def test(before: Expr) { - val after = Or(dnf(before).toSeq) + val after = Or((dnf(before) map And.apply).toSeq) println println("Before dnf : " + pp(before)) //println("After dnf : " + pp(after)) println("After dnf : ") - for (and <- dnf(before)) println(" " + pp(and)) + for (and <- dnf(before)) println(" " + pp(And(and))) } + + } \ No newline at end of file diff --git a/src/orderedsets/Unifier2.scala b/src/orderedsets/Unifier.scala similarity index 80% rename from src/orderedsets/Unifier2.scala rename to src/orderedsets/Unifier.scala index 786892665..6245ffcda 100644 --- a/src/orderedsets/Unifier2.scala +++ b/src/orderedsets/Unifier.scala @@ -2,7 +2,7 @@ package orderedsets import scala.{Symbol => ScalaSymbol} -object Example extends Unifier2[String, String] { +object ExampleUnifier extends Unifier[String, String] { // Tests and Examples val examplePage262 = List( @@ -43,7 +43,7 @@ object Example extends Unifier2[String, String] { for ((v, t) <- unify(terms)) println(" " + v + " -> " + pp(t)) } catch { - case UnificationFailure(msg) => + case UnificationImpossible(msg) => println("Unification failed: " + msg) } } @@ -76,44 +76,47 @@ import purescala.Trees._ import purescala.TypeTrees._ import purescala.Definitions.CaseClassDef -object PureScalaUnifier extends Unifier2[Variable,CaseClassDef] { - - def freshVar(typed: Typed) = Var(Variable(FreshIdentifier("UnifVar", true)) /*setType typed.getType*/) - + + +object ADTUnifier extends Unifier[Variable,CaseClassDef] { + def pv(v: Variable) = v.id.toString def pf(cc: CaseClassDef) = cc.id.toString - def unify(and: Expr) { + def freshVar(prefix: String)(typed: Typed) = Var(Variable(FreshIdentifier(prefix, true) setType typed.getType)) + + def unify(conjunction: Seq[Expr]) { val equalities = new ArrayBuffer[(Term,Term)]() val inequalities = new ArrayBuffer[(Var,Var)]() - def extractConstraint(expr: Expr) { expr match { + def extractEquality(expr: Expr): Unit = expr match { case Equals(t1, t2) => - equalities += ((convert(t1), convert(t2))) + equalities += expr2term(t1) -> expr2term(t2) case Not(Equals(t1, t2)) => - val x1 = freshVar(t1) - val x2 = freshVar(t2) - equalities += ((x1, convert(t1))) - equalities += ((x2, convert(t2))) - inequalities += ((x1, x2)) - case _ => - }} - def convert(expr: Expr): Term = expr match { + inequalities += toPureTerm(t1) -> toPureTerm(t2) + case _ => error("Should not happen after separating the formula.") + } + def toPureTerm(expr: Expr) = expr2term(expr) match { + case v@Var(_) => v + case term => + val v = freshVar("Diseq")(expr) + equalities += v -> term + v + } + def expr2term(expr: Expr): Term = expr match { case v@Variable(id) => Var(v) - case CaseClass(ccdef, args) => Fun(ccdef, args map convert) + case CaseClass(ccdef, args) => Fun(ccdef, args map expr2term) case CaseClassSelector(ex, sel) => val CaseClassType(ccdef) = ex.getType - val args = ccdef.fields map freshVar - equalities += convert(ex) -> Fun(ccdef, args) + val args = ccdef.fields map freshVar("Sel") + equalities += expr2term(ex) -> Fun(ccdef, args) args(ccdef.fields findIndexOf {_.id == sel}) - case _ => throw ConversionException(expr, "Cannot convert : ") - } - // extract constraints - and match { - case And(exprs) => exprs foreach extractConstraint - case _ => extractConstraint(and) + case _ => error("Should not happen after separating the formula.") } + // extract equality constraints + conjunction foreach extractEquality + /* println println("--- Input to the unifier ---") for ((l,r) <- equalities) println(" " + pp(l) + " = " + pp(r)) @@ -122,10 +125,13 @@ object PureScalaUnifier extends Unifier2[Variable,CaseClassDef] { for ((l,r) <- inequalities) println(" " + pp(l) + " != " + pp(r)) } println - + */ + val mgu = unify(equalities.toList) - val subst = blowUp(mgu) + val map = blowUp(mgu) + def subst(v: Variable) = map getOrElse (v, Var(v)) + /* def byName(entry1: (Variable,Term), entry2: (Variable,Term)) = pv(entry1._1) < pv(entry2._1) @@ -134,38 +140,49 @@ object PureScalaUnifier extends Unifier2[Variable,CaseClassDef] { for ((x, t) <- mgu.toList sortWith byName) println(" " + x + " = " + pp(t)) println + */ // check inequalities for ((Var(x1), Var(x2)) <- inequalities) { val t1 = subst(x1) val t2 = subst(x2) if (t1 == t2) - throw UnificationFailure("Inequality '" + x1.id + " != " + x2.id + "' does not hold") + throw UnificationImpossible("Inequality '" + x1.id + " != " + x2.id + "' is violated (both reduce to " + pp(t1) + ")") } + + /* if (!inequalities.isEmpty) println("Inequalities were checked to hold\n") println("--- Output of the unifier (Substitution table) ---") - val subst1 = subst.filterKeys{_.getType != NoType} - for ((x, t) <- subst1.toList sortWith byName) + val map1 = map.filterKeys{_.getType != NoType} + for ((x, t) <- map1.toList sortWith byName) println(" " + x + " = " + pp(t)) - if (subst1.isEmpty) println(" (empty table)") - println + if (map1.isEmpty) println(" (empty table)") + println + */ + () } + def term2expr(term: Term): Expr = term match { + case Var(v) => v + case Fun(cd, args) => CaseClass(cd, args map term2expr) + } } import scala.collection.mutable.{ArrayBuffer => Seq, Map, Set, Stack} -trait Unifier2[VarName >: Null, FunName >: Null] { +case class UnificationImpossible(msg: String) extends Exception(msg) + +trait Unifier[VarName >: Null, FunName >: Null] { type MGU = Seq[(VarName, Term)] type Subst = Map[VarName, Term] - // transitive closure for the mapping - the smart way + // transitive closure for the mapping - the smart way (in only one iteration) def blowUp(mgu: MGU): Subst = { val map = Map.empty[VarName, Term] def subst(term: Term): Term = term match { @@ -188,7 +205,7 @@ trait Unifier2[VarName >: Null, FunName >: Null] { case class Var(name: VarName) extends Term case class Fun(name: FunName, args: scala.collection.Seq[Term]) extends Term - case class UnificationFailure(msg: String) extends Exception(msg) + def pv(s: VarName): String @@ -304,7 +321,7 @@ trait Unifier2[VarName >: Null, FunName >: Null] { */ // Select multi equation - if (freeClasses.isEmpty) throw UnificationFailure("cycle") + if (freeClasses.isEmpty) throw UnificationImpossible("cycle") val currentClass = freeClasses.pop val currentVars = currentClass.eqn.vars @@ -364,8 +381,8 @@ trait Unifier2[VarName >: Null, FunName >: Null] { case None => equation1.fun = equation2.fun case Some(Function(name1, args1)) => equation2.fun match { case Some(Function(name2, args2)) => - if (name1 != name2) throw UnificationFailure("clash") - if (args1.size != args2.size) throw UnificationFailure("arity") + if (name1 != name2) throw UnificationImpossible("clash") + if (args1.size != args2.size) throw UnificationImpossible("arity") val args = for ((eqn1, eqn2) <- args1 zip args2) yield merge(eqn1, eqn2) equation1.fun = Some(Function(name1, args)) case None => diff --git a/src/orderedsets/UnifierMain.scala b/src/orderedsets/UnifierMain.scala index aef220675..fa425637b 100644 --- a/src/orderedsets/UnifierMain.scala +++ b/src/orderedsets/UnifierMain.scala @@ -4,11 +4,13 @@ import purescala.Reporter import purescala.Extensions.Solver import Reconstruction.Model +case class IncompleteException(msg: String) extends Exception(msg) + class UnifierMain(reporter: Reporter) extends Solver(reporter) { - import purescala.Trees.{Expr, And, Not, Equals, negate, expandLets} - import DNF._ + import purescala.Trees._ + import TreeOperations._ - val description = "Unifier testbench" + val description = "Unifier for ADTs with abstractions" override val shortDescription = "Unifier" // checks for V-A-L-I-D-I-T-Y ! @@ -19,33 +21,142 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) { def solve(exprWithLets: Expr): Option[Boolean] = { val expr = expandLets(exprWithLets) try { - reporter.info("") - expr match { - case and @ And(_) => - PureScalaUnifier.unify(and) - Some(true) - case Equals(_, _) | Not(Equals(_, _)) => - PureScalaUnifier.unify(expr) - Some(true) - //None - case _ => throw ConversionException(expr, "Neither a conjunction nor a (in)equality") + var counter = 0 + for (conjunction <- dnf(negate(expr))) { + counter += 1 + reporter.info("Solving conjunction " + counter) + //conjunction foreach println + conjunction foreach checkIsSupported + try { + solve(conjunction) + } catch { + case UnificationImpossible(msg) => + reporter.info("Conjunction " + counter + " is UNSAT, unification impossible : " + msg) + } } - + // All conjunctions were UNSAT + Some(true) } catch { - case PureScalaUnifier.UnificationFailure(msg) => - reporter.info("Unification impossible : " + msg) - Some(false) case ConversionException(badExpr, msg) => reporter.info(msg + " : " + badExpr.getClass.toString) // reporter.info(DNF.pp(badExpr)) + //error("should not happen") + None + case IncompleteException(msg) => + reporter.info("Unifier cannot disprove this because it is incomplete") + if (msg != null) reporter.info(msg) None + case SatException(_) => + Some(false) case e => - reporter.error("Unifier just crashed.\n exception = " + e.toString) + reporter.error("Component 'Unifier' just crashed.\n Exception = " + e.toString) e.printStackTrace None } finally { } } + + def checkIsSupported(expr: Expr) { + def check(ex: Expr): Option[Expr] = ex match { + case IfExpr(_, _, _) | Let(_, _, _) | MatchExpr(_, _) => + throw ConversionException(ex, "Not supported") + case _ => None + } + searchAndReplace(check)(expr) + } + + def solve(conjunction: Seq[Expr]) { + val (treeEquations, rest) = separateADT(conjunction) + + /* + reporter.info("Fc") + treeEquations foreach println + reporter.info("Rest") + rest foreach println + */ + + val subst = ADTUnifier.unify(treeEquations) + + throw IncompleteException(null) + + () + } + + /* Step 1 : Do DNF transformation (done elsewhere) */ + + /* Step 2 : Split conjunction into (FT, Rest) purifying terms if needed. + * FT are equations over ADT trees. + * We allow element variables to appear in FT. + * Later, we will also allow element variables to appear in FT, + * but this has not been implemented yet. + */ + + import scala.collection.mutable.{Stack,ArrayBuffer} + import purescala._ + import Common.FreshIdentifier + import TypeTrees.Typed + + def freshVar(prefix: String, typed: Typed) = Variable(FreshIdentifier(prefix, true) setType typed.getType) + + + def separateADT(conjunction: Seq[Expr]) = { + val workStack = Stack(conjunction.reverse: _*) + val good = new ArrayBuffer[Expr]() // Formulas over ADTs + val bad = new ArrayBuffer[Expr]() // Formulas of unknown logic + // TODO: Allow literals in unifier ? + def isGood(expr: Expr) = expr match { + case Variable(_) | CaseClass(_, _) | CaseClassSelector(_, _) => true + case _ => false + } + def isBad(expr: Expr) = expr match { + case CaseClass(_, _) | CaseClassSelector(_, _) => false + case _ => true + } + def purifyGood(expr: Expr) = if (isGood(expr)) None else { + val fresh = freshVar("col", expr) + workStack push Equals(fresh, expr) // will be bad +// println("PUSH bad : " + isBad(expr) + " " + expr) + Some(fresh) + } + def purifyBad(expr: Expr) = if (isBad(expr)) None else { + val fresh = freshVar("adt", expr) + workStack push Equals(fresh, expr) // will be good +// println("PUSH good : " + isGood(expr) + " " + expr) + Some(fresh) + } + def process(expr: Expr): Unit = expr match { + case Equals(t1, t2) if isGood(t1) && isGood(t2) => +// println("POP good : " + expr) + val g1 = searchAndReplace(purifyGood)(t1) + val g2 = searchAndReplace(purifyGood)(t2) + good += Equals(g1, g2) +// println("ADD good : " + Equals(g1, g2)) + case Not(Equals(t1, t2)) if isGood(t1) && isGood(t2) => +// println("POP good2 : " + expr) + val g1 = searchAndReplace(purifyGood)(t1) + val g2 = searchAndReplace(purifyGood)(t2) + good += Not(Equals(g1, g2)) +// println("ADD good2 : " + Not(Equals(g1, g2))) + case Not(Not(ex)) => + process(ex) + case _ => +// println("POP bad : " + expr) + val t = searchAndReplace(purifyBad)(expr) + bad += t +// println("ADD bad : " + t) + } + while (!workStack.isEmpty) { + val expr = workStack.pop + process(expr) + } + (good.toSeq, bad.toSeq) + } + + /* Step 3 : Perform unifcation on equations over ADTs. + * Obtain a substitution u = T(t) and + * disequalites N(u,t) over ADT variables, and + * get implied (dis)equalities FE over element variables. + */ } diff --git a/src/orderedsets/getAlpha.scala b/src/orderedsets/getAlpha.scala index 91f81970c..993058e4d 100644 --- a/src/orderedsets/getAlpha.scala +++ b/src/orderedsets/getAlpha.scala @@ -1,120 +1,26 @@ package orderedsets +import TreeOperations._ + import purescala._ import Trees._ import Common._ import TypeTrees._ import Definitions._ -import scala.collection.mutable.{Set => MutableSet} object getAlpha { var program: Program = null def setProgram(p: Program) = program = p - def searchAndApply(subst: Expr => Option[Expr], recursive: Boolean = true)(expr: Expr) = { - def rec(ex: Expr, skip: Expr = null): Expr = (if(ex == skip) None else subst(ex)) match { - case Some(newExpr) => { - if(newExpr.getType == NoType) { - Settings.reporter.warning("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) - } - if(ex == newExpr) - if(recursive) rec(ex, ex) else ex - else - if(recursive) rec(newExpr) else newExpr - } - case None => ex match { - case l @ Let(i,e,b) => { - val re = rec(e) - val rb = rec(b) - if(re != e || rb != b) - Let(i, re, rb).setType(l.getType) - else - l - } - case f @ FunctionInvocation(fd, args) => { - var change = false - val rargs = args.map(a => { - val ra = rec(a) - if(ra != a) { - change = true - ra - } else { - a - } - }) - if(change) - FunctionInvocation(fd, rargs).setType(f.getType) - else - f - } - case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType) - case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType) - case And(exs) => And(exs.map(rec(_))) - case Or(exs) => Or(exs.map(rec(_))) - case Not(e) => Not(rec(e)) - case u @ UnaryOperator(t,recons) => { - val r = rec(t) - if(r != t) - recons(r).setType(u.getType) - else - u - } - case b @ BinaryOperator(t1,t2,recons) => { - val r1 = rec(t1) - val r2 = rec(t2) - if(r1 != t1 || r2 != t2) - recons(r1,r2).setType(b.getType) - else - b - } - case c @ CaseClass(cd, args) => { - CaseClass(cd, args.map(rec(_))).setType(c.getType) - } - case c @ CaseClassSelector(cc, sel) => { - val rc = rec(cc) - if(rc != cc) - CaseClassSelector(rc, sel).setType(c.getType) - else - c - } - case _ => ex - } - } - - def inCase(cse: MatchCase) : MatchCase = cse match { - case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs)) - case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard), rec(rhs)) - } - - rec(expr) - } - - def asCatamorphism(f : FunDef) : Option[Seq[(CaseClassDef,Identifier,Seq[Identifier],Expr)]] = { - def recCallsOnMatchedVars(l: (CaseClassDef,Identifier,Seq[Identifier],Expr)) = { - var varSet = MutableSet.empty[Identifier] - searchAndApply({ case FunctionInvocation(_, Seq(Variable(id))) => varSet += id; None; case _ => None})(l._4) - varSet.subsetOf(l._3.toSet) - } - - val c = this.program.callees(f) - if(f.hasImplementation && f.args.size == 1 && c.size == 1 && c.head == f) f.body.get match { - case SimplePatternMatching(scrut, _, lstMatches) - if (scrut == f.args.head.toVariable) && lstMatches.forall(recCallsOnMatchedVars) => Some(lstMatches) - case _ => None - } else { - None - } - } - def isAlpha(varMap: Variable => Expr)(t: Expr): Option[Expr] = t match { - case FunctionInvocation(fd, Seq(v@ Variable(_))) => asCatamorphism(fd) match { + case FunctionInvocation(fd, Seq(v@ Variable(_))) => asCatamorphism(program, fd) match { case None => None case Some(lstMatch) => varMap(v) match { case CaseClass(cd, args) => { val (_, _, ids, rhs) = lstMatch.find( _._1 == cd).get - val repMap = Map( ids.map(id => Variable(id):Expr).zip(args):_* ) - Some(searchAndApply(repMap.get)(rhs)) + val repMap = Map( ids.map(id => Variable(id):Expr).zip(args): _* ) + Some(searchAndReplace(repMap.get)(rhs)) } case u @ Variable(_) => { val c = Variable(FreshIdentifier("Coll", true)).setType(t.getType) @@ -129,11 +35,11 @@ object getAlpha { } def apply(t: Expr, varMap: Variable => Expr): Expr = { - searchAndApply(isAlpha(varMap))(t) + searchAndReplace(isAlpha(varMap))(t) } // def solve(e : Expr): Option[Boolean] = { -// searchAndApply(isAlpha(x => x))(e) +// searchAndReplace(isAlpha(x => x))(e) // None // } } diff --git a/testcases/UnificationTest.scala b/testcases/UnificationTest.scala index f308087fe..60f06c4e4 100644 --- a/testcases/UnificationTest.scala +++ b/testcases/UnificationTest.scala @@ -3,17 +3,11 @@ package testcases object UnificationTest { - /* - sealed abstract class Value - case class X extends Value - case class Y extends Value - case class Z extends Value - */ - sealed abstract class Tree case class Leaf() extends Tree case class Node(left: Tree, value: Int, right: Tree) extends Tree + // Proved by unifier def mkTree(a: Int, b: Int, c: Int) = { Node(Node(Leaf(), a, Leaf()), b, Node(Leaf(), c, Leaf())) //Node(Leaf(), b, Node(Leaf(), c, Leaf())) @@ -34,7 +28,7 @@ object UnificationTest { def examplePage268(x1: Term, x2: Term, x3: Term, x4: Term, x5: Term) = { F(G(H(A(), x5), x2), x1, H(A(), x4), x4) - } ensuring ( _ == F(x1, G(x2, x3), x2, B()) ) + } //ensuring ( _ == F(x1, G(x2, x3), x2, B()) ) @@ -42,23 +36,22 @@ object UnificationTest { def examplePage269(x1: Term, x2: Term, x3: Term, x4: Term) = { Tuple3(H(x1, x1), H(x2, x2), H(x3, x3)) - } ensuring ( res => { + } /*ensuring ( res => { x2 == res._1 && x3 == res._2 && x4 == res._3 - }) - + })*/ - // Not working yet - /* + // Proved by unifier def mkInfiniteTree(x: Int): Node = { Node(mkInfiniteTree(x), x, mkInfiniteTree(x)) } ensuring (res => res.left != Leaf() && res.right != Leaf() ) + // Cannot be solved yet, due to the presence of an if expression def insert(tree: Tree, value: Int) : Node = (tree match { case Leaf() => Node(Leaf(), value, Leaf()) case n @ Node(l, v, r) => if(v < value) { @@ -69,5 +62,5 @@ object UnificationTest { n } }) ensuring(_ != Leaf()) - */ + } \ No newline at end of file -- GitLab