diff --git a/src/setconstraints/CnstrtGen.scala b/src/setconstraints/CnstrtGen.scala index 7ca661d094eaae5a9423cb8479e20591d1fdf16a..d18f2a35ffd5d87e58f7a7bff098ec3680345cf8 100644 --- a/src/setconstraints/CnstrtGen.scala +++ b/src/setconstraints/CnstrtGen.scala @@ -3,8 +3,9 @@ package setconstraints import scala.collection.mutable.{Map, HashMap, ListBuffer} import purescala.Definitions._ -import purescala.Trees.{And => _, _} +import purescala.Trees.{And => _, Equals => _, _} import purescala.Common.Identifier +import purescala.TypeTrees.ClassType import Trees._ @@ -17,20 +18,10 @@ object CnstrtGen { cl2adt: Map[ClassTypeDef, SetType] ): Formula = { - val funCallsCnstr: ListBuffer[Include] = new ListBuffer[Include]() - val patternCnstr: ListBuffer[Include] = new ListBuffer[Include]() + val funCallsCnstr: ListBuffer[Relation] = new ListBuffer[Relation]() + val patternCnstr: ListBuffer[Relation] = new ListBuffer[Relation]() - def addFunCallCnst(fi: FunctionInvocation) { - val (args,_) = funVars(fi.funDef) - args.zip(fi.args).foreach{case (v, expr) => { - val (newT, newCnstr) = cnstrExpr(expr, Map()) - funCallsCnstr ++= newCnstr - funCallsCnstr += Include(v, newT) - } - } - } - - def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (SetType, Seq[Include]) = expr match { + def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (SetType, Seq[Relation]) = expr match { case Variable(id) => { (context(id), Seq()) } @@ -52,7 +43,13 @@ object CnstrtGen { (mType, mCnstrs ++ cCnstrs.flatMap(x => x)) } case FunctionInvocation(fd, args) => { - val rt = funVars(fd)._2 + val (tArgs,rt) = funVars(fd) + tArgs.zip(args).foreach{case (v, expr) => { + val (newT, newCnstr) = cnstrExpr(expr, context) + funCallsCnstr ++= newCnstr + funCallsCnstr += Include(newT, v) + } + } (rt, Seq()) } case CaseClass(ccd, args) => { @@ -71,29 +68,33 @@ object CnstrtGen { case CaseClassPattern(binder, ccd, sps) => { val (subConsType, subVarType) = sps.map(p => pattern2Type(p)).unzip val newMap = subVarType.foldLeft(Map[Identifier, VariableType]())((acc, el) => acc ++ el) + subConsType.zip(ccd.fields)foreach{case (t, vd) => patternCnstr += Equals(t, cl2adt(vd.tpe.asInstanceOf[ClassType].classDef))} //TODO bug if there are nested pattern (ConstructorType(ccd.id.name, subConsType), newMap) } } - def cnstrFun(fd: FunDef): Seq[Include] = { + def cnstrFun(fd: FunDef): Seq[Relation] = { val argsT = funVars(fd)._1 val argsID = fd.args.map(vd => vd.id) val context = argsID.zip(argsT).foldLeft(Map[Identifier, VariableType]())((acc, el) => acc + el) val (bodyType, cnstrts) = cnstrExpr(fd.body.get, context) - cnstrts :+ Include(funVars(fd)._2, bodyType) + cnstrts :+ Include(bodyType, funVars(fd)._2) } - def cnstrTypeHierarchy(pgm: Program): Seq[Include] = { + def cnstrTypeHierarchy(pgm: Program): Seq[Relation] = { val caseClasses = pgm.definedClasses.filter(_.isInstanceOf[CaseClassDef]) caseClasses.map(cc => Include(cl2adt(cc), cl2adt(cc.parent.get))) } val cnstrtsTypes = cnstrTypeHierarchy(pgm) + println(typeVars) + println(cnstrtsTypes) + val funs = pgm.definedFunctions val cnstrtsFunctions = funs.flatMap(cnstrFun) - And(cnstrtsTypes ++ cnstrtsFunctions) + And(cnstrtsTypes ++ cnstrtsFunctions ++ funCallsCnstr ++ patternCnstr) } } diff --git a/src/setconstraints/Main.scala b/src/setconstraints/Main.scala index 4351317faf890a16605946f3fb0dc600845617bb..f501ef2d44f8e91063d99f95bb40390e790493bf 100644 --- a/src/setconstraints/Main.scala +++ b/src/setconstraints/Main.scala @@ -1,6 +1,8 @@ package setconstraints +import Trees._ import purescala.Definitions.Program +import purescala.Definitions.AbstractClassDef import purescala.Reporter import purescala.Extensions.Analyser @@ -15,13 +17,23 @@ class Main(reporter: Reporter) extends Analyser(reporter) { val cnstr = CnstrtGen(pgm, tpeVars, funVars, cl2adt) - reporter.info(tpeVars.toString) - reporter.info(funVars.toString) - reporter.info(cl2adt.toString) - - reporter.info("THE CONSTRAINTS") + reporter.info("The constraints are:") reporter.info(PrettyPrinter(cnstr)) + val classes = pgm.definedClasses + val form = classes.find(_.isInstanceOf[AbstractClassDef]) + + val (Seq(fa), fr) = funVars(pgm.definedFunctions(0)) + val fixpoints = Seq( + FixPoint(fa, cl2adt(form.get)), + FixPoint(fr, UnionType(Seq( + ConstructorType("Or", Seq(fr, fr)), + ConstructorType("Not", Seq(fr)))))) + reporter.info("The least fixpoint is:") + reporter.info(fixpoints.map(PrettyPrinter.apply)) + + MatchAnalyzer(pgm, fixpoints, reporter) + } } diff --git a/src/setconstraints/Manip.scala b/src/setconstraints/Manip.scala index 4d3a91a6c5989c9d6477e6294b1ca7f72994eb21..329a11ce2507c31f877bba9c4118817dc8ba2655 100644 --- a/src/setconstraints/Manip.scala +++ b/src/setconstraints/Manip.scala @@ -4,17 +4,24 @@ import setconstraints.Trees._ object Manip { - def flatten(f: Formula): Formula = { - def flatten0(form: Formula): Formula = form match { - case And(fs) => { - And(fs.foldLeft(Nil: Seq[Formula])((acc, f) => f match { - case And(fs2) => acc ++ fs2.map(flatten0) - case f2 => acc :+ flatten0(f2) - })) - } - case f => f - } - Tools.fix(flatten0, f) + def flatten(f: Formula): Formula = f match { + case And(fs) => And(fs.flatMap(f => flatten(f) match { + case And(fs2) => fs2 + case f => List(f) + })) + case f => f } + def includes(f: Formula): Seq[Include] = flatten(f) match { + case And(fs) if fs.forall(isRelation) => fs.flatMap(f => removeEquals(f.asInstanceOf[Relation])) + case f@_ => error("unexpected formula :" + f) + } + + private def removeEquals(r: Relation): Seq[Include] = r match { + case Equals(s1, s2) => Seq(Include(s1, s2), Include(s2, s1)) + case i@Include(_,_) => Seq(i) + } + + private def isRelation(f: Formula): Boolean = f.isInstanceOf[Relation] + } diff --git a/src/setconstraints/MatchAnalyzer.scala b/src/setconstraints/MatchAnalyzer.scala new file mode 100644 index 0000000000000000000000000000000000000000..768637d936241cdeeedb0644031f83ee889e9c39 --- /dev/null +++ b/src/setconstraints/MatchAnalyzer.scala @@ -0,0 +1,13 @@ +package setconstraints + +import Trees._ +import purescala.Definitions.Program +import purescala.Reporter + +object MatchAnalyzer { + + def apply(pgm: Program, fixPoints: Seq[FixPoint], reporter: Reporter) { + + } + +} diff --git a/src/setconstraints/PrettyPrinter.scala b/src/setconstraints/PrettyPrinter.scala index e7fc6cb89de9b98995a04534c670254e99f38f98..aee0378b98a42a64191a98719fff3da17b2b116f 100644 --- a/src/setconstraints/PrettyPrinter.scala +++ b/src/setconstraints/PrettyPrinter.scala @@ -8,9 +8,12 @@ object PrettyPrinter { def apply(st: SetType): String = ppSetType(st) + def apply(fp: FixPoint): String = ppFixPoint(fp) + private def ppFormula(f: Formula): String = f match { case And(fs) => fs.map(ppFormula).mkString("(", " \u2227 ", ")") case Include(s1, s2) => ppSetType(s1) + " \u2282 " + ppSetType(s2) + case Equals(s1, s2) => ppSetType(s1) + " = " + ppSetType(s2) } private def ppSetType(st: SetType): String = st match { @@ -22,4 +25,8 @@ object PrettyPrinter { case TupleType(sts) => sts.map(ppSetType).mkString("(", ", ", ")") case VariableType(name) => name } + + private def ppFixPoint(fp: FixPoint): String = fp match { + case FixPoint(t, s) => ppSetType(t) + " = " + ppSetType(s) + } } diff --git a/src/setconstraints/Trees.scala b/src/setconstraints/Trees.scala index df124255e077c60c454d28c768174a4d3ed3d710..075c2f2077e632c935732e4ffd812c76a17c81a5 100644 --- a/src/setconstraints/Trees.scala +++ b/src/setconstraints/Trees.scala @@ -11,14 +11,7 @@ object Trees { sealed abstract trait SetType case class Include(s1: SetType, s2: SetType) extends Relation - - object Equals { - def apply(s1: SetType, s2: SetType) = And(List(Include(s1, s2), Include(s2, s1))) - def unapply(f: Formula): Boolean = f match { - case And(List(Include(s1, s2), Include(s3, s4))) if s1 == s4 && s2 == s3 => true - case _ => false - } - } + case class Equals(s1: SetType, s2: SetType) extends Relation case class UnionType(sets: Seq[SetType]) extends SetType case class IntersectionType(sets: Seq[SetType]) extends SetType @@ -33,4 +26,6 @@ object Trees { VariableType(name + "_" + varCounter) } + case class FixPoint(t: VariableType, s: SetType) + }