diff --git a/project/build/funcheck.scala b/project/build/funcheck.scala index 0ef9fb71aef6c1eef440e5334ade8f1db3f763e7..666816601e849f1305aeaf695613c802680f7c8f 100644 --- a/project/build/funcheck.scala +++ b/project/build/funcheck.scala @@ -6,12 +6,13 @@ class FunCheckProject(info: ProjectInfo) extends DefaultProject(info) with FileT override def shouldCheckOutputDirectories = false lazy val purescala = project(".", "PureScala Definitions", new PureScalaProject(_)) - lazy val plugin = project(".", "FunCheck Plugin", new PluginProject(_), purescala) - lazy val multisets = project(".", "Multiset Solver", new MultisetsProject(_), plugin, purescala) + lazy val plugin = project(".", "FunCheck Plugin", new PluginProject(_), purescala, multisetsLib) + lazy val multisetsLib = project(".", "Multiset Placeholder Library", new MultisetsLibProject(_)) + lazy val multisets = project(".", "Multiset Solver", new MultisetsProject(_), plugin, purescala, multisetsLib) lazy val orderedsets = project(".", "Ordered Sets Solver", new OrderedSetsProject(_), plugin, purescala) lazy val setconstraints = project(".", "Type inference with set constraints", new SetConstraintsProject(_), plugin, purescala) - lazy val extensionJars : List[Path] = multisets.jarPath :: orderedsets.jarPath :: setconstraints.jarPath :: Nil + lazy val extensionJars : List[Path] = multisetsLib.jarPath :: multisets.jarPath :: orderedsets.jarPath :: setconstraints.jarPath :: Nil val scriptPath: Path = "." / "scalac-funcheck" @@ -40,6 +41,7 @@ class FunCheckProject(info: ProjectInfo) extends DefaultProject(info) with FileT fw.write(" FUNCHECKCLASSPATH=${FUNCHECKCLASSPATH}:${f}" + nl) fw.write(" fi" + nl) fw.write("done" + nl + nl) + fw.write("SCALACCLASSPATH=\"" + (multisetsLib.jarPath.absolutePath) + "\"" + nl) fw.write("LD_LIBRARY_PATH=" + ("." / "lib-bin").absolutePath + " \\" + nl) fw.write("java \\" + nl) @@ -48,7 +50,7 @@ class FunCheckProject(info: ProjectInfo) extends DefaultProject(info) with FileT fw.write(" -Dscala.home=" + libStr.substring(0, libStr.length-21) + " \\" + nl) fw.write(" -classpath ${FUNCHECKCLASSPATH} \\" + nl) - fw.write(" scala.tools.nsc.Main -Xplugin:" + plugin.jarPath.absolutePath + " $@" + nl) + fw.write(" scala.tools.nsc.Main -Xplugin:" + plugin.jarPath.absolutePath + " -classpath ${SCALACCLASSPATH} $@" + nl) fw.close f.setExecutable(true) None @@ -77,9 +79,14 @@ class FunCheckProject(info: ProjectInfo) extends DefaultProject(info) with FileT class PluginProject(info: ProjectInfo) extends PersonalizedProject(info) { override def outputPath = "bin" / "funcheck" override def mainScalaSourcePath = "src" / "funcheck" - override def unmanagedClasspath = super.unmanagedClasspath +++ purescala.jarPath + override def unmanagedClasspath = super.unmanagedClasspath +++ purescala.jarPath +++ multisetsLib.jarPath override def mainResourcesPath = "resources" / "funcheck" } + class MultisetsLibProject(info: ProjectInfo) extends PersonalizedProject(info) { + override def outputPath = "bin" / "multisets-lib" + override def mainScalaSourcePath = "src" / "multisets-lib" + override def unmanagedClasspath = super.unmanagedClasspath + } class MultisetsProject(info: ProjectInfo) extends PersonalizedProject(info) { override def outputPath = "bin" / "multisets" override def mainScalaSourcePath = "src" / "multisets" diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index bd3ffe62d6a1c24e20f67b4441afca7a4944f5e7..b60cab38f4301e87e58f2df860387e67b9ddbd1d 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -17,6 +17,7 @@ trait CodeExtraction extends Extractors { import ExpressionExtractors._ private lazy val setTraitSym = definitions.getClass("scala.collection.immutable.Set") + private lazy val multisetTraitSym = definitions.getClass("scala.collection.immutable.Multiset") private val varSubsts: scala.collection.mutable.Map[Symbol,Function0[Expr]] = scala.collection.mutable.Map.empty[Symbol,Function0[Expr]] @@ -368,18 +369,24 @@ trait CodeExtraction extends Extractors { case ExLessEqThan(l, r) => LessEquals(rec(l), rec(r)).setType(BooleanType) case ExFiniteSet(tt, args) => { val underlying = scalaType2PureScala(unit, silent)(tt.tpe) - FiniteSet(args.map(rec(_))).setType(SetType(underlying)) + FiniteSet(args.map(rec(_))).setType(SetType(underlying)) + } + case ExFiniteMultiset(tt, args) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + FiniteMultiset(args.map(rec(_))).setType(MultisetType(underlying)) } case ExEmptySet(tt) => { val underlying = scalaType2PureScala(unit, silent)(tt.tpe) EmptySet(underlying).setType(SetType(underlying)) } + case ExEmptyMultiset(tt) => { + val underlying = scalaType2PureScala(unit, silent)(tt.tpe) + EmptyMultiset(underlying).setType(MultisetType(underlying)) + } case ExSetMin(t) => { val set = rec(t) if(!set.getType.isInstanceOf[SetType]) { - if(!silent) { - unit.error(t.pos, "Min should be computed on a set.") - } + if(!silent) unit.error(t.pos, "Min should be computed on a set.") throw ImpureCodeEncounteredException(tree) } SetMin(set).setType(set.getType.asInstanceOf[SetType].base) @@ -387,9 +394,7 @@ trait CodeExtraction extends Extractors { case ExSetMax(t) => { val set = rec(t) if(!set.getType.isInstanceOf[SetType]) { - if(!silent) { - unit.error(t.pos, "Max should be computed on a set.") - } + if(!silent) unit.error(t.pos, "Max should be computed on a set.") throw ImpureCodeEncounteredException(tree) } SetMax(set).setType(set.getType.asInstanceOf[SetType].base) @@ -397,21 +402,65 @@ trait CodeExtraction extends Extractors { case ExUnion(t1,t2) => { val rl = rec(t1) val rr = rec(t2) - SetUnion(rl, rr).setType(rl.getType) // this is not entirely correct: should be a setype of LUB of underlying types of left and right. + rl.getType match { + case s @ SetType(_) => SetUnion(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetUnion(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Union of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } + } } case ExIntersection(t1,t2) => { val rl = rec(t1) val rr = rec(t2) - SetIntersection(rl, rr).setType(rl.getType) // same as union + rl.getType match { + case s @ SetType(_) => SetIntersection(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetIntersection(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Intersection of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } + } } case ExSetMinus(t1,t2) => { val rl = rec(t1) val rr = rec(t2) - SetDifference(rl, rr).setType(rl.getType) // same as union + rl.getType match { + case s @ SetType(_) => SetDifference(rl, rr).setType(s) + case m @ MultisetType(_) => MultisetDifference(rl, rr).setType(m) + case _ => { + if(!silent) unit.error(tree.pos, "Difference of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } + } } case ExSetCard(t) => { val rt = rec(t) - SetCardinality(rt) + rt.getType match { + case s @ SetType(_) => SetCardinality(rt) + case m @ MultisetType(_) => MultisetCardinality(rt) + case _ => { + if(!silent) unit.error(tree.pos, "Cardinality of non set/multiset expressions.") + throw ImpureCodeEncounteredException(tree) + } + } + } + case ExMultisetToSet(t) => { + val rt = rec(t) + rt.getType match { + case m @ MultisetType(u) => MultisetToSet(rt).setType(SetType(u)) + case _ => { + if(!silent) unit.error(tree.pos, "toSet can only be applied to multisets.") + throw ImpureCodeEncounteredException(tree) + } + } + } + + case ExPlusPlusPlus(t1,t2) => { + val rl = rec(t1) + val rr = rec(t2) + MultisetPlus(rl, rr).setType(rl.getType) } case ExIfThenElse(t1,t2,t3) => { val r1 = rec(t1) @@ -479,6 +528,7 @@ trait CodeExtraction extends Extractors { case tpe if tpe == IntClass.tpe => Int32Type case tpe if tpe == BooleanClass.tpe => BooleanType case TypeRef(_, sym, btt :: Nil) if sym == setTraitSym => SetType(rec(btt)) + case TypeRef(_, sym, btt :: Nil) if sym == multisetTraitSym => MultisetType(rec(btt)) case TypeRef(_, sym, Nil) if classesToClasses.keySet.contains(sym) => classDefToClassType(classesToClasses(sym)) case _ => { diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala index 1081ec2a54557e16294a410b8ab8abdf99e28304..62f4b0dae2882a4ab03d263c45adbf099bcf31b4 100644 --- a/src/funcheck/Extractors.scala +++ b/src/funcheck/Extractors.scala @@ -11,6 +11,7 @@ trait Extractors { import global.definitions._ private lazy val setTraitSym = definitions.getClass("scala.collection.immutable.Set") + private lazy val multisetTraitSym = definitions.getClass("scala.collection.immutable.Multiset") object StructuralExtractors { object ScalaPredef { @@ -338,6 +339,22 @@ trait Extractors { } } + object ExEmptyMultiset { + def unapply(tree: TypeApply): Option[Tree] = tree match { + case TypeApply( + Select( + Select( + Select( + Select(Ident(s), collectionName), + immutableName), + setName), + emptyName), theTypeTree :: Nil) if ( + collectionName.toString == "collection" && immutableName.toString == "immutable" && setName.toString == "Multiset" && emptyName.toString == "empty" + ) => Some(theTypeTree) + case _ => None + } + } + object ExFiniteSet { def unapply(tree: Apply): Option[(Tree,List[Tree])] = tree match { case Apply( @@ -355,12 +372,36 @@ trait Extractors { } } + object ExFiniteMultiset { + def unapply(tree: Apply): Option[(Tree,List[Tree])] = tree match { + case Apply( + TypeApply( + Select( + Select( + Select( + Select(Ident(s), collectionName), + immutableName), + setName), + emptyName), theTypeTree :: Nil), args) if ( + collectionName.toString == "collection" && immutableName.toString == "immutable" && setName.toString == "Multiset" && emptyName.toString == "apply" + )=> Some(theTypeTree, args) + case _ => None + } + } + object ExUnion { def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { case Apply(Select(lhs, n), List(rhs)) if (n == nme.PLUSPLUS) => Some((lhs,rhs)) case _ => None } } + + object ExPlusPlusPlus { + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(lhs, n), List(rhs)) if (n.toString == "$plus$plus$plus") => Some((lhs,rhs)) + case _ => None + } + } object ExIntersection { def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { @@ -382,5 +423,12 @@ trait Extractors { case _ => None } } + + object ExMultisetToSet { + def unapply(tree: Select): Option[Tree] = tree match { + case Select(t, n) if (n.toString == "toSet") => Some(t) + case _ => None + } + } } } diff --git a/src/orderedsets/Main.scala b/src/orderedsets/Main.scala index 4a1ae233482f2f377e51b7c3e6d699371437a081..187d077f0b1e6809eddef639c59737b3df71f42e 100644 --- a/src/orderedsets/Main.scala +++ b/src/orderedsets/Main.scala @@ -163,7 +163,7 @@ object ExprToASTConverter { // Set Formulas case ElementOfSet(elem, set) => toIntTerm(elem) selem toSetTerm(set) case SetEquals(set1, set2) => toSetTerm(set1) seq toSetTerm(set2) - case IsEmptySet(set) => toSetTerm(set).card === 0 + // case IsEmptySet(set) => toSetTerm(set).card === 0 case SubsetOf(set1, set2) => toSetTerm(set1) subseteq toSetTerm(set2) // Integer Formulas @@ -185,7 +185,7 @@ object ExprToASTConverter { // Set formulas case ElementOfSet(_, set) => Set(set.getType, SetType(Int32Type)) case SetEquals(set1, set2) => Set(set1.getType, set2.getType) - case IsEmptySet(set) => Set(set.getType) + // case IsEmptySet(set) => Set(set.getType) case SubsetOf(set1, set2) => Set(set1.getType, set2.getType) // Integer formulas case LessThan(lhs, rhs) => getSetTypes(lhs) ++ getSetTypes(rhs) diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala index b07d7766f735b02e0d633616e3e501ab2ee80b00..11daf5ebabc033fad337f7cd894daf7c53525fab 100644 --- a/src/purescala/PrettyPrinter.scala +++ b/src/purescala/PrettyPrinter.scala @@ -77,7 +77,6 @@ object PrettyPrinter { case Iff(l,r) => ppBinary(sb, l, r, " <=> ", lvl) case Implies(l,r) => ppBinary(sb, l, r, " ==> ", lvl) case UMinus(expr) => ppUnary(sb, expr, "-(", ")", lvl) - case SetEquals(l,r) => ppBinary(sb, l, r, " =S= ", lvl) case Equals(l,r) => ppBinary(sb, l, r, " == ", lvl) case IntLiteral(v) => sb.append(v) case BooleanLiteral(v) => sb.append(v) @@ -104,13 +103,21 @@ object PrettyPrinter { case LessEquals(l,r) => ppBinary(sb, l, r, " \u2264 ", lvl) // \leq case GreaterEquals(l,r) => ppBinary(sb, l, r, " \u2265 ", lvl) // \geq case FiniteSet(rs) => ppNary(sb, rs, "{", ", ", "}", lvl) + case FiniteMultiset(rs) => ppNary(sb, rs, "{|", ", ", "|}", lvl) case EmptySet(_) => sb.append("\u2205") // Ø + case EmptyMultiset(_) => sb.append("\u2205") // Ø case SetMin(s) => pp(s, sb, lvl).append(".min") case SetMax(s) => pp(s, sb, lvl).append(".max") case SetUnion(l,r) => ppBinary(sb, l, r, " \u222A ", lvl) // \cup + case MultisetUnion(l,r) => ppBinary(sb, l, r, " \u222A ", lvl) // \cup case SetDifference(l,r) => ppBinary(sb, l, r, " \\ ", lvl) + case MultisetDifference(l,r) => ppBinary(sb, l, r, " \\ ", lvl) case SetIntersection(l,r) => ppBinary(sb, l, r, " \u2229 ", lvl) // \cap + case MultisetIntersection(l,r) => ppBinary(sb, l, r, " \u2229 ", lvl) // \cap case SetCardinality(t) => ppUnary(sb, t, "|", "|", lvl) + case MultisetCardinality(t) => ppUnary(sb, t, "|", "|", lvl) + case MultisetPlus(l,r) => ppBinary(sb, l, r, " \u228E ", lvl) // U+ + case MultisetToSet(e) => pp(e, sb, lvl).append(".toSet") case IfExpr(c, t, e) => { var nsb = sb @@ -186,6 +193,7 @@ object PrettyPrinter { case Int32Type => sb.append("Int") case BooleanType => sb.append("Boolean") case SetType(bt) => pp(bt, sb.append("Set["), lvl).append("]") + case MultisetType(bt) => pp(bt, sb.append("Multiset["), lvl).append("]") case c: ClassType => sb.append(c.classDef.id) case _ => sb.append("Type?") } diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index 96418d14c5822672daec8205cf2a1765e427ace3..c920f9ee6ad6d1172740358489ba411d841eba6b 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -122,7 +122,28 @@ object Trees { } /* For all types that don't have their own XXXEquals */ - case class Equals(left: Expr, right: Expr) extends Expr with FixedType { + object Equals { + def apply(l : Expr, r : Expr) : Equals = new Equals(l,r) + def unapply(e : Equals) : Option[(Expr,Expr)] = if (e == null) None else Some((e.left, e.right)) + } + + object SetEquals { + def apply(l : Expr, r : Expr) : Equals = new Equals(l,r) + def unapply(e : Equals) : Option[(Expr,Expr)] = if(e == null) None else (e.left.getType, e.right.getType) match { + case (SetType(_), SetType(_)) => Some((e.left, e.right)) + case _ => None + } + } + + object MultisetEquals { + def apply(l : Expr, r : Expr) : Equals = new Equals(l,r) + def unapply(e : Equals) : Option[(Expr,Expr)] = if(e == null) None else (e.left.getType, e.right.getType) match { + case (MultisetType(_), MultisetType(_)) => Some((e.left, e.right)) + case _ => None + } + } + + class Equals(val left: Expr, val right: Expr) extends Expr with FixedType { val fixedType = BooleanType } @@ -191,10 +212,6 @@ object Trees { case class EmptySet(baseType: TypeTree) extends Expr with Terminal case class FiniteSet(elements: Seq[Expr]) extends Expr case class ElementOfSet(element: Expr, set: Expr) extends Expr - case class IsEmptySet(set: Expr) extends Expr - case class SetEquals(set1: Expr, set2: Expr) extends Expr with FixedType { - val fixedType = BooleanType - } case class SetCardinality(set: Expr) extends Expr with FixedType { val fixedType = Int32Type } @@ -209,9 +226,9 @@ object Trees { case class EmptyMultiset(baseType: TypeTree) extends Expr with Terminal case class FiniteMultiset(elements: Seq[Expr]) extends Expr case class Multiplicity(element: Expr, multiset: Expr) extends Expr - case class IsEmptyMultiset(multiset: Expr) extends Expr - case class MultisetEquals(multiset1: Expr, multiset2: Expr) extends Expr - case class MultisetCardinality(multiset: Expr) extends Expr + case class MultisetCardinality(multiset: Expr) extends Expr with FixedType { + val fixedType = Int32Type + } case class SubmultisetOf(multiset1: Expr, multiset2: Expr) extends Expr case class MultisetIntersection(multiset1: Expr, multiset2: Expr) extends Expr case class MultisetUnion(multiset1: Expr, multiset2: Expr) extends Expr @@ -239,8 +256,6 @@ object Trees { object UnaryOperator { def unapply(expr: Expr) : Option[(Expr,(Expr)=>Expr)] = expr match { case Not(t) => Some((t,Not(_))) - case IsEmptySet(t) => Some((t,IsEmptySet)) - case IsEmptyMultiset(t) => Some((t,IsEmptyMultiset)) case SetCardinality(t) => Some((t,SetCardinality)) case MultisetCardinality(t) => Some((t,MultisetCardinality)) case MultisetToSet(t) => Some((t,MultisetToSet)) @@ -255,7 +270,7 @@ object Trees { object BinaryOperator { def unapply(expr: Expr) : Option[(Expr,Expr,(Expr,Expr)=>Expr)] = expr match { - case Equals(t1,t2) => Some((t1,t2,Equals)) + case Equals(t1,t2) => Some((t1,t2,Equals(_,_))) case Iff(t1,t2) => Some((t1,t2,Iff)) case Implies(t1,t2) => Some((t1,t2, ((e1,e2) => Implies(e1,e2)))) case Plus(t1,t2) => Some((t1,t2,Plus)) @@ -267,13 +282,11 @@ object Trees { case LessEquals(t1,t2) => Some((t1,t2,LessEquals)) case GreaterEquals(t1,t2) => Some((t1,t2,GreaterEquals)) case ElementOfSet(t1,t2) => Some((t1,t2,ElementOfSet)) - case SetEquals(t1,t2) => Some((t1,t2,SetEquals)) case SubsetOf(t1,t2) => Some((t1,t2,SubsetOf)) case SetIntersection(t1,t2) => Some((t1,t2,SetIntersection)) case SetUnion(t1,t2) => Some((t1,t2,SetUnion)) case SetDifference(t1,t2) => Some((t1,t2,SetDifference)) case Multiplicity(t1,t2) => Some((t1,t2,Multiplicity)) - case MultisetEquals(t1,t2) => Some((t1,t2,MultisetEquals)) case SubmultisetOf(t1,t2) => Some((t1,t2,SubmultisetOf)) case MultisetIntersection(t1,t2) => Some((t1,t2,MultisetIntersection)) case MultisetUnion(t1,t2) => Some((t1,t2,MultisetUnion)) @@ -319,7 +332,7 @@ object Trees { // Warning ! This may loop forever if the substitutions are not // well-formed! def replace(substs: Map[Expr,Expr], expr: Expr) : Expr = { - searchAndApply(substs.isDefinedAt(_), substs(_), expr) + searchAndReplace(substs.get(_))(expr) } // the replacement map should be understood as follows: @@ -327,78 +340,78 @@ object Trees { // - repFun is applied is checkFun succeeded // - if the result of repFun is different from its argument and recursive // is set to true, search/replace is reapplied on the result. - def searchAndApply(checkFun: Expr=>Boolean, repFun: Expr=>Expr, expr: Expr, recursive: Boolean=true) : Expr = { - def rec(ex: Expr, skip: Expr = null) : Expr = ex match { - case _ if (ex != skip && checkFun(ex)) => { - val newExpr = repFun(ex) - if(newExpr.getType == NoType) { - Settings.reporter.warning("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) - } - if(ex == newExpr) - if(recursive) rec(ex, ex) else ex - else - if(recursive) rec(newExpr) else newExpr - } - case l @ Let(i,e,b) => { - val re = rec(e) - val rb = rec(b) - if(re != e || rb != b) - Let(i, re, rb).setType(l.getType) - else - l - } - case n @ NAryOperator(args, recons) => { - var change = false - val rargs = args.map(a => { - val ra = rec(a) - if(ra != a) { - change = true - ra - } else { - a - } - }) - if(change) - recons(rargs).setType(n.getType) - else - n - } - case b @ BinaryOperator(t1,t2,recons) => { - val r1 = rec(t1) - val r2 = rec(t2) - if(r1 != t1 || r2 != t2) - recons(r1,r2).setType(b.getType) - else - b - } - case u @ UnaryOperator(t,recons) => { - val r = rec(t) - if(r != t) - recons(r).setType(u.getType) - else - u - } - case i @ IfExpr(t1,t2,t3) => { - val r1 = rec(t1) - val r2 = rec(t2) - val r3 = rec(t3) - if(r1 != t1 || r2 != t2 || r3 != t3) - IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType) - else - i - } - case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType) - case t if t.isInstanceOf[Terminal] => t - case unhandled => scala.Predef.error("Non-terminal case should be handled in searchAndApply: " + unhandled) - } - - def inCase(cse: MatchCase) : MatchCase = cse match { - case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs)) - case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard), rec(rhs)) - } - - rec(expr) - } + // def searchAndApply(checkFun: Expr=>Boolean, repFun: Expr=>Expr, expr: Expr, recursive: Boolean=true) : Expr = { + // def rec(ex: Expr, skip: Expr = null) : Expr = ex match { + // case _ if (ex != skip && checkFun(ex)) => { + // val newExpr = repFun(ex) + // if(newExpr.getType == NoType) { + // Settings.reporter.warning("REPLACING IN EXPRESSION WITH AN UNTYPED TREE ! " + ex + " --to--> " + newExpr) + // } + // if(ex == newExpr) + // if(recursive) rec(ex, ex) else ex + // else + // if(recursive) rec(newExpr) else newExpr + // } + // case l @ Let(i,e,b) => { + // val re = rec(e) + // val rb = rec(b) + // if(re != e || rb != b) + // Let(i, re, rb).setType(l.getType) + // else + // l + // } + // case n @ NAryOperator(args, recons) => { + // var change = false + // val rargs = args.map(a => { + // val ra = rec(a) + // if(ra != a) { + // change = true + // ra + // } else { + // a + // } + // }) + // if(change) + // recons(rargs).setType(n.getType) + // else + // n + // } + // case b @ BinaryOperator(t1,t2,recons) => { + // val r1 = rec(t1) + // val r2 = rec(t2) + // if(r1 != t1 || r2 != t2) + // recons(r1,r2).setType(b.getType) + // else + // b + // } + // case u @ UnaryOperator(t,recons) => { + // val r = rec(t) + // if(r != t) + // recons(r).setType(u.getType) + // else + // u + // } + // case i @ IfExpr(t1,t2,t3) => { + // val r1 = rec(t1) + // val r2 = rec(t2) + // val r3 = rec(t3) + // if(r1 != t1 || r2 != t2 || r3 != t3) + // IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType) + // else + // i + // } + // case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType) + // case t if t.isInstanceOf[Terminal] => t + // case unhandled => scala.Predef.error("Non-terminal case should be handled in searchAndApply: " + unhandled) + // } + + // def inCase(cse: MatchCase) : MatchCase = cse match { + // case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs)) + // case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard), rec(rhs)) + // } + + // rec(expr) + // } def searchAndReplace(subst: Expr=>Option[Expr], recursive: Boolean=true)(expr: Expr) : Expr = { def rec(ex: Expr, skip: Expr = null) : Expr = (if (ex == skip) None else subst(ex)) match { @@ -462,7 +475,7 @@ object Trees { } case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType) case t if t.isInstanceOf[Terminal] => t - case unhandled => scala.Predef.error("Non-terminal case should be handled in searchAndApply: " + unhandled) + case unhandled => scala.Predef.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) } } @@ -481,26 +494,27 @@ object Trees { * Note that the code is simple but far from optimal (many traversals...) */ def simplifyLets(expr: Expr) : Expr = { - val isLet = ((t: Expr) => t.isInstanceOf[Let]) - def simplerLet(t: Expr) : Expr = t match { - case letExpr @ Let(i, Variable(v), b) => replace(Map((Variable(i) -> Variable(v))), b) - case letExpr @ Let(i, l: Literal[_], b) => replace(Map((Variable(i) -> l)), b) + def simplerLet(t: Expr) : Option[Expr] = t match { + case letExpr @ Let(i, Variable(v), b) => Some(replace(Map((Variable(i) -> Variable(v))), b)) + case letExpr @ Let(i, l: Literal[_], b) => Some(replace(Map((Variable(i) -> l)), b)) case letExpr @ Let(i,e,b) => { var occurences = 0 - def isOcc(tr: Expr) = (occurences < 2 && tr == Variable(i)) - def incCount(tr: Expr) = { occurences = occurences + 1; tr } - searchAndApply(isOcc, incCount, b, false) + def incCount(tr: Expr) = tr match { + case Variable(x) if x == i => { occurences = occurences + 1; None } + case _ => None + } + searchAndReplace(incCount, false)(b) if(occurences == 0) { - b + Some(b) } else if(occurences == 1) { - replace(Map((Variable(i) -> e)), b) + Some(replace(Map((Variable(i) -> e)), b)) } else { - t + None } } - case o => o + case _ => None } - searchAndApply(isLet,simplerLet,expr) + searchAndReplace(simplerLet)(expr) } /* Rewrites the expression so that all lets are at the top levels. */ @@ -512,18 +526,17 @@ object Trees { def pulloutAndKeepLets(expr: Expr) : (Seq[(Identifier,Expr)], Expr) = { var storedLets: List[(Identifier,Expr)] = Nil - val isLet = ((t: Expr) => t.isInstanceOf[Let]) - def storeLet(t: Expr) : Expr = t match { - case l @ Let(i, e, b) => (storedLets = ((i,e)) :: storedLets); l - case _ => t + def storeLet(t: Expr) : Option[Expr] = t match { + case l @ Let(i, e, b) => (storedLets = ((i,e)) :: storedLets); None + case _ => None } - def killLet(t: Expr) : Expr = t match { - case l @ Let(i, e, b) => b - case _ => t + def killLet(t: Expr) : Option[Expr] = t match { + case l @ Let(i, e, b) => Some(b) + case _ => None } - searchAndApply(isLet, storeLet, expr) - val noLets = searchAndApply(isLet, killLet, expr) + searchAndReplace(storeLet)(expr) + val noLets = searchAndReplace(killLet)(expr) (storedLets, noLets) } diff --git a/src/purescala/TypeTrees.scala b/src/purescala/TypeTrees.scala index efb55f21f9ec86b5859b89c478f010295a52045d..46ecb5d2be0887b499746b81d70a1b0602568bd9 100644 --- a/src/purescala/TypeTrees.scala +++ b/src/purescala/TypeTrees.scala @@ -92,6 +92,7 @@ object TypeTrees { case InfiniteSize => InfiniteSize case FiniteSize(n) => FiniteSize(scala.math.pow(2, n).toInt) } + case MultisetType(_) => InfiniteSize case MapType(from,to) => (domainSize(from),domainSize(to)) match { case (InfiniteSize,_) => InfiniteSize case (_,InfiniteSize) => InfiniteSize @@ -113,7 +114,7 @@ object TypeTrees { case class ListType(base: TypeTree) extends TypeTree case class TupleType(bases: Seq[TypeTree]) extends TypeTree { lazy val dimension: Int = bases.length } case class SetType(base: TypeTree) extends TypeTree - // case class MultisetType(base: TypeTree) extends TypeTree + case class MultisetType(base: TypeTree) extends TypeTree case class MapType(from: TypeTree, to: TypeTree) extends TypeTree case class OptionType(base: TypeTree) extends TypeTree diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index 5debb2bd8aff0b1d903f12c4dd894051fbad0b99..efefcea57df0cb0c31b4e9ed51878243a1c7801c 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -34,9 +34,6 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { } prepareSorts prepareFunctions - - //println(prog.callGraph.map(p => (p._1.id.name, p._2.id.name).toString)) - //println(prog.transitiveCallGraph.map(p => (p._1.id.name, p._2.id.name).toString)) } private object nextIntForSymbol { @@ -53,6 +50,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { private var boolSort : Z3Sort = null private var setSorts : Map[TypeTree,Z3Sort] = Map.empty private var adtSorts : Map[ClassTypeDef, Z3Sort] = Map.empty + private var fallbackSorts : Map[TypeTree,Z3Sort] = Map.empty private var adtTesters : Map[CaseClassDef, Z3FuncDecl] = Map.empty private var adtConstructors : Map[CaseClassDef, Z3FuncDecl] = Map.empty private var adtFieldSelectors : Map[Identifier,Z3FuncDecl] = Map.empty @@ -133,8 +131,8 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { def prepareFunctions : Unit = { for(funDef <- program.definedFunctions) { - val sortSeq = funDef.args.map(vd => typeToSort(vd.tpe).get) - functionDefToDef = functionDefToDef + (funDef -> z3.mkFreshFuncDecl(funDef.id.name, sortSeq, typeToSort(funDef.returnType).get)) + val sortSeq = funDef.args.map(vd => typeToSort(vd.tpe)) + functionDefToDef = functionDefToDef + (funDef -> z3.mkFreshFuncDecl(funDef.id.name, sortSeq, typeToSort(funDef.returnType))) } // universally quantifies all functions ! @@ -145,7 +143,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { funDef.args.size == 1 && funDef.args(0).toVariable == scrutinee ) => infos.foreach(i => { val (ccd, pid, subids, rhs) = i - val argSorts: Seq[Z3Sort] = subids.map(id => typeToSort(id.getType).get) + val argSorts: Seq[Z3Sort] = subids.map(id => typeToSort(id.getType)) val boundVars = argSorts.zipWithIndex.map(p => z3.mkBound(p._2, p._1)) val matcher: Z3AST = adtConstructors(ccd)(boundVars: _*) val pattern: Z3Pattern = z3.mkPattern(functionDefToDef(funDef)(matcher)) @@ -173,7 +171,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { case _ => { // Newly introduced variables should in fact // probably also be universally quantified. - val argSorts: Seq[Z3Sort] = funDef.args.map(vd => typeToSort(vd.getType).get) + val argSorts: Seq[Z3Sort] = funDef.args.map(vd => typeToSort(vd.getType)) val boundVars = argSorts.zipWithIndex.map(p => z3.mkBound(p._2, p._1)) val pattern: Z3Pattern = z3.mkPattern(functionDefToDef(funDef)(boundVars: _*)) val nameTypePairs = argSorts.map(s => (z3.mkIntSymbol(nextIntForSymbol()), s)) @@ -204,28 +202,34 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { } // assumes prepareSorts has been called.... - def typeToSort(tt: TypeTree) : Option[Z3Sort] = tt match { - case Int32Type => Some(intSort) - case BooleanType => Some(boolSort) - case AbstractClassType(cd) => Some(adtSorts(cd)) + def typeToSort(tt: TypeTree) : Z3Sort = tt match { + case Int32Type => intSort + case BooleanType => boolSort + case AbstractClassType(cd) => adtSorts(cd) case CaseClassType(cd) => { - Some(if(cd.hasParent) { + if(cd.hasParent) { adtSorts(cd.parent.get) } else { adtSorts(cd) - }) + } } case SetType(base) => setSorts.get(base) match { - case s @ Some(_) => s - case None => typeToSort(base).map(s => { - val newSetSort = z3.mkSetSort(s) + case Some(s) => s + case None => { + val newSetSort = z3.mkSetSort(typeToSort(base)) setSorts = setSorts + (base -> newSetSort) newSetSort - }) + } } - case _ => { - reporter.warning("No sort for type " + tt) - None + case other => fallbackSorts.get(other) match { + case Some(s) => s + case None => { + reporter.warning("Resorting to uninterpreted type for : " + other) + val symbol = z3.mkIntSymbol(nextIntForSymbol()) + val newFBSort = z3.mkUninterpretedSort(symbol) + fallbackSorts = fallbackSorts + (other -> newFBSort) + newFBSort + } } } @@ -284,15 +288,7 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { case v @ Variable(id) => z3Vars.get(id.uniqueName) match { case Some(ast) => ast case None => { - val newAST = typeToSort(v.getType) match { - case Some(s) => { - z3.mkFreshConst(id.name, s) - } - case None => { - reporter.warning("Unsupported type in Z3 transformation: " + v.getType) - throw new CantTranslateException - } - } + val newAST = z3.mkFreshConst(id.name, typeToSort(v.getType)) z3Vars = z3Vars + (id.uniqueName -> newAST) newAST } @@ -329,13 +325,13 @@ class Z3Solver(reporter: Reporter) extends Solver(reporter) { abstractedFormula = true z3.mkApp(functionDefToDef(fd), args.map(rec(_)): _*) } - case e @ EmptySet(_) => z3.mkEmptySet(typeToSort(e.getType.asInstanceOf[SetType].base).get) + case e @ EmptySet(_) => z3.mkEmptySet(typeToSort(e.getType.asInstanceOf[SetType].base)) case SetEquals(s1,s2) => z3.mkEq(rec(s1), rec(s2)) case SubsetOf(s1,s2) => z3.mkSetSubset(rec(s1), rec(s2)) case SetIntersection(s1,s2) => z3.mkSetIntersect(rec(s1), rec(s2)) case SetUnion(s1,s2) => z3.mkSetUnion(rec(s1), rec(s2)) case SetDifference(s1,s2) => z3.mkSetDifference(rec(s1), rec(s2)) - case f @ FiniteSet(elems) => elems.foldLeft(z3.mkEmptySet(typeToSort(f.getType.asInstanceOf[SetType].base).get))((ast,el) => z3.mkSetAdd(ast,rec(el))) + case f @ FiniteSet(elems) => elems.foldLeft(z3.mkEmptySet(typeToSort(f.getType.asInstanceOf[SetType].base)))((ast,el) => z3.mkSetAdd(ast,rec(el))) case _ => { reporter.warning("Can't handle this in translation to Z3: " + ex) throw new CantTranslateException diff --git a/testcases/MultisetOperations.scala b/testcases/MultisetOperations.scala new file mode 100644 index 0000000000000000000000000000000000000000..c59d2a85fea1909954b03c9b9b7ee9ab736a31a3 --- /dev/null +++ b/testcases/MultisetOperations.scala @@ -0,0 +1,13 @@ +import scala.collection.immutable.Set +import scala.collection.immutable.Multiset + +object MultisetOperations { + def preservedUnderToSet(a: Multiset[Int], b: Multiset[Int]) : Boolean = { + ((a ++ b).toSet.size == (a.toSet ++ b.toSet).size) && + ((a ** b).toSet.size == (a.toSet ** b.toSet).size) + } ensuring(res => res) + + def sumPreservesSizes(a: Multiset[Int], b: Multiset[Int]) : Boolean = { + ((a +++ b).size == a.size + b.size) + } ensuring(res => res) +}