diff --git a/src/orderedsets/Unifier.scala b/src/orderedsets/Unifier.scala index 6245ffcdac7f9b5f40bb3a08834abc6d771d427b..b569b425abf157bbb27c81dab73171f5873047bd 100644 --- a/src/orderedsets/Unifier.scala +++ b/src/orderedsets/Unifier.scala @@ -71,6 +71,7 @@ object ExampleUnifier extends Unifier[String, String] { } import scala.collection.mutable.ArrayBuffer +import scala.collection.Map import purescala.Common._ import purescala.Trees._ import purescala.TypeTrees._ @@ -85,7 +86,7 @@ object ADTUnifier extends Unifier[Variable,CaseClassDef] { def freshVar(prefix: String)(typed: Typed) = Var(Variable(FreshIdentifier(prefix, true) setType typed.getType)) - def unify(conjunction: Seq[Expr]) { + def unify(conjunction: Seq[Expr]): Map[Variable,Expr] = { val equalities = new ArrayBuffer[(Term,Term)]() val inequalities = new ArrayBuffer[(Var,Var)]() @@ -128,8 +129,8 @@ object ADTUnifier extends Unifier[Variable,CaseClassDef] { */ val mgu = unify(equalities.toList) - val map = blowUp(mgu) - def subst(v: Variable) = map getOrElse (v, Var(v)) + val table = blowUp(mgu) + def subst(v: Variable) = table getOrElse (v, Var(v)) /* def byName(entry1: (Variable,Term), entry2: (Variable,Term)) = @@ -140,8 +141,14 @@ object ADTUnifier extends Unifier[Variable,CaseClassDef] { for ((x, t) <- mgu.toList sortWith byName) println(" " + x + " = " + pp(t)) println + + val substTable = table mapValues term2expr + println("--- Output of the unifier (Substitution table) ---") + for ((x, t) <- substTable.toList sortWith {_._1.id.name < _._1.id.name}) + println(" " + x + " = " + t) */ + // check inequalities for ((Var(x1), Var(x2)) <- inequalities) { val t1 = subst(x1) @@ -161,30 +168,29 @@ object ADTUnifier extends Unifier[Variable,CaseClassDef] { if (map1.isEmpty) println(" (empty table)") println */ - () + table mapValues term2expr } - - def term2expr(term: Term): Expr = term match { - case Var(v) => v - case Fun(cd, args) => CaseClass(cd, args map term2expr) - } + def term2expr(term: Term): Expr = term match { + case Var(v) => v + case Fun(cd, args) => CaseClass(cd, args map term2expr) + } } -import scala.collection.mutable.{ArrayBuffer => Seq, Map, Set, Stack} +import scala.collection.mutable.{ArrayBuffer => Seq, Map => MutableMap, Set, Stack} case class UnificationImpossible(msg: String) extends Exception(msg) trait Unifier[VarName >: Null, FunName >: Null] { type MGU = Seq[(VarName, Term)] - type Subst = Map[VarName, Term] + type Subst = MutableMap[VarName, Term] // transitive closure for the mapping - the smart way (in only one iteration) def blowUp(mgu: MGU): Subst = { - val map = Map.empty[VarName, Term] + val map = MutableMap[VarName, Term]() def subst(term: Term): Term = term match { case Var(v) => map get v match { case Some(t) => t @@ -222,7 +228,7 @@ trait Unifier[VarName >: Null, FunName >: Null] { unify(List((term1, term2))) def unify(terms: List[(Term, Term)]): MGU = { - val variableMap = Map[VarName, Variable]() + val variableMap = MutableMap[VarName, Variable]() def convertTerm(term: Term): Equation = term match { case Var(name) => variableMap get name match { case Some(v) => diff --git a/src/orderedsets/UnifierMain.scala b/src/orderedsets/UnifierMain.scala index fa425637bd97015c82bfd9a760653448e7e6982d..8f56d5c3d2f48f557ee996b4c649f2719d5d632c 100644 --- a/src/orderedsets/UnifierMain.scala +++ b/src/orderedsets/UnifierMain.scala @@ -1,5 +1,6 @@ package orderedsets +import scala.collection.Map import purescala.Reporter import purescala.Extensions.Solver import Reconstruction.Model @@ -77,7 +78,11 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) { rest foreach println */ - val subst = ADTUnifier.unify(treeEquations) + // The substitution table + val substTable = ADTUnifier.unify(treeEquations) + + // The substitution function (returns identity if unmapped) + def subst(v: Variable): Expr = substTable getOrElse (v, v) throw IncompleteException(null)