From 480fe17548bc2c18bcdb5822dd1848f356cae4be Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Wed, 11 Jan 2023 14:23:12 +0100
Subject: [PATCH] Add label distribution option

---
 src/decentralizepy/datasets/CIFAR10.py |  2 ++
 src/decentralizepy/datasets/Celeba.py  |  2 ++
 src/decentralizepy/datasets/Dataset.py | 15 +++++++++++++++
 src/decentralizepy/datasets/Femnist.py |  2 ++
 4 files changed, 21 insertions(+)

diff --git a/src/decentralizepy/datasets/CIFAR10.py b/src/decentralizepy/datasets/CIFAR10.py
index 22d7464..9b0029e 100644
--- a/src/decentralizepy/datasets/CIFAR10.py
+++ b/src/decentralizepy/datasets/CIFAR10.py
@@ -114,6 +114,8 @@ class CIFAR10(Dataset):
             test_batch_size,
         )
 
+        self.num_classes = NUM_CLASSES
+
         self.partition_niid = partition_niid
         self.shards = shards
         self.transform = transforms.Compose(
diff --git a/src/decentralizepy/datasets/Celeba.py b/src/decentralizepy/datasets/Celeba.py
index ab077b3..02b2571 100644
--- a/src/decentralizepy/datasets/Celeba.py
+++ b/src/decentralizepy/datasets/Celeba.py
@@ -230,6 +230,8 @@ class Celeba(Dataset):
         self.IMAGES_DIR = utils.conditional_value(images_dir, "", None)
         assert self.IMAGES_DIR != None
 
+        self.num_classes = NUM_CLASSES
+
         if self.__training__:
             self.load_trainset()
 
diff --git a/src/decentralizepy/datasets/Dataset.py b/src/decentralizepy/datasets/Dataset.py
index 18f736c..468b447 100644
--- a/src/decentralizepy/datasets/Dataset.py
+++ b/src/decentralizepy/datasets/Dataset.py
@@ -52,6 +52,7 @@ class Dataset:
         self.test_dir = utils.conditional_value(test_dir, "", None)
         self.sizes = utils.conditional_value(sizes, "", None)
         self.test_batch_size = utils.conditional_value(test_batch_size, "", 64)
+        self.num_classes = None
         if self.sizes:
             if type(self.sizes) == str:
                 self.sizes = eval(self.sizes)
@@ -66,6 +67,20 @@ class Dataset:
         else:
             self.__testing__ = False
 
+        self.label_distribution = None
+
+    def get_label_distribution(self):
+        # Only supported for classification
+        if self.label_distribution == None:
+            self.label_distribution = [0 for _ in range(self.num_classes)]
+            tr_set = self.get_trainset()
+            for _, ys in tr_set:
+                for y in ys:
+                    y_val = y.item()
+                    self.label_distribution[y_val] += 1
+
+        return self.label_distribution
+
     def get_trainset(self):
         """
         Function to get the training set
diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py
index a7b2677..38c10bb 100644
--- a/src/decentralizepy/datasets/Femnist.py
+++ b/src/decentralizepy/datasets/Femnist.py
@@ -223,6 +223,8 @@ class Femnist(Dataset):
             test_batch_size,
         )
 
+        self.num_classes = NUM_CLASSES
+
         if self.__training__:
             self.load_trainset()
 
-- 
GitLab