From 480fe17548bc2c18bcdb5822dd1848f356cae4be Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Wed, 11 Jan 2023 14:23:12 +0100 Subject: [PATCH] Add label distribution option --- src/decentralizepy/datasets/CIFAR10.py | 2 ++ src/decentralizepy/datasets/Celeba.py | 2 ++ src/decentralizepy/datasets/Dataset.py | 15 +++++++++++++++ src/decentralizepy/datasets/Femnist.py | 2 ++ 4 files changed, 21 insertions(+) diff --git a/src/decentralizepy/datasets/CIFAR10.py b/src/decentralizepy/datasets/CIFAR10.py index 22d7464..9b0029e 100644 --- a/src/decentralizepy/datasets/CIFAR10.py +++ b/src/decentralizepy/datasets/CIFAR10.py @@ -114,6 +114,8 @@ class CIFAR10(Dataset): test_batch_size, ) + self.num_classes = NUM_CLASSES + self.partition_niid = partition_niid self.shards = shards self.transform = transforms.Compose( diff --git a/src/decentralizepy/datasets/Celeba.py b/src/decentralizepy/datasets/Celeba.py index ab077b3..02b2571 100644 --- a/src/decentralizepy/datasets/Celeba.py +++ b/src/decentralizepy/datasets/Celeba.py @@ -230,6 +230,8 @@ class Celeba(Dataset): self.IMAGES_DIR = utils.conditional_value(images_dir, "", None) assert self.IMAGES_DIR != None + self.num_classes = NUM_CLASSES + if self.__training__: self.load_trainset() diff --git a/src/decentralizepy/datasets/Dataset.py b/src/decentralizepy/datasets/Dataset.py index 18f736c..468b447 100644 --- a/src/decentralizepy/datasets/Dataset.py +++ b/src/decentralizepy/datasets/Dataset.py @@ -52,6 +52,7 @@ class Dataset: self.test_dir = utils.conditional_value(test_dir, "", None) self.sizes = utils.conditional_value(sizes, "", None) self.test_batch_size = utils.conditional_value(test_batch_size, "", 64) + self.num_classes = None if self.sizes: if type(self.sizes) == str: self.sizes = eval(self.sizes) @@ -66,6 +67,20 @@ class Dataset: else: self.__testing__ = False + self.label_distribution = None + + def get_label_distribution(self): + # Only supported for classification + if self.label_distribution == None: + self.label_distribution = [0 for _ in range(self.num_classes)] + tr_set = self.get_trainset() + for _, ys in tr_set: + for y in ys: + y_val = y.item() + self.label_distribution[y_val] += 1 + + return self.label_distribution + def get_trainset(self): """ Function to get the training set diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py index a7b2677..38c10bb 100644 --- a/src/decentralizepy/datasets/Femnist.py +++ b/src/decentralizepy/datasets/Femnist.py @@ -223,6 +223,8 @@ class Femnist(Dataset): test_batch_size, ) + self.num_classes = NUM_CLASSES + if self.__training__: self.load_trainset() -- GitLab