From ff8587c36fa6df6d54087369ce094efeb4780d8a Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Wed, 23 Mar 2022 14:25:03 +0000
Subject: [PATCH] Reddit

---
 eval/step_configs/config_reddit_sharing.ini |  33 ++
 split_into_files.py                         |   6 +-
 src/decentralizepy/datasets/CIFAR10.py      |  41 +-
 src/decentralizepy/datasets/Partitioner.py  |   1 +
 src/decentralizepy/datasets/Reddit.py       | 553 ++++++++++++++++++++
 src/decentralizepy/node/Node.py             |   5 +-
 6 files changed, 615 insertions(+), 24 deletions(-)
 create mode 100644 eval/step_configs/config_reddit_sharing.ini
 create mode 100644 src/decentralizepy/datasets/Reddit.py

diff --git a/eval/step_configs/config_reddit_sharing.ini b/eval/step_configs/config_reddit_sharing.ini
new file mode 100644
index 0000000..0aa4af0
--- /dev/null
+++ b/eval/step_configs/config_reddit_sharing.ini
@@ -0,0 +1,33 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Reddit
+dataset_class = Reddit
+random_seed = 97
+model_class = RNN
+train_dir = /mnt/nfs/shared/leaf/data/reddit_new/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/reddit_new/new_small_data/test
+; python list of fractions below
+sizes =
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.Sharing
+sharing_class = Sharing
diff --git a/split_into_files.py b/split_into_files.py
index fbc29d9..fe489f7 100644
--- a/split_into_files.py
+++ b/split_into_files.py
@@ -1,9 +1,11 @@
 import sys
 
-from decentralizepy.datasets.Femnist import Femnist
+from decentralizepy.datasets.Reddit import Reddit
+from decentralizepy.mappings import Linear
 
 if __name__ == "__main__":
-    f = Femnist(None, None, None)
+    mapping = Linear(6, 16)
+    f = Reddit(0, 0, mapping)
     assert len(sys.argv) == 3
     frm = sys.argv[1]
     to = sys.argv[2]
diff --git a/src/decentralizepy/datasets/CIFAR10.py b/src/decentralizepy/datasets/CIFAR10.py
index 5de71cb..5519d9b 100644
--- a/src/decentralizepy/datasets/CIFAR10.py
+++ b/src/decentralizepy/datasets/CIFAR10.py
@@ -3,9 +3,9 @@ import os
 
 import numpy as np
 import torch
+import torch.nn.functional as F
 import torchvision
 import torchvision.transforms as transforms
-import torch.nn.functional as F
 from torch import nn
 from torch.utils.data import DataLoader
 
@@ -17,6 +17,7 @@ from decentralizepy.models.Model import Model
 
 NUM_CLASSES = 10
 
+
 class CIFAR10(Dataset):
     """
     Class for the FEMNIST dataset
@@ -29,11 +30,11 @@ class CIFAR10(Dataset):
 
         """
         logging.info("Loading training set.")
-        trainset = torchvision.datasets.CIFAR10(root=self.train_dir, train=True,
-                                        download=True, transform=self.transform)
+        trainset = torchvision.datasets.CIFAR10(
+            root=self.train_dir, train=True, download=True, transform=self.transform
+        )
         c_len = len(trainset)
 
-
         if self.sizes == None:  # Equal distribution of data among processes
             e = c_len // self.n_procs
             frac = e / c_len
@@ -45,14 +46,16 @@ class CIFAR10(Dataset):
 
         if not self.partition_niid:
             self.trainset = DataPartitioner(trainset, self.sizes).use(self.uid)
-        else:        
+        else:
             train_data = {key: [] for key in range(10)}
             for x, y in trainset:
                 train_data[y].append(x)
             all_trainset = []
             for y, x in train_data.items():
                 all_trainset.extend([(a, y) for a in x])
-            self.trainset = SimpleDataPartitioner(all_trainset, self.sizes).use(self.uid)
+            self.trainset = SimpleDataPartitioner(all_trainset, self.sizes).use(
+                self.uid
+            )
 
     def load_testset(self):
         """
@@ -60,10 +63,10 @@ class CIFAR10(Dataset):
 
         """
         logging.info("Loading testing set.")
-        
-        self.testset = torchvision.datasets.CIFAR10(root=self.test_dir, train=False,
-                                       download=True, transform=self.transform)
-        
+
+        self.testset = torchvision.datasets.CIFAR10(
+            root=self.test_dir, train=False, download=True, transform=self.transform
+        )
 
     def __init__(
         self,
@@ -75,7 +78,7 @@ class CIFAR10(Dataset):
         test_dir="",
         sizes="",
         test_batch_size=1024,
-        partition_niid=False
+        partition_niid=False,
     ):
         """
         Constructor which reads the data files, instantiates and partitions the dataset
@@ -115,8 +118,11 @@ class CIFAR10(Dataset):
 
         self.partition_niid = partition_niid
         self.transform = transforms.Compose(
-            [transforms.ToTensor(),
-             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+            [
+                transforms.ToTensor(),
+                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+            ]
+        )
 
         if self.__training__:
             self.load_trainset()
@@ -146,9 +152,7 @@ class CIFAR10(Dataset):
 
         """
         if self.__training__:
-            return DataLoader(
-                self.trainset, batch_size=batch_size, shuffle=shuffle
-            )
+            return DataLoader(self.trainset, batch_size=batch_size, shuffle=shuffle)
         raise RuntimeError("Training set not initialized!")
 
     def get_testset(self):
@@ -166,9 +170,7 @@ class CIFAR10(Dataset):
 
         """
         if self.__testing__:
-            return DataLoader(
-                self.testset , batch_size=self.test_batch_size
-            )
+            return DataLoader(self.testset, batch_size=self.test_batch_size)
         raise RuntimeError("Test set not initialized!")
 
     def test(self, model, loss):
@@ -228,6 +230,7 @@ class CIFAR10(Dataset):
         logging.info("Overall accuracy is: {:.1f} %".format(accuracy))
         return accuracy, loss_val
 
+
 class CNN(Model):
     """
     Class for a CNN Model for CIFAR10
diff --git a/src/decentralizepy/datasets/Partitioner.py b/src/decentralizepy/datasets/Partitioner.py
index 80c3a80..0d9710f 100644
--- a/src/decentralizepy/datasets/Partitioner.py
+++ b/src/decentralizepy/datasets/Partitioner.py
@@ -103,6 +103,7 @@ class DataPartitioner(object):
         """
         return Partition(self.data, self.partitions[rank])
 
+
 class SimpleDataPartitioner(DataPartitioner):
     """
     Class to partition the dataset
diff --git a/src/decentralizepy/datasets/Reddit.py b/src/decentralizepy/datasets/Reddit.py
new file mode 100644
index 0000000..4bc3e2f
--- /dev/null
+++ b/src/decentralizepy/datasets/Reddit.py
@@ -0,0 +1,553 @@
+import collections
+import json
+import logging
+import os
+import pickle
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.utils.data import DataLoader
+
+from decentralizepy.datasets.Data import Data
+from decentralizepy.datasets.Dataset import Dataset
+from decentralizepy.datasets.Partitioner import DataPartitioner
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.models.Model import Model
+
+VOCAB_LEN = 9999  # 10000 was used as it needed to be +1 due to using mask_zero in the tf embedding
+SEQ_LEN = 10
+EMBEDDING_DIM = 200
+
+
+class Reddit(Dataset):
+    """
+    Class for the Reddit dataset
+    --  Based on https://gitlab.epfl.ch/sacs/efficient-federated-learning/-/blob/master/grad_guessing/data_utils.py
+        and Femnist.py
+    """
+
+    def __read_file__(self, file_path):
+        """
+        Read data from the given json file
+
+        Parameters
+        ----------
+        file_path : str
+            The file path
+
+        Returns
+        -------
+        tuple
+            (users, num_samples, data)
+
+        """
+        with open(file_path, "r") as inf:
+            client_data = json.load(inf)
+        return (
+            client_data["users"],
+            client_data["num_samples"],
+            client_data["user_data"],
+        )
+
+    def __read_dir__(self, data_dir):
+        """
+        Function to read all the Reddit data files in the directory
+
+        Parameters
+        ----------
+        data_dir : str
+            Path to the folder containing the data files
+
+        Returns
+        -------
+        3-tuple
+            A tuple containing list of users, number of samples per client,
+            and the data items per client
+
+        """
+        users = []
+        num_samples = []
+        data = defaultdict(lambda: None)
+
+        files = os.listdir(data_dir)
+        files = [f for f in files if f.endswith(".json")]
+        for f in files:
+            file_path = os.path.join(data_dir, f)
+            u, n, d = self.__read_file__(file_path)
+            users.extend(u)
+            num_samples.extend(n)
+            data.update(d)
+        return users, num_samples, data
+
+    def file_per_user(self, dir, write_dir):
+        """
+        Function to read all the Reddit data files and write one file per user
+
+        Parameters
+        ----------
+        dir : str
+            Path to the folder containing the data files
+        write_dir : str
+            Path to the folder to write the files
+
+        """
+        clients, num_samples, train_data = self.__read_dir__(dir)
+        for index, client in enumerate(clients):
+            my_data = dict()
+            my_data["users"] = [client]
+            my_data["num_samples"] = num_samples[index]
+            my_samples = {"x": train_data[client]["x"], "y": train_data[client]["y"]}
+            my_data["user_data"] = {client: my_samples}
+            with open(os.path.join(write_dir, client + ".json"), "w") as of:
+                json.dump(my_data, of)
+                print("Created File: ", client + ".json")
+
+    def load_trainset(self):
+        """
+        Loads the training set. Partitions it if needed.
+
+        """
+        logging.info("Loading training set.")
+        files = os.listdir(self.train_dir)
+        files = [f for f in files if f.endswith(".json")]
+        files.sort()
+        c_len = len(files)
+
+        # clients, num_samples, train_data = self.__read_dir__(self.train_dir)
+
+        if self.sizes == None:  # Equal distribution of data among processes
+            e = c_len // self.n_procs
+            frac = e / c_len
+            self.sizes = [frac] * self.n_procs
+            self.sizes[-1] += 1.0 - frac * self.n_procs
+            logging.debug("Size fractions: {}".format(self.sizes))
+
+        self.uid = self.mapping.get_uid(self.rank, self.machine_id)
+        my_clients = DataPartitioner(files, self.sizes).use(self.uid)
+        my_train_data = {"x": [], "y": []}
+        self.clients = []
+        self.num_samples = []
+        logging.debug("Clients Length: %d", c_len)
+        logging.debug("My_clients_len: %d", my_clients.__len__())
+        for i in range(my_clients.__len__()):
+            cur_file = my_clients.__getitem__(i)
+
+            clients, _, train_data = self.__read_file__(
+                os.path.join(self.train_dir, cur_file)
+            )
+            for cur_client in clients:
+                self.clients.append(cur_client)
+                processed_x, processed_y = self.prepare_data(train_data[cur_client])
+                # processed_x is an list of fixed size word id arrays that represent a phrase
+                # processed_y is a list of word ids that each represent the next word of a phrase
+                my_train_data["x"].extend(processed_x)
+                my_train_data["y"].extend(processed_y)
+                self.num_samples.append(len(processed_y))
+        # turns the list of lists into a single list
+        self.train_y = np.array(my_train_data["y"], dtype=np.dtype("int64")).reshape(-1)
+        self.train_x = np.array(
+            my_train_data["x"], dtype=np.dtype("int64")
+        )  # .reshape(-1)
+        logging.info("train_x.shape: %s", str(self.train_x.shape))
+        logging.info("train_y.shape: %s", str(self.train_y.shape))
+        assert self.train_x.shape[0] == self.train_y.shape[0]
+        assert self.train_x.shape[0] > 0
+
+    def load_testset(self):
+        """
+        Loads the testing set.
+
+        """
+        logging.info("Loading testing set.")
+        _, _, d = self.__read_dir__(self.test_dir)
+        test_x = []
+        test_y = []
+        for test_data in d.values():
+            processed_x, processed_y = self.prepare_data(test_data)
+            # processed_x is an list of fixed size word id arrays that represent a phrase
+            # processed_y is a list of word ids that each represent the next word of a phrase
+            test_x.extend(processed_x)
+            test_y.extend(processed_y)
+        self.test_y = np.array(test_y, dtype=np.dtype("int64")).reshape(-1)
+        self.test_x = np.array(test_x, dtype=np.dtype("int64"))
+        logging.info("test_x.shape: %s", str(self.test_x.shape))
+        logging.info("test_y.shape: %s", str(self.test_y.shape))
+        assert self.test_x.shape[0] == self.test_y.shape[0]
+        assert self.test_x.shape[0] > 0
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        n_procs="",
+        train_dir="",
+        test_dir="",
+        sizes="",
+        test_batch_size=1024,
+    ):
+        """
+        Constructor which reads the data files, instantiates and partitions the dataset
+
+        Parameters
+        ----------
+        rank : int
+            Rank of the current process (to get the partition).
+        machine_id : int
+            Machine ID
+        mapping : decentralizepy.mappings.Mapping
+            Mapping to convert rank, machine_id -> uid for data partitioning
+            It also provides the total number of global processes
+        train_dir : str, optional
+            Path to the training data files. Required to instantiate the training set
+            The training set is partitioned according to the number of global processes and sizes
+        test_dir : str. optional
+            Path to the testing data files Required to instantiate the testing set
+        sizes : list(int), optional
+            A list of fractions specifying how much data to alot each process. Sum of fractions should be 1.0
+            By default, each process gets an equal amount.
+        test_batch_size : int, optional
+            Batch size during testing. Default value is 64
+
+        """
+        super().__init__(
+            rank,
+            machine_id,
+            mapping,
+            train_dir,
+            test_dir,
+            sizes,
+            test_batch_size,
+        )
+        if self.train_dir and Path(self.train_dir).exists():
+            vocab_path = os.path.join(self.train_dir, "../../vocab/reddit_vocab.pck")
+            (
+                self.vocab,
+                self.vocab_size,
+                self.unk_symbol,
+                self.pad_symbol,
+            ) = self._load_vocab(vocab_path)
+            logging.info("The reddit vocab has %i tokens.", len(self.vocab))
+        if self.__training__:
+            self.load_trainset()
+
+        if self.__testing__:
+            self.load_testset()
+
+        # TODO: Add Validation
+
+    def _load_vocab(self, VOCABULARY_PATH):
+        """
+        loads the training vocabulary
+        copied from https://gitlab.epfl.ch/sacs/efficient-federated-learning/-/blob/master/grad_guessing/data_utils.py
+        Parameters
+        ----------
+        VOCABULARY_PATH : str
+            Path to the pickled training vocabulary
+        Returns
+        -------
+            Tuple
+                vocabulary, size, unk symbol, pad symbol
+        """
+        vocab_file = pickle.load(open(VOCABULARY_PATH, "rb"))
+        vocab = collections.defaultdict(lambda: vocab_file["unk_symbol"])
+        vocab.update(vocab_file["vocab"])
+
+        return (
+            vocab,
+            vocab_file["size"],
+            vocab_file["unk_symbol"],
+            vocab_file["pad_symbol"],
+        )
+
+    def prepare_data(self, data):
+        """
+        copied from https://gitlab.epfl.ch/sacs/efficient-federated-learning/-/blob/master/grad_guessing/data_utils.py
+        Parameters
+        ----------
+        data
+
+        Returns
+        -------
+
+        """
+        data_x = data["x"]
+        data_y = data["y"]
+
+        # flatten lists
+        def flatten_lists(data_x_by_comment, data_y_by_comment):
+            data_x_by_seq, data_y_by_seq = [], []
+            for c, l in zip(data_x_by_comment, data_y_by_comment):
+                data_x_by_seq.extend(c)
+                data_y_by_seq.extend(l["target_tokens"])
+
+            return data_x_by_seq, data_y_by_seq
+
+        data_x, data_y = flatten_lists(data_x, data_y)
+
+        data_x_processed = self.process_x(data_x)
+        data_y_processed = self.process_y(data_y)
+
+        filtered_x, filtered_y = [], []
+        for i in range(len(data_x_processed)):
+            if np.sum(data_y_processed[i]) != 0:
+                filtered_x.append(data_x_processed[i])
+                filtered_y.append(data_y_processed[i])
+
+        return (filtered_x, filtered_y)
+
+    def _tokens_to_ids(self, raw_batch):
+        """
+        Turns an list of list of tokens that are of the same size (with padding <PAD>) if needed
+        into a list of list of word ids
+
+        [['<BOS>', 'do', 'you', 'have', 'proof', 'of', 'purchase', 'for', 'clay', 'play'], [ ...], ...]
+        turns into:
+        [[   5   45   13   24 1153   11 1378   17 6817  165], ...]
+
+        copied from https://gitlab.epfl.ch/sacs/efficient-federated-learning/-/blob/master/grad_guessing/data_utils.py
+        Parameters
+        ----------
+        raw_batch : list
+            list of fixed size token lists
+
+        Returns
+        -------
+            2D array with the rows representing fixed size token_ids pharases
+        """
+
+        def tokens_to_word_ids(tokens, word2id):
+            return [word2id[word] for word in tokens]
+
+        to_ret = [tokens_to_word_ids(seq, self.vocab) for seq in raw_batch]
+        return np.array(to_ret)
+
+    def process_x(self, raw_x_batch):
+        """
+        copied from https://gitlab.epfl.ch/sacs/efficient-federated-learning/-/blob/master/grad_guessing/data_utils.py
+        Parameters
+        ----------
+        raw_x_batch
+
+        Returns
+        -------
+
+        """
+        tokens = self._tokens_to_ids([s for s in raw_x_batch])
+        return tokens
+
+    def process_y(self, raw_y_batch):
+        """
+        copied from https://gitlab.epfl.ch/sacs/efficient-federated-learning/-/blob/master/grad_guessing/data_utils.py
+        Parameters
+        ----------
+        raw_y_batch
+
+        Returns
+        -------
+
+        """
+        tokens = self._tokens_to_ids([s for s in raw_y_batch])
+
+        def getNextWord(token_ids):
+            n = len(token_ids)
+            for i in range(n):
+                # gets the word at the end of the phrase that should be predicted
+                # that is the last token that is not a pad.
+                if token_ids[n - i - 1] != self.pad_symbol:
+                    return token_ids[n - i - 1]
+            return self.pad_symbol
+
+        return [getNextWord(t) for t in tokens]
+
+    def get_client_ids(self):
+        """
+        Function to retrieve all the clients of the current process
+
+        Returns
+        -------
+        list(str)
+            A list of strings of the client ids.
+
+        """
+        return self.clients
+
+    def get_client_id(self, i):
+        """
+        Function to get the client id of the ith sample
+
+        Parameters
+        ----------
+        i : int
+            Index of the sample
+
+        Returns
+        -------
+        str
+            Client ID
+
+        Raises
+        ------
+        IndexError
+            If the sample index is out of bounds
+
+        """
+        lb = 0
+        for j in range(len(self.clients)):
+            if i < lb + self.num_samples[j]:
+                return self.clients[j]
+
+        raise IndexError("i is out of bounds!")
+
+    def get_trainset(self, batch_size=1, shuffle=False):
+        """
+        Function to get the training set
+
+        Parameters
+        ----------
+        batch_size : int, optional
+            Batch size for learning
+
+        Returns
+        -------
+        torch.utils.Dataset(decentralizepy.datasets.Data)
+
+        Raises
+        ------
+        RuntimeError
+            If the training set was not initialized
+
+        """
+        if self.__training__:
+            return DataLoader(
+                Data(self.train_x, self.train_y), batch_size=batch_size, shuffle=shuffle
+            )
+        raise RuntimeError("Training set not initialized!")
+
+    def get_testset(self):
+        """
+        Function to get the test set
+
+        Returns
+        -------
+        torch.utils.Dataset(decentralizepy.datasets.Data)
+
+        Raises
+        ------
+        RuntimeError
+            If the test set was not initialized
+
+        """
+        if self.__testing__:
+            return DataLoader(
+                Data(self.test_x, self.test_y), batch_size=self.test_batch_size
+            )
+        raise RuntimeError("Test set not initialized!")
+
+    def test(self, model, loss):
+        """
+        Function to evaluate model on the test dataset.
+
+        Parameters
+        ----------
+        model : decentralizepy.models.Model
+            Model to evaluate
+        loss : torch.nn.loss
+            Loss function to evaluate
+
+        Returns
+        -------
+        tuple
+            (accuracy, loss_value)
+
+        """
+        testloader = self.get_testset()
+
+        logging.debug("Test Loader instantiated.")
+
+        correct_pred = [0 for _ in range(VOCAB_LEN)]
+        total_pred = [0 for _ in range(VOCAB_LEN)]
+
+        total_correct = 0
+        total_predicted = 0
+
+        with torch.no_grad():
+            loss_val = 0.0
+            count = 0
+            for elems, labels in testloader:
+                outputs = model(elems)
+                loss_val += loss(outputs, labels).item()
+                count += 1
+                _, predictions = torch.max(outputs, 1)
+                for label, prediction in zip(labels, predictions):
+                    logging.debug("{} predicted as {}".format(label, prediction))
+                    if label == prediction:
+                        correct_pred[label] += 1
+                        total_correct += 1
+                    total_pred[label] += 1
+                    total_predicted += 1
+
+        logging.debug("Predicted on the test set")
+
+        for key, value in enumerate(correct_pred):
+            if total_pred[key] != 0:
+                accuracy = 100 * float(value) / total_pred[key]
+            else:
+                accuracy = 100.0
+            logging.debug("Accuracy for class {} is: {:.1f} %".format(key, accuracy))
+
+        accuracy = 100 * float(total_correct) / total_predicted
+        loss_val = loss_val / count
+        logging.info("Overall accuracy is: {:.1f} %".format(accuracy))
+        return accuracy, loss_val
+
+
+class RNN(Model):
+    """
+    Class for a RNN Model for Reddit
+
+    """
+
+    def __init__(self):
+        """
+        Constructor. Instantiates the RNN Model to predict the next word of a sequence of word.
+        Based on the TensorFlow model found here: https://gitlab.epfl.ch/sacs/efficient-federated-learning/-/blob/master/grad_guessing/data_utils.py
+        """
+        super().__init__()
+
+        # input_length does not exist
+        self.embedding = nn.Embedding(VOCAB_LEN, EMBEDDING_DIM, padding_idx=0)
+        self.rnn_cells = nn.LSTM(EMBEDDING_DIM, 256, batch_first=True, num_layers=2)
+        # activation function is added in the forward pass
+        # Note: the tensorflow implementation did not use any activation function in this step?
+        # should I use one.
+        self.l1 = nn.Linear(256, 128)
+        # the tf model used sofmax activation here
+        self.l2 = nn.Linear(128, VOCAB_LEN)
+
+    def forward(self, x):
+        """
+        Forward pass of the model
+
+        Parameters
+        ----------
+        x : torch.tensor
+            The input torch tensor
+
+        Returns
+        -------
+        torch.tensor
+            The output torch tensor
+
+        """
+        x = self.embedding(x)
+        x = self.rnn_cells(x)
+        last_layer_output = x[1][0][1, ...]
+        x = F.relu(self.l1(last_layer_output))
+        x = self.l2(x)
+        # softmax is applied by the CrossEntropyLoss used during training
+        return x
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 74de4e1..e85159c 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -268,7 +268,7 @@ class Node:
         log_dir=".",
         log_level=logging.INFO,
         test_after=5,
-        train_evaluate_after = 1,
+        train_evaluate_after=1,
         reset_optimizer=1,
         *args
     ):
@@ -345,7 +345,6 @@ class Node:
                 )  # Reset optimizer state
                 self.trainer.reset_optimizer(self.optimizer)
 
-
             if iteration:
                 with open(
                     os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
@@ -376,7 +375,7 @@ class Node:
                 results_dict["grad_mean"][iteration + 1] = self.sharing.mean
             if hasattr(self.sharing, "std"):
                 results_dict["grad_std"][iteration + 1] = self.sharing.std
-            
+
             rounds_to_train_evaluate -= 1
 
             if rounds_to_train_evaluate == 0:
-- 
GitLab