From 0d50bc39e7eff29e0a244d5a3f5c7d41a603108a Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Thu, 20 Jan 2022 13:18:28 +0100
Subject: [PATCH] GradientAccumulator migration to steps

---
 .../training/GradientAccumulator.py           | 67 +++++++++++++------
 src/decentralizepy/training/Training.py       | 15 +++--
 2 files changed, 53 insertions(+), 29 deletions(-)

diff --git a/src/decentralizepy/training/GradientAccumulator.py b/src/decentralizepy/training/GradientAccumulator.py
index 66e1e5b..23d4fdc 100644
--- a/src/decentralizepy/training/GradientAccumulator.py
+++ b/src/decentralizepy/training/GradientAccumulator.py
@@ -26,35 +26,58 @@ class GradientAccumulator(Training):
         """
         super().__init__(model, optimizer, loss, epochs_per_round, batch_size, shuffle)
 
-    def train(self, dataset):
+    def trainstep(self, data, target):
         """
-        One training iteration with accumulation of gradients in model.accumulated_gradients.
-        Goes through the entire dataset.
+        One training step on a minibatch.
         Parameters
         ----------
-        dataset : decentralizepy.datasets.Dataset
-            The training dataset. Should implement get_trainset(batch_size, shuffle)
+        data : any
+            Data item
+        target : any
+            Label
+        Returns
+        -------
+        int
+            Loss Value for the step
         """
-        trainset = dataset.get_trainset(self.batch_size, self.shuffle)
-        self.model.accumulated_gradients = []
+        self.model.zero_grad()
+        output = self.model(data)
+        loss_val = self.loss(output, target)
+        loss_val.backward()
+        logging.debug("Accumulating Gradients")
+        self.model.accumulated_gradients.append(
+            {
+                k: v.grad.clone().detach()
+                for k, v in zip(self.model.state_dict(), self.model.parameters())
+            }
+        )
+        self.optimizer.step()
+        return loss_val.item()
 
-        for epoch in range(self.epochs_per_round):
+    def train_full(self, trainset):
+        """
+        One training iteration, goes through the entire dataset
+        Parameters
+        ----------
+        trainset : torch.utils.data.Dataloader
+            The training dataset.
+        """
+        for epoch in range(self.rounds):
             epoch_loss = 0.0
             count = 0
             for data, target in trainset:
-                self.model.zero_grad()
-                output = self.model(data)
-                loss_val = self.loss(output, target)
-                epoch_loss += loss_val.item()
-                loss_val.backward()
-                self.model.accumulated_gradients.append(
-                    {
-                        k: v.grad.clone().detach()
-                        for k, v in zip(
-                            self.model.state_dict(), self.model.parameters()
-                        )
-                    }
-                )
-                self.optimizer.step()
+                epoch_loss += self.trainstep(data, target)
                 count += 1
             logging.info("Epoch: {} loss: {}".format(epoch, epoch_loss / count))
+
+    def train(self, dataset):
+        """
+        One training iteration with accumulation of gradients in model.accumulated_gradients.
+        Goes through the entire dataset.
+        Parameters
+        ----------
+        dataset : decentralizepy.datasets.Dataset
+            The training dataset. Should implement get_trainset(batch_size, shuffle)
+        """
+        self.model.accumulated_gradients = []
+        super().train(dataset)
diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py
index a7fb3b1..d5b3e9c 100644
--- a/src/decentralizepy/training/Training.py
+++ b/src/decentralizepy/training/Training.py
@@ -107,7 +107,7 @@ class Training:
         trainset : torch.utils.data.Dataloader
             The training dataset.
         """
-        for epoch in range(self.epochs_per_round):
+        for epoch in range(self.rounds):
             epoch_loss = 0.0
             count = 0
             for data, target in trainset:
@@ -130,9 +130,10 @@ class Training:
         else:
             iter_loss = 0.0
             count = 0
-            for data, target in trainset:
-                iter_loss += self.trainstep(data, target)
-                count += 1
-                logging.info("Round: {} loss: {}".format(count, iter_loss / count))
-                if count >= self.rounds:
-                    break
+            while count < self.rounds:
+                for data, target in trainset:
+                    iter_loss += self.trainstep(data, target)
+                    count += 1
+                    logging.info("Round: {} loss: {}".format(count, iter_loss / count))
+                    if count >= self.rounds:
+                        break
-- 
GitLab