Skip to content
Snippets Groups Projects
Commit 480fe175 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Add label distribution option

parent b2098c44
No related branches found
No related tags found
No related merge requests found
......@@ -114,6 +114,8 @@ class CIFAR10(Dataset):
test_batch_size,
)
self.num_classes = NUM_CLASSES
self.partition_niid = partition_niid
self.shards = shards
self.transform = transforms.Compose(
......
......@@ -230,6 +230,8 @@ class Celeba(Dataset):
self.IMAGES_DIR = utils.conditional_value(images_dir, "", None)
assert self.IMAGES_DIR != None
self.num_classes = NUM_CLASSES
if self.__training__:
self.load_trainset()
......
......@@ -52,6 +52,7 @@ class Dataset:
self.test_dir = utils.conditional_value(test_dir, "", None)
self.sizes = utils.conditional_value(sizes, "", None)
self.test_batch_size = utils.conditional_value(test_batch_size, "", 64)
self.num_classes = None
if self.sizes:
if type(self.sizes) == str:
self.sizes = eval(self.sizes)
......@@ -66,6 +67,20 @@ class Dataset:
else:
self.__testing__ = False
self.label_distribution = None
def get_label_distribution(self):
# Only supported for classification
if self.label_distribution == None:
self.label_distribution = [0 for _ in range(self.num_classes)]
tr_set = self.get_trainset()
for _, ys in tr_set:
for y in ys:
y_val = y.item()
self.label_distribution[y_val] += 1
return self.label_distribution
def get_trainset(self):
"""
Function to get the training set
......
......@@ -223,6 +223,8 @@ class Femnist(Dataset):
test_batch_size,
)
self.num_classes = NUM_CLASSES
if self.__training__:
self.load_trainset()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment