diff --git a/src/setconstraints/CnstrtGen.scala b/src/setconstraints/CnstrtGen.scala index d18f2a35ffd5d87e58f7a7bff098ec3680345cf8..975a70a5ba8159abb6cb2a96e00f319297c3fe9f 100644 --- a/src/setconstraints/CnstrtGen.scala +++ b/src/setconstraints/CnstrtGen.scala @@ -1,13 +1,10 @@ package setconstraints -import scala.collection.mutable.{Map, HashMap, ListBuffer} - import purescala.Definitions._ import purescala.Trees.{And => _, Equals => _, _} import purescala.Common.Identifier import purescala.TypeTrees.ClassType - import Trees._ object CnstrtGen { @@ -18,67 +15,85 @@ object CnstrtGen { cl2adt: Map[ClassTypeDef, SetType] ): Formula = { - val funCallsCnstr: ListBuffer[Relation] = new ListBuffer[Relation]() - val patternCnstr: ListBuffer[Relation] = new ListBuffer[Relation]() + def unzip3[A,B,C](seqs: Seq[(A,B,C)]): (Seq[A],Seq[B],Seq[C]) = + seqs.foldLeft((Seq[A](), Seq[B](), Seq[C]()))((a, t) => (t._1 +: a._1, t._2 +: a._2, t._3 +: a._3)) - def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (SetType, Seq[Relation]) = expr match { - case Variable(id) => { - (context(id), Seq()) - } - case IfExpr(cond, then, elze) => { - val (tType, tCnstrs) = cnstrExpr(then, context) - val (eType, eCnstrs) = cnstrExpr(elze, context) - (UnionType(Seq(tType, eType)), tCnstrs ++ eCnstrs) - } - case MatchExpr(scrut, cases) => { - val (sType, sCnstrs) = cnstrExpr(scrut, context) - val (cType, cCnstrs) = cases.map(mc => { - //val theGuard = mc.theGuard - val rhs = mc.rhs - val (pt, pvc) = pattern2Type(mc.pattern) - cnstrExpr(rhs, context ++ pvc) - }).unzip - val mType = freshVar("match") - val mCnstrs = cType.map(t => Include(t, mType)) - (mType, mCnstrs ++ cCnstrs.flatMap(x => x)) - } - case FunctionInvocation(fd, args) => { - val (tArgs,rt) = funVars(fd) - tArgs.zip(args).foreach{case (v, expr) => { - val (newT, newCnstr) = cnstrExpr(expr, context) - funCallsCnstr ++= newCnstr - funCallsCnstr += Include(newT, v) + def cnstrExpr(expr: Expr, context: Map[Identifier, VariableType]): (VariableType, Seq[Relation], Map[Expr, VariableType]) = { + val exprVarType = freshVar("expr") + val (rels, e2t) = expr match { + case Variable(id) => { + (Seq(Equals(context(id), exprVarType)), Map[Expr, VariableType]()) + } + case IfExpr(cond, then, elze) => { + val (tType, tCnstrs, tMap) = cnstrExpr(then, context) + val (eType, eCnstrs, eMap) = cnstrExpr(elze, context) + val newCnstrs = Equals(UnionType(Seq(tType, eType)), exprVarType) +: (tCnstrs ++ eCnstrs) + (newCnstrs, (tMap ++ eMap)) + } + case MatchExpr(scrut, cases) => { + val (sType, sCnstrs, sMap) = cnstrExpr(scrut, context) + val (pts, ptexpcnstr) = cases.map(mc => { + val (pt, cnstrs, pvc) = pattern2Type(mc.pattern) + val (expT, expC, expM) = cnstrExpr(mc.rhs, context ++ pvc) + (pt, (expT, expC ++ cnstrs, expM)) + }).unzip + val (cTypes, cCnstrs, cMaps) = unzip3(ptexpcnstr) + val mCnstrs = cTypes.map(t => Include(t, exprVarType)) + val scrutPatternCnstr = Include(sType, UnionType(pts)) + val fMap: Map[Expr, VariableType] = cMaps.foldLeft(sMap)((a, m) => a ++ m) + val finalCnstrs = scrutPatternCnstr +: (mCnstrs ++ cCnstrs.flatMap(x => x) ++ sCnstrs) + (finalCnstrs, fMap) + } + case FunctionInvocation(fd, args) => { + val (tArgs,rt) = funVars(fd) + /* + tArgs.zip(args).foreach{case (v, expr) => { + val (newT, newCnstr) = cnstrExpr(expr, context) + funCallsCnstr ++= newCnstr + funCallsCnstr += Include(newT, v) + } } + */ + (Seq(Equals(rt, exprVarType)), Map[Expr, VariableType]()) } - (rt, Seq()) - } - case CaseClass(ccd, args) => { - val (argsType, cnstrts) = args.map(e => cnstrExpr(e, context)).unzip - (ConstructorType(ccd.id.name, argsType), cnstrts.flatMap(x => x)) + case CaseClass(ccd, args) => { + val (argsType, cnstrts, maps) = unzip3(args.map(e => cnstrExpr(e, context))) + val fMap = maps.foldLeft(Map[Expr, VariableType]())((a, m) => a ++ m) + val fcnstrts = Equals(ConstructorType(ccd.id.name, argsType), exprVarType) +: cnstrts.flatMap(x => x) + (fcnstrts, fMap) + } + case _ => error("Not yet supported: " + expr) } - case _ => error("Not yet supported: " + expr) + (exprVarType, rels, (e2t: Map[Expr, VariableType]) + (expr -> exprVarType)) } - def pattern2Type(pattern: Pattern): (SetType, Map[Identifier, VariableType]) = pattern match { + def pattern2Type(pattern: Pattern): (VariableType, Seq[Relation], Map[Identifier, VariableType]) = pattern match { case InstanceOfPattern(binder, ctd) => error("not yet supported") case WildcardPattern(binder) => { val v = freshVar(binder match {case Some(id) => id.name case None => "x"}) - (v, binder match {case Some(id) => Map(id -> v) case None => Map()}) + (v, Seq[Relation](), binder match {case Some(id) => Map(id -> v) case None => Map()}) } case CaseClassPattern(binder, ccd, sps) => { - val (subConsType, subVarType) = sps.map(p => pattern2Type(p)).unzip + val cvt = freshVar(ccd.id.name) + val (subConsType, cnstrs, subVarType) = unzip3(sps.map(p => pattern2Type(p))) val newMap = subVarType.foldLeft(Map[Identifier, VariableType]())((acc, el) => acc ++ el) - subConsType.zip(ccd.fields)foreach{case (t, vd) => patternCnstr += Equals(t, cl2adt(vd.tpe.asInstanceOf[ClassType].classDef))} //TODO bug if there are nested pattern - (ConstructorType(ccd.id.name, subConsType), newMap) + val nCnstrs: Seq[Relation] = subConsType.zip(ccd.fields).zip(sps).foldLeft(cnstrs.flatMap(x => x))((a, el) => el match { + case ((t, vd), sp) => sp match { + case WildcardPattern(_) => a :+ Equals(t, cl2adt(vd.tpe.asInstanceOf[ClassType].classDef)) + case _ => a + } + }) + val ccnstr = Equals(ConstructorType(ccd.id.name, subConsType), cvt) + (cvt, ccnstr +: nCnstrs, newMap) } } - def cnstrFun(fd: FunDef): Seq[Relation] = { + def cnstrFun(fd: FunDef): (Seq[Relation], Map[Expr, VariableType]) = { val argsT = funVars(fd)._1 val argsID = fd.args.map(vd => vd.id) val context = argsID.zip(argsT).foldLeft(Map[Identifier, VariableType]())((acc, el) => acc + el) - val (bodyType, cnstrts) = cnstrExpr(fd.body.get, context) - cnstrts :+ Include(bodyType, funVars(fd)._2) + val (bodyType, cnstrts, map) = cnstrExpr(fd.body.get, context) + (cnstrts :+ Include(bodyType, funVars(fd)._2), map) } def cnstrTypeHierarchy(pgm: Program): Seq[Relation] = { @@ -88,13 +103,12 @@ object CnstrtGen { val cnstrtsTypes = cnstrTypeHierarchy(pgm) - println(typeVars) - println(cnstrtsTypes) - val funs = pgm.definedFunctions - val cnstrtsFunctions = funs.flatMap(cnstrFun) - - And(cnstrtsTypes ++ cnstrtsFunctions ++ funCallsCnstr ++ patternCnstr) + val (cnstrtsFunctions, map) = funs.foldLeft(Seq[Relation](), Map[Expr, VariableType]())((a, f) => { + val (rels, m) = cnstrFun(f) + (a._1 ++ rels, a._2 ++ m) + }) + And(cnstrtsTypes ++ cnstrtsFunctions) } } diff --git a/src/setconstraints/Main.scala b/src/setconstraints/Main.scala index f501ef2d44f8e91063d99f95bb40390e790493bf..d843f83833bface3446373f2e4033892e13ea55b 100644 --- a/src/setconstraints/Main.scala +++ b/src/setconstraints/Main.scala @@ -15,7 +15,7 @@ class Main(reporter: Reporter) extends Analyser(reporter) { val (tpeVars, funVars) = LabelProgram(pgm) val cl2adt = ADTExtractor(pgm) - val cnstr = CnstrtGen(pgm, tpeVars, funVars, cl2adt) + val cnstr = CnstrtGen(pgm, Map(tpeVars.toList: _*), Map(funVars.toList: _*), Map(cl2adt.toList: _*)) reporter.info("The constraints are:") reporter.info(PrettyPrinter(cnstr)) diff --git a/src/setconstraints/Manip.scala b/src/setconstraints/Manip.scala index 329a11ce2507c31f877bba9c4118817dc8ba2655..64e64fcae3b3315c636231758bd0e3b181ed2f0f 100644 --- a/src/setconstraints/Manip.scala +++ b/src/setconstraints/Manip.scala @@ -4,12 +4,56 @@ import setconstraints.Trees._ object Manip { - def flatten(f: Formula): Formula = f match { - case And(fs) => And(fs.flatMap(f => flatten(f) match { - case And(fs2) => fs2 - case f => List(f) - })) - case f => f + def map(s: SetType, f: (SetType) => SetType): SetType = s match { + case EmptyType | UniversalType | VariableType(_) => f(s) + case UnionType(sts) => f(UnionType(sts.map(s => map(s, f)))) + case IntersectionType(sts) => f(IntersectionType(sts.map(s => map(s, f)))) + case ComplementType(s) => f(ComplementType(map(s, f))) + case ConstructorType(n@_, sts) => f(ConstructorType(n, sts.map(s => map(s, f)))) + case FunctionType(s1, s2) => { + val ns1 = map(s1, f) + val ns2 = map(s2, f) + f(FunctionType(ns1, ns2)) + } + case TupleType(sts) => f(TupleType(sts.map(s => f(s)))) + } + def map(f: Formula, ff: (Formula) => Formula, ft: (SetType) => SetType): Formula = f match { + case And(fs) => ff(And(fs.map(f => map(f, ff, ft)))) + case Include(s1, s2) => { + val ns1 = map(s1, ft) + val ns2 = map(s2, ft) + ff(Include(ns1, ns2)) + } + case Equals(s1, s2) => { + val ns1 = map(s1, ft) + val ns2 = map(s2, ft) + ff(Equals(ns1, ns2)) + } + } + + def flatten(formula: Formula): Formula = { + def flatten0(f: Formula) = f match { + case And(fs) => And(fs.flatMap{ + case And(fs2) => fs2 + case f => List(f) + }) + case f => f + } + map(formula, flatten0, s => s) + } + def flatten(setType: SetType): SetType = { + def flatten0(s: SetType) = s match { + case UnionType(sts) => UnionType(sts.flatMap{ + case UnionType(sts2) => sts2 + case s => List(s) + }) + case IntersectionType(sts) => IntersectionType(sts.flatMap{ + case IntersectionType(sts2) => sts2 + case s => List(s) + }) + case s => s + } + map(setType, flatten0) } def includes(f: Formula): Seq[Include] = flatten(f) match { diff --git a/src/setconstraints/PrettyPrinter.scala b/src/setconstraints/PrettyPrinter.scala index aee0378b98a42a64191a98719fff3da17b2b116f..d1b5bc4f56fa760acf2e1360f703aae353b25420 100644 --- a/src/setconstraints/PrettyPrinter.scala +++ b/src/setconstraints/PrettyPrinter.scala @@ -11,8 +11,8 @@ object PrettyPrinter { def apply(fp: FixPoint): String = ppFixPoint(fp) private def ppFormula(f: Formula): String = f match { - case And(fs) => fs.map(ppFormula).mkString("(", " \u2227 ", ")") - case Include(s1, s2) => ppSetType(s1) + " \u2282 " + ppSetType(s2) + case And(fs) => fs.map(ppFormula).mkString("( ", "\n \u2227 ", ")") + case Include(s1, s2) => ppSetType(s1) + " \u2286 " + ppSetType(s2) case Equals(s1, s2) => ppSetType(s1) + " = " + ppSetType(s2) } @@ -21,9 +21,12 @@ object PrettyPrinter { case ConstructorType(name, sts) => name + sts.map(ppSetType).mkString("(", ", ", ")") case UnionType(sts) => sts.map(ppSetType).mkString("(", " \u222A ", ")") case IntersectionType(sts) => sts.map(ppSetType).mkString("(", " \u2229 ", ")") + case ComplementType(s) => "\u00AC" + ppSetType(s) case FunctionType(s1, s2) => "(" + ppSetType(s1) + " --> " + ppSetType(s2) + ")" case TupleType(sts) => sts.map(ppSetType).mkString("(", ", ", ")") case VariableType(name) => name + case EmptyType => "0" + case UniversalType => "1" } private def ppFixPoint(fp: FixPoint): String = fp match { diff --git a/src/setconstraints/Solver.scala b/src/setconstraints/Solver.scala new file mode 100644 index 0000000000000000000000000000000000000000..b7e5947ea39e46841e618d89de8ea46e40b00334 --- /dev/null +++ b/src/setconstraints/Solver.scala @@ -0,0 +1,58 @@ +package setconstraints + +import Trees._ +import Manip._ + +object Solver { + + def apply(system: List[Relation]): Option[List[FixPoint]] = { + error("TODO") + } + + def solve(system: List[Relation]): Option[List[Equals]] = { + error("TODO") + } + + def oneLevel(system: List[Include]): List[Include] = { + + val emptyRightSystem = system.map{ + case Include(s1, s2) if s2 != EmptyType => Include(IntersectionType(List(s1, ComplementType(s2))), EmptyType) + case incl => incl + } + + + error("TODO") + } + + def isConstructor(s: SetType): Boolean = s match { + case ConstructorType(_, _) => true + case _ => false + } + + def isLiteral(s: SetType): Boolean = s match { + case VariableType(_) => true + case ComplementType(VariableType(_)) => true + case _ => false + } + def isConjunctionLit(s: SetType): Boolean = flatten(s) match { + case IntersectionType(sts) if sts.foldLeft(true)((b, st) => b && isLiteral(st) && sts.forall(l => l != ComplementType(st))) => false + case _ => false + } + def isConjunctionLitWithUniversal(s: SetType): Boolean = flatten(s) match { + case IntersectionType(sts) if sts.last == UniversalType && isConjunctionLit(IntersectionType(sts.init)) => true + case _ => false + } + def isOneLevel(s: SetType): Boolean = flatten(s) match { + case EmptyType => true + case IntersectionType(sts) if isConstructor(sts.last) && isConjunctionLit(IntersectionType(sts.init)) => { + val ConstructorType(_, args) = sts.last + args.forall(isConjunctionLitWithUniversal) + } + case s => isConjunctionLitWithUniversal(s) + } + def isOneLevel(r: Relation): Boolean = r match { + case Include(s1, EmptyType) if isOneLevel(s1) => true + case _ => false + } + def isOneLevel(system: List[Relation]): Boolean = system.forall(isOneLevel) +} diff --git a/src/setconstraints/SolverSuite.scala b/src/setconstraints/SolverSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..b5325561fd54e75444bd85b3357677bd8dc5eec9 --- /dev/null +++ b/src/setconstraints/SolverSuite.scala @@ -0,0 +1,9 @@ +package setconstraints + +/* +import org.scalatest.FunSuite + +class SolverSuite extends FunSuite { + +} +*/ diff --git a/src/setconstraints/Trees.scala b/src/setconstraints/Trees.scala index 075c2f2077e632c935732e4ffd812c76a17c81a5..21cd03ac985e8dcefed38a968ce106eac9a90f61 100644 --- a/src/setconstraints/Trees.scala +++ b/src/setconstraints/Trees.scala @@ -15,10 +15,13 @@ object Trees { case class UnionType(sets: Seq[SetType]) extends SetType case class IntersectionType(sets: Seq[SetType]) extends SetType + case class ComplementType(st: SetType) extends SetType case class FunctionType(s1: SetType, s2: SetType) extends SetType case class TupleType(sets: Seq[SetType]) extends SetType case class ConstructorType(name: String, sets: Seq[SetType]) extends SetType case class VariableType(name: String) extends SetType + case object EmptyType extends SetType + case object UniversalType extends SetType private var varCounter = -1 def freshVar(name: String) = {