Skip to content
Snippets Groups Projects
Commit 67e34977 authored by Robin Steiger's avatar Robin Steiger
Browse files

Added an extension for unification

parent 4c4cd9fc
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,6 @@ import purescala.Extensions.Solver ...@@ -5,7 +5,6 @@ import purescala.Extensions.Solver
import Reconstruction.Model import Reconstruction.Model
class Main(reporter: Reporter) extends Solver(reporter) { class Main(reporter: Reporter) extends Solver(reporter) {
import ExprToASTConverter.ConversionException
import purescala.Trees.Expr import purescala.Trees.Expr
import AST.Formula import AST.Formula
val description = "BAPA with ordering" val description = "BAPA with ordering"
...@@ -90,14 +89,16 @@ case class SatException(model: Model) extends Exception("A model was found") ...@@ -90,14 +89,16 @@ case class SatException(model: Model) extends Exception("A model was found")
// Thrown when a contradiction was derived during guessing // Thrown when a contradiction was derived during guessing
case class UnsatException(msg: String) extends Exception(msg) case class UnsatException(msg: String) extends Exception(msg)
case class ConversionException(expr: purescala.Trees.Expr, msg: String) extends RuntimeException(msg)
// Convert PureScala expressions to OrdBAPA AST's // Convert PureScala expressions to OrdBAPA AST's
object ExprToASTConverter { object ExprToASTConverter {
import purescala.TypeTrees._ import purescala.TypeTrees._
import purescala.Trees._ import purescala.Trees._
import Primitives._ import Primitives._
case class ConversionException(expr: Expr, msg: String) extends RuntimeException(msg)
private def isSetType(_type: TypeTree) = _type match { private def isSetType(_type: TypeTree) = _type match {
case SetType(_) => true case SetType(_) => true
case _ => false case _ => false
......
...@@ -64,33 +64,144 @@ object Example extends Unifier2[String, String] { ...@@ -64,33 +64,144 @@ object Example extends Unifier2[String, String] {
case RawVar(name) => Var(name) case RawVar(name) => Var(name)
case RawFun(name, args) => Fun(name, args map raw2term) case RawFun(name, args) => Fun(name, args map raw2term)
} }
def pv(str: String) = str
def pf(str: String) = str
} }
import scala.collection.mutable.ArrayBuffer
import purescala.Common._
import purescala.Trees._
import purescala.TypeTrees._
import purescala.Definitions.CaseClassDef
object PureScalaUnifier extends Unifier2[Variable,CaseClassDef] {
def freshVar(typed: Typed) = Var(Variable(FreshIdentifier("UnifVar", true)) /*setType typed.getType*/)
def pv(v: Variable) = v.id.toString
def pf(cc: CaseClassDef) = cc.id.toString
def unify(and: And) {
val equalities = new ArrayBuffer[(Term,Term)]()
val inequalities = new ArrayBuffer[(Var,Var)]()
def extractConstraint(expr: Expr) { expr match {
case Equals(t1, t2) =>
equalities += ((convert(t1), convert(t2)))
case Not(Equals(t1, t2)) =>
val x1 = freshVar(t1)
val x2 = freshVar(t2)
equalities += ((x1, convert(t1)))
equalities += ((x2, convert(t2)))
inequalities += ((x1, x2))
case _ =>
}}
def convert(expr: Expr): Term = expr match {
case v@Variable(id) => Var(v)
case CaseClass(ccdef, args) => Fun(ccdef, args map convert)
case CaseClassSelector(ex, sel) =>
val CaseClassType(ccdef) = ex.getType
val args = ccdef.fields map freshVar
equalities += convert(ex) -> Fun(ccdef, args)
args(ccdef.fields findIndexOf {_.id == sel})
case _ => throw ConversionException(expr, "Cannot convert : ")
}
// extract constraints
and.exprs foreach extractConstraint
println
println("--- Input to the unifier ---")
for ((l,r) <- equalities) println(" " + pp(l) + " = " + pp(r))
if (!inequalities.isEmpty) {
println("and")
for ((l,r) <- inequalities) println(" " + pp(l) + " != " + pp(r))
}
println
val mgu = unify(equalities.toList)
val subst = blowUp(mgu)
def byName(entry1: (Variable,Term), entry2: (Variable,Term)) =
pv(entry1._1) < pv(entry2._1)
//println
println("--- Output of the unifier (MGU) ---")
for ((x, t) <- mgu.toList sortWith byName)
println(" " + x + " = " + pp(t))
println
// check inequalities
for ((Var(x1), Var(x2)) <- inequalities) {
val t1 = subst(x1)
val t2 = subst(x2)
if (t1 == t2)
throw UnificationFailure("Inequality '" + x1.id + " != " + x2.id + "' does not hold")
}
if (!inequalities.isEmpty)
println("Inequalities were checked to hold\n")
println("--- Output of the unifier (Substitution table) ---")
val subst1 = subst.filterKeys{_.getType != NoType}
for ((x, t) <- subst1.toList sortWith byName)
println(" " + x + " = " + pp(t))
if (subst1.isEmpty) println(" (empty table)")
println
}
}
import scala.collection.mutable.{ArrayBuffer => Seq, Map, Set, Stack} import scala.collection.mutable.{ArrayBuffer => Seq, Map, Set, Stack}
trait Unifier2[VarName >: Null, FunName >: Null] { trait Unifier2[VarName >: Null, FunName >: Null] {
type MGU = Seq[(VarName, Term)]
type Subst = Map[VarName, Term]
// transitive closure for the mapping - the smart way
def blowUp(mgu: MGU): Subst = {
val map = Map.empty[VarName, Term]
def subst(term: Term): Term = term match {
case Var(v) => map get v match {
case Some(t) => t
case None => term
}
case Fun(f, args) => Fun(f, args map subst)
}
for ((v, t) <- mgu.reverse) {
map(v) = subst(t)
}
map
}
/* Interface */ /* Interface */
// The AST to be unified // The AST to be unified
sealed abstract class Term sealed abstract class Term
case class Var(name: VarName) extends Term case class Var(name: VarName) extends Term
case class Fun(name: FunName, args: List[Term]) extends Term case class Fun(name: FunName, args: scala.collection.Seq[Term]) extends Term
case class UnificationFailure(msg: String) extends Exception(msg) case class UnificationFailure(msg: String) extends Exception(msg)
def pv(s: VarName): String
def pf(f: FunName): String
def _pv(s: VarName): String = if (s == null) "<null>" else pv(s)
def _pf(f: FunName): String = if (f == null) "<null>" else pf(f)
def pp(t: Term): String = t match { def pp(t: Term): String = t match {
case Var(s) => "" + s case Var(s) => _pv(s)
case Fun(f, ts) => "" + f + (ts map pp).mkString("(", ", ", ")") case Fun(f, ts) => _pf(f) + (ts map pp).mkString("(", ", ", ")")
} }
def unify(term1: Term, term2: Term): Map[VarName, Term] = def unify(term1: Term, term2: Term): MGU =
unify(List((term1, term2))) unify(List((term1, term2)))
def unify(terms: List[(Term, Term)]): Map[VarName, Term] = { def unify(terms: List[(Term, Term)]): MGU = {
val variableMap = Map[VarName, Variable]() val variableMap = Map[VarName, Variable]()
def convertTerm(term: Term): Equation = term match { def convertTerm(term: Term): Equation = term match {
case Var(name) => variableMap get name match { case Var(name) => variableMap get name match {
...@@ -109,23 +220,23 @@ trait Unifier2[VarName >: Null, FunName >: Null] { ...@@ -109,23 +220,23 @@ trait Unifier2[VarName >: Null, FunName >: Null] {
dummyVariable.eqclass.eqn.fun = Some(new Function(null, Seq(frontier: _*))) dummyVariable.eqclass.eqn.fun = Some(new Function(null, Seq(frontier: _*)))
val allVariables = Seq(dummyVariable) ++ variableMap.values val allVariables = Seq(dummyVariable) ++ variableMap.values
unify(allVariables map {_.eqclass}) - null unify(allVariables map {_.eqclass}) filter {_._1 != null}
} }
/* Implementation */ /* Data structures */
private case class Variable(name: VarName) { private case class Variable(name: VarName) {
// The equivalence class for that variable // The equivalence class for that variable
var eqclass: Equivalence = new Equivalence(this) var eqclass: Equivalence = new Equivalence(this)
override def toString = "" + name override def toString = _pv(name)
} }
private case class Function(val name: FunName, val eqns: Seq[Equation]) { private case class Function(val name: FunName, val eqns: Seq[Equation]) {
override def toString = name + eqns.mkString("(", ",", ")") override def toString = _pf(name) + eqns.mkString("(", ",", ")")
} }
private class Equation(val vars: Set[Variable] = Set(), private class Equation(val vars: Seq[Variable] = Seq(),
var fun: Option[Function] = None) { var fun: Option[Function] = None) {
def this(v: Variable) = this (vars = Set(v)) def this(v: Variable) = this (vars = Seq(v))
def this(f: Function) = this (fun = Some(f)) def this(f: Function) = this (fun = Some(f))
...@@ -145,11 +256,18 @@ trait Unifier2[VarName >: Null, FunName >: Null] { ...@@ -145,11 +256,18 @@ trait Unifier2[VarName >: Null, FunName >: Null] {
override def toString = "[" + refCounter + "] " + eqn override def toString = "[" + refCounter + "] " + eqn
} }
private def unify(equivalences: Seq[Equivalence]): Map[VarName, Term] = { /* Implementation */
private def unify(equivalences: Seq[Equivalence]): MGU = {
var numberOfClasses = equivalences.size var numberOfClasses = equivalences.size
val substitutions = Map[VarName, Term]() val substitutions = Seq[(VarName, Term)]()
val freeClasses = Stack[Equivalence]() // Equations with a zero ref counter val freeClasses = Stack[Equivalence]() // Equations with a zero ref counter
/*
val vars = equivalences map {_.eqn.vars.head}
val fvars = Seq[Variable]()
*/
// Initialize reference counters // Initialize reference counters
def countRefs(fun: Function) { def countRefs(fun: Function) {
for (eqn <- fun.eqns) { for (eqn <- fun.eqns) {
...@@ -170,30 +288,55 @@ trait Unifier2[VarName >: Null, FunName >: Null] { ...@@ -170,30 +288,55 @@ trait Unifier2[VarName >: Null, FunName >: Null] {
// Main loop // Main loop
while (numberOfClasses > 0) { while (numberOfClasses > 0) {
/*
println()
println("U:")
println(" vars : " + vars.mkString(", "))
val classes = (vars map {_.eqclass}).toSet
println(classes.size + " / " + numberOfClasses)
for (cl <- classes) println(cl)
println("T: ")
println(" vars : " + fvars.mkString(", "))
for ((v,t) <- substitutions) println(" " + v + " -> " + pp(t))
*/
// Select multi equation // Select multi equation
if (freeClasses.isEmpty) throw UnificationFailure("cycle") if (freeClasses.isEmpty) throw UnificationFailure("cycle")
val currentClass = freeClasses.pop val currentClass = freeClasses.pop
val currentVars = currentClass.eqn.vars val currentVars = currentClass.eqn.vars
val representative = Var(currentVars.head.name)
for (v <- currentVars.tail)
substitutions += (v.name -> representative)
currentClass.eqn.fun match { currentClass.eqn.fun match {
case Some(function) => case Some(function) =>
val (commonPart, frontier) = reduce(function) val (commonPart, frontier) = reduce(function)
substitutions += (representative.name -> commonPart)
for (v <- currentVars)
substitutions(v.name) = commonPart
// Compact equations (i.e. merge equivalence classes) // Compact equations (i.e. merge equivalence classes)
for (eqn <- frontier) { for (eqn <- frontier) {
/*
println(eqn)
*/
val eqclass = (eqn.vars map {_.eqclass}) reduceLeft compact val eqclass = (eqn.vars map {_.eqclass}) reduceLeft compact
eqclass.refCounter -= eqn.vars.size eqclass.refCounter -= eqn.vars.size
eqn.vars.clear
merge(eqclass.eqn, eqn) merge(eqclass.eqn, eqn)
if (eqclass.refCounter == 0) freeClasses push eqclass if (eqclass.refCounter == 0) freeClasses push eqclass
/*
println(" " + eqclass)
*/
} }
case None => case None =>
val representative = Var(currentVars.head.name)
for (v <- currentVars.tail) substitutions(v.name) = representative
} }
numberOfClasses -= 1 numberOfClasses -= 1
/*
vars --= currentVars
fvars ++= currentVars
*/
} }
substitutions substitutions
} }
......
package orderedsets
import purescala.Reporter
import purescala.Extensions.Solver
import Reconstruction.Model
class UnifierMain(reporter: Reporter) extends Solver(reporter) {
import purescala.Trees.{Expr, And, Not, Equals, negate, expandLets}
import DNF._
val description = "Unifier testbench"
override val shortDescription = "Unifier"
// checks for V-A-L-I-D-I-T-Y !
// Some(true) means formula is valid (negation is unsat)
// Some(false) means formula is not valid (negation is sat)
// None means you don't know.
//
def solve(exprWithLets: Expr): Option[Boolean] = {
val expr = expandLets(exprWithLets)
try {
reporter.info("")
expr match {
case and @ And(_) =>
PureScalaUnifier.unify(and)
Some(true)
case Equals(_, _) | Not(Equals(_, _)) =>
PureScalaUnifier.unify(And(List(expr)))
Some(true)
//None
case _ => throw ConversionException(expr, "Neither a conjunction nor a (in)equality")
}
} catch {
case PureScalaUnifier.UnificationFailure(msg) =>
reporter.info("Unification impossible : " + msg)
Some(false)
case ConversionException(badExpr, msg) =>
reporter.info(msg + " : " + badExpr.getClass.toString)
// reporter.info(DNF.pp(badExpr))
None
case e =>
reporter.error("Unifier just crashed.\n exception = " + e.toString)
e.printStackTrace
None
} finally {
}
}
}
package testcases
object UnificationTest {
/*
sealed abstract class Value
case class X extends Value
case class Y extends Value
case class Z extends Value
*/
sealed abstract class Tree
case class Leaf() extends Tree
case class Node(left: Tree, value: Int, right: Tree) extends Tree
def mkTree(a: Int, b: Int, c: Int) = {
Node(Node(Leaf(), a, Leaf()), b, Node(Leaf(), c, Leaf()))
//Node(Leaf(), b, Node(Leaf(), c, Leaf()))
} ensuring ( res => {
res.left != Leaf() &&
res.value == b &&
res.right == Node(Leaf(), c, Leaf())
})
sealed abstract class Term
case class F(t1: Term, t2: Term, t3: Term, t4: Term) extends Term
case class G(s1: Term, s2: Term) extends Term
case class H(r1: Term, r2: Term) extends Term
case class A extends Term
case class B extends Term
def examplePage268(x1: Term, x2: Term, x3: Term, x4: Term, x5: Term) = {
F(G(H(A(), x5), x2), x1, H(A(), x4), x4)
} ensuring ( _ == F(x1, G(x2, x3), x2, B()) )
case class Tuple3(_1: Term, _2: Term, _3: Term)
def examplePage269(x1: Term, x2: Term, x3: Term, x4: Term) = {
Tuple3(H(x1, x1), H(x2, x2), H(x3, x3))
} ensuring ( res => {
x2 == res._1 &&
x3 == res._2 &&
x4 == res._3
})
// Not working yet
/*
def mkInfiniteTree(x: Int): Node = {
Node(mkInfiniteTree(x), x, mkInfiniteTree(x))
} ensuring (res =>
res.left != Leaf() && res.right != Leaf()
)
def insert(tree: Tree, value: Int) : Node = (tree match {
case Leaf() => Node(Leaf(), value, Leaf())
case n @ Node(l, v, r) => if(v < value) {
Node(l, v, insert(r, value))
} else if(v > value) {
Node(insert(l, value), v, r)
} else {
n
}
}) ensuring(_ != Leaf())
*/
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment