Skip to content
Snippets Groups Projects
Commit 446d44d4 authored by Régis Blanc's avatar Régis Blanc
Browse files

No commit message

No commit message
parent 7548f04c
Branches
Tags
No related merge requests found
package setconstraints package setconstraints
import scala.collection.mutable.{Map, HashMap, ListBuffer}
import purescala.Definitions._ import purescala.Definitions._
import purescala.Trees.{And => _, Equals => _, _} import purescala.Trees.{And => _, Equals => _, _}
import purescala.Common.Identifier import purescala.Common.Identifier
import purescala.TypeTrees.ClassType import purescala.TypeTrees.ClassType
import Trees._ import Trees._
object CnstrtGen { object CnstrtGen {
...@@ -18,67 +15,85 @@ object CnstrtGen { ...@@ -18,67 +15,85 @@ object CnstrtGen {
cl2adt: Map[ClassTypeDef, SetType] cl2adt: Map[ClassTypeDef, SetType]
): Formula = { ): Formula = {
val funCallsCnstr: ListBuffer[Relation] = new ListBuffer[Relation]() def unzip3[A,B,C](seqs: Seq[(A,B,C)]): (Seq[A],Seq[B],Seq[C]) =
val patternCnstr: ListBuffer[Relation] = new ListBuffer[Relation]() seqs.foldLeft((Seq[A](), Seq[B](), Seq[C]()))((a, t) => (t._1 +: a._1, t._2 +: a._2, t._3 +: a._3))
def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (SetType, Seq[Relation]) = expr match { def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (VariableType, Seq[Relation], Map[Expr, VariableType]) = {
case Variable(id) => { val exprVarType = freshVar("expr")
(context(id), Seq()) val (rels, e2t) = expr match {
} case Variable(id) => {
case IfExpr(cond, then, elze) => { (Seq(Equals(context(id), exprVarType)), Map[Expr, VariableType]())
val (tType, tCnstrs) = cnstrExpr(then, context) }
val (eType, eCnstrs) = cnstrExpr(elze, context) case IfExpr(cond, then, elze) => {
(UnionType(Seq(tType, eType)), tCnstrs ++ eCnstrs) val (tType, tCnstrs, tMap) = cnstrExpr(then, context)
} val (eType, eCnstrs, eMap) = cnstrExpr(elze, context)
case MatchExpr(scrut, cases) => { val newCnstrs = Equals(UnionType(Seq(tType, eType)), exprVarType) +: (tCnstrs ++ eCnstrs)
val (sType, sCnstrs) = cnstrExpr(scrut, context) (newCnstrs, (tMap ++ eMap))
val (cType, cCnstrs) = cases.map(mc => { }
//val theGuard = mc.theGuard case MatchExpr(scrut, cases) => {
val rhs = mc.rhs val (sType, sCnstrs, sMap) = cnstrExpr(scrut, context)
val (pt, pvc) = pattern2Type(mc.pattern) val (pts, ptexpcnstr) = cases.map(mc => {
cnstrExpr(rhs, context ++ pvc) val (pt, cnstrs, pvc) = pattern2Type(mc.pattern)
}).unzip val (expT, expC, expM) = cnstrExpr(mc.rhs, context ++ pvc)
val mType = freshVar("match") (pt, (expT, expC ++ cnstrs, expM))
val mCnstrs = cType.map(t => Include(t, mType)) }).unzip
(mType, mCnstrs ++ cCnstrs.flatMap(x => x)) val (cTypes, cCnstrs, cMaps) = unzip3(ptexpcnstr)
} val mCnstrs = cTypes.map(t => Include(t, exprVarType))
case FunctionInvocation(fd, args) => { val scrutPatternCnstr = Include(sType, UnionType(pts))
val (tArgs,rt) = funVars(fd) val fMap: Map[Expr, VariableType] = cMaps.foldLeft(sMap)((a, m) => a ++ m)
tArgs.zip(args).foreach{case (v, expr) => { val finalCnstrs = scrutPatternCnstr +: (mCnstrs ++ cCnstrs.flatMap(x => x) ++ sCnstrs)
val (newT, newCnstr) = cnstrExpr(expr, context) (finalCnstrs, fMap)
funCallsCnstr ++= newCnstr }
funCallsCnstr += Include(newT, v) case FunctionInvocation(fd, args) => {
val (tArgs,rt) = funVars(fd)
/*
tArgs.zip(args).foreach{case (v, expr) => {
val (newT, newCnstr) = cnstrExpr(expr, context)
funCallsCnstr ++= newCnstr
funCallsCnstr += Include(newT, v)
}
} }
*/
(Seq(Equals(rt, exprVarType)), Map[Expr, VariableType]())
} }
(rt, Seq()) case CaseClass(ccd, args) => {
} val (argsType, cnstrts, maps) = unzip3(args.map(e => cnstrExpr(e, context)))
case CaseClass(ccd, args) => { val fMap = maps.foldLeft(Map[Expr, VariableType]())((a, m) => a ++ m)
val (argsType, cnstrts) = args.map(e => cnstrExpr(e, context)).unzip val fcnstrts = Equals(ConstructorType(ccd.id.name, argsType), exprVarType) +: cnstrts.flatMap(x => x)
(ConstructorType(ccd.id.name, argsType), cnstrts.flatMap(x => x)) (fcnstrts, fMap)
}
case _ => error("Not yet supported: " + expr)
} }
case _ => error("Not yet supported: " + expr) (exprVarType, rels, (e2t: Map[Expr, VariableType]) + (expr -> exprVarType))
} }
def pattern2Type(pattern: Pattern): (SetType, Map[Identifier, VariableType]) = pattern match { def pattern2Type(pattern: Pattern): (VariableType, Seq[Relation], Map[Identifier, VariableType]) = pattern match {
case InstanceOfPattern(binder, ctd) => error("not yet supported") case InstanceOfPattern(binder, ctd) => error("not yet supported")
case WildcardPattern(binder) => { case WildcardPattern(binder) => {
val v = freshVar(binder match {case Some(id) => id.name case None => "x"}) val v = freshVar(binder match {case Some(id) => id.name case None => "x"})
(v, binder match {case Some(id) => Map(id -> v) case None => Map()}) (v, Seq[Relation](), binder match {case Some(id) => Map(id -> v) case None => Map()})
} }
case CaseClassPattern(binder, ccd, sps) => { case CaseClassPattern(binder, ccd, sps) => {
val (subConsType, subVarType) = sps.map(p => pattern2Type(p)).unzip val cvt = freshVar(ccd.id.name)
val (subConsType, cnstrs, subVarType) = unzip3(sps.map(p => pattern2Type(p)))
val newMap = subVarType.foldLeft(Map[Identifier, VariableType]())((acc, el) => acc ++ el) val newMap = subVarType.foldLeft(Map[Identifier, VariableType]())((acc, el) => acc ++ el)
subConsType.zip(ccd.fields)foreach{case (t, vd) => patternCnstr += Equals(t, cl2adt(vd.tpe.asInstanceOf[ClassType].classDef))} //TODO bug if there are nested pattern val nCnstrs: Seq[Relation] = subConsType.zip(ccd.fields).zip(sps).foldLeft(cnstrs.flatMap(x => x))((a, el) => el match {
(ConstructorType(ccd.id.name, subConsType), newMap) case ((t, vd), sp) => sp match {
case WildcardPattern(_) => a :+ Equals(t, cl2adt(vd.tpe.asInstanceOf[ClassType].classDef))
case _ => a
}
})
val ccnstr = Equals(ConstructorType(ccd.id.name, subConsType), cvt)
(cvt, ccnstr +: nCnstrs, newMap)
} }
} }
def cnstrFun(fd: FunDef): Seq[Relation] = { def cnstrFun(fd: FunDef): (Seq[Relation], Map[Expr, VariableType]) = {
val argsT = funVars(fd)._1 val argsT = funVars(fd)._1
val argsID = fd.args.map(vd => vd.id) val argsID = fd.args.map(vd => vd.id)
val context = argsID.zip(argsT).foldLeft(Map[Identifier, VariableType]())((acc, el) => acc + el) val context = argsID.zip(argsT).foldLeft(Map[Identifier, VariableType]())((acc, el) => acc + el)
val (bodyType, cnstrts) = cnstrExpr(fd.body.get, context) val (bodyType, cnstrts, map) = cnstrExpr(fd.body.get, context)
cnstrts :+ Include(bodyType, funVars(fd)._2) (cnstrts :+ Include(bodyType, funVars(fd)._2), map)
} }
def cnstrTypeHierarchy(pgm: Program): Seq[Relation] = { def cnstrTypeHierarchy(pgm: Program): Seq[Relation] = {
...@@ -88,13 +103,12 @@ object CnstrtGen { ...@@ -88,13 +103,12 @@ object CnstrtGen {
val cnstrtsTypes = cnstrTypeHierarchy(pgm) val cnstrtsTypes = cnstrTypeHierarchy(pgm)
println(typeVars)
println(cnstrtsTypes)
val funs = pgm.definedFunctions val funs = pgm.definedFunctions
val cnstrtsFunctions = funs.flatMap(cnstrFun) val (cnstrtsFunctions, map) = funs.foldLeft(Seq[Relation](), Map[Expr, VariableType]())((a, f) => {
val (rels, m) = cnstrFun(f)
And(cnstrtsTypes ++ cnstrtsFunctions ++ funCallsCnstr ++ patternCnstr) (a._1 ++ rels, a._2 ++ m)
})
And(cnstrtsTypes ++ cnstrtsFunctions)
} }
} }
...@@ -15,7 +15,7 @@ class Main(reporter: Reporter) extends Analyser(reporter) { ...@@ -15,7 +15,7 @@ class Main(reporter: Reporter) extends Analyser(reporter) {
val (tpeVars, funVars) = LabelProgram(pgm) val (tpeVars, funVars) = LabelProgram(pgm)
val cl2adt = ADTExtractor(pgm) val cl2adt = ADTExtractor(pgm)
val cnstr = CnstrtGen(pgm, tpeVars, funVars, cl2adt) val cnstr = CnstrtGen(pgm, Map(tpeVars.toList: _*), Map(funVars.toList: _*), Map(cl2adt.toList: _*))
reporter.info("The constraints are:") reporter.info("The constraints are:")
reporter.info(PrettyPrinter(cnstr)) reporter.info(PrettyPrinter(cnstr))
......
...@@ -4,12 +4,56 @@ import setconstraints.Trees._ ...@@ -4,12 +4,56 @@ import setconstraints.Trees._
object Manip { object Manip {
def flatten(f: Formula): Formula = f match { def map(s: SetType, f: (SetType) => SetType): SetType = s match {
case And(fs) => And(fs.flatMap(f => flatten(f) match { case EmptyType | UniversalType | VariableType(_) => f(s)
case And(fs2) => fs2 case UnionType(sts) => f(UnionType(sts.map(s => map(s, f))))
case f => List(f) case IntersectionType(sts) => f(IntersectionType(sts.map(s => map(s, f))))
})) case ComplementType(s) => f(ComplementType(map(s, f)))
case f => f case ConstructorType(n@_, sts) => f(ConstructorType(n, sts.map(s => map(s, f))))
case FunctionType(s1, s2) => {
val ns1 = map(s1, f)
val ns2 = map(s2, f)
f(FunctionType(ns1, ns2))
}
case TupleType(sts) => f(TupleType(sts.map(s => f(s))))
}
def map(f: Formula, ff: (Formula) => Formula, ft: (SetType) => SetType): Formula = f match {
case And(fs) => ff(And(fs.map(f => map(f, ff, ft))))
case Include(s1, s2) => {
val ns1 = map(s1, ft)
val ns2 = map(s2, ft)
ff(Include(ns1, ns2))
}
case Equals(s1, s2) => {
val ns1 = map(s1, ft)
val ns2 = map(s2, ft)
ff(Equals(ns1, ns2))
}
}
def flatten(formula: Formula): Formula = {
def flatten0(f: Formula) = f match {
case And(fs) => And(fs.flatMap{
case And(fs2) => fs2
case f => List(f)
})
case f => f
}
map(formula, flatten0, s => s)
}
def flatten(setType: SetType): SetType = {
def flatten0(s: SetType) = s match {
case UnionType(sts) => UnionType(sts.flatMap{
case UnionType(sts2) => sts2
case s => List(s)
})
case IntersectionType(sts) => IntersectionType(sts.flatMap{
case IntersectionType(sts2) => sts2
case s => List(s)
})
case s => s
}
map(setType, flatten0)
} }
def includes(f: Formula): Seq[Include] = flatten(f) match { def includes(f: Formula): Seq[Include] = flatten(f) match {
......
...@@ -11,8 +11,8 @@ object PrettyPrinter { ...@@ -11,8 +11,8 @@ object PrettyPrinter {
def apply(fp: FixPoint): String = ppFixPoint(fp) def apply(fp: FixPoint): String = ppFixPoint(fp)
private def ppFormula(f: Formula): String = f match { private def ppFormula(f: Formula): String = f match {
case And(fs) => fs.map(ppFormula).mkString("(", " \u2227 ", ")") case And(fs) => fs.map(ppFormula).mkString("( ", "\n \u2227 ", ")")
case Include(s1, s2) => ppSetType(s1) + " \u2282 " + ppSetType(s2) case Include(s1, s2) => ppSetType(s1) + " \u2286 " + ppSetType(s2)
case Equals(s1, s2) => ppSetType(s1) + " = " + ppSetType(s2) case Equals(s1, s2) => ppSetType(s1) + " = " + ppSetType(s2)
} }
...@@ -21,9 +21,12 @@ object PrettyPrinter { ...@@ -21,9 +21,12 @@ object PrettyPrinter {
case ConstructorType(name, sts) => name + sts.map(ppSetType).mkString("(", ", ", ")") case ConstructorType(name, sts) => name + sts.map(ppSetType).mkString("(", ", ", ")")
case UnionType(sts) => sts.map(ppSetType).mkString("(", " \u222A ", ")") case UnionType(sts) => sts.map(ppSetType).mkString("(", " \u222A ", ")")
case IntersectionType(sts) => sts.map(ppSetType).mkString("(", " \u2229 ", ")") case IntersectionType(sts) => sts.map(ppSetType).mkString("(", " \u2229 ", ")")
case ComplementType(s) => "\u00AC" + ppSetType(s)
case FunctionType(s1, s2) => "(" + ppSetType(s1) + " --> " + ppSetType(s2) + ")" case FunctionType(s1, s2) => "(" + ppSetType(s1) + " --> " + ppSetType(s2) + ")"
case TupleType(sts) => sts.map(ppSetType).mkString("(", ", ", ")") case TupleType(sts) => sts.map(ppSetType).mkString("(", ", ", ")")
case VariableType(name) => name case VariableType(name) => name
case EmptyType => "0"
case UniversalType => "1"
} }
private def ppFixPoint(fp: FixPoint): String = fp match { private def ppFixPoint(fp: FixPoint): String = fp match {
......
package setconstraints
import Trees._
import Manip._
object Solver {
def apply(system: List[Relation]): Option[List[FixPoint]] = {
error("TODO")
}
def solve(system: List[Relation]): Option[List[Equals]] = {
error("TODO")
}
def oneLevel(system: List[Include]): List[Include] = {
val emptyRightSystem = system.map{
case Include(s1, s2) if s2 != EmptyType => Include(IntersectionType(List(s1, ComplementType(s2))), EmptyType)
case incl => incl
}
error("TODO")
}
def isConstructor(s: SetType): Boolean = s match {
case ConstructorType(_, _) => true
case _ => false
}
def isLiteral(s: SetType): Boolean = s match {
case VariableType(_) => true
case ComplementType(VariableType(_)) => true
case _ => false
}
def isConjunctionLit(s: SetType): Boolean = flatten(s) match {
case IntersectionType(sts) if sts.foldLeft(true)((b, st) => b && isLiteral(st) && sts.forall(l => l != ComplementType(st))) => false
case _ => false
}
def isConjunctionLitWithUniversal(s: SetType): Boolean = flatten(s) match {
case IntersectionType(sts) if sts.last == UniversalType && isConjunctionLit(IntersectionType(sts.init)) => true
case _ => false
}
def isOneLevel(s: SetType): Boolean = flatten(s) match {
case EmptyType => true
case IntersectionType(sts) if isConstructor(sts.last) && isConjunctionLit(IntersectionType(sts.init)) => {
val ConstructorType(_, args) = sts.last
args.forall(isConjunctionLitWithUniversal)
}
case s => isConjunctionLitWithUniversal(s)
}
def isOneLevel(r: Relation): Boolean = r match {
case Include(s1, EmptyType) if isOneLevel(s1) => true
case _ => false
}
def isOneLevel(system: List[Relation]): Boolean = system.forall(isOneLevel)
}
package setconstraints
/*
import org.scalatest.FunSuite
class SolverSuite extends FunSuite {
}
*/
...@@ -15,10 +15,13 @@ object Trees { ...@@ -15,10 +15,13 @@ object Trees {
case class UnionType(sets: Seq[SetType]) extends SetType case class UnionType(sets: Seq[SetType]) extends SetType
case class IntersectionType(sets: Seq[SetType]) extends SetType case class IntersectionType(sets: Seq[SetType]) extends SetType
case class ComplementType(st: SetType) extends SetType
case class FunctionType(s1: SetType, s2: SetType) extends SetType case class FunctionType(s1: SetType, s2: SetType) extends SetType
case class TupleType(sets: Seq[SetType]) extends SetType case class TupleType(sets: Seq[SetType]) extends SetType
case class ConstructorType(name: String, sets: Seq[SetType]) extends SetType case class ConstructorType(name: String, sets: Seq[SetType]) extends SetType
case class VariableType(name: String) extends SetType case class VariableType(name: String) extends SetType
case object EmptyType extends SetType
case object UniversalType extends SetType
private var varCounter = -1 private var varCounter = -1
def freshVar(name: String) = { def freshVar(name: String) = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment