Skip to content
Snippets Groups Projects
Commit 7548f04c authored by Régis Blanc's avatar Régis Blanc
Browse files

No commit message

No commit message
parent 4abf29f0
Branches
Tags
No related merge requests found
......@@ -3,8 +3,9 @@ package setconstraints
import scala.collection.mutable.{Map, HashMap, ListBuffer}
import purescala.Definitions._
import purescala.Trees.{And => _, _}
import purescala.Trees.{And => _, Equals => _, _}
import purescala.Common.Identifier
import purescala.TypeTrees.ClassType
import Trees._
......@@ -17,20 +18,10 @@ object CnstrtGen {
cl2adt: Map[ClassTypeDef, SetType]
): Formula = {
val funCallsCnstr: ListBuffer[Include] = new ListBuffer[Include]()
val patternCnstr: ListBuffer[Include] = new ListBuffer[Include]()
val funCallsCnstr: ListBuffer[Relation] = new ListBuffer[Relation]()
val patternCnstr: ListBuffer[Relation] = new ListBuffer[Relation]()
def addFunCallCnst(fi: FunctionInvocation) {
val (args,_) = funVars(fi.funDef)
args.zip(fi.args).foreach{case (v, expr) => {
val (newT, newCnstr) = cnstrExpr(expr, Map())
funCallsCnstr ++= newCnstr
funCallsCnstr += Include(v, newT)
}
}
}
def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (SetType, Seq[Include]) = expr match {
def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (SetType, Seq[Relation]) = expr match {
case Variable(id) => {
(context(id), Seq())
}
......@@ -52,7 +43,13 @@ object CnstrtGen {
(mType, mCnstrs ++ cCnstrs.flatMap(x => x))
}
case FunctionInvocation(fd, args) => {
val rt = funVars(fd)._2
val (tArgs,rt) = funVars(fd)
tArgs.zip(args).foreach{case (v, expr) => {
val (newT, newCnstr) = cnstrExpr(expr, context)
funCallsCnstr ++= newCnstr
funCallsCnstr += Include(newT, v)
}
}
(rt, Seq())
}
case CaseClass(ccd, args) => {
......@@ -71,29 +68,33 @@ object CnstrtGen {
case CaseClassPattern(binder, ccd, sps) => {
val (subConsType, subVarType) = sps.map(p => pattern2Type(p)).unzip
val newMap = subVarType.foldLeft(Map[Identifier, VariableType]())((acc, el) => acc ++ el)
subConsType.zip(ccd.fields)foreach{case (t, vd) => patternCnstr += Equals(t, cl2adt(vd.tpe.asInstanceOf[ClassType].classDef))} //TODO bug if there are nested pattern
(ConstructorType(ccd.id.name, subConsType), newMap)
}
}
def cnstrFun(fd: FunDef): Seq[Include] = {
def cnstrFun(fd: FunDef): Seq[Relation] = {
val argsT = funVars(fd)._1
val argsID = fd.args.map(vd => vd.id)
val context = argsID.zip(argsT).foldLeft(Map[Identifier, VariableType]())((acc, el) => acc + el)
val (bodyType, cnstrts) = cnstrExpr(fd.body.get, context)
cnstrts :+ Include(funVars(fd)._2, bodyType)
cnstrts :+ Include(bodyType, funVars(fd)._2)
}
def cnstrTypeHierarchy(pgm: Program): Seq[Include] = {
def cnstrTypeHierarchy(pgm: Program): Seq[Relation] = {
val caseClasses = pgm.definedClasses.filter(_.isInstanceOf[CaseClassDef])
caseClasses.map(cc => Include(cl2adt(cc), cl2adt(cc.parent.get)))
}
val cnstrtsTypes = cnstrTypeHierarchy(pgm)
println(typeVars)
println(cnstrtsTypes)
val funs = pgm.definedFunctions
val cnstrtsFunctions = funs.flatMap(cnstrFun)
And(cnstrtsTypes ++ cnstrtsFunctions)
And(cnstrtsTypes ++ cnstrtsFunctions ++ funCallsCnstr ++ patternCnstr)
}
}
package setconstraints
import Trees._
import purescala.Definitions.Program
import purescala.Definitions.AbstractClassDef
import purescala.Reporter
import purescala.Extensions.Analyser
......@@ -15,13 +17,23 @@ class Main(reporter: Reporter) extends Analyser(reporter) {
val cnstr = CnstrtGen(pgm, tpeVars, funVars, cl2adt)
reporter.info(tpeVars.toString)
reporter.info(funVars.toString)
reporter.info(cl2adt.toString)
reporter.info("THE CONSTRAINTS")
reporter.info("The constraints are:")
reporter.info(PrettyPrinter(cnstr))
val classes = pgm.definedClasses
val form = classes.find(_.isInstanceOf[AbstractClassDef])
val (Seq(fa), fr) = funVars(pgm.definedFunctions(0))
val fixpoints = Seq(
FixPoint(fa, cl2adt(form.get)),
FixPoint(fr, UnionType(Seq(
ConstructorType("Or", Seq(fr, fr)),
ConstructorType("Not", Seq(fr))))))
reporter.info("The least fixpoint is:")
reporter.info(fixpoints.map(PrettyPrinter.apply))
MatchAnalyzer(pgm, fixpoints, reporter)
}
}
......@@ -4,17 +4,24 @@ import setconstraints.Trees._
object Manip {
def flatten(f: Formula): Formula = {
def flatten0(form: Formula): Formula = form match {
case And(fs) => {
And(fs.foldLeft(Nil: Seq[Formula])((acc, f) => f match {
case And(fs2) => acc ++ fs2.map(flatten0)
case f2 => acc :+ flatten0(f2)
}))
}
case f => f
}
Tools.fix(flatten0, f)
def flatten(f: Formula): Formula = f match {
case And(fs) => And(fs.flatMap(f => flatten(f) match {
case And(fs2) => fs2
case f => List(f)
}))
case f => f
}
def includes(f: Formula): Seq[Include] = flatten(f) match {
case And(fs) if fs.forall(isRelation) => fs.flatMap(f => removeEquals(f.asInstanceOf[Relation]))
case f@_ => error("unexpected formula :" + f)
}
private def removeEquals(r: Relation): Seq[Include] = r match {
case Equals(s1, s2) => Seq(Include(s1, s2), Include(s2, s1))
case i@Include(_,_) => Seq(i)
}
private def isRelation(f: Formula): Boolean = f.isInstanceOf[Relation]
}
package setconstraints
import Trees._
import purescala.Definitions.Program
import purescala.Reporter
object MatchAnalyzer {
def apply(pgm: Program, fixPoints: Seq[FixPoint], reporter: Reporter) {
}
}
......@@ -8,9 +8,12 @@ object PrettyPrinter {
def apply(st: SetType): String = ppSetType(st)
def apply(fp: FixPoint): String = ppFixPoint(fp)
private def ppFormula(f: Formula): String = f match {
case And(fs) => fs.map(ppFormula).mkString("(", " \u2227 ", ")")
case Include(s1, s2) => ppSetType(s1) + " \u2282 " + ppSetType(s2)
case Equals(s1, s2) => ppSetType(s1) + " = " + ppSetType(s2)
}
private def ppSetType(st: SetType): String = st match {
......@@ -22,4 +25,8 @@ object PrettyPrinter {
case TupleType(sts) => sts.map(ppSetType).mkString("(", ", ", ")")
case VariableType(name) => name
}
private def ppFixPoint(fp: FixPoint): String = fp match {
case FixPoint(t, s) => ppSetType(t) + " = " + ppSetType(s)
}
}
......@@ -11,14 +11,7 @@ object Trees {
sealed abstract trait SetType
case class Include(s1: SetType, s2: SetType) extends Relation
object Equals {
def apply(s1: SetType, s2: SetType) = And(List(Include(s1, s2), Include(s2, s1)))
def unapply(f: Formula): Boolean = f match {
case And(List(Include(s1, s2), Include(s3, s4))) if s1 == s4 && s2 == s3 => true
case _ => false
}
}
case class Equals(s1: SetType, s2: SetType) extends Relation
case class UnionType(sets: Seq[SetType]) extends SetType
case class IntersectionType(sets: Seq[SetType]) extends SetType
......@@ -33,4 +26,6 @@ object Trees {
VariableType(name + "_" + varCounter)
}
case class FixPoint(t: VariableType, s: SetType)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment