diff --git a/src/orderedsets/Main.scala b/src/orderedsets/Main.scala index 27e684f3d2a49553e32982c1fab0cc276781943e..4a01636218aba3e0a13eda6d00e2f7d8101c574b 100644 --- a/src/orderedsets/Main.scala +++ b/src/orderedsets/Main.scala @@ -5,7 +5,6 @@ import purescala.Extensions.Solver import Reconstruction.Model class Main(reporter: Reporter) extends Solver(reporter) { - import ExprToASTConverter.ConversionException import purescala.Trees.Expr import AST.Formula val description = "BAPA with ordering" @@ -90,14 +89,16 @@ case class SatException(model: Model) extends Exception("A model was found") // Thrown when a contradiction was derived during guessing case class UnsatException(msg: String) extends Exception(msg) +case class ConversionException(expr: purescala.Trees.Expr, msg: String) extends RuntimeException(msg) + + // Convert PureScala expressions to OrdBAPA AST's object ExprToASTConverter { import purescala.TypeTrees._ import purescala.Trees._ import Primitives._ - case class ConversionException(expr: Expr, msg: String) extends RuntimeException(msg) - + private def isSetType(_type: TypeTree) = _type match { case SetType(_) => true case _ => false diff --git a/src/orderedsets/Unifier2.scala b/src/orderedsets/Unifier2.scala index c98d4a6707573ff65a981bae7c27824ddbb227ef..c2d2722ad6330b887a803ce3f044294d6157adfd 100644 --- a/src/orderedsets/Unifier2.scala +++ b/src/orderedsets/Unifier2.scala @@ -64,33 +64,144 @@ object Example extends Unifier2[String, String] { case RawVar(name) => Var(name) case RawFun(name, args) => Fun(name, args map raw2term) } + + def pv(str: String) = str + def pf(str: String) = str } +import scala.collection.mutable.ArrayBuffer +import purescala.Common._ +import purescala.Trees._ +import purescala.TypeTrees._ +import purescala.Definitions.CaseClassDef + +object PureScalaUnifier extends Unifier2[Variable,CaseClassDef] { + + def freshVar(typed: Typed) = Var(Variable(FreshIdentifier("UnifVar", true)) /*setType typed.getType*/) + + def pv(v: Variable) = v.id.toString + def pf(cc: CaseClassDef) = cc.id.toString + + def unify(and: And) { + val equalities = new ArrayBuffer[(Term,Term)]() + val inequalities = new ArrayBuffer[(Var,Var)]() + + def extractConstraint(expr: Expr) { expr match { + case Equals(t1, t2) => + equalities += ((convert(t1), convert(t2))) + case Not(Equals(t1, t2)) => + val x1 = freshVar(t1) + val x2 = freshVar(t2) + equalities += ((x1, convert(t1))) + equalities += ((x2, convert(t2))) + inequalities += ((x1, x2)) + case _ => + }} + def convert(expr: Expr): Term = expr match { + case v@Variable(id) => Var(v) + case CaseClass(ccdef, args) => Fun(ccdef, args map convert) + case CaseClassSelector(ex, sel) => + val CaseClassType(ccdef) = ex.getType + val args = ccdef.fields map freshVar + equalities += convert(ex) -> Fun(ccdef, args) + args(ccdef.fields findIndexOf {_.id == sel}) + case _ => throw ConversionException(expr, "Cannot convert : ") + } + // extract constraints + and.exprs foreach extractConstraint + + println + println("--- Input to the unifier ---") + for ((l,r) <- equalities) println(" " + pp(l) + " = " + pp(r)) + if (!inequalities.isEmpty) { + println("and") + for ((l,r) <- inequalities) println(" " + pp(l) + " != " + pp(r)) + } + println + + val mgu = unify(equalities.toList) + val subst = blowUp(mgu) + + def byName(entry1: (Variable,Term), entry2: (Variable,Term)) = + pv(entry1._1) < pv(entry2._1) + + //println + println("--- Output of the unifier (MGU) ---") + for ((x, t) <- mgu.toList sortWith byName) + println(" " + x + " = " + pp(t)) + println + + // check inequalities + for ((Var(x1), Var(x2)) <- inequalities) { + val t1 = subst(x1) + val t2 = subst(x2) + if (t1 == t2) + throw UnificationFailure("Inequality '" + x1.id + " != " + x2.id + "' does not hold") + } + if (!inequalities.isEmpty) + println("Inequalities were checked to hold\n") + + println("--- Output of the unifier (Substitution table) ---") + val subst1 = subst.filterKeys{_.getType != NoType} + for ((x, t) <- subst1.toList sortWith byName) + println(" " + x + " = " + pp(t)) + if (subst1.isEmpty) println(" (empty table)") + println + } + + +} + + import scala.collection.mutable.{ArrayBuffer => Seq, Map, Set, Stack} trait Unifier2[VarName >: Null, FunName >: Null] { + + type MGU = Seq[(VarName, Term)] + type Subst = Map[VarName, Term] + + // transitive closure for the mapping - the smart way + def blowUp(mgu: MGU): Subst = { + val map = Map.empty[VarName, Term] + def subst(term: Term): Term = term match { + case Var(v) => map get v match { + case Some(t) => t + case None => term + } + case Fun(f, args) => Fun(f, args map subst) + } + for ((v, t) <- mgu.reverse) { + map(v) = subst(t) + } + map + } /* Interface */ // The AST to be unified sealed abstract class Term case class Var(name: VarName) extends Term - case class Fun(name: FunName, args: List[Term]) extends Term + case class Fun(name: FunName, args: scala.collection.Seq[Term]) extends Term case class UnificationFailure(msg: String) extends Exception(msg) + + def pv(s: VarName): String + def pf(f: FunName): String + def _pv(s: VarName): String = if (s == null) "<null>" else pv(s) + def _pf(f: FunName): String = if (f == null) "<null>" else pf(f) def pp(t: Term): String = t match { - case Var(s) => "" + s - case Fun(f, ts) => "" + f + (ts map pp).mkString("(", ", ", ")") + case Var(s) => _pv(s) + case Fun(f, ts) => _pf(f) + (ts map pp).mkString("(", ", ", ")") } - def unify(term1: Term, term2: Term): Map[VarName, Term] = + def unify(term1: Term, term2: Term): MGU = unify(List((term1, term2))) - def unify(terms: List[(Term, Term)]): Map[VarName, Term] = { + def unify(terms: List[(Term, Term)]): MGU = { val variableMap = Map[VarName, Variable]() def convertTerm(term: Term): Equation = term match { case Var(name) => variableMap get name match { @@ -109,23 +220,23 @@ trait Unifier2[VarName >: Null, FunName >: Null] { dummyVariable.eqclass.eqn.fun = Some(new Function(null, Seq(frontier: _*))) val allVariables = Seq(dummyVariable) ++ variableMap.values - unify(allVariables map {_.eqclass}) - null + unify(allVariables map {_.eqclass}) filter {_._1 != null} } - /* Implementation */ + /* Data structures */ private case class Variable(name: VarName) { // The equivalence class for that variable var eqclass: Equivalence = new Equivalence(this) - override def toString = "" + name + override def toString = _pv(name) } private case class Function(val name: FunName, val eqns: Seq[Equation]) { - override def toString = name + eqns.mkString("(", ",", ")") + override def toString = _pf(name) + eqns.mkString("(", ",", ")") } - private class Equation(val vars: Set[Variable] = Set(), + private class Equation(val vars: Seq[Variable] = Seq(), var fun: Option[Function] = None) { - def this(v: Variable) = this (vars = Set(v)) + def this(v: Variable) = this (vars = Seq(v)) def this(f: Function) = this (fun = Some(f)) @@ -145,11 +256,18 @@ trait Unifier2[VarName >: Null, FunName >: Null] { override def toString = "[" + refCounter + "] " + eqn } - private def unify(equivalences: Seq[Equivalence]): Map[VarName, Term] = { + /* Implementation */ + + private def unify(equivalences: Seq[Equivalence]): MGU = { var numberOfClasses = equivalences.size - val substitutions = Map[VarName, Term]() + val substitutions = Seq[(VarName, Term)]() val freeClasses = Stack[Equivalence]() // Equations with a zero ref counter + /* + val vars = equivalences map {_.eqn.vars.head} + val fvars = Seq[Variable]() + */ + // Initialize reference counters def countRefs(fun: Function) { for (eqn <- fun.eqns) { @@ -170,30 +288,55 @@ trait Unifier2[VarName >: Null, FunName >: Null] { // Main loop while (numberOfClasses > 0) { + /* + println() + println("U:") + println(" vars : " + vars.mkString(", ")) + val classes = (vars map {_.eqclass}).toSet + println(classes.size + " / " + numberOfClasses) + for (cl <- classes) println(cl) + println("T: ") + println(" vars : " + fvars.mkString(", ")) + for ((v,t) <- substitutions) println(" " + v + " -> " + pp(t)) + */ + // Select multi equation if (freeClasses.isEmpty) throw UnificationFailure("cycle") val currentClass = freeClasses.pop val currentVars = currentClass.eqn.vars + val representative = Var(currentVars.head.name) + for (v <- currentVars.tail) + substitutions += (v.name -> representative) + currentClass.eqn.fun match { case Some(function) => val (commonPart, frontier) = reduce(function) - - for (v <- currentVars) - substitutions(v.name) = commonPart + substitutions += (representative.name -> commonPart) // Compact equations (i.e. merge equivalence classes) for (eqn <- frontier) { + /* + println(eqn) + */ val eqclass = (eqn.vars map {_.eqclass}) reduceLeft compact eqclass.refCounter -= eqn.vars.size + eqn.vars.clear merge(eqclass.eqn, eqn) if (eqclass.refCounter == 0) freeClasses push eqclass + + /* + println(" " + eqclass) + */ } case None => - val representative = Var(currentVars.head.name) - for (v <- currentVars.tail) substitutions(v.name) = representative } numberOfClasses -= 1 + + /* + vars --= currentVars + fvars ++= currentVars + */ } substitutions } diff --git a/src/orderedsets/UnifierMain.scala b/src/orderedsets/UnifierMain.scala new file mode 100644 index 0000000000000000000000000000000000000000..bc872cca4923a02c325824e1c39f694970b06760 --- /dev/null +++ b/src/orderedsets/UnifierMain.scala @@ -0,0 +1,51 @@ +package orderedsets + +import purescala.Reporter +import purescala.Extensions.Solver +import Reconstruction.Model + +class UnifierMain(reporter: Reporter) extends Solver(reporter) { + import purescala.Trees.{Expr, And, Not, Equals, negate, expandLets} + import DNF._ + + val description = "Unifier testbench" + override val shortDescription = "Unifier" + + // checks for V-A-L-I-D-I-T-Y ! + // Some(true) means formula is valid (negation is unsat) + // Some(false) means formula is not valid (negation is sat) + // None means you don't know. + // + def solve(exprWithLets: Expr): Option[Boolean] = { + val expr = expandLets(exprWithLets) + try { + reporter.info("") + expr match { + case and @ And(_) => + PureScalaUnifier.unify(and) + Some(true) + case Equals(_, _) | Not(Equals(_, _)) => + PureScalaUnifier.unify(And(List(expr))) + Some(true) + //None + case _ => throw ConversionException(expr, "Neither a conjunction nor a (in)equality") + } + + } catch { + case PureScalaUnifier.UnificationFailure(msg) => + reporter.info("Unification impossible : " + msg) + Some(false) + case ConversionException(badExpr, msg) => + reporter.info(msg + " : " + badExpr.getClass.toString) +// reporter.info(DNF.pp(badExpr)) + None + case e => + reporter.error("Unifier just crashed.\n exception = " + e.toString) + e.printStackTrace + None + } finally { + + } + } + +} diff --git a/testcases/UnificationTest.scala b/testcases/UnificationTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..f308087fe5377128cb4b28099268b9e1c7ed568b --- /dev/null +++ b/testcases/UnificationTest.scala @@ -0,0 +1,73 @@ +package testcases + + +object UnificationTest { + + /* + sealed abstract class Value + case class X extends Value + case class Y extends Value + case class Z extends Value + */ + + sealed abstract class Tree + case class Leaf() extends Tree + case class Node(left: Tree, value: Int, right: Tree) extends Tree + + def mkTree(a: Int, b: Int, c: Int) = { + Node(Node(Leaf(), a, Leaf()), b, Node(Leaf(), c, Leaf())) + //Node(Leaf(), b, Node(Leaf(), c, Leaf())) + } ensuring ( res => { + res.left != Leaf() && + res.value == b && + res.right == Node(Leaf(), c, Leaf()) + }) + + + + sealed abstract class Term + case class F(t1: Term, t2: Term, t3: Term, t4: Term) extends Term + case class G(s1: Term, s2: Term) extends Term + case class H(r1: Term, r2: Term) extends Term + case class A extends Term + case class B extends Term + + def examplePage268(x1: Term, x2: Term, x3: Term, x4: Term, x5: Term) = { + F(G(H(A(), x5), x2), x1, H(A(), x4), x4) + } ensuring ( _ == F(x1, G(x2, x3), x2, B()) ) + + + + case class Tuple3(_1: Term, _2: Term, _3: Term) + + def examplePage269(x1: Term, x2: Term, x3: Term, x4: Term) = { + Tuple3(H(x1, x1), H(x2, x2), H(x3, x3)) + } ensuring ( res => { + x2 == res._1 && + x3 == res._2 && + x4 == res._3 + }) + + + + // Not working yet + + /* + def mkInfiniteTree(x: Int): Node = { + Node(mkInfiniteTree(x), x, mkInfiniteTree(x)) + } ensuring (res => + res.left != Leaf() && res.right != Leaf() + ) + + def insert(tree: Tree, value: Int) : Node = (tree match { + case Leaf() => Node(Leaf(), value, Leaf()) + case n @ Node(l, v, r) => if(v < value) { + Node(l, v, insert(r, value)) + } else if(v > value) { + Node(insert(l, value), v, r) + } else { + n + } + }) ensuring(_ != Leaf()) + */ +} \ No newline at end of file