diff --git a/src/decentralizepy/sharing/GrowingAlpha.py b/src/decentralizepy/sharing/GrowingAlpha.py
index afe86768ededc073de60518373192513184ec03d..f587cf5f73a667b3d66ad19d0282a7eaaef8c9ef 100644
--- a/src/decentralizepy/sharing/GrowingAlpha.py
+++ b/src/decentralizepy/sharing/GrowingAlpha.py
@@ -60,7 +60,6 @@ class GrowingAlpha(PartialModel):
                     self.graph,
                     self.model,
                     self.dataset,
-                    self.log_dir,
                 )
                 self.base.communication_round = self.communication_round
             self.base.step()
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index ed91e39630aa1be2e6bec6874041ac226443c20a..1df11c4871ebeedb0ef4cb805e2c5637e996cb0c 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -23,9 +23,31 @@ class PartialModel(Sharing):
         dict_ordered=True,
         save_shared=False,
     ):
+        """
+        Constructor
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yer! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        """
         super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+            rank, machine_id, communication, mapping, graph, model, dataset
         )
+        self.log_dir = log_dir
         self.alpha = alpha
         self.dict_ordered = dict_ordered
         self.save_shared = save_shared
diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py
index 91d25d06badb20500e1dd8c797f4c1d84cca1b8b..648e481d9e2439a6048b95aa6b9760cb64f91c1e 100644
--- a/src/decentralizepy/sharing/Sharing.py
+++ b/src/decentralizepy/sharing/Sharing.py
@@ -11,9 +11,28 @@ class Sharing:
     API defining who to share with and what, and what to do on receiving
     """
 
-    def __init__(
-        self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir
-    ):
+    def __init__(self, rank, machine_id, communication, mapping, graph, model, dataset):
+        """
+        Constructor
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yer! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        """
         self.rank = rank
         self.machine_id = machine_id
         self.uid = mapping.get_uid(rank, machine_id)
@@ -22,7 +41,6 @@ class Sharing:
         self.graph = graph
         self.model = model
         self.dataset = dataset
-        self.log_dir = log_dir
         self.communication_round = 0
 
         self.peer_deques = dict()
@@ -31,22 +49,58 @@ class Sharing:
             self.peer_deques[n] = deque()
 
     def received_from_all(self):
+        """
+        Check if all neighbors have sent the current iteration
+        Returns
+        -------
+        bool
+            True if required data has been received, False otherwise
+        """
         for _, i in self.peer_deques.items():
             if len(i) == 0:
                 return False
         return True
 
     def get_neighbors(self, neighbors):
+        """
+        Choose which neighbors to share with
+        Parameters
+        ----------
+        neighbors : list(int)
+            List of all neighbors
+        Returns
+        -------
+        list(int)
+            Neighbors to share with
+        """
         # modify neighbors here
         return neighbors
 
     def serialized_model(self):
+        """
+        Convert model to json dict. Here we can choose how much to share
+        Returns
+        -------
+        dict
+            Model converted to json dict
+        """
         m = dict()
         for key, val in self.model.state_dict().items():
             m[key] = json.dumps(val.numpy().tolist())
         return m
 
     def deserialized_model(self, m):
+        """
+        Convert received json dict to state_dict.
+        Parameters
+        ----------
+        m : dict
+            json dict received
+        Returns
+        -------
+        state_dict
+            state_dict of received
+        """
         state_dict = dict()
         for key, value in m.items():
             state_dict[key] = torch.from_numpy(numpy.array(json.loads(value)))
@@ -67,7 +121,7 @@ class Sharing:
             logging.debug("Received model from {}".format(sender))
             degree = data["degree"]
             del data["degree"]
-            self.peer_deques[sender].append((degree, self.deserialized_model(data)))
+            self.peer_deques[sender].append((degree, data))
             logging.debug("Deserialized received model from {}".format(sender))
 
         logging.info("Starting model averaging after receiving from all neighbors")
@@ -76,6 +130,7 @@ class Sharing:
         for i, n in enumerate(self.peer_deques):
             logging.debug("Averaging model from neighbor {}".format(i))
             degree, data = self.peer_deques[n].popleft()
+            data = self.deserialized_model(data)
             weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
             weight_total += weight
             for key, value in data.items():
diff --git a/src/decentralizepy/training/GradientAccumulator.py b/src/decentralizepy/training/GradientAccumulator.py
index 9ea42e62b68c01c1c6a620d7a62c28f46a0dc2eb..66e1e5bf6efdb81126b5edd788ed72e5ba829b48 100644
--- a/src/decentralizepy/training/GradientAccumulator.py
+++ b/src/decentralizepy/training/GradientAccumulator.py
@@ -7,11 +7,29 @@ class GradientAccumulator(Training):
     def __init__(
         self, model, optimizer, loss, epochs_per_round="", batch_size="", shuffle=""
     ):
+        """
+        Constructor
+        Parameters
+        ----------
+        model : torch.nn.Module
+            Neural Network for training
+        optimizer : torch.optim
+            Optimizer to learn parameters
+        loss : function
+            Loss function
+        epochs_per_round : int, optional
+            Number of epochs per training call
+        batch_size : int, optional
+            Number of items to learn over, in one batch
+        shuffle : bool
+            True if the dataset should be shuffled before training. Not implemented yet! TODO
+        """
         super().__init__(model, optimizer, loss, epochs_per_round, batch_size, shuffle)
 
     def train(self, dataset):
         """
-        One training iteration with accumulation of gradients in model.accumulated_gradients
+        One training iteration with accumulation of gradients in model.accumulated_gradients.
+        Goes through the entire dataset.
         Parameters
         ----------
         dataset : decentralizepy.datasets.Dataset
diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py
index 2b55820b9c6770cfeb6889b9efe30f520fc87c95..6129b34b90e6ae9abbf7e94d7d93dab668cf5b2e 100644
--- a/src/decentralizepy/training/Training.py
+++ b/src/decentralizepy/training/Training.py
@@ -27,6 +27,8 @@ class Training:
             Number of epochs per training call
         batch_size : int, optional
             Number of items to learn over, in one batch
+        shuffle : bool
+            True if the dataset should be shuffled before training. Not implemented yet! TODO
         """
         self.model = model
         self.optimizer = optimizer
@@ -36,11 +38,18 @@ class Training:
         self.shuffle = utils.conditional_value(shuffle, "", False)
 
     def reset_optimizer(self, optimizer):
+        """
+        Replace the current optimizer with a new one
+        Parameters
+        ----------
+        optimizer : torch.optim
+            A new optimizer
+        """
         self.optimizer = optimizer
 
     def eval_loss(self, dataset):
         """
-        Evaluate the loss
+        Evaluate the loss on the training set
         Parameters
         ----------
         dataset : decentralizepy.datasets.Dataset
@@ -61,7 +70,7 @@ class Training:
 
     def train(self, dataset):
         """
-        One training iteration
+        One training iteration, goes through the entire dataset
         Parameters
         ----------
         dataset : decentralizepy.datasets.Dataset
diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py
index a6bab0ddf4959363ad597c4ba0a160fa313ebc18..76eec069b0399000a3744931baca8590b3e94fe1 100644
--- a/src/decentralizepy/utils.py
+++ b/src/decentralizepy/utils.py
@@ -5,6 +5,21 @@ import os
 
 
 def conditional_value(var, nul, default):
+    """
+    Set the value to default if nul.
+    Parameters
+    ----------
+    var : any
+        The value
+    nul : any
+        The null value. Assigns default if var == nul
+    default : any
+        The default value
+    Returns
+    -------
+    type(var)
+        The final value
+    """
     if var != nul:
         return var
     else:
@@ -12,10 +27,30 @@ def conditional_value(var, nul, default):
 
 
 def remove_keys(d, keys_to_remove):
+    """
+    Removes given keys from the dict. Returns a new list.
+    Parameters
+    ----------
+    d : dict
+        The initial dictionary
+    keys_to_remove : list
+        List of keys to remove from dict
+    Returns
+    -------
+    dict
+        A new dictionary with the given keys removed.
+    """
     return {key: d[key] for key in d if key not in keys_to_remove}
 
 
 def get_args():
+    """
+    Utility to parse arguments.
+    Returns
+    -------
+    args
+        Command line arguments
+    """
     parser = argparse.ArgumentParser()
     parser.add_argument("-mid", "--machine_id", type=int, default=0)
     parser.add_argument("-ps", "--procs_per_machine", type=int, default=1)
@@ -38,6 +73,15 @@ def get_args():
 
 
 def write_args(args, path):
+    """
+    Write arguments to a json file
+    Parameters
+    ----------
+    args : args
+        Command line args
+    path : str
+        Location of the file to write to
+    """
     data = {
         "machine_id": args.machine_id,
         "procs_per_machine": args.procs_per_machine,