diff --git a/src/decentralizepy/datasets/CIFAR10.py b/src/decentralizepy/datasets/CIFAR10.py index 22d7464efbd4a03d18bc581c114061af92f2c769..9b0029e04b10425782af2c78a347074b7d9c6a35 100644 --- a/src/decentralizepy/datasets/CIFAR10.py +++ b/src/decentralizepy/datasets/CIFAR10.py @@ -114,6 +114,8 @@ class CIFAR10(Dataset): test_batch_size, ) + self.num_classes = NUM_CLASSES + self.partition_niid = partition_niid self.shards = shards self.transform = transforms.Compose( diff --git a/src/decentralizepy/datasets/Celeba.py b/src/decentralizepy/datasets/Celeba.py index ab077b378422f58407e738e4367ab3877c183bd6..02b2571d521876562e12ebc4bdf3cf55bbd3ff1e 100644 --- a/src/decentralizepy/datasets/Celeba.py +++ b/src/decentralizepy/datasets/Celeba.py @@ -230,6 +230,8 @@ class Celeba(Dataset): self.IMAGES_DIR = utils.conditional_value(images_dir, "", None) assert self.IMAGES_DIR != None + self.num_classes = NUM_CLASSES + if self.__training__: self.load_trainset() diff --git a/src/decentralizepy/datasets/Dataset.py b/src/decentralizepy/datasets/Dataset.py index 18f736cffc887406207713851ec9c3a1dbab26a9..468b4474fc2008f71c26a7deb870a14dac599077 100644 --- a/src/decentralizepy/datasets/Dataset.py +++ b/src/decentralizepy/datasets/Dataset.py @@ -52,6 +52,7 @@ class Dataset: self.test_dir = utils.conditional_value(test_dir, "", None) self.sizes = utils.conditional_value(sizes, "", None) self.test_batch_size = utils.conditional_value(test_batch_size, "", 64) + self.num_classes = None if self.sizes: if type(self.sizes) == str: self.sizes = eval(self.sizes) @@ -66,6 +67,20 @@ class Dataset: else: self.__testing__ = False + self.label_distribution = None + + def get_label_distribution(self): + # Only supported for classification + if self.label_distribution == None: + self.label_distribution = [0 for _ in range(self.num_classes)] + tr_set = self.get_trainset() + for _, ys in tr_set: + for y in ys: + y_val = y.item() + self.label_distribution[y_val] += 1 + + return self.label_distribution + def get_trainset(self): """ Function to get the training set diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py index a7b267752281dafda454a891be2e60145a98f8ad..38c10bb4a4514b58ea4989d0206a19a80d3b11fb 100644 --- a/src/decentralizepy/datasets/Femnist.py +++ b/src/decentralizepy/datasets/Femnist.py @@ -223,6 +223,8 @@ class Femnist(Dataset): test_batch_size, ) + self.num_classes = NUM_CLASSES + if self.__training__: self.load_trainset()