Skip to content
Snippets Groups Projects
Commit b54935fc authored by Philippe Suter's avatar Philippe Suter
Browse files

completing the move

parent 96141ec3
Branches
Tags
No related merge requests found
Showing
with 0 additions and 1401 deletions
package funcheck.lib
object Specs {
/**
* this is used to annotate the (unique) class method
* that will be used by our funcheck plugin to
* automagically define a class generator that can be used
* by ScalaCheck to create test cases.
*/
class generator extends StaticAnnotation
implicit def extendedBoolean(b: => Boolean) = new {
def ==>(p: => Boolean) = Specs ==> (b,p)
}
def forAll[A](f: A => Boolean): Boolean = {
Console.err.println("Warning: ignored forAll. Are you using the funcheck plugin?")
true
// error("\"forAll\" combinator is currently unsupported by plugin.")
}
/** Implication */
def ==>(ifz: => Boolean, then: => Boolean): Boolean = !ifz || then
}
package funcheck.scalacheck
import scala.tools.nsc.Global
/**
* A tree traverser which filter the trees elements that contain the
* <code>@generator</code> annotation, defined in <code>funcheck.lib.Specs</code>
* module.
*
* Note: For the moment this is working only for <code>ClassDef</code> and
* <code>DefDef</code> tree elements.
*
* This trait is meant to be used with the <code>FilterTreeTraverser</code>
* class, available in the <code>scala.tools.nsc.ast.Trees</code> trait.
*
*
* Usage Example:
*
* new FilterTreeTraverser(filterTreesWithGeneratorAnnotation(unit))
*
* where <code>unit</code> is the current Compilation Unit.
*/
trait FilterGeneratorAnnotations {
val global: Global
// [[INFO]] "hasAttribute" is "hasAnnotation" in future compiler release 2.8
import global._
/** Funcheck <code>@generator</code> annotation. */
private lazy val generator: Symbol = definitions.getClass("funcheck.lib.Specs.generator")
/**
* Check for <code>@generator</code> annotation only for class and method
* definitions. A class is considered to be annotated if either the class itself
* has the annotation, or if the class inherit from an annotated abstract class.
*/
def filterTreesWithGeneratorAnnotation(Unit: CompilationUnit)(tree: Tree): Boolean = {
lazy val sym = tree.symbol
tree match {
case cd: ClassDef => isAbstractClass(sym) ||
sym.hasAttribute(generator) ||
abstractSuperClassHasGeneratorAnnotation(sym.superClass)
case d: DefDef => sym.hasAttribute(generator)
case _ => false
}
}
/** Return true if the class (or superclass) symbol is flagged as being ABSTRACT and contains
* the <code>@generator</code> annotation.*/
private def abstractSuperClassHasGeneratorAnnotation(superclass: Symbol): Boolean = {
//require(superclass.isInstanceOf[ClassSymbol], "expected ClassSymbol, found "+superclass)
superclass match {
case NoSymbol => false
case cs: ClassSymbol =>
(isAbstractClass(cs) && cs.hasAttribute(generator)) ||
abstractSuperClassHasGeneratorAnnotation(cs.superClass)
case _ =>
assert(false, "expected ClassSymbol, found "+superclass)
false
}
}
private def isAbstractClass(s: Symbol): Boolean = s match {
case cs: ClassSymbol => cs.hasFlag(scala.tools.nsc.symtab.Flags.ABSTRACT)
case _ => false
}
}
\ No newline at end of file
package funcheck.scalacheck
import scala.tools.nsc.transform.TypingTransformers
import scala.tools.nsc.util.NoPosition
import funcheck.util.FreshNameCreator
/** Takes care of mapping Specs.forAll methods calls to
* ScalaCheck org.scalacheck.Prop.forAll.
*/
trait ForAllTransformer extends TypingTransformers
with ScalaCheck
with FreshNameCreator
{
import global._
private lazy val specsModule: Symbol = definitions.getModule("funcheck.lib.Specs")
def forAllTransform(unit: CompilationUnit): Unit =
unit.body = new ForAllTransformer(unit).transform(unit.body)
class ForAllTransformer(unit: CompilationUnit)
extends TypingTransformer(unit)
{
override def transform(tree: Tree): Tree = {
curTree = tree
tree match {
/* XXX: This only works for top-level forAll. Nested forAll are not handled by the current version*/
case Apply(TypeApply(s: Select, _), rhs @ List(f @ Function(vparams,body))) if isSelectOfSpecsMethod(s.symbol, "forAll") =>
atOwner(currentOwner) {
assert(vparams.size == 1, "funcheck.Specs.forAll properties are expected to take a single (tuple) parameter")
val v @ ValDef(mods,name,vtpt,vinit) = vparams.head
vtpt.tpe match {
// the type of the (single, by the above assumption) function parameter
// will tell us what are the generators needed. In fact, we need to manually
// provide the generators since despite the generators that we create are
// implicit definitions, funcheck is hooking after the typechecking phase
// and implicit definition are solved at typechecking. Therefore the need
// of manually provide every single parameter to the org.scalacheck.Prop.forAll
// method.
// This is actually one of the major limitations of this plugin since it is not
// really quite flexible. For a future work it might be a good idea to rethink
// how this problem can be fixed (an idea could be to inject the code and actuate
// the forall conversion and then typecheck the whole program from zero).
case tpe @ TypeRef(_,value,vtpes) =>
var fun: Function = {
if(vtpes.size <= 1) {
// if there is less than one parameter then the function tree can be injected
// without (almost) no modificcation because it matches what Scalacheck Prop.forAll
// expects
f
}
else {
// Transforming a pair into a list of arguments (this is what ScalaCheck Prop.forAll expects)
// create a fresh name for each parameter declared parametric type
val freshNames = vtpes.map(i => fresh.newName(NoPosition,"v"))
val funSym = tree.symbol
val subst = for { i <- 0 to vtpes.size-1} yield {
val toSym = funSym.newValueParameter(funSym.pos, freshNames(i)).setInfo(vtpes(i))
val from = Select(v, v.symbol.tpe.decl("_"+(i+1)))
val to = ValDef(toSym, EmptyTree) setPos (tree.pos)
(from, to)
}
val valdefs = subst.map(_._2).toList
val fun = localTyper.typed {
val newBody = new MyTreeSubstituter(subst.map(p => p._1.symbol).toList, valdefs.map(v => Ident(v.symbol)).toList).transform(resetAttrs(body))
Function(valdefs, newBody)
}.asInstanceOf[Function]
new ChangeOwnerTraverser(funSym, fun.symbol).traverse(fun);
new ForeachTreeTraverser({t: Tree => t setPos tree.pos}).traverse(fun)
fun
}
}
// Prop.forall(function , where function is of the form (v1,v2,...,vn) => expr(v1,v2,..,vn))
val prop = Prop.forAll(List(transform(fun)))
// the following are the list of (implicit) parameters that need to be provided
// when calling Prop.forall
var buf = new collection.mutable.ListBuffer[Tree]()
val blockValSym = newSyntheticValueParam(fun.symbol, definitions.BooleanClass.typeConstructor)
val fun2 = localTyper.typed {
val body = Prop.propBoolean(resetAttrs(Ident(blockValSym)))
Function(List(ValDef(blockValSym, EmptyTree)), body)
}.asInstanceOf[Function]
new ChangeOwnerTraverser(fun.symbol, fun2.symbol).traverse(fun2);
new ForeachTreeTraverser({t: Tree => t setPos tree.pos}).traverse(fun2)
buf += Block(Nil,fun2)
if(vtpes.size <= 1) {
buf += resetAttrs(Arbitrary.arbitrary(tpe))
buf += resetAttrs(Shrink.shrinker(tpe))
} else {
for { tpe <- vtpes } {
buf += resetAttrs(Arbitrary.arbitrary(tpe))
buf += resetAttrs(Shrink.shrinker(tpe))
}
}
import posAssigner.atPos // for filling in tree positions
val property = localTyper.typed {
atPos(tree.pos) {
Apply(prop, buf.toList)
}
}
localTyper.typed {
atPos(tree.pos) {
Test.isPassed(Test.check(property))
}
}
case t =>
assert(false, "expected ValDef of type TypeRef, found "+t)
tree
}
}
/** Delegates the recursive traversal of the tree. */
case _ => super.transform(tree)
}
}
class ChangeOwnerTraverser(val oldowner: Symbol, val newowner: Symbol) extends Traverser {
override def traverse(tree: Tree) {
if (tree != null && tree.symbol != null && tree.symbol != NoSymbol && tree.symbol.owner == oldowner)
tree.symbol.owner = newowner;
super.traverse(tree)
}
}
/** Synthetic value parameters when parameter symbols are not available
*/
final def newSyntheticValueParam(owner: Symbol, argtype: Type): Symbol = {
var cnt = 0
def freshName() = { cnt += 1; newTermName("x$" + cnt) }
def param(tp: Type) =
owner.newValueParameter(owner.pos, freshName()).setFlag(scala.tools.nsc.symtab.Flags.SYNTHETIC).setInfo(tp)
param(argtype)
}
private def isSelectOfSpecsMethod(s: Symbol, method: String): Boolean =
s == specsModule.tpe.decl(method)
/** Quick (and dirty) hack for enabling tree substitution for pair elements.
* Specifically, this allow to map pair accesses such as p._1 to a new variable 'x'
* ([p._1 |-> x, p._2 |-> y, ..., p._n |-> z])
*/
class MyTreeSubstituter(from: List[Symbol], to: List[Tree]) extends TreeSubstituter(from,to) {
override def transform(tree: Tree): Tree = tree match {
// Useful for substutite values like p._1 where 'p' is a pair
// Inherithed 'TreeSubstituter' cannot handle this
case Select(Ident(_), name) =>
def subst(from: List[Symbol], to: List[Tree]): Tree =
if (from.isEmpty) tree
else if (tree.symbol == from.head) to.head
else subst(from.tail, to.tail);
subst(from, to)
case _ =>
super.transform(tree)
}
}
}
}
package funcheck.scalacheck
import scala.tools.nsc.transform.TypingTransformers
trait GeneratorDefDefInjector extends TypingTransformers {
import global._
def injectGenDefDefs(injecting: List[DefDef], unit: CompilationUnit): Unit =
unit.body = new GenDefDefTransformer(injecting, unit).transform(unit.body)
class GenDefDefTransformer(injecting: List[DefDef], unit: CompilationUnit)
extends /*Code Injection*/ TypingTransformer(unit)
{
override def transform(tree: Tree): Tree = {
curTree = tree
tree match {
case impl @ Template(parents, self, body) =>
atOwner(currentOwner) {
val newBody: List[Tree] = body ::: (injecting.map(localTyper.typed(_)))
val cd = copy.Template(impl, parents, self, newBody)
cd
}
/** Delegates the recursive traversal of the tree. */
case _ => super.transform(tree)
}
}
}
}
\ No newline at end of file
package funcheck.scalacheck
import scala.tools.nsc.Global
import scala.tools.nsc.util.NoPosition
import funcheck.util.FreshNameCreator
/**
* Utilitarity class that is used as a factory for creating Tree nodes for method
* calls of classes and modules in the <code>org.scalacheck</code> package.
*/
trait ScalaCheck extends FreshNameCreator {
val global: Global
import global._
trait GenericScalaCheckModule {
/** Symbol for a module definition. */
protected val moduleSym: Symbol
/** Symbol for the module's companion class definition. */
protected lazy val classSym = moduleSym.linkedClassOfModule
/**
* <p>
* Take a <code>symbol</code> and method <code>name</code> and return the associated
* method's symbol.
* </p>
* <p>
* Note: if <code>symbol</code> does not contain a method named </code>name</code>, the
* returned symbol will be a <code>NoSymbol</code>.
* </p>
*
* @param symbol A module/class symbol,
* @param name A name of the method that should be part of the declared members of the <code>symbol</code>.
* @return The symbol for the method 'symbol.name' or <code>NoSymbol</code> if the <code>symbol</code> does
* not have a member named <code>name</code>.
*/
private def symDecl(symbol: Symbol, name: String) = symbol.tpe.decl(name)
/** Identical to symDecl(symbol: Symbol, name: String), but uses a Name object
* instead of a String for the <code>name</code>.*/
private def symDecl(symbol: Symbol, name: Name) = symbol.tpe.decl(name)
/** Retrieve the constructor Symbol for the passes <code>cs</code> ClassSymbol. */
private def constructorDecl(cs: ClassSymbol) = symDecl(cs, nme.CONSTRUCTOR)
/**
* <p>
* Take a method <code>name</code> and return the associated module method's symbol.
* </p>
* <p>
* Note: if module does not contain a method named </code>name</code>, the
* returned symbol will be a <code>NoSymbol</code>.
* </p>
*
* @param name A name of the method that should be part of the declared members of the module.
* @return The symbol for the method 'module.name' or <code>NoSymbol</code> if the module does
* not have a member named <code>name</code>.
*/
protected def modDecl(method: String) = symDecl(moduleSym, method)
/**
* <p>
* Take a method <code>name</code> and return the associated (module's) companion class method's symbol.
* </p>
* <p>
* Note: if class does not contain a method named </code>name</code>, the
* returned symbol will be a <code>NoSymbol</code>.
* </p>
*
* @param name A name of the method that should be part of the declared members of the class.
* @return The symbol for the method 'class.name' or <code>NoSymbol</code> if the class does
* not have a member named <code>name</code>.
*/
protected def classDecl(method: String) = symDecl(classSym, method)
/**
* <p>
* Take an <code>instance</code> symbol and a <code>method</code> name and return
* valid Scala Tree node of the form 'instance.method'.
* </p>
* <p>
* The precondition for this method to execute is that the <code>instance</code> Symbol
* contains in its members the selected <code>method</code>, otherwise calling this routine
* will throw an exception.
* </p>
*
* @param instance The Symbol for the instance whose <code>method</code> is selected.
* @param method The name of the selected method.
* @return A Scala valid Select Tree node of the form 'instance.method'.
*
*/
protected def select(instance: Symbol, method: String): Select = {
require(instance.tpe.decl(method) != NoSymbol)
Select(Ident(instance), symDecl(instance,method))
}
/**
* <p>
* Apply <code>arg</code> to the passed <code>method</code> contained in the
* <code>moduleSym</code> module and return a valid Scala Tree node of the
* form 'module.method(arg)'.
* </p>
* <p>
* The precondition for this method to execute is that the module Symbol
* contains in its members the passed <code>method</code>, otherwise calling
* this routine will throw an exception.
* </p>
*
* @param method The name of the selected method.
* @args The arguments to which the <code>method</code> is applied to.
* @return A Scala valid Apply Tree node of the form 'moduleSym.method(arg)'.
*
*/
protected def moduleApply(method: String, args: List[Tree]): Apply =
apply(select(moduleSym,method), args)
/** Calls <code>moduleApply</code> and wraps the passed <code>arg</code> into a
* List, i.e., moduleApply(method, List(arg).*/
protected def moduleApply(method: String, arg: Tree): Apply = moduleApply(method,List(arg))
/**
* <p>
* Generic apply. Applies <code>select</code> to the passed list of <code>arguments</code>.
* and return a valid Scala Tree Apply Node of the form 'select(arg1,arg2,...,argN)'.
* </p>
* <p>
* Note: No check is performed to ensure that <code>select</code> can accept
* the passed list of <code>arguments</code>
* </p>
*
* @param select The left hand side of the application.
* @param arguments The arguments of the application.
* @return A Scala valid Apply Tree node of the form 'select(arg1, arg32, ..., argN)'
*/
protected def apply(select: Tree, arguments: List[Tree]): Apply =
Apply(select, arguments)
/** Calls <code>apply</code> and wraps the passed <code>arg</code> into a List,
* i.e., apply(select, List(arg)). */
protected def apply(select: Tree, argument: Tree): Apply =
apply(select, List(argument))
}
/** Module for creating scalac Tree nodes for calling methods of the
* <code>org.scalacheck.Gen</code> class and module.*/
object Gen extends GenericScalaCheckModule {
/** Symbol for the <code>org.scalacheck.Gen</code> module definition. */
override protected lazy val moduleSym = definitions.getModule("org.scalacheck.Gen")
/**
* Apply <code>polyTpe</code> to the polymorphic type <code>org.scalacheck.Gen</code>.
*
* @param polyTpe the type to be applied to <code>org.scalacheck.Gen</code>.
* @return The polymorphic type resulting from applying <code>polyTpe</code>
* to the polymorphic type <code>org.scalacheck.Gen</code>, i.e.,
* <code>Gen[polyTpe]</code>.
*/
private def applyType(polyTpe: Type) = appliedType(classSym.tpe, List(polyTpe))
/**
* This creates a Tree node for the call <code>org.scalacheck.Gen.value[T](rhs)</code>,
* where the polymorphic type <code>T</code> will be inferred during the next
* typer phase (this usually means that the typer has to be called explictly,
* so it is the developer duty to ensure that this happen at some point).
*/
def value(rhs: Tree): Tree = moduleApply("value", rhs)
/**
* This creates a Tree node for the call <code>org.scalacheck.Gen.oneOf[T](generators)</code>,
* where the polymorphic type <code>T</code> will be inferred during the next
* typer phase (this usually means that the typer has to be called explictly,
* so it is the developer duty to ensure that this happen at some point).
*/
def oneOf(generators: List[Symbol]): Tree =
moduleApply("oneOf", generators.map(Ident(_)))
def lzy(generator: Tree): Tree = moduleApply("lzy", generator)
/**
* This creates a Tree node for the call <code>org.scalacheck.Gen.flatMap[T](body)</code>,
* where the polymorphic type <code>T</code> will be inferred during the next
* typer phase (this usually means that the typer has to be called explictly,
* so it is the developer duty to ensure that this happen at some point).
*/
def flatMap(qualifier: Tree, body: Tree): Tree =
apply(Select(qualifier, classDecl("flatMap")), body)
/**
* This creates a Tree node for the call <code>org.scalacheck.Gen.map[T](rhs)</code>,
* where the polymorphic type <code>T</code> will be inferred during the next
* typer phase (this usually means that the typer has to be called explictly,
* so it is the developer duty to ensure that this happen at some point).
*/
def map(qualifier: Tree, body: Tree): Tree =
apply(Select(qualifier, classDecl("map")), body)
/**
* Utilitary method for creating a method symbol for a <code>org.scalacheck.Gen</codee>
* generator method.
*
* @param owner The owner of the method (DefDef) which will use the returned method symbol.
* @param genName The name of the method symbol (which will also be the name of the method).
* @param retTpe The method's returning type.
* @return The method symbol for a generator method.
*/
def createGenDefSymbol(owner: Symbol, genName: String, retTpe: Type): Symbol = {
// returning type of the new method, i.e., Gen["retTpe"]
val genDefRetTpe = applyType(retTpe)
// create a symbol for the generator method that will be created next
owner.newMethod(owner.pos,genName).setInfo(PolyType(List(), genDefRetTpe))
}
/**
* Map that stores for each @generator annotated ClassDef or DefDef the automatically
* generated DefDef for creating instances of the <code>org.scalacheck.Gen</code> class.
* The <code>Type</code> that is associated to the DefDef is either the type of the
* ClassDef or the returning type of the DefDef.
*/
private val tpe2listGen = scala.collection.mutable.Map.empty[Type, List[DefDef]]
private val tpe2listGenSym = scala.collection.mutable.Map.empty[Type, List[Symbol]]
/**
* Add the <code>gen</code> generator DefDef declaration to the list of
* generators for <code>tpe</code>.
*
* @param tpe The type of elements generated by the <code>gen</code>.
* @param gen The DefDef declaration for the generator method.
*/
def +[T](map: collection.mutable.Map[Type, List[T]], key: Type, value: T): Unit = map.get(key) match {
case None => map += key -> List(value)
case Some(values) => map += key -> (value :: values)
}
/** List of generator DefDef symbols for a given type <code>tpe</code>*/
def genSymbolsForType(tpe: Type): List[Symbol] = tpe2listGenSym.get(tpe) match {
case None => Nil
case Some(symbols) => symbols
}
/**
* Second Pass: Create symbols for the generator DefDef that will be created
* durind the Third Pass.
*/
def createGenDefDef(klasses: List[ClassDef], defs: List[DefDef]): List[DefDef] = {
val generable: List[(Symbol,Tree)] = createGenDefSyms(klasses, defs)
for { (genSym, genTree) <- generable } genTree match {
case cd: ClassDef =>
val tpe = cd.symbol.tpe
Gen + (tpe2listGen, tpe, Gen.createGenDef(cd, genSym))
case d: DefDef =>
val tpe = d.tpt.symbol.tpe
val generated = DefDef(genSym, Modifiers(0), List(), rhsGenDef(Ident(d.name))(d)(genSym))
Gen + (tpe2listGen, tpe, generated)
}
// flatten into single list values of Gen.tpe2listGen
(List[DefDef]() /: Gen.tpe2listGen.values) {
case (xs, xss) => xs ::: xss
}
}
/**
* Create method symbols for each <code>@generator</code> annotated ClassDef
* and DefDef.
*/
private def createGenDefSyms(klasses: List[ClassDef], defs: List[DefDef]): List[(Symbol, Tree)] = {
val genKlasses: List[(Symbol, ClassDef)] = for(klass <- klasses) yield {
val genName = fresh.newName(NoPosition, "gen"+klass.name)
val tpe = klass.symbol.tpe
val genSym = createGenDefSymbol(klass.symbol.enclClass.owner, genName, tpe)
Gen + (tpe2listGenSym, tpe, genSym)
(genSym, klass)
}
val genDefs: List[(Symbol, DefDef)] = for(d <- defs) yield {
val genName = fresh.newName(NoPosition, "gen"+d.name)
val tpe = d.tpt.symbol.tpe
val genSym = createGenDefSymbol(d.symbol.owner, genName, tpe)
Gen + (tpe2listGenSym, tpe, genSym)
(genSym, d)
}
genKlasses ::: genDefs
}
def createGenDef(cd: ClassDef, genDef: Symbol): DefDef = {
val d: DefDef = getConstructorOf(cd)
val DefDef(_,_,_,vparamss,retTpe,_) = d
assert(vparamss.size <= 1, "currying is not supported. Change signature of "+cd.symbol)
if(cd.symbol.hasFlag(scala.tools.nsc.symtab.Flags.ABSTRACT)) {
val generators = retTpe.symbol.children.toList.map(s => genSymbolsForType(s.tpe)).flatMap(v=>v)
DefDef(genDef, Modifiers(0), List(), Gen.lzy(Gen.oneOf(generators)))
}
else {
val constrObj = resetAttrs(d.tpt.duplicate)
val instance = Select(New(constrObj), nme.CONSTRUCTOR)
assert(d.tpt.isInstanceOf[TypeTree])
val body = rhsGenDef(instance)(d)(genDef)
DefDef(genDef, Modifiers(0), List(), body)
}
}
/** <code>base</code> is either
* - Select(New(tpe),constructor) [constructor]
* - Ident(name) [method call]
*/
private def rhsGenDef(base: Tree)(d: DefDef)(extOwner: Symbol): Tree = {
val DefDef(_,name,_,vparamss,retTpe,_) = d
assert(vparamss.size <= 1, "currying is not supported. Change signature of "+d.symbol)
// XXX: quick fix to force creation of arbitrary objects for each data type, this should be refactored!!
Arbitrary.arbitrary(retTpe.asInstanceOf[TypeTree])
if(vparamss.head.isEmpty)
Gen.value(Apply(base, Nil))
else {
var owner = extOwner
val paramssTpe: List[ValDef] = vparamss.flatMap(v=>v).map(p =>
ValDef(Modifiers(0), fresh.newName(NoPosition, "v"), resetAttrs(p.tpt.duplicate), EmptyTree))
var last = true
val z :Tree = Apply(base, paramssTpe.map(p => Ident(p.name)))
val body = (paramssTpe :\ z) {
case (param,apply) => {
val body = Function(List(param), apply)
body.symbol.owner = owner
owner = body.symbol
//XXX: it is not flatMap in general. fix this!!
if(last) {
last = false
Gen.map(Arbitrary.arbitrary(param.tpt.asInstanceOf[TypeTree]), body)
} else
Gen.flatMap(Arbitrary.arbitrary(param.tpt.asInstanceOf[TypeTree]), body)
}
}
Gen.lzy(body)
}
}
private def getConstructorOf(cd: ClassDef): DefDef = {
val Template(parents, self, body) = cd.impl
var dd: DefDef = null
for { b <- body } b match {
case d @ DefDef(_, nme.CONSTRUCTOR, _, _, _, _) => dd = d
case _ => ;
}
dd
} ensuring (res => res != null)
}
/** Module for creating scalac Tree nodes for calling methods of the
* <code>org.scalacheck.Arbitrary</code> class and module.*/
object Arbitrary extends GenericScalaCheckModule {
/** Symbol for the <code>org.scalacheck.Arbitrary</code> module definition. */
override protected lazy val moduleSym = definitions.getModule("org.scalacheck.Arbitrary")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbInt</code> method definition. */
private val arbInt = select(moduleSym, "arbInt")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbBool</code> method definition. */
private val arbBool = select(moduleSym, "arbBool")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbLong</code> method definition. */
private val arbLong = select(moduleSym, "arbLong")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbThrowable</code> method definition. */
private val arbThrowable = select(moduleSym, "arbThrowable")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbDouble</code> method definition. */
private val arbDouble = select(moduleSym, "arbDouble")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbChar</code> method definition. */
private val arbChar = select(moduleSym, "arbChar")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbString</code> method definition. */
private val arbString = select(moduleSym, "arbString")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbOption</code> method definition. */
private val arbOption = select(moduleSym, "arbOption")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbImmutableMap</code> method definition. */
private val arbImmutableMap = select(moduleSym, "arbImmutableMap")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbList</code> method definition. */
private val arbList = select(moduleSym, "arbList")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbArray</code> method definition. */
private val arbArray = select(moduleSym, "arbArray")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbSet</code> method definition. */
private val arbSet = select(moduleSym, "arbSet")
/** Symbol for the <code>org.scalacheck.Arbitrary.arbTuple2</code> method definition. */
private val arbTuple2 = select(moduleSym, "arbTuple2")
//[[TODO]]
//lazy val arbMultiSet = Select(Ident(arbitraryModule), arbitraryModule.tpe.decl("arbMultiSet"))
/** Map that stores <code>org.scalacheck.Arbitrary.arbitrary[Type]</code> calls. */
protected val tpe2arbApp = scala.collection.mutable.Map.empty[Type,Tree]
// initialize map with ScalaCheck built-in types that are part of our PureScala language
import definitions._
tpe2arbApp += IntClass.typeConstructor -> arbInt
tpe2arbApp += BooleanClass.typeConstructor -> arbBool
tpe2arbApp += LongClass.typeConstructor -> arbLong
tpe2arbApp += ThrowableClass.typeConstructor -> arbThrowable
tpe2arbApp += DoubleClass.typeConstructor -> arbDouble
tpe2arbApp += CharClass.typeConstructor -> arbChar
tpe2arbApp += StringClass.typeConstructor -> arbString // XXX: NOT WORKING
tpe2arbApp += OptionClass.typeConstructor -> arbOption
//lazy val ImmutableMapClass: Symbol = definitions.getClass(newTypeName("scala.collection.immutable.Map"))
//lazy val ImmutableSetClass: Symbol = definitions.getClass(newTypeName("scala.collection.immutable.Set"))
//tpe2arbApp += ImmutableMapClass.typeConstructor -> arbImmutableMap
tpe2arbApp += ListClass.typeConstructor -> arbList
tpe2arbApp += ArrayClass.typeConstructor -> arbArray
//tpe2arbApp += ImmutableSetClass.typeConstructor -> arbSet
tpe2arbApp += TupleClass(2).typeConstructor -> arbTuple2
/**
* Apply <code>polyTpe</code> to the polymorphic type <code>org.scalacheck.Arbitrary</code>.
*
* @param polyTpe the type to be applied to <code>org.scalacheck.Arbitrary</code>.
* @return The polymorphic type resulting from applying <code>polyTpe</code>
* to the polymorphic type <code>org.scalacheck.Arbitrary</code>, i.e.,
* <code>Arbitrary[polyTpe]</code>.
*/
private def applyType(tpe: Type) = appliedType(classSym.tpe, List(tpe))
/**
* Creates a Tree node for the call <code>org.scalacheck.Arbitrary.apply[T](generator)</code>,
* where the polymorphic type <code>T</code> will be inferred during the next
* typer phase (this usually means that the typer has to be called explictly,
* so it is the developer duty to ensure that this happen at some point).
*/
def apply(generator: Tree): Tree = moduleApply("apply", generator)
def arbitrary(tpe: Type): Tree = tpe2arbApp.get(tpe) match {
case Some(arbTree) => arbTree
case None =>
val TypeRef(_,sym,params) = tpe
apply(arbitrary(sym.typeConstructor), params.map(arbitrary(_)))
}
/**
*
*/
def arbitrary(polyTpe: TypeTree): Apply = {
val symbol = polyTpe.symbol
val tpe = symbol.tpe
tpe2arbApp.get(tpe) match {
case Some(arb) => applyArbitrary(arb)
case None => arbitrary(symbol)
}
}
/** Map that stores not built-in <code>org.scalacheck.Arbitrary</code> DefDef definitions. */
private val tpe2arbDefDef = scala.collection.mutable.Map.empty[Type,DefDef]
def getArbitraryDefDefs: List[DefDef] = tpe2arbDefDef.values.toList
def arbitrary(tpeSym: Symbol): Apply = {
require(tpe2arbApp.get(tpeSym.tpe).isEmpty, "Arbitrary.arbitrary["+tpeSym.tpe+"] is already in the map")
val owner = tpeSym.toplevelClass
val arbName = fresh.newName(NoPosition,"arb"+tpeSym.name)
val tpe = tpeSym.tpe
val arbDef = createArbitraryDefSymbol(owner, arbName, tpe)
val genNames = Gen.genSymbolsForType(tpe)
val generated = DefDef(arbDef, Modifiers(0), List(), Arbitrary(Gen.oneOf(genNames)))
tpe2arbDefDef += tpe -> generated
val result = applyArbitrary(Ident(arbDef))
tpe2arbApp += tpe -> Ident(arbDef)
result
}
protected def applyArbitrary(param: Tree): Apply =
apply(select(moduleSym, "arbitrary"), param)
/**
* Utilitary method for creating a method symbol for a <code>org.scalacheck.Arbitrary</codee>
* generator method.
*
* @param owner The owner of the method (DefDef) which will use the returned method symbol.
* @param arbName The name of the method symbol (which will also be the name of the method).
* @param retTpe The method's returning type.
* @return The method symbol for a generator method.
*/
def createArbitraryDefSymbol(owner: Symbol, arbName: String, retTpe: Type): Symbol = {
// returning type of the new method, i.e., Arbitrary["retTpe"]
val arbRetTpe = applyType(retTpe)
// Create the DefDef for the new Arbitrary object
val arbDef = owner.newMethod(owner.pos, arbName).setInfo(PolyType(List(), arbRetTpe))
// Implicit only because of ScalaCheck rational (not really needed since we are injecting code)
arbDef.setFlag(scala.tools.nsc.symtab.Flags.IMPLICIT)
arbDef
}
}
object Prop extends GenericScalaCheckModule {
/** Symbol for the <code>org.scalacheck.Prop</code> module definition. */
override protected lazy val moduleSym = definitions.getModule("org.scalacheck.Prop")
def forAll(props: List[Tree]): Apply =
moduleApply("forAll", props)
def forAll(prop: Tree): Apply = forAll(List(prop))
def ==>(ifz: Tree, then: Tree): Apply = moduleApply("==>", List(ifz,propBoolean(then)))
def propBoolean(prop: Tree): Apply = moduleApply("propBoolean", List(prop))
}
object Shrink extends GenericScalaCheckModule {
/** Symbol for the <code>org.scalacheck.Shrink</code> module definition. */
override protected lazy val moduleSym = definitions.getModule("org.scalacheck.Shrink")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkInt</code> method definition. */
private val shrinkInt = select(moduleSym, "shrinkInt")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkString</code> method definition. */
private val shrinkString = select(moduleSym, "shrinkString")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkOption</code> method definition. */
private val shrinkOption = select(moduleSym, "shrinkOption")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkList</code> method definition. */
private val shrinkList = select(moduleSym, "shrinkList")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkSet</code> method definition. */
private val shrinkArray = select(moduleSym, "shrinkArray")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkSet</code> method definition. */
private val shrinkSet = select(moduleSym, "shrinkSet")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkTuple2</code> method definition. */
private val shrinkTuple2 = select(moduleSym, "shrinkTuple2")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkTuple3</code> method definition. */
private val shrinkTuple3 = select(moduleSym, "shrinkTuple3")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkTuple4</code> method definition. */
private val shrinkTuple4 = select(moduleSym, "shrinkTuple4")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkTuple5</code> method definition. */
private val shrinkTuple5 = select(moduleSym, "shrinkTuple5")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkTuple6</code> method definition. */
private val shrinkTuple6 = select(moduleSym, "shrinkTuple6")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkTuple7</code> method definition. */
private val shrinkTuple7 = select(moduleSym, "shrinkTuple7")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkTuple8</code> method definition. */
private val shrinkTuple8 = select(moduleSym, "shrinkTuple8")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkTuple9</code> method definition. */
private val shrinkTuple9 = select(moduleSym, "shrinkTuple9")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkIntList</code> method definition. */
private val shrinkIntList = select(moduleSym, "shrinkIntList")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkBooleanList</code> method definition. */
private val shrinkBooleanList = select(moduleSym, "shrinkBooleanList")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkDoubleList</code> method definition. */
private val shrinkDoubleList = select(moduleSym, "shrinkDoubleList")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkStringList</code> method definition. */
private val shrinkStringList = select(moduleSym, "shrinkStringList")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkIntArray</code> method definition. */
private val shrinkIntArray = select(moduleSym, "shrinkIntArray")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkBooleanArray</code> method definition. */
private val shrinkBooleanArray = select(moduleSym, "shrinkBooleanArray")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkDoubleArray</code> method definition. */
private val shrinkDoubleArray = select(moduleSym, "shrinkDoubleArray")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkStringArray</code> method definition. */
private val shrinkStringArray = select(moduleSym, "shrinkStringArray")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkIntSet</code> method definition. */
private val shrinkIntSet = select(moduleSym, "shrinkIntSet")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkBooleanSet</code> method definition. */
private val shrinkBooleanSet = select(moduleSym, "shrinkBooleanSet")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkDoubleSet</code> method definition. */
private val shrinkDoubleSet = select(moduleSym, "shrinkDoubleSet")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkStringSet</code> method definition. */
private val shrinkStringSet = select(moduleSym, "shrinkStringSet")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkIntOption</code> method definition. */
private val shrinkIntOption = select(moduleSym, "shrinkIntOption")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkBooleanOption</code> method definition. */
private val shrinkBooleanOption= select(moduleSym, "shrinkBooleanOption")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkDoubleOption</code> method definition. */
private val shrinkDoubleOption = select(moduleSym, "shrinkDoubleOption")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkStringOption</code> method definition. */
private val shrinkStringOption = select(moduleSym, "shrinkStringOption")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkIntTuple2</code> method definition. */
private val shrinkIntTuple2 = select(moduleSym, "shrinkIntTuple2")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkBooleanTuple2</code> method definition. */
private val shrinkBooleanTuple2= select(moduleSym, "shrinkBooleanTuple2")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkDoubleTuple2</code> method definition. */
private val shrinkDoubleTuple2 = select(moduleSym, "shrinkDoubleTuple2")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkStringTuple2</code> method definition. */
private val shrinkStringTuple2 = select(moduleSym, "shrinkStringTuple2")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkIntTuple3</code> method definition. */
private val shrinkIntTuple3 = select(moduleSym, "shrinkIntTuple3")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkBooleanTuple3</code> method definition. */
private val shrinkBooleanTuple3= select(moduleSym, "shrinkBooleanTuple3")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkDoubleTuple3</code> method definition. */
private val shrinkDoubleTuple3 = select(moduleSym, "shrinkDoubleTuple3")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkStringTuple3</code> method definition. */
private val shrinkStringTuple3 = select(moduleSym, "shrinkStringTuple3")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkIntTuple4</code> method definition. */
private val shrinkIntTuple4 = select(moduleSym, "shrinkIntTuple4")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkBooleanTuple4</code> method definition. */
private val shrinkBooleanTuple4= select(moduleSym, "shrinkBooleanTuple4")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkDoubleTuple4</code> method definition. */
private val shrinkDoubleTuple4 = select(moduleSym, "shrinkDoubleTuple4")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkStringTuple4</code> method definition. */
private val shrinkStringTuple4 = select(moduleSym, "shrinkStringTuple4")
/** Symbol for the <code>org.scalacheck.Shrink.shrinkAny</code> method definition.
* This is a generic shrinker which does not shrink whatever object is passed to it.
*/
private val shrinkAny = select(moduleSym, "shrinkAny")
def shrinker(tpe: Type): Select = tpe2shrinker.getOrElse(tpe, shrinkAny)
private val tpe2shrinker: Map[Type, Select] = {
import definitions._
val SetClass: Symbol = definitions.getClass("scala.collection.immutable.Set")
def apply(container: Type)(parametric: Type): Type =
appliedType(container, List(parametric))
def listOf(tpe: Type): Type = apply(ListClass.typeConstructor)(tpe)
def arrayOf(tpe: Type): Type = apply(ArrayClass.typeConstructor)(tpe)
def setOf(tpe: Type): Type = apply(SetClass.typeConstructor)(tpe)
def optionOf(tpe: Type): Type = apply(OptionClass.typeConstructor)(tpe)
def tupleOf(arity: Int, tpe: Type): Type = apply(TupleClass(arity).typeConstructor)(tpe)
val IntListTpe = listOf(IntClass.typeConstructor)
val BooleanListTpe = listOf(BooleanClass.typeConstructor)
val DoubleListTpe = listOf(DoubleClass.typeConstructor)
val StringListTpe = listOf(StringClass.typeConstructor)
val IntArrayTpe = arrayOf(IntClass.typeConstructor)
val BooleanArrayTpe = arrayOf(BooleanClass.typeConstructor)
val DoubleArrayTpe = arrayOf(DoubleClass.typeConstructor)
val StringArrayTpe = arrayOf(StringClass.typeConstructor)
val IntSetTpe = setOf(IntClass.typeConstructor)
val BooleanSetTpe = setOf(BooleanClass.typeConstructor)
val DoubleSetTpe = setOf(DoubleClass.typeConstructor)
val StringSetTpe = setOf(StringClass.typeConstructor)
val IntOptionTpe = optionOf(IntClass.typeConstructor)
val BooleanOptionTpe = optionOf(BooleanClass.typeConstructor)
val DoubleOptionTpe = optionOf(DoubleClass.typeConstructor)
val StringOptionTpe = optionOf(StringClass.typeConstructor)
val IntTuple2Tpe = tupleOf(2, IntClass.typeConstructor)
val BooleanTuple2Tpe = tupleOf(2, BooleanClass.typeConstructor)
val DoubleTuple2Tpe = tupleOf(2, DoubleClass.typeConstructor)
val StringTuple2Tpe = tupleOf(2, StringClass.typeConstructor)
val IntTuple3Tpe = tupleOf(3, IntClass.typeConstructor)
val BooleanTuple3Tpe = tupleOf(3, BooleanClass.typeConstructor)
val DoubleTuple3Tpe = tupleOf(3, DoubleClass.typeConstructor)
val StringTuple3Tpe = tupleOf(3, StringClass.typeConstructor)
val IntTuple4Tpe = tupleOf(4, IntClass.typeConstructor)
val BooleanTuple4Tpe = tupleOf(4, BooleanClass.typeConstructor)
val DoubleTuple4Tpe = tupleOf(4, DoubleClass.typeConstructor)
val StringTuple4Tpe = tupleOf(4, StringClass.typeConstructor)
Map(
IntClass.typeConstructor -> shrinkInt,
StringClass.typeConstructor -> shrinkString,
OptionClass.typeConstructor -> shrinkOption,
ListClass.typeConstructor -> shrinkList,
ArrayClass.typeConstructor -> shrinkArray,
SetClass.typeConstructor -> shrinkSet,
TupleClass(2).typeConstructor -> shrinkTuple2,
TupleClass(3).typeConstructor -> shrinkTuple3,
TupleClass(4).typeConstructor -> shrinkTuple4,
TupleClass(5).typeConstructor -> shrinkTuple5,
TupleClass(6).typeConstructor -> shrinkTuple6,
TupleClass(7).typeConstructor -> shrinkTuple7,
TupleClass(8).typeConstructor -> shrinkTuple8,
TupleClass(9).typeConstructor -> shrinkTuple9,
IntListTpe -> shrinkIntList,
BooleanListTpe -> shrinkBooleanList,
DoubleListTpe -> shrinkDoubleList,
StringListTpe -> shrinkStringList,
IntArrayTpe -> shrinkIntArray,
BooleanArrayTpe -> shrinkBooleanArray,
DoubleArrayTpe -> shrinkDoubleArray,
StringArrayTpe -> shrinkStringArray,
IntSetTpe -> shrinkIntSet,
BooleanSetTpe -> shrinkBooleanSet,
DoubleSetTpe -> shrinkDoubleSet,
StringSetTpe -> shrinkStringSet,
IntOptionTpe -> shrinkIntOption,
BooleanOptionTpe -> shrinkBooleanOption,
DoubleOptionTpe -> shrinkDoubleOption,
StringOptionTpe -> shrinkStringOption,
IntTuple2Tpe -> shrinkIntTuple2,
BooleanTuple2Tpe -> shrinkBooleanTuple2,
DoubleTuple2Tpe -> shrinkDoubleTuple2,
StringTuple2Tpe -> shrinkStringTuple2,
IntTuple3Tpe -> shrinkIntTuple3,
BooleanTuple3Tpe -> shrinkBooleanTuple3,
DoubleTuple3Tpe -> shrinkDoubleTuple3,
StringTuple3Tpe -> shrinkStringTuple3,
IntTuple4Tpe -> shrinkIntTuple4,
BooleanTuple4Tpe -> shrinkBooleanTuple4,
DoubleTuple4Tpe -> shrinkDoubleTuple4,
StringTuple4Tpe -> shrinkStringTuple4
)
}
}
object ConsoleReporter extends GenericScalaCheckModule {
/** Symbol for the <code>org.scalacheck.ConsoleReporter</code> module definition. */
override protected lazy val moduleSym = definitions.getModule("org.scalacheck.ConsoleReporter")
def testStatsEx(testRes: Tree): Tree = testStatsEx("", testRes)
def testStatsEx(msg: String, testRes: Tree): Tree =
Apply(select(moduleSym, "testStatsEx"), List(Literal(msg), testRes))
}
object Test extends GenericScalaCheckModule {
/** Symbol for the <code>org.scalacheck.Test</code> module definition. */
override protected lazy val moduleSym = definitions.getModule("org.scalacheck.Test")
def check(prop: Tree): Tree = moduleApply("check", prop)
def isPassed(res: Tree): Tree = Select(res, "passed")
}
}
package funcheck.scalacheck
import scala.tools.nsc.{Global, SubComponent}
trait ScalaCheckIntegrator extends ScalaCheck
with FilterGeneratorAnnotations
with GeneratorDefDefInjector
with ForAllTransformer
{
val global: Global
import global._
def createGeneratorDefDefs(unit: CompilationUnit): (List[DefDef], List[DefDef]) = {
val filteredGenTree = new FilterTreeTraverser(filterTreesWithGeneratorAnnotation(unit))
filteredGenTree.traverse(unit.body)
val klasses = collection.mutable.Set.empty[ClassDef]
val defs = collection.mutable.Set.empty[DefDef]
for {tree <- filteredGenTree.hits} tree match {
case c: ClassDef => klasses + c
case d: DefDef => defs + d
}
(Gen.createGenDefDef(klasses.toList,defs.toList), Arbitrary.getArbitraryDefDefs)
}
}
package funcheck.util
trait FreshNameCreator {
var fresh: scala.tools.nsc.util.FreshNameCreator
}
package scala.collection
trait Multiset[A] extends (A => Int) with Collection[A]{
/** Returns the number of elements in this multiset.
*
* @return number of multiset elements.
*/
def size: Int
/** This method allows multisets to be interpreted as predicates.
* It returns <code>0</code>, iff this multiset does not contain
* element <code>elem</code>, or <code>N</code> where <code>N</code>
* is the number of occurences of <code>elem</code> in this multiset.
*
* @param elem the element to check for membership.
* @return <code>0</code> iff <code>elem</code> is not contained in
* this multiset, or <code>N</code> where <code>N</code>
* is the number of occurences of <code>elem</code> in this
* multiset.
*/
def apply(elem: A): Int
/** Checks if this set contains element <code>elem</code>.
*
* @param elem the element to check for membership.
* @return <code>true</code> iff <code>elem</code> is not contained in
* this multiset.
*/
def contains(elem: A): Boolean = apply(elem) > 0
/** Checks if this multiset is empty.
*
* @return <code>true</code> iff there is no element in the multiset.
*/
override def isEmpty: Boolean = size == 0
/** Checks if this multiset is a subset of set <code>that</code>.
*
* @param that another multiset.
* @return <code>true</code> iff the other multiset is a superset of
* this multiset.
* todo: rename to isSubsetOf
*/
def subsetOf(that: Multiset[A]): Boolean =
forall(e => this(e) <= that(e))
/**
* This method is an alias for <code>intersect</code>. It computes an
* intersection with set that. It removes all the elements
* <code>that</code> are not present in that.
*/
def ** (that: Multiset[A]): Multiset[A]
/** @return this multiset as set. */
def asSet: Set[A]
//structural equality
/** Compares this multiset with another object and returns true, iff the
* other object is also a multiset which contains the same elements as
* this multiset, with the same cardinality.
*
* @param that the other object
* @note not necessarily run-time type safe.
* @return <code>true</code> iff this multiset and the other multiset
* contain the same elements, with same cardinality.
*/
override def equals(that: Any): Boolean = that match {
case other: Multiset[_] => other.size == this.size && subsetOf(other.asInstanceOf[Multiset[A]])
case _ => false
}
/** Defines the prefix of this object's <code>toString</code> representation.
*/
override protected def stringPrefix : String = "Multiset"
override def toString = elements.mkString(stringPrefix + "(", ", ", ")")
}
package scala.collection.immutable
class EmptyMultiset[A] extends Multiset[A] with Helper[A]{
def empty[C]: Multiset[C] = new EmptyMultiset[C]
override def size: Int = 0
override def apply(elem: A): Int = 0
override def contains(elem: A): Boolean = false
override def intersect (that: collection.Multiset[A]): Multiset[A] = empty
override def ++ (elems: Iterable[A]): Multiset[A] = iterable2multiset(elems)
override def +++ (elems: Iterable[A]): Multiset[A] = this ++ elems
override def --(elems: Iterable[A]): Multiset[A] = empty
override def elements: Iterator[A] = Iterator.empty
override def asSet: Set[A] = new EmptySet[A]
}
package scala.collection.immutable
object HashMultiset {
/** The empty multiset of this type. */
def empty[A]: Multiset[A] = new EmptyMultiset[A]
/** The canonical factory for this type */
def apply[A](elems: A*) = empty[A] ++ elems
}
class HashMultiset[A] private[immutable] (private val map: Map[A,Int]) extends Multiset[A] with Helper[A] {
def empty[C]: Multiset[C] = new EmptyMultiset[C]
override def size: Int = map.values.foldLeft(0)((a,b) => a+b)
override def apply(elem: A): Int = map.getOrElse(elem,0)
override def intersect (that: collection.Multiset[A]): Multiset[A] = {
def inner(entries: List[A], result: Map[A,Int]): Map[A,Int] = entries match {
case Nil => result
case elem :: rest => inner(rest, result.update(elem, min(this(elem),that(elem))))
}
new HashMultiset[A](inner(asSet.toList,new HashMap[A,Int].empty))
}
override def ++ (elems: Iterable[A]): Multiset[A] = {
val that = iterable2multiset(elems)
def inner(entries: List[A], result: Map[A,Int]): Map[A,Int] = entries match {
case Nil => result
case elem :: rest =>
inner(rest, result.update(elem,max(result.getOrElse(elem,0),that(elem))))
}
new HashMultiset[A](inner(that.asSet.toList, map))
}
override def +++ (elems: Iterable[A]): Multiset[A] = {
def inner(entries: List[A], result: Map[A,Int]): Map[A,Int] = entries match {
case Nil => result
case elem :: rest =>
inner(rest, result.update(elem,result.getOrElse(elem,0)+1))
}
new HashMultiset[A](inner(elems.toList,map))
}
override def --(elems: Iterable[A]): Multiset[A] = {
val that = iterable2multiset(elems)
def inner(entries: List[A], result: Map[A,Int]): Map[A,Int] = entries match {
case Nil => result
case elem :: rest =>
val diff = result.getOrElse(elem,0) - that(elem)
if(diff > 0)
inner(rest, result.update(elem,diff))
else
inner(rest, result - elem)
}
new HashMultiset[A](inner(that.asSet.toList,map))
}
override def elements: Iterator[A] = {
def inner(entries: List[A], result: List[A]): List[A] = entries match {
case Nil => result
case elem :: rest =>
inner(rest, result ::: int2list(elem, this(elem)))
}
inner(map.keys.toList, Nil).elements
}
override def asSet: Set[A] = Set.empty[A] ++ map.keys
}
package scala.collection.immutable
private[immutable] trait Helper[A] {
protected def int2list[C](elem: C, times: Int): List[C] = {
require(times >= 0)
if(times == 0)
Nil
else
elem :: int2list(elem,times-1)
} ensuring (res => res.size == times)
protected def iterable2multiset(elems: Iterable[A]): Multiset[A] = {
def inner(elems: List[A], result: Map[A,Int]): Map[A,Int] = elems match {
case Nil => result
case elem :: tail => inner(tail, result.update(elem, result.getOrElse(elem,0) + 1))
}
new HashMultiset[A](inner(elems.toList,new scala.collection.immutable.HashMap[A,Int].empty))
}
protected def min(a: Int, b: Int): Int = if(a <= b) a else b
protected def max(a: Int, b: Int): Int = if(a < b) b else a
}
package scala.collection.immutable
object Multiset {
/** The empty set of this type */
def empty[A]: Multiset[A] = new EmptyMultiset[A]
/** The canonical factory for this type */
def apply[A](elems: A*): Multiset[A] = empty[A] +++ elems
}
trait Multiset[A] extends AnyRef with collection.Multiset[A]{
/** This method is an alias for <code>intersect</code>.
* It computes an intersection with multiset <code>that</code>.
* It removes all the elements that are not present in <code>that</code>.
*
* @param that the multiset to intersect with
*/
final override def ** (that: collection.Multiset[A]): Multiset[A] = intersect(that)
/**
* This method computes an intersection with multiset that. It removes all
* the elements that are not present in that.
*/
def intersect (that: collection.Multiset[A]): Multiset[A]
// A U elems (max)
def ++ (elems: Iterable[A]): Multiset[A]
// A U elems (sum)
def +++ (elems: Iterable[A]): Multiset[A]
// A \ {elems}
def --(elems: Iterable[A]): Multiset[A]
// A U {elems}
final def + (elems: A*): Multiset[A] = this ++ elems
// A \ {elems}
final def - (elems: A*): Multiset[A] = this -- elems
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment