From 3c5619a5eb50570283b0e6e915843f10326aa15b Mon Sep 17 00:00:00 2001
From: Robin Steiger <robin.steiger@epfl.ch>
Date: Sat, 10 Jul 2010 11:54:33 +0000
Subject: [PATCH] The unifier now translates terms (internal representation)
 back to expressions (pure Scala) and returns the substitution table.

---
 src/orderedsets/Unifier.scala     | 32 ++++++++++++++++++-------------
 src/orderedsets/UnifierMain.scala |  7 ++++++-
 2 files changed, 25 insertions(+), 14 deletions(-)

diff --git a/src/orderedsets/Unifier.scala b/src/orderedsets/Unifier.scala
index 6245ffcda..b569b425a 100644
--- a/src/orderedsets/Unifier.scala
+++ b/src/orderedsets/Unifier.scala
@@ -71,6 +71,7 @@ object ExampleUnifier extends Unifier[String, String] {
 }
 
 import scala.collection.mutable.ArrayBuffer
+import scala.collection.Map
 import purescala.Common._
 import purescala.Trees._
 import purescala.TypeTrees._
@@ -85,7 +86,7 @@ object ADTUnifier extends Unifier[Variable,CaseClassDef] {
   
   def freshVar(prefix: String)(typed: Typed) = Var(Variable(FreshIdentifier(prefix, true) setType typed.getType))
   
-  def unify(conjunction: Seq[Expr]) {
+  def unify(conjunction: Seq[Expr]): Map[Variable,Expr] = {
     val equalities = new ArrayBuffer[(Term,Term)]()
     val inequalities = new ArrayBuffer[(Var,Var)]()
   
@@ -128,8 +129,8 @@ object ADTUnifier extends Unifier[Variable,CaseClassDef] {
     */
     
     val mgu = unify(equalities.toList)
-    val map = blowUp(mgu)
-    def subst(v: Variable) = map getOrElse (v, Var(v))
+    val table = blowUp(mgu)
+    def subst(v: Variable) = table getOrElse (v, Var(v))
     
     /*
     def byName(entry1: (Variable,Term), entry2: (Variable,Term)) =
@@ -140,8 +141,14 @@ object ADTUnifier extends Unifier[Variable,CaseClassDef] {
     for ((x, t) <- mgu.toList sortWith byName)
       println("  " + x + "  =  " + pp(t))
     println
+  
+    val substTable = table mapValues term2expr
+    println("--- Output of the unifier (Substitution table) ---")
+    for ((x, t) <- substTable.toList sortWith {_._1.id.name < _._1.id.name})
+      println("  " + x + "  =  " + t)
     */
   
+  
     // check inequalities
     for ((Var(x1), Var(x2)) <- inequalities) {
       val t1 = subst(x1)
@@ -161,30 +168,29 @@ object ADTUnifier extends Unifier[Variable,CaseClassDef] {
     if (map1.isEmpty) println("  (empty table)")
     println 
     */
-    ()
+    table mapValues term2expr
   }
   
-  
-    def term2expr(term: Term): Expr = term match {
-      case Var(v) => v
-      case Fun(cd, args) => CaseClass(cd, args map term2expr)
-    }
+  def term2expr(term: Term): Expr = term match {
+    case Var(v) => v
+    case Fun(cd, args) => CaseClass(cd, args map term2expr)
+  }
 }
 
 
 
-import scala.collection.mutable.{ArrayBuffer => Seq, Map, Set, Stack}
+import scala.collection.mutable.{ArrayBuffer => Seq, Map => MutableMap, Set, Stack}
 
 case class UnificationImpossible(msg: String) extends Exception(msg)
   
 trait Unifier[VarName >: Null, FunName >: Null] {
   
   type MGU = Seq[(VarName, Term)]
-  type Subst = Map[VarName, Term]
+  type Subst = MutableMap[VarName, Term]
 
   // transitive closure for the mapping - the smart way (in only one iteration)
   def blowUp(mgu: MGU): Subst = {
-    val map = Map.empty[VarName, Term]
+    val map = MutableMap[VarName, Term]()
     def subst(term: Term): Term = term match {
       case Var(v) => map get v match {
         case Some(t) => t
@@ -222,7 +228,7 @@ trait Unifier[VarName >: Null, FunName >: Null] {
     unify(List((term1, term2)))
 
   def unify(terms: List[(Term, Term)]): MGU = {
-    val variableMap = Map[VarName, Variable]()
+    val variableMap = MutableMap[VarName, Variable]()
     def convertTerm(term: Term): Equation = term match {
       case Var(name) => variableMap get name match {
         case Some(v) =>
diff --git a/src/orderedsets/UnifierMain.scala b/src/orderedsets/UnifierMain.scala
index fa425637b..8f56d5c3d 100644
--- a/src/orderedsets/UnifierMain.scala
+++ b/src/orderedsets/UnifierMain.scala
@@ -1,5 +1,6 @@
 package orderedsets
 
+import scala.collection.Map
 import purescala.Reporter
 import purescala.Extensions.Solver
 import Reconstruction.Model
@@ -77,7 +78,11 @@ class UnifierMain(reporter: Reporter) extends Solver(reporter) {
     rest foreach println
     */
     
-    val subst = ADTUnifier.unify(treeEquations)
+    // The substitution table
+    val substTable = ADTUnifier.unify(treeEquations)
+    
+    // The substitution function (returns identity if unmapped)
+    def subst(v: Variable): Expr = substTable getOrElse (v, v)
     
     throw IncompleteException(null)
     
-- 
GitLab