From 46f88e9f5dbd33c5aa78f66da12a2732c75cd8da Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Thu, 6 Jan 2022 20:57:49 +0100
Subject: [PATCH] Deserialize later, add more comments

---
 src/decentralizepy/sharing/GrowingAlpha.py    |  1 -
 src/decentralizepy/sharing/PartialModel.py    | 24 ++++++-
 src/decentralizepy/sharing/Sharing.py         | 65 +++++++++++++++++--
 .../training/GradientAccumulator.py           | 20 +++++-
 src/decentralizepy/training/Training.py       | 13 +++-
 src/decentralizepy/utils.py                   | 44 +++++++++++++
 6 files changed, 157 insertions(+), 10 deletions(-)

diff --git a/src/decentralizepy/sharing/GrowingAlpha.py b/src/decentralizepy/sharing/GrowingAlpha.py
index afe8676..f587cf5 100644
--- a/src/decentralizepy/sharing/GrowingAlpha.py
+++ b/src/decentralizepy/sharing/GrowingAlpha.py
@@ -60,7 +60,6 @@ class GrowingAlpha(PartialModel):
                     self.graph,
                     self.model,
                     self.dataset,
-                    self.log_dir,
                 )
                 self.base.communication_round = self.communication_round
             self.base.step()
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index ed91e39..1df11c4 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -23,9 +23,31 @@ class PartialModel(Sharing):
         dict_ordered=True,
         save_shared=False,
     ):
+        """
+        Constructor
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yer! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        """
         super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+            rank, machine_id, communication, mapping, graph, model, dataset
         )
+        self.log_dir = log_dir
         self.alpha = alpha
         self.dict_ordered = dict_ordered
         self.save_shared = save_shared
diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py
index 91d25d0..648e481 100644
--- a/src/decentralizepy/sharing/Sharing.py
+++ b/src/decentralizepy/sharing/Sharing.py
@@ -11,9 +11,28 @@ class Sharing:
     API defining who to share with and what, and what to do on receiving
     """
 
-    def __init__(
-        self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir
-    ):
+    def __init__(self, rank, machine_id, communication, mapping, graph, model, dataset):
+        """
+        Constructor
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yer! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        """
         self.rank = rank
         self.machine_id = machine_id
         self.uid = mapping.get_uid(rank, machine_id)
@@ -22,7 +41,6 @@ class Sharing:
         self.graph = graph
         self.model = model
         self.dataset = dataset
-        self.log_dir = log_dir
         self.communication_round = 0
 
         self.peer_deques = dict()
@@ -31,22 +49,58 @@ class Sharing:
             self.peer_deques[n] = deque()
 
     def received_from_all(self):
+        """
+        Check if all neighbors have sent the current iteration
+        Returns
+        -------
+        bool
+            True if required data has been received, False otherwise
+        """
         for _, i in self.peer_deques.items():
             if len(i) == 0:
                 return False
         return True
 
     def get_neighbors(self, neighbors):
+        """
+        Choose which neighbors to share with
+        Parameters
+        ----------
+        neighbors : list(int)
+            List of all neighbors
+        Returns
+        -------
+        list(int)
+            Neighbors to share with
+        """
         # modify neighbors here
         return neighbors
 
     def serialized_model(self):
+        """
+        Convert model to json dict. Here we can choose how much to share
+        Returns
+        -------
+        dict
+            Model converted to json dict
+        """
         m = dict()
         for key, val in self.model.state_dict().items():
             m[key] = json.dumps(val.numpy().tolist())
         return m
 
     def deserialized_model(self, m):
+        """
+        Convert received json dict to state_dict.
+        Parameters
+        ----------
+        m : dict
+            json dict received
+        Returns
+        -------
+        state_dict
+            state_dict of received
+        """
         state_dict = dict()
         for key, value in m.items():
             state_dict[key] = torch.from_numpy(numpy.array(json.loads(value)))
@@ -67,7 +121,7 @@ class Sharing:
             logging.debug("Received model from {}".format(sender))
             degree = data["degree"]
             del data["degree"]
-            self.peer_deques[sender].append((degree, self.deserialized_model(data)))
+            self.peer_deques[sender].append((degree, data))
             logging.debug("Deserialized received model from {}".format(sender))
 
         logging.info("Starting model averaging after receiving from all neighbors")
@@ -76,6 +130,7 @@ class Sharing:
         for i, n in enumerate(self.peer_deques):
             logging.debug("Averaging model from neighbor {}".format(i))
             degree, data = self.peer_deques[n].popleft()
+            data = self.deserialized_model(data)
             weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
             weight_total += weight
             for key, value in data.items():
diff --git a/src/decentralizepy/training/GradientAccumulator.py b/src/decentralizepy/training/GradientAccumulator.py
index 9ea42e6..66e1e5b 100644
--- a/src/decentralizepy/training/GradientAccumulator.py
+++ b/src/decentralizepy/training/GradientAccumulator.py
@@ -7,11 +7,29 @@ class GradientAccumulator(Training):
     def __init__(
         self, model, optimizer, loss, epochs_per_round="", batch_size="", shuffle=""
     ):
+        """
+        Constructor
+        Parameters
+        ----------
+        model : torch.nn.Module
+            Neural Network for training
+        optimizer : torch.optim
+            Optimizer to learn parameters
+        loss : function
+            Loss function
+        epochs_per_round : int, optional
+            Number of epochs per training call
+        batch_size : int, optional
+            Number of items to learn over, in one batch
+        shuffle : bool
+            True if the dataset should be shuffled before training. Not implemented yet! TODO
+        """
         super().__init__(model, optimizer, loss, epochs_per_round, batch_size, shuffle)
 
     def train(self, dataset):
         """
-        One training iteration with accumulation of gradients in model.accumulated_gradients
+        One training iteration with accumulation of gradients in model.accumulated_gradients.
+        Goes through the entire dataset.
         Parameters
         ----------
         dataset : decentralizepy.datasets.Dataset
diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py
index 2b55820..6129b34 100644
--- a/src/decentralizepy/training/Training.py
+++ b/src/decentralizepy/training/Training.py
@@ -27,6 +27,8 @@ class Training:
             Number of epochs per training call
         batch_size : int, optional
             Number of items to learn over, in one batch
+        shuffle : bool
+            True if the dataset should be shuffled before training. Not implemented yet! TODO
         """
         self.model = model
         self.optimizer = optimizer
@@ -36,11 +38,18 @@ class Training:
         self.shuffle = utils.conditional_value(shuffle, "", False)
 
     def reset_optimizer(self, optimizer):
+        """
+        Replace the current optimizer with a new one
+        Parameters
+        ----------
+        optimizer : torch.optim
+            A new optimizer
+        """
         self.optimizer = optimizer
 
     def eval_loss(self, dataset):
         """
-        Evaluate the loss
+        Evaluate the loss on the training set
         Parameters
         ----------
         dataset : decentralizepy.datasets.Dataset
@@ -61,7 +70,7 @@ class Training:
 
     def train(self, dataset):
         """
-        One training iteration
+        One training iteration, goes through the entire dataset
         Parameters
         ----------
         dataset : decentralizepy.datasets.Dataset
diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py
index a6bab0d..76eec06 100644
--- a/src/decentralizepy/utils.py
+++ b/src/decentralizepy/utils.py
@@ -5,6 +5,21 @@ import os
 
 
 def conditional_value(var, nul, default):
+    """
+    Set the value to default if nul.
+    Parameters
+    ----------
+    var : any
+        The value
+    nul : any
+        The null value. Assigns default if var == nul
+    default : any
+        The default value
+    Returns
+    -------
+    type(var)
+        The final value
+    """
     if var != nul:
         return var
     else:
@@ -12,10 +27,30 @@ def conditional_value(var, nul, default):
 
 
 def remove_keys(d, keys_to_remove):
+    """
+    Removes given keys from the dict. Returns a new list.
+    Parameters
+    ----------
+    d : dict
+        The initial dictionary
+    keys_to_remove : list
+        List of keys to remove from dict
+    Returns
+    -------
+    dict
+        A new dictionary with the given keys removed.
+    """
     return {key: d[key] for key in d if key not in keys_to_remove}
 
 
 def get_args():
+    """
+    Utility to parse arguments.
+    Returns
+    -------
+    args
+        Command line arguments
+    """
     parser = argparse.ArgumentParser()
     parser.add_argument("-mid", "--machine_id", type=int, default=0)
     parser.add_argument("-ps", "--procs_per_machine", type=int, default=1)
@@ -38,6 +73,15 @@ def get_args():
 
 
 def write_args(args, path):
+    """
+    Write arguments to a json file
+    Parameters
+    ----------
+    args : args
+        Command line args
+    path : str
+        Location of the file to write to
+    """
     data = {
         "machine_id": args.machine_id,
         "procs_per_machine": args.procs_per_machine,
-- 
GitLab