Skip to content
Snippets Groups Projects
Commit 30be81b3 authored by Régis Blanc's avatar Régis Blanc
Browse files

No commit message

No commit message
parent 3ff6d8bb
No related branches found
No related tags found
No related merge requests found
package setconstraints
import purescala.Definitions._
import purescala.TypeTrees.ClassType
import setconstraints.Trees._
import scala.collection.mutable.HashMap
object ADTExtractor {
def apply(pgm: Program): HashMap[ClassTypeDef, SetType] = {
val hm = new HashMap[ClassTypeDef, SetType]
var dcls = pgm.definedClasses
while(!dcls.isEmpty) {
val curr = dcls.head
if(curr.isInstanceOf[AbstractClassDef]) {
hm.put(curr, freshVar(curr.id.name))
dcls = dcls.filterNot(_ == curr)
} else if(curr.isInstanceOf[CaseClassDef]) {
val name = curr.id.name
val fields = curr.asInstanceOf[CaseClassDef].fields
try {
val l = fields.map(vd => hm(vd.tpe.asInstanceOf[ClassType].classDef)).toList
hm.put(curr, ConstructorType(name, l))
dcls = dcls.filterNot(_ == curr)
} catch {
case _: NoSuchElementException => {
dcls = dcls.tail ++ List(dcls.head)
}
}
} else error("Found a class which is neither an AbstractClassDef nor a CaseClassDef")
}
hm
}
}
package setconstraints
import scala.collection.mutable.{Map, HashMap, ListBuffer}
import purescala.Definitions._
import purescala.Trees.{And => _, _}
import purescala.Common.Identifier
import Trees._
object CnstrtGen {
def apply(pgm: Program,
typeVars: Map[ClassTypeDef, VariableType],
funVars: Map[FunDef, (Seq[VariableType], VariableType)],
cl2adt: Map[ClassTypeDef, SetType]
): Formula = {
val funCallsCnstr: ListBuffer[Include] = new ListBuffer[Include]()
val patternCnstr: ListBuffer[Include] = new ListBuffer[Include]()
def addFunCallCnst(fi: FunctionInvocation) {
val (args,_) = funVars(fi.funDef)
args.zip(fi.args).foreach{case (v, expr) => {
val (newT, newCnstr) = cnstrExpr(expr, Map())
funCallsCnstr ++= newCnstr
funCallsCnstr += Include(v, newT)
}
}
}
def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (SetType, Seq[Include]) = expr match {
case Variable(id) => {
(context(id), Seq())
}
case IfExpr(cond, then, elze) => {
val (tType, tCnstrs) = cnstrExpr(then, context)
val (eType, eCnstrs) = cnstrExpr(elze, context)
(UnionType(Seq(tType, eType)), tCnstrs ++ eCnstrs)
}
case MatchExpr(scrut, cases) => {
val (sType, sCnstrs) = cnstrExpr(scrut, context)
val (cType, cCnstrs) = cases.map(mc => {
//val theGuard = mc.theGuard
val rhs = mc.rhs
val (pt, pvc) = pattern2Type(mc.pattern)
cnstrExpr(rhs, context ++ pvc)
}).unzip
val mType = freshVar("match")
val mCnstrs = cType.map(t => Include(t, mType))
(mType, mCnstrs ++ cCnstrs.flatMap(x => x))
}
case FunctionInvocation(fd, args) => {
val rt = funVars(fd)._2
(rt, Seq())
}
case CaseClass(ccd, args) => {
val (argsType, cnstrts) = args.map(e => cnstrExpr(e, context)).unzip
(ConstructorType(ccd.id.name, argsType), cnstrts.flatMap(x => x))
}
case _ => error("Not yet supported: " + expr)
}
def pattern2Type(pattern: Pattern): (SetType, Map[Identifier, VariableType]) = pattern match {
case InstanceOfPattern(binder, ctd) => error("not yet supported")
case WildcardPattern(binder) => {
val v = freshVar(binder match {case Some(id) => id.name case None => "x"})
(v, binder match {case Some(id) => Map(id -> v) case None => Map()})
}
case CaseClassPattern(binder, ccd, sps) => {
val (subConsType, subVarType) = sps.map(p => pattern2Type(p)).unzip
val newMap = subVarType.foldLeft(Map[Identifier, VariableType]())((acc, el) => acc ++ el)
(ConstructorType(ccd.id.name, subConsType), newMap)
}
}
def cnstrFun(fd: FunDef): Seq[Include] = {
val argsT = funVars(fd)._1
val argsID = fd.args.map(vd => vd.id)
val context = argsID.zip(argsT).foldLeft(Map[Identifier, VariableType]())((acc, el) => acc + el)
val (bodyType, cnstrts) = cnstrExpr(fd.body.get, context)
cnstrts :+ Include(funVars(fd)._2, bodyType)
}
def cnstrTypeHierarchy(pgm: Program): Seq[Include] = {
val caseClasses = pgm.definedClasses.filter(_.isInstanceOf[CaseClassDef])
caseClasses.map(cc => Include(cl2adt(cc), cl2adt(cc.parent.get)))
}
val cnstrtsTypes = cnstrTypeHierarchy(pgm)
val funs = pgm.definedFunctions
val cnstrtsFunctions = funs.flatMap(cnstrFun)
And(cnstrtsTypes ++ cnstrtsFunctions)
}
}
package setconstraints
import scala.collection.mutable.{Map, HashMap}
import purescala.Definitions._
import setconstraints.Trees._
object LabelProgram {
def apply(pgm: Program): (Map[ClassTypeDef, VariableType],
Map[FunDef, (Seq[VariableType], VariableType)]) =
(labelTypeHierarchy(pgm), labelFunction(pgm))
private def labelFunction(pgm: Program): Map[FunDef, (Seq[VariableType], VariableType)] = {
val hm = new HashMap[FunDef, (Seq[VariableType], VariableType)]
pgm.definedFunctions.foreach(fd => {
val varTypes = (fd.args.map(vd => freshVar(fd.id.name + "_arg_" + vd.id.name)), freshVar(fd.id.name + "_return"))
hm.put(fd, varTypes)
})
hm
}
private def labelTypeHierarchy(pgm: Program): Map[ClassTypeDef, VariableType] = {
val hm = new HashMap[ClassTypeDef, VariableType]
pgm.definedClasses.foreach(clDef => hm.put(clDef, freshVar(clDef.id.name)))
hm
}
}
package setconstraints package setconstraints
import purescala.Extensions._ import purescala.Definitions.Program
import purescala.Definitions._
import purescala.Trees._
import purescala.Reporter import purescala.Reporter
import purescala.Extensions.Analyser
class Main(reporter: Reporter) extends Analyser(reporter) { class Main(reporter: Reporter) extends Analyser(reporter) {
val description: String = "Analyser for advanced type inference based on set constraints" val description: String = "Analyser for advanced type inference based on set constraints"
override val shortDescription = "Set constraints" override val shortDescription = "Set constraints"
def analyse(program: Program) : Unit = { def analyse(pgm: Program) : Unit = {
reporter.info("Nothing to do in this analysis.")
val (tpeVars, funVars) = LabelProgram(pgm)
val cl2adt = ADTExtractor(pgm)
val cnstr = CnstrtGen(pgm, tpeVars, funVars, cl2adt)
reporter.info(tpeVars.toString)
reporter.info(funVars.toString)
reporter.info(cl2adt.toString)
reporter.info("THE CONSTRAINTS")
reporter.info(PrettyPrinter(cnstr))
} }
} }
package setconstraints
import setconstraints.Trees._
object PrettyPrinter {
def apply(f: Formula): String = ppFormula(f)
def apply(st: SetType): String = ppSetType(st)
private def ppFormula(f: Formula): String = f match {
case And(fs) => fs.map(ppFormula).mkString("(", " \u2227 ", ")")
case Include(s1, s2) => ppSetType(s1) + " \u2282 " + ppSetType(s2)
}
private def ppSetType(st: SetType): String = st match {
case ConstructorType(name, Seq()) => name
case ConstructorType(name, sts) => name + sts.map(ppSetType).mkString("(", ", ", ")")
case UnionType(sts) => sts.map(ppSetType).mkString("(", " \u222A ", ")")
case IntersectionType(sts) => sts.map(ppSetType).mkString("(", " \u2229 ", ")")
case FunctionType(s1, s2) => "(" + ppSetType(s1) + " --> " + ppSetType(s2) + ")"
case TupleType(sts) => sts.map(ppSetType).mkString("(", ", ", ")")
case VariableType(name) => name
}
}
package setconstraints
import purescala.Definitions._
object Tools {
def childOf(root: ClassTypeDef, classes: Seq[ClassTypeDef]) = classes.filter(_.parent == root)
def toCaseClasses(classes: Seq[ClassTypeDef]): Seq[CaseClassDef] =
classes.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef])
}
package setconstraints
object Trees {
sealed trait Formula
case class And(fs: Seq[Formula]) extends Formula
sealed abstract trait Relation extends Formula
sealed abstract trait SetType
case class Include(s1: SetType, s2: SetType) extends Relation
object Equals {
def apply(s1: SetType, s2: SetType) = And(List(Include(s1, s2), Include(s2, s1)))
def unapply(f: Formula): Boolean = f match {
case And(List(Include(s1, s2), Include(s3, s4))) if s1 == s4 && s2 == s3 => true
case _ => false
}
}
case class UnionType(sets: Seq[SetType]) extends SetType
case class IntersectionType(sets: Seq[SetType]) extends SetType
case class FunctionType(s1: SetType, s2: SetType) extends SetType
case class TupleType(sets: Seq[SetType]) extends SetType
case class ConstructorType(name: String, sets: Seq[SetType]) extends SetType
case class VariableType(name: String) extends SetType
private var varCounter = -1
def freshVar(name: String) = {
varCounter += 1
VariableType(name + "_" + varCounter)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment