From 4f05cc01f01dcfba63d6612845e30e05ed1c0810 Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Fri, 1 Apr 2022 20:33:43 +0200
Subject: [PATCH] Add LeNet, Add Sharding

---
 src/decentralizepy/datasets/CIFAR10.py     | 59 +++++++++++++++++++---
 src/decentralizepy/datasets/Partitioner.py | 47 +++++++++++++++++
 2 files changed, 99 insertions(+), 7 deletions(-)

diff --git a/src/decentralizepy/datasets/CIFAR10.py b/src/decentralizepy/datasets/CIFAR10.py
index 5519d9b..22d7464 100644
--- a/src/decentralizepy/datasets/CIFAR10.py
+++ b/src/decentralizepy/datasets/CIFAR10.py
@@ -1,7 +1,5 @@
 import logging
-import os
 
-import numpy as np
 import torch
 import torch.nn.functional as F
 import torchvision
@@ -9,9 +7,8 @@ import torchvision.transforms as transforms
 from torch import nn
 from torch.utils.data import DataLoader
 
-from decentralizepy.datasets.Data import Data
 from decentralizepy.datasets.Dataset import Dataset
-from decentralizepy.datasets.Partitioner import DataPartitioner, SimpleDataPartitioner
+from decentralizepy.datasets.Partitioner import DataPartitioner, KShardDataPartitioner
 from decentralizepy.mappings.Mapping import Mapping
 from decentralizepy.models.Model import Model
 
@@ -53,9 +50,9 @@ class CIFAR10(Dataset):
             all_trainset = []
             for y, x in train_data.items():
                 all_trainset.extend([(a, y) for a in x])
-            self.trainset = SimpleDataPartitioner(all_trainset, self.sizes).use(
-                self.uid
-            )
+            self.trainset = KShardDataPartitioner(
+                all_trainset, self.sizes, shards=self.shards
+            ).use(self.uid)
 
     def load_testset(self):
         """
@@ -79,6 +76,7 @@ class CIFAR10(Dataset):
         sizes="",
         test_batch_size=1024,
         partition_niid=False,
+        shards=1,
     ):
         """
         Constructor which reads the data files, instantiates and partitions the dataset
@@ -117,6 +115,7 @@ class CIFAR10(Dataset):
         )
 
         self.partition_niid = partition_niid
+        self.shards = shards
         self.transform = transforms.Compose(
             [
                 transforms.ToTensor(),
@@ -274,3 +273,49 @@ class CNN(Model):
         x = F.relu(self.fc2(x))
         x = self.fc3(x)
         return x
+
+
+class LeNet(Model):
+    """
+    Class for a LeNet Model for CIFAR10
+    Inspired by original LeNet network for MNIST: https://ieeexplore.ieee.org/abstract/document/726791
+
+    """
+
+    def __init__(self):
+        """
+        Constructor. Instantiates the CNN Model
+            with 10 output classes
+
+        """
+        super().__init__()
+        self.conv1 = nn.Conv2d(3, 32, 5, padding="same")
+        self.pool = nn.MaxPool2d(2, 2)
+        self.gn1 = nn.GroupNorm(2, 32)
+        self.conv2 = nn.Conv2d(32, 32, 5, padding="same")
+        self.gn2 = nn.GroupNorm(2, 32)
+        self.conv3 = nn.Conv2d(32, 64, 5, padding="same")
+        self.gn3 = nn.GroupNorm(2, 64)
+        self.fc1 = nn.Linear(64 * 4 * 4, NUM_CLASSES)
+
+    def forward(self, x):
+        """
+        Forward pass of the model
+
+        Parameters
+        ----------
+        x : torch.tensor
+            The input torch tensor
+
+        Returns
+        -------
+        torch.tensor
+            The output torch tensor
+
+        """
+        x = self.pool(F.relu(self.gn1(self.conv1(x))))
+        x = self.pool(F.relu(self.gn2(self.conv2(x))))
+        x = self.pool(F.relu(self.gn3(self.conv3(x))))
+        x = torch.flatten(x, 1)
+        x = self.fc1(x)
+        return x
diff --git a/src/decentralizepy/datasets/Partitioner.py b/src/decentralizepy/datasets/Partitioner.py
index 0d9710f..e93c456 100644
--- a/src/decentralizepy/datasets/Partitioner.py
+++ b/src/decentralizepy/datasets/Partitioner.py
@@ -131,3 +131,50 @@ class SimpleDataPartitioner(DataPartitioner):
             part_len = int(frac * data_len)
             self.partitions.append(indexes[0:part_len])
             indexes = indexes[part_len:]
+
+
+class KShardDataPartitioner(DataPartitioner):
+    """
+    Class to partition the dataset
+
+    """
+
+    def __init__(self, data, sizes=[1.0], shards=1, seed=1234):
+        """
+        Constructor. Partitions the data according the parameters
+
+        Parameters
+        ----------
+        data : indexable
+            An indexable list of data items
+        sizes : list(float)
+            A list of fractions for each process
+        shards : int
+            Number of shards to allot to process
+        seed : int, optional
+            Seed for generating a random subset
+
+        """
+        self.data = data
+        self.partitions = []
+        data_len = len(data)
+        indexes = [x for x in range(0, data_len)]
+        rng = Random()
+        rng.seed(seed)
+
+        for frac in sizes:
+            self.partitions.append([])
+            for _ in range(shards):
+                start = rng.randint(0, len(indexes) - 1)
+                part_len = int(frac * data_len) // shards
+                if start + part_len > len(indexes):
+                    self.partitions[-1].extend(indexes[start:])
+                    self.partitions[-1].extend(
+                        indexes[: (start + part_len - len(indexes))]
+                    )
+                    indexes = indexes[(start + part_len - len(indexes)) : start]
+                else:
+                    self.partitions[-1].extend(indexes[start : start + part_len])
+                    index_start = indexes[:start]
+                    index_start.extend(indexes[start + part_len :])
+                    indexes = index_start
-- 
GitLab