diff --git a/src/decentralizepy/datasets/CIFAR10.py b/src/decentralizepy/datasets/CIFAR10.py
index 22d7464efbd4a03d18bc581c114061af92f2c769..9b0029e04b10425782af2c78a347074b7d9c6a35 100644
--- a/src/decentralizepy/datasets/CIFAR10.py
+++ b/src/decentralizepy/datasets/CIFAR10.py
@@ -114,6 +114,8 @@ class CIFAR10(Dataset):
             test_batch_size,
         )
 
+        self.num_classes = NUM_CLASSES
+
         self.partition_niid = partition_niid
         self.shards = shards
         self.transform = transforms.Compose(
diff --git a/src/decentralizepy/datasets/Celeba.py b/src/decentralizepy/datasets/Celeba.py
index ab077b378422f58407e738e4367ab3877c183bd6..02b2571d521876562e12ebc4bdf3cf55bbd3ff1e 100644
--- a/src/decentralizepy/datasets/Celeba.py
+++ b/src/decentralizepy/datasets/Celeba.py
@@ -230,6 +230,8 @@ class Celeba(Dataset):
         self.IMAGES_DIR = utils.conditional_value(images_dir, "", None)
         assert self.IMAGES_DIR != None
 
+        self.num_classes = NUM_CLASSES
+
         if self.__training__:
             self.load_trainset()
 
diff --git a/src/decentralizepy/datasets/Dataset.py b/src/decentralizepy/datasets/Dataset.py
index 18f736cffc887406207713851ec9c3a1dbab26a9..468b4474fc2008f71c26a7deb870a14dac599077 100644
--- a/src/decentralizepy/datasets/Dataset.py
+++ b/src/decentralizepy/datasets/Dataset.py
@@ -52,6 +52,7 @@ class Dataset:
         self.test_dir = utils.conditional_value(test_dir, "", None)
         self.sizes = utils.conditional_value(sizes, "", None)
         self.test_batch_size = utils.conditional_value(test_batch_size, "", 64)
+        self.num_classes = None
         if self.sizes:
             if type(self.sizes) == str:
                 self.sizes = eval(self.sizes)
@@ -66,6 +67,20 @@ class Dataset:
         else:
             self.__testing__ = False
 
+        self.label_distribution = None
+
+    def get_label_distribution(self):
+        # Only supported for classification
+        if self.label_distribution == None:
+            self.label_distribution = [0 for _ in range(self.num_classes)]
+            tr_set = self.get_trainset()
+            for _, ys in tr_set:
+                for y in ys:
+                    y_val = y.item()
+                    self.label_distribution[y_val] += 1
+
+        return self.label_distribution
+
     def get_trainset(self):
         """
         Function to get the training set
diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py
index a7b267752281dafda454a891be2e60145a98f8ad..38c10bb4a4514b58ea4989d0206a19a80d3b11fb 100644
--- a/src/decentralizepy/datasets/Femnist.py
+++ b/src/decentralizepy/datasets/Femnist.py
@@ -223,6 +223,8 @@ class Femnist(Dataset):
             test_batch_size,
         )
 
+        self.num_classes = NUM_CLASSES
+
         if self.__training__:
             self.load_trainset()