diff --git a/src/setconstraints/ADTExtractor.scala b/src/setconstraints/ADTExtractor.scala new file mode 100644 index 0000000000000000000000000000000000000000..45647561a34e1d70d9b44972fca1625c961b2ff4 --- /dev/null +++ b/src/setconstraints/ADTExtractor.scala @@ -0,0 +1,36 @@ +package setconstraints + +import purescala.Definitions._ +import purescala.TypeTrees.ClassType +import setconstraints.Trees._ + +import scala.collection.mutable.HashMap + +object ADTExtractor { + + def apply(pgm: Program): HashMap[ClassTypeDef, SetType] = { + val hm = new HashMap[ClassTypeDef, SetType] + var dcls = pgm.definedClasses + while(!dcls.isEmpty) { + val curr = dcls.head + if(curr.isInstanceOf[AbstractClassDef]) { + hm.put(curr, freshVar(curr.id.name)) + dcls = dcls.filterNot(_ == curr) + } else if(curr.isInstanceOf[CaseClassDef]) { + val name = curr.id.name + val fields = curr.asInstanceOf[CaseClassDef].fields + try { + val l = fields.map(vd => hm(vd.tpe.asInstanceOf[ClassType].classDef)).toList + hm.put(curr, ConstructorType(name, l)) + dcls = dcls.filterNot(_ == curr) + } catch { + case _: NoSuchElementException => { + dcls = dcls.tail ++ List(dcls.head) + } + } + } else error("Found a class which is neither an AbstractClassDef nor a CaseClassDef") + } + hm + } + +} diff --git a/src/setconstraints/CnstrtGen.scala b/src/setconstraints/CnstrtGen.scala new file mode 100644 index 0000000000000000000000000000000000000000..7ca661d094eaae5a9423cb8479e20591d1fdf16a --- /dev/null +++ b/src/setconstraints/CnstrtGen.scala @@ -0,0 +1,99 @@ +package setconstraints + +import scala.collection.mutable.{Map, HashMap, ListBuffer} + +import purescala.Definitions._ +import purescala.Trees.{And => _, _} +import purescala.Common.Identifier + + +import Trees._ + +object CnstrtGen { + + def apply(pgm: Program, + typeVars: Map[ClassTypeDef, VariableType], + funVars: Map[FunDef, (Seq[VariableType], VariableType)], + cl2adt: Map[ClassTypeDef, SetType] + ): Formula = { + + val funCallsCnstr: ListBuffer[Include] = new ListBuffer[Include]() + val patternCnstr: ListBuffer[Include] = new ListBuffer[Include]() + + def addFunCallCnst(fi: FunctionInvocation) { + val (args,_) = funVars(fi.funDef) + args.zip(fi.args).foreach{case (v, expr) => { + val (newT, newCnstr) = cnstrExpr(expr, Map()) + funCallsCnstr ++= newCnstr + funCallsCnstr += Include(v, newT) + } + } + } + + def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (SetType, Seq[Include]) = expr match { + case Variable(id) => { + (context(id), Seq()) + } + case IfExpr(cond, then, elze) => { + val (tType, tCnstrs) = cnstrExpr(then, context) + val (eType, eCnstrs) = cnstrExpr(elze, context) + (UnionType(Seq(tType, eType)), tCnstrs ++ eCnstrs) + } + case MatchExpr(scrut, cases) => { + val (sType, sCnstrs) = cnstrExpr(scrut, context) + val (cType, cCnstrs) = cases.map(mc => { + //val theGuard = mc.theGuard + val rhs = mc.rhs + val (pt, pvc) = pattern2Type(mc.pattern) + cnstrExpr(rhs, context ++ pvc) + }).unzip + val mType = freshVar("match") + val mCnstrs = cType.map(t => Include(t, mType)) + (mType, mCnstrs ++ cCnstrs.flatMap(x => x)) + } + case FunctionInvocation(fd, args) => { + val rt = funVars(fd)._2 + (rt, Seq()) + } + case CaseClass(ccd, args) => { + val (argsType, cnstrts) = args.map(e => cnstrExpr(e, context)).unzip + (ConstructorType(ccd.id.name, argsType), cnstrts.flatMap(x => x)) + } + case _ => error("Not yet supported: " + expr) + } + + def pattern2Type(pattern: Pattern): (SetType, Map[Identifier, VariableType]) = pattern match { + case InstanceOfPattern(binder, ctd) => error("not yet supported") + case WildcardPattern(binder) => { + val v = freshVar(binder match {case Some(id) => id.name case None => "x"}) + (v, binder match {case Some(id) => Map(id -> v) case None => Map()}) + } + case CaseClassPattern(binder, ccd, sps) => { + val (subConsType, subVarType) = sps.map(p => pattern2Type(p)).unzip + val newMap = subVarType.foldLeft(Map[Identifier, VariableType]())((acc, el) => acc ++ el) + (ConstructorType(ccd.id.name, subConsType), newMap) + } + } + + def cnstrFun(fd: FunDef): Seq[Include] = { + val argsT = funVars(fd)._1 + val argsID = fd.args.map(vd => vd.id) + val context = argsID.zip(argsT).foldLeft(Map[Identifier, VariableType]())((acc, el) => acc + el) + val (bodyType, cnstrts) = cnstrExpr(fd.body.get, context) + cnstrts :+ Include(funVars(fd)._2, bodyType) + } + + def cnstrTypeHierarchy(pgm: Program): Seq[Include] = { + val caseClasses = pgm.definedClasses.filter(_.isInstanceOf[CaseClassDef]) + caseClasses.map(cc => Include(cl2adt(cc), cl2adt(cc.parent.get))) + } + + val cnstrtsTypes = cnstrTypeHierarchy(pgm) + + val funs = pgm.definedFunctions + val cnstrtsFunctions = funs.flatMap(cnstrFun) + + And(cnstrtsTypes ++ cnstrtsFunctions) + } + +} diff --git a/src/setconstraints/LabelProgram.scala b/src/setconstraints/LabelProgram.scala new file mode 100644 index 0000000000000000000000000000000000000000..0abf494096d9a856666743e97ca809a895eb038c --- /dev/null +++ b/src/setconstraints/LabelProgram.scala @@ -0,0 +1,29 @@ +package setconstraints + +import scala.collection.mutable.{Map, HashMap} + +import purescala.Definitions._ +import setconstraints.Trees._ + +object LabelProgram { + + def apply(pgm: Program): (Map[ClassTypeDef, VariableType], + Map[FunDef, (Seq[VariableType], VariableType)]) = + (labelTypeHierarchy(pgm), labelFunction(pgm)) + + + private def labelFunction(pgm: Program): Map[FunDef, (Seq[VariableType], VariableType)] = { + val hm = new HashMap[FunDef, (Seq[VariableType], VariableType)] + pgm.definedFunctions.foreach(fd => { + val varTypes = (fd.args.map(vd => freshVar(fd.id.name + "_arg_" + vd.id.name)), freshVar(fd.id.name + "_return")) + hm.put(fd, varTypes) + }) + hm + } + + private def labelTypeHierarchy(pgm: Program): Map[ClassTypeDef, VariableType] = { + val hm = new HashMap[ClassTypeDef, VariableType] + pgm.definedClasses.foreach(clDef => hm.put(clDef, freshVar(clDef.id.name))) + hm + } +} diff --git a/src/setconstraints/Main.scala b/src/setconstraints/Main.scala index d7deb70f116c0ac7d9ec05edc00a99b8cd7e1898..4351317faf890a16605946f3fb0dc600845617bb 100644 --- a/src/setconstraints/Main.scala +++ b/src/setconstraints/Main.scala @@ -1,15 +1,27 @@ package setconstraints -import purescala.Extensions._ -import purescala.Definitions._ -import purescala.Trees._ +import purescala.Definitions.Program import purescala.Reporter +import purescala.Extensions.Analyser class Main(reporter: Reporter) extends Analyser(reporter) { val description: String = "Analyser for advanced type inference based on set constraints" override val shortDescription = "Set constraints" - def analyse(program: Program) : Unit = { - reporter.info("Nothing to do in this analysis.") + def analyse(pgm: Program) : Unit = { + + val (tpeVars, funVars) = LabelProgram(pgm) + val cl2adt = ADTExtractor(pgm) + + val cnstr = CnstrtGen(pgm, tpeVars, funVars, cl2adt) + + reporter.info(tpeVars.toString) + reporter.info(funVars.toString) + reporter.info(cl2adt.toString) + + reporter.info("THE CONSTRAINTS") + reporter.info(PrettyPrinter(cnstr)) + } + } diff --git a/src/setconstraints/PrettyPrinter.scala b/src/setconstraints/PrettyPrinter.scala new file mode 100644 index 0000000000000000000000000000000000000000..e7fc6cb89de9b98995a04534c670254e99f38f98 --- /dev/null +++ b/src/setconstraints/PrettyPrinter.scala @@ -0,0 +1,25 @@ +package setconstraints + +import setconstraints.Trees._ + +object PrettyPrinter { + + def apply(f: Formula): String = ppFormula(f) + + def apply(st: SetType): String = ppSetType(st) + + private def ppFormula(f: Formula): String = f match { + case And(fs) => fs.map(ppFormula).mkString("(", " \u2227 ", ")") + case Include(s1, s2) => ppSetType(s1) + " \u2282 " + ppSetType(s2) + } + + private def ppSetType(st: SetType): String = st match { + case ConstructorType(name, Seq()) => name + case ConstructorType(name, sts) => name + sts.map(ppSetType).mkString("(", ", ", ")") + case UnionType(sts) => sts.map(ppSetType).mkString("(", " \u222A ", ")") + case IntersectionType(sts) => sts.map(ppSetType).mkString("(", " \u2229 ", ")") + case FunctionType(s1, s2) => "(" + ppSetType(s1) + " --> " + ppSetType(s2) + ")" + case TupleType(sts) => sts.map(ppSetType).mkString("(", ", ", ")") + case VariableType(name) => name + } +} diff --git a/src/setconstraints/Tools.scala b/src/setconstraints/Tools.scala new file mode 100644 index 0000000000000000000000000000000000000000..98257a867a301508829e650727c1e22b1428565d --- /dev/null +++ b/src/setconstraints/Tools.scala @@ -0,0 +1,11 @@ +package setconstraints + +import purescala.Definitions._ + +object Tools { + + def childOf(root: ClassTypeDef, classes: Seq[ClassTypeDef]) = classes.filter(_.parent == root) + + def toCaseClasses(classes: Seq[ClassTypeDef]): Seq[CaseClassDef] = + classes.filter(_.isInstanceOf[CaseClassDef]).map(_.asInstanceOf[CaseClassDef]) +} diff --git a/src/setconstraints/Trees.scala b/src/setconstraints/Trees.scala new file mode 100644 index 0000000000000000000000000000000000000000..df124255e077c60c454d28c768174a4d3ed3d710 --- /dev/null +++ b/src/setconstraints/Trees.scala @@ -0,0 +1,36 @@ +package setconstraints + +object Trees { + + sealed trait Formula + + case class And(fs: Seq[Formula]) extends Formula + + sealed abstract trait Relation extends Formula + + sealed abstract trait SetType + + case class Include(s1: SetType, s2: SetType) extends Relation + + object Equals { + def apply(s1: SetType, s2: SetType) = And(List(Include(s1, s2), Include(s2, s1))) + def unapply(f: Formula): Boolean = f match { + case And(List(Include(s1, s2), Include(s3, s4))) if s1 == s4 && s2 == s3 => true + case _ => false + } + } + + case class UnionType(sets: Seq[SetType]) extends SetType + case class IntersectionType(sets: Seq[SetType]) extends SetType + case class FunctionType(s1: SetType, s2: SetType) extends SetType + case class TupleType(sets: Seq[SetType]) extends SetType + case class ConstructorType(name: String, sets: Seq[SetType]) extends SetType + case class VariableType(name: String) extends SetType + + private var varCounter = -1 + def freshVar(name: String) = { + varCounter += 1 + VariableType(name + "_" + varCounter) + } + +}