From 446d44d4f3560f9a7d3a3c83257c9346e3f2c061 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Mon, 2 Aug 2010 22:10:12 +0000
Subject: [PATCH]

---
 src/setconstraints/CnstrtGen.scala     | 118 ++++++++++++++-----------
 src/setconstraints/Main.scala          |   2 +-
 src/setconstraints/Manip.scala         |  56 ++++++++++--
 src/setconstraints/PrettyPrinter.scala |   7 +-
 src/setconstraints/Solver.scala        |  58 ++++++++++++
 src/setconstraints/SolverSuite.scala   |   9 ++
 src/setconstraints/Trees.scala         |   3 +
 7 files changed, 192 insertions(+), 61 deletions(-)
 create mode 100644 src/setconstraints/Solver.scala
 create mode 100644 src/setconstraints/SolverSuite.scala

diff --git a/src/setconstraints/CnstrtGen.scala b/src/setconstraints/CnstrtGen.scala
index d18f2a35f..975a70a5b 100644
--- a/src/setconstraints/CnstrtGen.scala
+++ b/src/setconstraints/CnstrtGen.scala
@@ -1,13 +1,10 @@
 package setconstraints
 
-import scala.collection.mutable.{Map, HashMap, ListBuffer}
-
 import purescala.Definitions._
 import purescala.Trees.{And => _, Equals => _, _}
 import purescala.Common.Identifier
 import purescala.TypeTrees.ClassType
 
-
 import Trees._
 
 object CnstrtGen {
@@ -18,67 +15,85 @@ object CnstrtGen {
                 cl2adt: Map[ClassTypeDef, SetType]
               ): Formula = {
 
-    val funCallsCnstr: ListBuffer[Relation] = new ListBuffer[Relation]()
-    val patternCnstr: ListBuffer[Relation] = new ListBuffer[Relation]()
+    def unzip3[A,B,C](seqs: Seq[(A,B,C)]): (Seq[A],Seq[B],Seq[C]) = 
+      seqs.foldLeft((Seq[A](), Seq[B](), Seq[C]()))((a, t) => (t._1 +: a._1, t._2 +: a._2, t._3 +: a._3))
 
-    def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (SetType, Seq[Relation]) = expr match {
-      case Variable(id) => {
-        (context(id), Seq())
-      }
-      case IfExpr(cond, then, elze) => {
-        val (tType, tCnstrs) = cnstrExpr(then, context)
-        val (eType, eCnstrs) = cnstrExpr(elze, context)
-        (UnionType(Seq(tType, eType)), tCnstrs ++ eCnstrs)
-      }
-      case MatchExpr(scrut, cases) => {
-        val (sType, sCnstrs) = cnstrExpr(scrut, context)
-        val (cType, cCnstrs) = cases.map(mc => {
-          //val theGuard = mc.theGuard
-          val rhs = mc.rhs
-          val (pt, pvc) = pattern2Type(mc.pattern)
-          cnstrExpr(rhs, context ++ pvc)
-        }).unzip
-        val mType = freshVar("match")
-        val mCnstrs = cType.map(t => Include(t, mType))
-        (mType, mCnstrs ++ cCnstrs.flatMap(x => x))
-      }
-      case FunctionInvocation(fd, args) => {
-        val (tArgs,rt) = funVars(fd)
-        tArgs.zip(args).foreach{case (v, expr) => {
-            val (newT, newCnstr) = cnstrExpr(expr, context)
-            funCallsCnstr ++= newCnstr
-            funCallsCnstr += Include(newT, v)
+    def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (VariableType, Seq[Relation], Map[Expr, VariableType]) = {
+      val exprVarType = freshVar("expr")
+      val (rels, e2t) = expr match {
+        case Variable(id) => {
+          (Seq(Equals(context(id), exprVarType)), Map[Expr, VariableType]())
+        }
+        case IfExpr(cond, then, elze) => {
+          val (tType, tCnstrs, tMap) = cnstrExpr(then, context)
+          val (eType, eCnstrs, eMap) = cnstrExpr(elze, context)
+          val newCnstrs = Equals(UnionType(Seq(tType, eType)), exprVarType) +: (tCnstrs ++ eCnstrs)
+          (newCnstrs, (tMap ++ eMap)) 
+        }
+        case MatchExpr(scrut, cases) => {
+          val (sType, sCnstrs, sMap) = cnstrExpr(scrut, context)
+          val (pts, ptexpcnstr) = cases.map(mc => {
+            val (pt, cnstrs, pvc) = pattern2Type(mc.pattern)
+            val (expT, expC, expM) = cnstrExpr(mc.rhs, context ++ pvc)
+            (pt, (expT, expC ++ cnstrs, expM))
+          }).unzip
+          val (cTypes, cCnstrs, cMaps) = unzip3(ptexpcnstr)
+          val mCnstrs = cTypes.map(t => Include(t, exprVarType))
+          val scrutPatternCnstr = Include(sType, UnionType(pts))
+          val fMap: Map[Expr, VariableType] = cMaps.foldLeft(sMap)((a, m) => a ++ m)
+          val finalCnstrs = scrutPatternCnstr +: (mCnstrs ++ cCnstrs.flatMap(x => x) ++ sCnstrs)
+          (finalCnstrs, fMap)
+        }
+        case FunctionInvocation(fd, args) => {
+          val (tArgs,rt) = funVars(fd)
+          /*
+          tArgs.zip(args).foreach{case (v, expr) => {
+              val (newT, newCnstr) = cnstrExpr(expr, context)
+              funCallsCnstr ++= newCnstr
+              funCallsCnstr += Include(newT, v)
+            }
           }
+          */
+          (Seq(Equals(rt, exprVarType)), Map[Expr, VariableType]())
         }
-        (rt, Seq())
-      }
-      case CaseClass(ccd, args) => {
-        val (argsType, cnstrts) = args.map(e => cnstrExpr(e, context)).unzip
-        (ConstructorType(ccd.id.name, argsType), cnstrts.flatMap(x => x))
+        case CaseClass(ccd, args) => {
+          val (argsType, cnstrts, maps) = unzip3(args.map(e => cnstrExpr(e, context)))
+          val fMap = maps.foldLeft(Map[Expr, VariableType]())((a, m) => a ++ m)
+          val fcnstrts = Equals(ConstructorType(ccd.id.name, argsType), exprVarType) +: cnstrts.flatMap(x => x)
+          (fcnstrts, fMap)
+        }
+        case _ => error("Not yet supported: " + expr)
       }
-      case _ => error("Not yet supported: " + expr)
+      (exprVarType, rels, (e2t: Map[Expr, VariableType]) + (expr -> exprVarType))
     }
 
-    def pattern2Type(pattern: Pattern): (SetType, Map[Identifier, VariableType]) = pattern match {
+    def pattern2Type(pattern: Pattern): (VariableType, Seq[Relation], Map[Identifier, VariableType]) = pattern match {
       case InstanceOfPattern(binder, ctd) => error("not yet supported")
       case WildcardPattern(binder) => {
         val v = freshVar(binder match {case Some(id) => id.name case None => "x"})
-        (v, binder match {case Some(id) => Map(id -> v) case None => Map()})
+        (v, Seq[Relation](), binder match {case Some(id) => Map(id -> v) case None => Map()})
       }
       case CaseClassPattern(binder, ccd, sps) => {
-        val (subConsType, subVarType) = sps.map(p => pattern2Type(p)).unzip
+        val cvt = freshVar(ccd.id.name)
+        val (subConsType, cnstrs, subVarType) = unzip3(sps.map(p => pattern2Type(p)))
         val newMap = subVarType.foldLeft(Map[Identifier, VariableType]())((acc, el) => acc ++ el)
-        subConsType.zip(ccd.fields)foreach{case (t, vd) => patternCnstr += Equals(t, cl2adt(vd.tpe.asInstanceOf[ClassType].classDef))} //TODO bug if there are nested pattern
-        (ConstructorType(ccd.id.name, subConsType), newMap)
+        val nCnstrs: Seq[Relation] = subConsType.zip(ccd.fields).zip(sps).foldLeft(cnstrs.flatMap(x => x))((a, el) => el match {
+          case ((t, vd), sp) => sp match {
+            case WildcardPattern(_) => a :+ Equals(t, cl2adt(vd.tpe.asInstanceOf[ClassType].classDef))
+            case _ => a
+          }
+        })
+        val ccnstr = Equals(ConstructorType(ccd.id.name, subConsType), cvt)
+        (cvt, ccnstr +: nCnstrs, newMap)
       }
     }
 
-    def cnstrFun(fd: FunDef): Seq[Relation] = {
+    def cnstrFun(fd: FunDef): (Seq[Relation], Map[Expr, VariableType]) = {
       val argsT = funVars(fd)._1
       val argsID = fd.args.map(vd => vd.id)
       val context = argsID.zip(argsT).foldLeft(Map[Identifier, VariableType]())((acc, el) => acc + el)
-      val (bodyType, cnstrts) = cnstrExpr(fd.body.get, context)
-      cnstrts :+ Include(bodyType, funVars(fd)._2)
+      val (bodyType, cnstrts, map) = cnstrExpr(fd.body.get, context)
+      (cnstrts :+ Include(bodyType, funVars(fd)._2), map)
     }
 
     def cnstrTypeHierarchy(pgm: Program): Seq[Relation] = {
@@ -88,13 +103,12 @@ object CnstrtGen {
 
     val cnstrtsTypes = cnstrTypeHierarchy(pgm)
 
-    println(typeVars)
-    println(cnstrtsTypes)
-
     val funs = pgm.definedFunctions
-    val cnstrtsFunctions = funs.flatMap(cnstrFun)
-
-    And(cnstrtsTypes ++ cnstrtsFunctions ++ funCallsCnstr ++ patternCnstr)
+    val (cnstrtsFunctions, map) = funs.foldLeft(Seq[Relation](), Map[Expr, VariableType]())((a, f) => {
+      val (rels, m) = cnstrFun(f)
+      (a._1 ++ rels, a._2 ++ m)
+    })
+    And(cnstrtsTypes ++ cnstrtsFunctions)
   }
 
 }
diff --git a/src/setconstraints/Main.scala b/src/setconstraints/Main.scala
index f501ef2d4..d843f8383 100644
--- a/src/setconstraints/Main.scala
+++ b/src/setconstraints/Main.scala
@@ -15,7 +15,7 @@ class Main(reporter: Reporter) extends Analyser(reporter) {
     val (tpeVars, funVars) = LabelProgram(pgm)
     val cl2adt = ADTExtractor(pgm)
 
-    val cnstr = CnstrtGen(pgm, tpeVars, funVars, cl2adt)
+    val cnstr = CnstrtGen(pgm, Map(tpeVars.toList: _*), Map(funVars.toList: _*), Map(cl2adt.toList: _*))
 
     reporter.info("The constraints are:")
     reporter.info(PrettyPrinter(cnstr))
diff --git a/src/setconstraints/Manip.scala b/src/setconstraints/Manip.scala
index 329a11ce2..64e64fcae 100644
--- a/src/setconstraints/Manip.scala
+++ b/src/setconstraints/Manip.scala
@@ -4,12 +4,56 @@ import setconstraints.Trees._
 
 object Manip {
 
-  def flatten(f: Formula): Formula = f match {
-    case And(fs) => And(fs.flatMap(f => flatten(f) match {
-        case And(fs2) => fs2
-        case f => List(f)
-      }))
-    case f => f
+  def map(s: SetType, f: (SetType) => SetType): SetType = s match {
+    case EmptyType | UniversalType | VariableType(_) => f(s)
+    case UnionType(sts) => f(UnionType(sts.map(s => map(s, f))))
+    case IntersectionType(sts) => f(IntersectionType(sts.map(s => map(s, f))))
+    case ComplementType(s) => f(ComplementType(map(s, f)))
+    case ConstructorType(n@_, sts) => f(ConstructorType(n, sts.map(s => map(s, f))))
+    case FunctionType(s1, s2) => {
+      val ns1 = map(s1, f)
+      val ns2 = map(s2, f)
+      f(FunctionType(ns1, ns2))
+    }
+    case TupleType(sts) => f(TupleType(sts.map(s => f(s))))
+  }
+  def map(f: Formula, ff: (Formula) => Formula, ft: (SetType) => SetType): Formula = f match {
+    case And(fs) => ff(And(fs.map(f => map(f, ff, ft))))
+    case Include(s1, s2) => {
+      val ns1 = map(s1, ft)
+      val ns2 = map(s2, ft)
+      ff(Include(ns1, ns2))
+    }
+    case Equals(s1, s2) => {
+      val ns1 = map(s1, ft)
+      val ns2 = map(s2, ft)
+      ff(Equals(ns1, ns2))
+    }
+  }
+
+  def flatten(formula: Formula): Formula = {
+    def flatten0(f: Formula) = f match {
+      case And(fs) => And(fs.flatMap{
+          case And(fs2) => fs2
+          case f => List(f)
+        })
+      case f => f
+    }
+    map(formula, flatten0, s => s)
+  }
+  def flatten(setType: SetType): SetType = {
+    def flatten0(s: SetType) = s match {
+      case UnionType(sts) => UnionType(sts.flatMap{
+          case UnionType(sts2) => sts2
+          case s => List(s)
+        })
+      case IntersectionType(sts) => IntersectionType(sts.flatMap{
+          case IntersectionType(sts2) => sts2
+          case s => List(s)
+        })
+      case s => s
+    }
+    map(setType, flatten0)
   }
 
   def includes(f: Formula): Seq[Include] = flatten(f) match {
diff --git a/src/setconstraints/PrettyPrinter.scala b/src/setconstraints/PrettyPrinter.scala
index aee0378b9..d1b5bc4f5 100644
--- a/src/setconstraints/PrettyPrinter.scala
+++ b/src/setconstraints/PrettyPrinter.scala
@@ -11,8 +11,8 @@ object PrettyPrinter {
   def apply(fp: FixPoint): String = ppFixPoint(fp)
 
   private def ppFormula(f: Formula): String = f match {
-    case And(fs) => fs.map(ppFormula).mkString("(", " \u2227 ", ")")
-    case Include(s1, s2) => ppSetType(s1) + " \u2282 " + ppSetType(s2)
+    case And(fs) => fs.map(ppFormula).mkString("(  ", "\n \u2227 ", ")")
+    case Include(s1, s2) => ppSetType(s1) + " \u2286 " + ppSetType(s2)
     case Equals(s1, s2) => ppSetType(s1) + " = " + ppSetType(s2)
   }
 
@@ -21,9 +21,12 @@ object PrettyPrinter {
     case ConstructorType(name, sts) => name + sts.map(ppSetType).mkString("(", ", ", ")")
     case UnionType(sts) => sts.map(ppSetType).mkString("(", " \u222A ", ")")
     case IntersectionType(sts) => sts.map(ppSetType).mkString("(", " \u2229 ", ")")
+    case ComplementType(s) => "\u00AC" + ppSetType(s)
     case FunctionType(s1, s2) => "(" + ppSetType(s1) + " --> " + ppSetType(s2) + ")"
     case TupleType(sts) => sts.map(ppSetType).mkString("(", ", ", ")")
     case VariableType(name) => name
+    case EmptyType => "0"
+    case UniversalType => "1"
   }
 
   private def ppFixPoint(fp: FixPoint): String = fp match {
diff --git a/src/setconstraints/Solver.scala b/src/setconstraints/Solver.scala
new file mode 100644
index 000000000..b7e5947ea
--- /dev/null
+++ b/src/setconstraints/Solver.scala
@@ -0,0 +1,58 @@
+package setconstraints
+
+import Trees._
+import Manip._
+
+object Solver {
+
+  def apply(system: List[Relation]): Option[List[FixPoint]] = {
+    error("TODO")
+  }
+
+  def solve(system: List[Relation]): Option[List[Equals]] = {
+    error("TODO")
+  }
+
+  def oneLevel(system: List[Include]): List[Include] = {
+
+    val emptyRightSystem = system.map{
+      case Include(s1, s2) if s2 != EmptyType => Include(IntersectionType(List(s1, ComplementType(s2))), EmptyType)
+      case incl => incl
+    }
+
+
+    error("TODO")
+  }
+
+  def isConstructor(s: SetType): Boolean = s match {
+    case ConstructorType(_, _) => true
+    case _ => false
+  }
+
+  def isLiteral(s: SetType): Boolean = s match {
+    case VariableType(_) => true
+    case ComplementType(VariableType(_)) => true
+    case _ => false
+  }
+  def isConjunctionLit(s: SetType): Boolean = flatten(s) match {
+    case IntersectionType(sts) if sts.foldLeft(true)((b, st) => b && isLiteral(st) && sts.forall(l => l != ComplementType(st))) => false
+    case _ => false
+  }
+  def isConjunctionLitWithUniversal(s: SetType): Boolean = flatten(s) match {
+    case IntersectionType(sts) if sts.last == UniversalType && isConjunctionLit(IntersectionType(sts.init)) => true
+    case _ => false
+  }
+  def isOneLevel(s: SetType): Boolean = flatten(s) match {
+    case EmptyType => true
+    case IntersectionType(sts) if isConstructor(sts.last) && isConjunctionLit(IntersectionType(sts.init)) => {
+      val ConstructorType(_, args) = sts.last 
+      args.forall(isConjunctionLitWithUniversal)
+    }
+    case s => isConjunctionLitWithUniversal(s)
+  }
+  def isOneLevel(r: Relation): Boolean = r match {
+    case Include(s1, EmptyType) if isOneLevel(s1) => true
+    case _ => false
+  }
+  def isOneLevel(system: List[Relation]): Boolean = system.forall(isOneLevel)
+}
diff --git a/src/setconstraints/SolverSuite.scala b/src/setconstraints/SolverSuite.scala
new file mode 100644
index 000000000..b5325561f
--- /dev/null
+++ b/src/setconstraints/SolverSuite.scala
@@ -0,0 +1,9 @@
+package setconstraints
+
+/*
+import org.scalatest.FunSuite
+
+class SolverSuite extends FunSuite {
+
+}
+*/
diff --git a/src/setconstraints/Trees.scala b/src/setconstraints/Trees.scala
index 075c2f207..21cd03ac9 100644
--- a/src/setconstraints/Trees.scala
+++ b/src/setconstraints/Trees.scala
@@ -15,10 +15,13 @@ object Trees {
 
   case class UnionType(sets: Seq[SetType]) extends SetType
   case class IntersectionType(sets: Seq[SetType]) extends SetType
+  case class ComplementType(st: SetType) extends SetType
   case class FunctionType(s1: SetType, s2: SetType) extends SetType
   case class TupleType(sets: Seq[SetType]) extends SetType
   case class ConstructorType(name: String, sets: Seq[SetType]) extends SetType
   case class VariableType(name: String) extends SetType
+  case object EmptyType extends SetType
+  case object UniversalType extends SetType
 
   private var varCounter = -1
   def freshVar(name: String) = {
-- 
GitLab