From 806a5638679b8a5a586adf7f69012faa4619b29e Mon Sep 17 00:00:00 2001
From: Rishi Sharma <33762894+rishi-s8@users.noreply.github.com>
Date: Thu, 10 Feb 2022 11:41:57 +0100
Subject: [PATCH] ChangeAccumulator

---
 README.rst                                    |   1 +
 eval/main.ipynb                               |  24 ++-
 src/decentralizepy/node/Node.py               |   2 +-
 .../training/ChangeAccumulator.py             | 148 ++++++++++++++++++
 .../training/GradientAccumulator.py           |   5 +-
 src/decentralizepy/training/Training.py       |   4 +
 6 files changed, 180 insertions(+), 4 deletions(-)
 create mode 100644 src/decentralizepy/training/ChangeAccumulator.py

diff --git a/README.rst b/README.rst
index 38cba02..5acb1c4 100644
--- a/README.rst
+++ b/README.rst
@@ -15,6 +15,7 @@ Setting up decentralizepy
     pip3 install --upgrade pip
     pip install --upgrade pip
 
+* On Mac M1, installing ``pyzmq`` fails with `pip`. Use ``conda``.
 * Install decentralizepy for development. ::
 
     pip3 install --editable .\[dev\]
diff --git a/eval/main.ipynb b/eval/main.ipynb
index db66e89..80daae6 100644
--- a/eval/main.ipynb
+++ b/eval/main.ipynb
@@ -268,7 +268,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [
     {
@@ -300,6 +300,26 @@
     "\n"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<generator object Module.parameters at 0x111ec1ba0>"
+      ]
+     },
+     "execution_count": 2,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "m1.parameters()"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 25,
@@ -5715,7 +5735,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.9.7"
+   "version": "3.8.11"
   },
   "orig_nbformat": 4
  },
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index ef66d17..24c3168 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -177,7 +177,7 @@ class Node:
             ],
         )
         self.trainer = train_class(
-            self.model, self.optimizer, self.loss, **train_params
+            self.model, self.optimizer, self.loss, self.log_dir, **train_params
         )
 
     def init_comm(self, comm_configs):
diff --git a/src/decentralizepy/training/ChangeAccumulator.py b/src/decentralizepy/training/ChangeAccumulator.py
new file mode 100644
index 0000000..cf9641f
--- /dev/null
+++ b/src/decentralizepy/training/ChangeAccumulator.py
@@ -0,0 +1,148 @@
+import json
+import os
+from pathlib import Path
+
+import torch
+
+from decentralizepy.training.Training import Training
+from decentralizepy.utils import conditional_value
+
+
+class ChangeAccumulator(Training):
+    """
+    This class implements the training module which also accumulates model change in a list.
+
+    """
+
+    def __init__(
+        self,
+        model,
+        optimizer,
+        loss,
+        log_dir,
+        rounds="",
+        full_epochs="",
+        batch_size="",
+        shuffle="",
+        save_accumulated="",
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        model : torch.nn.Module
+            Neural Network for training
+        optimizer : torch.optim
+            Optimizer to learn parameters
+        loss : function
+            Loss function
+        log_dir : str
+            Directory to log the model change.
+        rounds : int, optional
+            Number of steps/epochs per training call
+        full_epochs: bool, optional
+            True if 1 round = 1 epoch. False if 1 round = 1 minibatch
+        batch_size : int, optional
+            Number of items to learn over, in one batch
+        shuffle : bool
+            True if the dataset should be shuffled before training.
+        save_accumulated : bool
+            True if accumulated weight change should be written to file
+
+        """
+        super().__init__(
+            model, optimizer, loss, log_dir, rounds, full_epochs, batch_size, shuffle
+        )
+        self.save_accumulated = conditional_value(save_accumulated, "", True)
+        self.communication_round = 0
+        if self.save_accumulated:
+            self.model_change_path = os.path.join(
+                self.log_dir, "model_change/{}".format(self.rank)
+            )
+            Path(self.model_change_path).mkdir(parents=True, exist_ok=True)
+
+            self.model_val_path = os.path.join(
+                self.log_dir, "model_val/{}".format(self.rank)
+            )
+            Path(self.model_val_path).mkdir(parents=True, exist_ok=True)
+
+    def save_vector(self, v, s):
+        """
+        Saves the given vector to the file.
+
+        Parameters
+        ----------
+        v : torch.tensor
+            The torch tensor to write to file
+        s : str
+            Path to folder to write to
+
+        """
+        output_dict = dict()
+        output_dict["order"] = list(self.model.state_dict().keys())
+        shapes = dict()
+        for k, v in self.model.state_dict().items():
+            shapes[k] = list(v.shape)
+        output_dict["shapes"] = shapes
+
+        output_dict[self.communication_round] = v.tolist()
+
+        with open(
+            os.path.join(
+                s,
+                "{}.json".format(self.communication_round + 1),
+            ),
+            "w",
+        ) as of:
+            json.dump(output_dict, of)
+
+    def save_change(self):
+        """
+        Saves the change and the gradient values for every iteration
+
+        """
+        tensors_to_cat = [
+            v.data.flatten() for _, v in self.model.accumulated_gradients[0].items()
+        ]
+        change = torch.abs(torch.cat(tensors_to_cat, dim=0))
+        self.save_vector(change, self.model_change_path)
+
+    def save_model_params(self):
+        """
+        Saves the change and the gradient values for every iteration
+
+        """
+        tensors_to_cat = [v.data.flatten() for _, v in self.model.items()]
+        params = torch.abs(torch.cat(tensors_to_cat, dim=0))
+        self.save_vector(params, self.model_val_path)
+
+    def train(self, dataset):
+        """
+        One training iteration with accumulation of model change in model.accumulated_gradients.
+        Goes through the entire dataset.
+
+        Parameters
+        ----------
+        dataset : decentralizepy.datasets.Dataset
+            The training dataset. Should implement get_trainset(batch_size, shuffle)
+
+        """
+        self.model.accumulated_gradients = []
+        self.init_model = {
+            k: v.data.clone().detach()
+            for k, v in zip(self.model.state_dict(), self.model.parameters())
+        }
+        super().train(dataset)
+        with torch.no_grad():
+            change = {
+                k: v.data.clone().detach() - self.init_model[k]
+                for k, v in zip(self.model.state_dict(), self.model.parameters())
+            }
+            self.model.accumulated_gradients.append(change)
+
+            if self.save_accumulated:
+                self.save_change()
+                self.save_model_params()
+
+        self.communication_round += 1
diff --git a/src/decentralizepy/training/GradientAccumulator.py b/src/decentralizepy/training/GradientAccumulator.py
index 3c16059..3171019 100644
--- a/src/decentralizepy/training/GradientAccumulator.py
+++ b/src/decentralizepy/training/GradientAccumulator.py
@@ -14,6 +14,7 @@ class GradientAccumulator(Training):
         model,
         optimizer,
         loss,
+        log_dir,
         rounds="",
         full_epochs="",
         batch_size="",
@@ -30,6 +31,8 @@ class GradientAccumulator(Training):
             Optimizer to learn parameters
         loss : function
             Loss function
+        log_dir : str
+            Directory to log the model change.
         rounds : int, optional
             Number of steps/epochs per training call
         full_epochs: bool, optional
@@ -41,7 +44,7 @@ class GradientAccumulator(Training):
 
         """
         super().__init__(
-            model, optimizer, loss, rounds, full_epochs, batch_size, shuffle
+            model, optimizer, loss, log_dir, rounds, full_epochs, batch_size, shuffle
         )
 
     def trainstep(self, data, target):
diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py
index 6696f35..47c8f77 100644
--- a/src/decentralizepy/training/Training.py
+++ b/src/decentralizepy/training/Training.py
@@ -16,6 +16,7 @@ class Training:
         model,
         optimizer,
         loss,
+        log_dir,
         rounds="",
         full_epochs="",
         batch_size="",
@@ -32,6 +33,8 @@ class Training:
             Optimizer to learn parameters
         loss : function
             Loss function
+        log_dir : str
+            Directory to log the model change.
         rounds : int, optional
             Number of steps/epochs per training call
         full_epochs: bool, optional
@@ -45,6 +48,7 @@ class Training:
         self.model = model
         self.optimizer = optimizer
         self.loss = loss
+        self.log_dir = log_dir
         self.rounds = utils.conditional_value(rounds, "", int(1))
         self.full_epochs = utils.conditional_value(full_epochs, "", False)
         self.batch_size = utils.conditional_value(batch_size, "", int(1))
-- 
GitLab