From f1d5035f3e83c1c11571216f6b304155ecbcb9a4 Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Fri, 22 Jul 2022 16:55:43 +0200
Subject: [PATCH] Add peer sampler, refactor everything

---
 eval/main.ipynb                               |  50 +-
 eval/plot.py                                  |   2 +-
 eval/plotting_from_csv.py                     |  20 +-
 eval/testing.py                               |   4 +-
 setup.cfg                                     |   2 +-
 src/decentralizepy/communication/TCP.py       | 151 +-----
 src/decentralizepy/compression/Compression.py |   3 -
 src/decentralizepy/compression/Elias.py       |   1 -
 src/decentralizepy/compression/EliasFpzip.py  |   1 -
 .../compression/EliasFpzipLossy.py            |   1 -
 src/decentralizepy/datasets/MovieLens.py      |   1 -
 src/decentralizepy/datasets/Shakespeare.py    |   1 -
 src/decentralizepy/graphs/Graph.py            |   3 +
 src/decentralizepy/mappings/Linear.py         |   7 +-
 src/decentralizepy/node/DPSGDNode.py          | 513 ++++++++++++++++++
 src/decentralizepy/node/Node.py               | 384 ++++---------
 src/decentralizepy/node/PeerSampler.py        | 221 ++++++++
 src/decentralizepy/sharing/FFT.py             |  27 +-
 src/decentralizepy/sharing/GrowingAlpha.py    |  11 +-
 src/decentralizepy/sharing/LowerBoundTopK.py  |  20 +-
 src/decentralizepy/sharing/ManualAdapt.py     |  13 +-
 src/decentralizepy/sharing/PartialModel.py    |  36 +-
 src/decentralizepy/sharing/RandomAlpha.py     |  10 +-
 .../sharing/RandomAlphaIncremental.py         |  11 +-
 .../sharing/RandomAlphaWavelet.py             |  10 +-
 .../sharing/RoundRobinPartial.py              |  19 +-
 src/decentralizepy/sharing/Sharing.py         | 127 ++---
 .../sharing/SharingCentrality.py              |   1 +
 src/decentralizepy/sharing/SubSampling.py     |  19 +-
 src/decentralizepy/sharing/Synchronous.py     |   1 +
 src/decentralizepy/sharing/TopKNormalized.py  |   6 +
 src/decentralizepy/sharing/TopKParams.py      |  19 +-
 src/decentralizepy/sharing/TopKPlusRandom.py  |   6 +
 src/decentralizepy/sharing/Wavelet.py         |  26 +-
 src/decentralizepy/train_test_evaluation.py   |   1 -
 35 files changed, 1183 insertions(+), 545 deletions(-)
 create mode 100644 src/decentralizepy/node/DPSGDNode.py
 create mode 100644 src/decentralizepy/node/PeerSampler.py

diff --git a/eval/main.ipynb b/eval/main.ipynb
index 80daae6..0873005 100644
--- a/eval/main.ipynb
+++ b/eval/main.ipynb
@@ -5709,6 +5709,41 @@
     "print(i)"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from torch import multiprocessing as mp\n",
+    "from decentralizepy.node.PeerSampler import PeerSampler\n",
+    "from decentralizepy.node.Node import Node\n",
+    "from decentralizepy.mappings.Linear import Linear\n",
+    "from decentralizepy.graphs.Regular import Regular\n",
+    "\n",
+    "l = Linear(1, 6)\n",
+    "g = Regular(6, 2)\n",
+    "processes = [mp.Process(target = PeerSampler, args=[-1, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[1, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[2, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[3, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[4, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[5, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[6, 0, l, g, None]),\n",
+    "            ]\n",
+    "\n",
+    "for p in processes:\n",
+    "    p.start()\n",
+    "\n",
+    "for p in processes:\n",
+    "    p.join()\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -5718,11 +5753,9 @@
   }
  ],
  "metadata": {
-  "interpreter": {
-   "hash": "996934296aa9d79be6c3d800a38d8fdb7dfa8fe7bb07df178f1397cde2cb8742"
-  },
   "kernelspec": {
-   "display_name": "Python 3.9.7 64-bit ('tff': conda)",
+   "display_name": "Python 3.9.7 ('decpy')",
+   "language": "python",
    "name": "python3"
   },
   "language_info": {
@@ -5735,9 +5768,14 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.11"
+   "version": "3.9.12"
   },
-  "orig_nbformat": 4
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "70be49349d3cda3718db277e01495433e35b5db6f514174958763e3b43682235"
+   }
+  }
  },
  "nbformat": 4,
  "nbformat_minor": 2
diff --git a/eval/plot.py b/eval/plot.py
index 601d8e8..8a9bcc1 100644
--- a/eval/plot.py
+++ b/eval/plot.py
@@ -38,7 +38,7 @@ def plot(means, stdevs, mins, maxs, title, label, loc):
 
 def plot_results(path, centralized, data_machine="machine0", data_node=0):
     folders = os.listdir(path)
-    if centralized.lower() in ['true', '1', 't', 'y', 'yes']:
+    if centralized.lower() in ["true", "1", "t", "y", "yes"]:
         centralized = True
         print("Centralized")
     else:
diff --git a/eval/plotting_from_csv.py b/eval/plotting_from_csv.py
index b8d4320..dbd1c1a 100644
--- a/eval/plotting_from_csv.py
+++ b/eval/plotting_from_csv.py
@@ -23,7 +23,7 @@ def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel):
 
 
 def plot_results(path, epochs, global_epochs="True"):
-    if global_epochs.lower() in ['true', '1', 't', 'y', 'yes']:
+    if global_epochs.lower() in ["true", "1", "t", "y", "yes"]:
         global_epochs = True
     else:
         global_epochs = False
@@ -52,10 +52,12 @@ def plot_results(path, epochs, global_epochs="True"):
         if global_epochs:
             rounds = results_csv["rounds"].iloc[0]
             print("Rounds: ", rounds)
-            results_cr = results_csv[results_csv.rounds <= epochs*rounds]
+            results_cr = results_csv[results_csv.rounds <= epochs * rounds]
             means = results_cr["mean"].to_numpy()
             stdevs = results_cr["std"].to_numpy()
-            x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
+            x_axis = (
+                results_cr["rounds"].to_numpy() / rounds
+            )  # list(np.arange(0, len(means), 1))
             x_label = "global epochs"
         else:
             results_cr = results_csv[results_csv.rounds <= epochs]
@@ -85,10 +87,12 @@ def plot_results(path, epochs, global_epochs="True"):
         if global_epochs:
             rounds = results_csv["rounds"].iloc[0]
             print("Rounds: ", rounds)
-            results_cr = results_csv[results_csv.rounds <= epochs*rounds]
+            results_cr = results_csv[results_csv.rounds <= epochs * rounds]
             means = results_cr["mean"].to_numpy()
             stdevs = results_cr["std"].to_numpy()
-            x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
+            x_axis = (
+                results_cr["rounds"].to_numpy() / rounds
+            )  # list(np.arange(0, len(means), 1))
             x_label = "global epochs"
         else:
             results_cr = results_csv[results_csv.rounds <= epochs]
@@ -120,10 +124,12 @@ def plot_results(path, epochs, global_epochs="True"):
         if global_epochs:
             rounds = results_csv["rounds"].iloc[0]
             print("Rounds: ", rounds)
-            results_cr = results_csv[results_csv.rounds <= epochs*rounds]
+            results_cr = results_csv[results_csv.rounds <= epochs * rounds]
             means = results_cr["mean"].to_numpy()
             stdevs = results_cr["std"].to_numpy()
-            x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
+            x_axis = (
+                results_cr["rounds"].to_numpy() / rounds
+            )  # list(np.arange(0, len(means), 1))
             x_label = "global epochs"
         else:
             results_cr = results_csv[results_csv.rounds <= epochs]
diff --git a/eval/testing.py b/eval/testing.py
index 9125828..9d67b28 100644
--- a/eval/testing.py
+++ b/eval/testing.py
@@ -8,7 +8,7 @@ from torch import multiprocessing as mp
 from decentralizepy import utils
 from decentralizepy.graphs.Graph import Graph
 from decentralizepy.mappings.Linear import Linear
-from decentralizepy.node.Node import Node
+from decentralizepy.node.DPSGDNode import DPSGDNode
 
 
 def read_ini(file_path):
@@ -51,7 +51,7 @@ if __name__ == "__main__":
     m_id = args.machine_id
 
     mp.spawn(
-        fn=Node,
+        fn=DPSGDNode,
         nprocs=procs_per_machine,
         args=[
             m_id,
diff --git a/setup.cfg b/setup.cfg
index 1b3f6c7..2df457a 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -54,7 +54,7 @@ python_requires = >=3.6
 where = src
 [options.extras_require]
 dev =
-        black
+        black>22.3.0
         coverage
         isort
         pytest
diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py
index 6d72e1c..8c8587b 100644
--- a/src/decentralizepy/communication/TCP.py
+++ b/src/decentralizepy/communication/TCP.py
@@ -1,4 +1,3 @@
-import importlib
 import json
 import logging
 import pickle
@@ -36,7 +35,8 @@ class TCP(Communication):
 
         """
         machine_addr = self.ip_addrs[str(machine_id)]
-        port = rank + self.offset
+        port = (2 * rank + 1) + self.offset
+        assert port > 0
         return "tcp://{}:{}".format(machine_addr, port)
 
     def __init__(
@@ -46,10 +46,7 @@ class TCP(Communication):
         mapping,
         total_procs,
         addresses_filepath,
-        compress=False,
-        offset=20000,
-        compression_package=None,
-        compression_class=None,
+        offset=9000,
     ):
         """
         Constructor
@@ -81,30 +78,19 @@ class TCP(Communication):
         self.rank = rank
         self.machine_id = machine_id
         self.mapping = mapping
-        self.offset = 20000 + offset
+        self.offset = offset
         self.uid = mapping.get_uid(rank, machine_id)
         self.identity = str(self.uid).encode()
         self.context = zmq.Context()
         self.router = self.context.socket(zmq.ROUTER)
         self.router.setsockopt(zmq.IDENTITY, self.identity)
         self.router.bind(self.addr(rank, machine_id))
-        self.sent_disconnections = False
-        self.compress = compress
-
-        if compression_package and compression_class:
-            compressor_module = importlib.import_module(compression_package)
-            compressor_class = getattr(compressor_module, compression_class)
-            self.compressor = compressor_class()
-            logging.info(f"Using the {compressor_class} to compress the data")
-        else:
-            assert not self.compress
 
         self.total_data = 0
         self.total_meta = 0
 
         self.peer_deque = deque()
         self.peer_sockets = dict()
-        self.barrier = set()
 
     def __del__(self):
         """
@@ -128,26 +114,12 @@ class TCP(Communication):
             Encoded data
 
         """
-        if self.compress:
-            if "indices" in data:
-                data["indices"] = self.compressor.compress(data["indices"])
-
-            assert "params" in data
-            data["params"] = self.compressor.compress_float(data["params"])
+        data_len = 0
+        if "params" in data:
             data_len = len(pickle.dumps(data["params"]))
-            output = pickle.dumps(data)
-
-            # the compressed meta data gets only a few bytes smaller after pickling
-            self.total_meta += len(output) - data_len
-            self.total_data += data_len
-        else:
-            output = pickle.dumps(data)
-            # centralized testing uses its own instance
-            if type(data) == dict:
-                assert "params" in data
-                data_len = len(pickle.dumps(data["params"]))
-                self.total_meta += len(output) - data_len
-                self.total_data += data_len
+        output = pickle.dumps(data)
+        self.total_meta += len(output) - data_len
+        self.total_data += data_len
         return output
 
     def decrypt(self, sender, data):
@@ -168,63 +140,25 @@ class TCP(Communication):
 
         """
         sender = int(sender.decode())
-        if self.compress:
-            data = pickle.loads(data)
-            if "indices" in data:
-                data["indices"] = self.compressor.decompress(data["indices"])
-            if "params" in data:
-                data["params"] = self.compressor.decompress_float(data["params"])
-        else:
-            data = pickle.loads(data)
+        data = pickle.loads(data)
         return sender, data
 
-    def connect_neighbors(self, neighbors):
+    def init_connection(self, neighbor):
         """
-        Connects all neighbors. Sends HELLO. Waits for HELLO.
-        Caches any data received while waiting for HELLOs.
+        Initiates a socket to a given node.
 
         Parameters
         ----------
-        neighbors : list(int)
-            List of neighbors
-
-        Raises
-        ------
-        RuntimeError
-            If received BYE while waiting for HELLO
+        neighbor : int
+            neighbor to connect to
 
         """
-        logging.info("Sending connection request to neighbors")
-        for uid in neighbors:
-            logging.debug("Connecting to my neighbour: {}".format(uid))
-            id = str(uid).encode()
-            req = self.context.socket(zmq.DEALER)
-            req.setsockopt(zmq.IDENTITY, self.identity)
-            req.connect(self.addr(*self.mapping.get_machine_and_rank(uid)))
-            self.peer_sockets[id] = req
-            req.send(HELLO)
-
-        num_neighbors = len(neighbors)
-        while len(self.barrier) < num_neighbors:
-            sender, recv = self.router.recv_multipart()
-
-            if recv == HELLO:
-                logging.debug("Received {} from {}".format(HELLO, sender))
-                self.barrier.add(sender)
-            elif recv == BYE:
-                logging.debug("Received {} from {}".format(BYE, sender))
-                raise RuntimeError(
-                    "A neighbour wants to disconnect before training started!"
-                )
-            else:
-                logging.debug(
-                    "Received message from {} @ connect_neighbors".format(sender)
-                )
-
-                self.peer_deque.append(self.decrypt(sender, recv))
-
-        logging.info("Connected to all neighbors")
-        self.initialized = True
+        logging.debug("Connecting to my neighbour: {}".format(neighbor))
+        id = str(neighbor).encode()
+        req = self.context.socket(zmq.DEALER)
+        req.setsockopt(zmq.IDENTITY, self.identity)
+        req.connect(self.addr(*self.mapping.get_machine_and_rank(neighbor)))
+        self.peer_sockets[id] = req
 
     def receive(self):
         """
@@ -241,25 +175,9 @@ class TCP(Communication):
             If received HELLO
 
         """
-        assert self.initialized == True
-        if len(self.peer_deque) != 0:
-            resp = self.peer_deque.popleft()
-            return resp
 
         sender, recv = self.router.recv_multipart()
-
-        if recv == HELLO:
-            logging.debug("Received {} from {}".format(HELLO, sender))
-            raise RuntimeError(
-                "A neighbour wants to connect when everyone is connected!"
-            )
-        elif recv == BYE:
-            logging.debug("Received {} from {}".format(BYE, sender))
-            self.barrier.remove(sender)
-            return self.receive()
-        else:
-            logging.debug("Received message from {}".format(sender))
-            return self.decrypt(sender, recv)
+        return self.decrypt(sender, recv)
 
     def send(self, uid, data, encrypt=True):
         """
@@ -273,7 +191,6 @@ class TCP(Communication):
             Message as a Python dictionary
 
         """
-        assert self.initialized == True
         if encrypt:
             to_send = self.encrypt(data)
         else:
@@ -283,28 +200,4 @@ class TCP(Communication):
         id = str(uid).encode()
         self.peer_sockets[id].send(to_send)
         logging.debug("{} sent the message to {}.".format(self.uid, uid))
-        logging.info("Sent this round: {}".format(data_size))
-
-    def disconnect_neighbors(self):
-        """
-        Disconnects all neighbors.
-
-        """
-        assert self.initialized == True
-        if not self.sent_disconnections:
-            logging.info("Disconnecting neighbors")
-            for sock in self.peer_sockets.values():
-                sock.send(BYE)
-            self.sent_disconnections = True
-            while len(self.barrier):
-                sender, recv = self.router.recv_multipart()
-                if recv == BYE:
-                    logging.debug("Received {} from {}".format(BYE, sender))
-                    self.barrier.remove(sender)
-                else:
-                    logging.critical(
-                        "Received unexpected {} from {}".format(recv, sender)
-                    )
-                    raise RuntimeError(
-                        "Received a message when expecting BYE from {}".format(sender)
-                    )
+        logging.info("Sent message size: {}".format(data_size))
diff --git a/src/decentralizepy/compression/Compression.py b/src/decentralizepy/compression/Compression.py
index 0924caf..b45e641 100644
--- a/src/decentralizepy/compression/Compression.py
+++ b/src/decentralizepy/compression/Compression.py
@@ -1,6 +1,3 @@
-import numpy as np
-
-
 class Compression:
     """
     Compression API
diff --git a/src/decentralizepy/compression/Elias.py b/src/decentralizepy/compression/Elias.py
index 235cf00..0d408d8 100644
--- a/src/decentralizepy/compression/Elias.py
+++ b/src/decentralizepy/compression/Elias.py
@@ -1,6 +1,5 @@
 # elias implementation: taken from this stack overflow post:
 # https://stackoverflow.com/questions/62843156/python-fast-compression-of-large-amount-of-numbers-with-elias-gamma
-import fpzip
 import numpy as np
 
 from decentralizepy.compression.Compression import Compression
diff --git a/src/decentralizepy/compression/EliasFpzip.py b/src/decentralizepy/compression/EliasFpzip.py
index 0c82560..0142dd9 100644
--- a/src/decentralizepy/compression/EliasFpzip.py
+++ b/src/decentralizepy/compression/EliasFpzip.py
@@ -1,7 +1,6 @@
 # elias implementation: taken from this stack overflow post:
 # https://stackoverflow.com/questions/62843156/python-fast-compression-of-large-amount-of-numbers-with-elias-gamma
 import fpzip
-import numpy as np
 
 from decentralizepy.compression.Elias import Elias
 
diff --git a/src/decentralizepy/compression/EliasFpzipLossy.py b/src/decentralizepy/compression/EliasFpzipLossy.py
index 617a78b..0b60307 100644
--- a/src/decentralizepy/compression/EliasFpzipLossy.py
+++ b/src/decentralizepy/compression/EliasFpzipLossy.py
@@ -1,7 +1,6 @@
 # elias implementation: taken from this stack overflow post:
 # https://stackoverflow.com/questions/62843156/python-fast-compression-of-large-amount-of-numbers-with-elias-gamma
 import fpzip
-import numpy as np
 
 from decentralizepy.compression.Elias import Elias
 
diff --git a/src/decentralizepy/datasets/MovieLens.py b/src/decentralizepy/datasets/MovieLens.py
index dafb4ce..95e55cc 100644
--- a/src/decentralizepy/datasets/MovieLens.py
+++ b/src/decentralizepy/datasets/MovieLens.py
@@ -3,7 +3,6 @@ import math
 import os
 import zipfile
 
-import numpy as np
 import pandas as pd
 import requests
 import torch
diff --git a/src/decentralizepy/datasets/Shakespeare.py b/src/decentralizepy/datasets/Shakespeare.py
index 0c02932..c7ede74 100644
--- a/src/decentralizepy/datasets/Shakespeare.py
+++ b/src/decentralizepy/datasets/Shakespeare.py
@@ -1,7 +1,6 @@
 import json
 import logging
 import os
-import re
 from collections import defaultdict
 
 import numpy as np
diff --git a/src/decentralizepy/graphs/Graph.py b/src/decentralizepy/graphs/Graph.py
index 689d2dc..dc66eef 100644
--- a/src/decentralizepy/graphs/Graph.py
+++ b/src/decentralizepy/graphs/Graph.py
@@ -22,6 +22,9 @@ class Graph:
             self.n_procs = n_procs
             self.adj_list = [set() for i in range(self.n_procs)]
 
+    def get_all_nodes(self):
+        return [i for i in range(self.n_procs)]
+
     def __insert_adj__(self, node, neighbours):
         """
         Inserts `neighbours` into the adjacency list of `node`
diff --git a/src/decentralizepy/mappings/Linear.py b/src/decentralizepy/mappings/Linear.py
index 9419fbd..f166dc9 100644
--- a/src/decentralizepy/mappings/Linear.py
+++ b/src/decentralizepy/mappings/Linear.py
@@ -8,7 +8,7 @@ class Linear(Mapping):
 
     """
 
-    def __init__(self, n_machines, procs_per_machine):
+    def __init__(self, n_machines, procs_per_machine, global_service_machine=0):
         """
         Constructor
 
@@ -23,6 +23,7 @@ class Linear(Mapping):
         super().__init__(n_machines * procs_per_machine)
         self.n_machines = n_machines
         self.procs_per_machine = procs_per_machine
+        self.global_service_machine = global_service_machine
 
     def get_uid(self, rank: int, machine_id: int):
         """
@@ -41,6 +42,8 @@ class Linear(Mapping):
             the unique identifier
 
         """
+        if rank < 0:
+            return rank
         return machine_id * self.procs_per_machine + rank
 
     def get_machine_and_rank(self, uid: int):
@@ -58,6 +61,8 @@ class Linear(Mapping):
             a tuple of rank and machine_id
 
         """
+        if uid < 0:
+            return uid, self.global_service_machine
         return (uid % self.procs_per_machine), (uid // self.procs_per_machine)
 
     def get_local_procs_count(self):
diff --git a/src/decentralizepy/node/DPSGDNode.py b/src/decentralizepy/node/DPSGDNode.py
new file mode 100644
index 0000000..964f103
--- /dev/null
+++ b/src/decentralizepy/node/DPSGDNode.py
@@ -0,0 +1,513 @@
+import importlib
+import json
+import logging
+import math
+import os
+from collections import deque
+
+import torch
+from matplotlib import pyplot as plt
+
+from decentralizepy import utils
+from decentralizepy.communication.TCP import TCP
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.graphs.Star import Star
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+from decentralizepy.train_test_evaluation import TrainTestHelper
+
+
+class DPSGDNode(Node):
+    """
+    This class defines the node for DPSGD
+
+    """
+
+    def save_plot(self, l, label, title, xlabel, filename):
+        """
+        Save Matplotlib plot. Clears previous plots.
+
+        Parameters
+        ----------
+        l : dict
+            dict of x -> y. `x` must be castable to int.
+        label : str
+            label of the plot. Used for legend.
+        title : str
+            Header
+        xlabel : str
+            x-axis label
+        filename : str
+            Name of file to save the plot as.
+
+        """
+        plt.clf()
+        y_axis = [l[key] for key in l.keys()]
+        x_axis = list(map(int, l.keys()))
+        plt.plot(x_axis, y_axis, label=label)
+        plt.xlabel(xlabel)
+        plt.title(title)
+        plt.savefig(filename)
+
+    def run(self):
+        """
+        Start the decentralized learning
+
+        """
+        self.testset = self.dataset.get_testset()
+        rounds_to_test = self.test_after
+        rounds_to_train_evaluate = self.train_evaluate_after
+        global_epoch = 1
+        change = 1
+        if self.uid == 0:
+            dataset = self.dataset
+            if self.centralized_train_eval:
+                dataset_params_copy = self.dataset_params.copy()
+                if "sizes" in dataset_params_copy:
+                    del dataset_params_copy["sizes"]
+                self.whole_dataset = self.dataset_class(
+                    self.rank,
+                    self.machine_id,
+                    self.mapping,
+                    sizes=[1.0],
+                    **dataset_params_copy
+                )
+                dataset = self.whole_dataset
+            if self.centralized_test_eval:
+                tthelper = TrainTestHelper(
+                    dataset,  # self.whole_dataset,
+                    # self.model_test, # todo: this only works if eval_train is set to false
+                    self.model,
+                    self.loss,
+                    self.weights_store_dir,
+                    self.mapping.get_n_procs(),
+                    self.trainer,
+                    self.testing_comm,
+                    self.star,
+                    self.threads_per_proc,
+                    eval_train=self.centralized_train_eval,
+                )
+
+        for iteration in range(self.iterations):
+            logging.info("Starting training iteration: %d", iteration)
+            self.trainer.train(self.dataset)
+            to_send = self.sharing.get_data_to_send()
+
+            for neighbor in self.my_neighbors:
+                self.communication.send(neighbor, to_send)
+
+            while not self.received_from_all():
+                sender, data = self.receive()
+
+                if "HELLO" in data:
+                    logging.critical(
+                        "Received unexpected {} from {}".format("HELLO", sender)
+                    )
+                    raise RuntimeError("A neighbour wants to connect during training!")
+                elif "BYE" in data:
+                    logging.debug("Received {} from {}".format("BYE", sender))
+                    self.barrier.remove(sender)
+                else:
+                    logging.debug("Received message from {}".format(sender))
+                    self.peer_deques[sender].append(data)
+
+            self.sharing._averaging(self.peer_deques)
+
+            if self.reset_optimizer:
+                self.optimizer = self.optimizer_class(
+                    self.model.parameters(), **self.optimizer_params
+                )  # Reset optimizer state
+                self.trainer.reset_optimizer(self.optimizer)
+
+            if iteration:
+                with open(
+                    os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
+                    "r",
+                ) as inf:
+                    results_dict = json.load(inf)
+            else:
+                results_dict = {
+                    "train_loss": {},
+                    "test_loss": {},
+                    "test_acc": {},
+                    "total_bytes": {},
+                    "total_meta": {},
+                    "total_data_per_n": {},
+                    "grad_mean": {},
+                    "grad_std": {},
+                }
+
+            results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
+
+            if hasattr(self.communication, "total_meta"):
+                results_dict["total_meta"][
+                    iteration + 1
+                ] = self.communication.total_meta
+            if hasattr(self.communication, "total_data"):
+                results_dict["total_data_per_n"][
+                    iteration + 1
+                ] = self.communication.total_data
+            if hasattr(self.sharing, "mean"):
+                results_dict["grad_mean"][iteration + 1] = self.sharing.mean
+            if hasattr(self.sharing, "std"):
+                results_dict["grad_std"][iteration + 1] = self.sharing.std
+
+            rounds_to_train_evaluate -= 1
+
+            if rounds_to_train_evaluate == 0 and not self.centralized_train_eval:
+                logging.info("Evaluating on train set.")
+                rounds_to_train_evaluate = self.train_evaluate_after * change
+                loss_after_sharing = self.trainer.eval_loss(self.dataset)
+                results_dict["train_loss"][iteration + 1] = loss_after_sharing
+                self.save_plot(
+                    results_dict["train_loss"],
+                    "train_loss",
+                    "Training Loss",
+                    "Communication Rounds",
+                    os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
+                )
+
+            rounds_to_test -= 1
+
+            if self.dataset.__testing__ and rounds_to_test == 0:
+                rounds_to_test = self.test_after * change
+                if self.centralized_test_eval:
+                    if self.uid == 0:
+                        ta, tl, trl = tthelper.train_test_evaluation(iteration)
+                        results_dict["test_acc"][iteration + 1] = ta
+                        results_dict["test_loss"][iteration + 1] = tl
+                        if trl is not None:
+                            results_dict["train_loss"][iteration + 1] = trl
+                    else:
+                        self.testing_comm.send(0, self.model.get_weights())
+                        sender, data = self.testing_comm.receive()
+                        assert sender == 0 and data == "finished"
+                else:
+                    logging.info("Evaluating on test set.")
+                    ta, tl = self.dataset.test(self.model, self.loss)
+                    results_dict["test_acc"][iteration + 1] = ta
+                    results_dict["test_loss"][iteration + 1] = tl
+
+                if global_epoch == 49:
+                    change *= 2
+
+                global_epoch += change
+
+            with open(
+                os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
+            ) as of:
+                json.dump(results_dict, of)
+        if self.model.shared_parameters_counter is not None:
+            logging.info("Saving the shared parameter counts")
+            with open(
+                os.path.join(
+                    self.log_dir, "{}_shared_parameters.json".format(self.rank)
+                ),
+                "w",
+            ) as of:
+                json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
+        self.disconnect_neighbors()
+        logging.info("Storing final weight")
+        self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
+        logging.info("All neighbors disconnected. Process complete!")
+
+    def cache_fields(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+        weights_store_dir,
+        test_after,
+        train_evaluate_after,
+        reset_optimizer,
+        centralized_train_eval,
+        centralized_test_eval,
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        centralized_train_eval : bool
+            If set the train set evaluation happens at the node with uid 0
+        centralized_test_eval : bool
+            If set the train set evaluation happens at the node with uid 0
+        """
+        self.rank = rank
+        self.machine_id = machine_id
+        self.graph = graph
+        self.mapping = mapping
+        self.uid = self.mapping.get_uid(rank, machine_id)
+        self.log_dir = log_dir
+        self.weights_store_dir = weights_store_dir
+        self.iterations = iterations
+        self.test_after = test_after
+        self.train_evaluate_after = train_evaluate_after
+        self.reset_optimizer = reset_optimizer
+        self.centralized_train_eval = centralized_train_eval
+        self.centralized_test_eval = centralized_test_eval
+        self.sent_disconnections = False
+
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
+
+        if centralized_test_eval or centralized_train_eval:
+            self.star = Star(self.mapping.get_n_procs())
+
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
+        comm_module = importlib.import_module(comm_configs["comm_package"])
+        comm_class = getattr(comm_module, comm_configs["comm_class"])
+        comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
+        if self.centralized_test_eval:
+            self.testing_comm = TCP(
+                self.rank,
+                self.machine_id,
+                self.mapping,
+                self.star.n_procs,
+                self.addresses_filepath,
+                offset=self.star.n_procs,
+            )
+            self.testing_comm.connect_neighbors(self.star.neighbors(self.uid))
+
+        self.communication = comm_class(
+            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
+        )
+
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        centralized_train_eval=False,
+        centralized_test_eval=True,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        centralized_train_eval : bool
+            If set the train set evaluation happens at the node with uid 0
+        centralized_test_eval : bool
+            If set the train set evaluation happens at the node with uid 0
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.init_log(log_dir, rank, log_level)
+
+        self.cache_fields(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+            centralized_train_eval,
+            centralized_test_eval,
+        )
+        self.init_dataset_model(config["DATASET"])
+        self.init_optimizer(config["OPTIMIZER_PARAMS"])
+        self.init_trainer(config["TRAIN_PARAMS"])
+        self.init_comm(config["COMMUNICATION"])
+
+        self.message_queue = deque()
+        self.barrier = set()
+        self.my_neighbors = self.graph.neighbors(self.uid)
+
+        self.init_sharing(config["SHARING"])
+        self.peer_deques = dict()
+        for n in self.my_neighbors:
+            self.peer_deques[n] = deque()
+
+        self.connect_neighbors()
+
+    def received_from_all(self):
+        """
+        Check if all neighbors have sent the current iteration
+
+        Returns
+        -------
+        bool
+            True if required data has been received, False otherwise
+
+        """
+        for k in self.my_neighbors:
+            if len(self.peer_deques[k]) == 0:
+                return False
+        return True
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        centralized_train_eval=0,
+        centralized_test_eval=1,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        centralized_train_eval : int
+            If set then the train set evaluation happens at the node with uid 0.
+            Note: If it is True then centralized_test_eval needs to be true as well!
+        centralized_test_eval : int
+            If set then the trainset evaluation happens at the node with uid 0
+        args : optional
+            Other arguments
+
+        """
+        centralized_train_eval = centralized_train_eval == 1
+        centralized_test_eval = centralized_test_eval == 1
+        # If centralized_train_eval is True then centralized_test_eval needs to be true as well!
+        assert not centralized_train_eval or centralized_test_eval
+
+        total_threads = os.cpu_count()
+        self.threads_per_proc = max(
+            math.floor(total_threads / mapping.procs_per_machine), 1
+        )
+        torch.set_num_threads(self.threads_per_proc)
+        torch.set_num_interop_threads(1)
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            log_level,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+            centralized_train_eval == 1,
+            centralized_test_eval == 1,
+            *args
+        )
+        logging.info(
+            "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
+        )
+        self.run()
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 91f34e5..67ee659 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -1,18 +1,14 @@
 import importlib
-import json
 import logging
 import math
 import os
+from collections import deque
 
 import torch
-from matplotlib import pyplot as plt
 
 from decentralizepy import utils
-from decentralizepy.communication.TCP import TCP
 from decentralizepy.graphs.Graph import Graph
-from decentralizepy.graphs.Star import Star
 from decentralizepy.mappings.Mapping import Mapping
-from decentralizepy.train_test_evaluation import TrainTestHelper
 
 
 class Node:
@@ -21,31 +17,96 @@ class Node:
 
     """
 
-    def save_plot(self, l, label, title, xlabel, filename):
+    def connect_neighbor(self, neighbor):
         """
-        Save Matplotlib plot. Clears previous plots.
+        Connects given neighbor. Sends HELLO.
 
-        Parameters
-        ----------
-        l : dict
-            dict of x -> y. `x` must be castable to int.
-        label : str
-            label of the plot. Used for legend.
-        title : str
-            Header
-        xlabel : str
-            x-axis label
-        filename : str
-            Name of file to save the plot as.
+        """
+        logging.info("Sending connection request to {}".format(neighbor))
+        self.communication.init_connection(neighbor)
+        self.communication.send(neighbor, {"HELLO": self.uid})
+
+    def wait_for_hello(self, neighbor):
+        """
+        Waits for HELLO.
+        Caches any data received while waiting for HELLOs.
+
+        Raises
+        ------
+        RuntimeError
+            If received BYE while waiting for HELLO
+
+        """
+
+        while neighbor not in self.barrier:
+            sender, recv = self.communication.receive()
+
+            if "HELLO" in recv:
+                logging.debug("Received {} from {}".format("HELLO", sender))
+                self.barrier.add(sender)
+            elif "BYE" in recv:
+                logging.debug("Received {} from {}".format("BYE", sender))
+                raise RuntimeError(
+                    "A neighbour wants to disconnect before training started!"
+                )
+            else:
+                logging.debug(
+                    "Received message from {} @ connect_neighbors".format(sender)
+                )
+                self.message_queue.append((sender, recv))
 
+    def receive(self):
+        if len(self.message_queue) > 0:
+            resp = self.message_queue.popleft()
+        else:
+            resp = self.communication.receive()
+        return resp
+
+    def connect_neighbors(self):
         """
-        plt.clf()
-        y_axis = [l[key] for key in l.keys()]
-        x_axis = list(map(int, l.keys()))
-        plt.plot(x_axis, y_axis, label=label)
-        plt.xlabel(xlabel)
-        plt.title(title)
-        plt.savefig(filename)
+        Connects all neighbors. Sends HELLO. Waits for HELLO.
+        Caches any data received while waiting for HELLOs.
+
+        Raises
+        ------
+        RuntimeError
+            If received BYE while waiting for HELLO
+
+        """
+        logging.info("Sending connection request to all neighbors")
+        for neighbor in self.my_neighbors:
+            self.connect_neighbor(neighbor)
+
+        for neighbor in self.my_neighbors:
+            self.wait_for_hello(neighbor)
+
+    def disconnect_neighbors(self):
+        """
+        Disconnects all neighbors.
+
+        Raises
+        ------
+        RuntimeError
+            If received another message while waiting for BYEs
+
+        """
+        if not self.sent_disconnections:
+            logging.info("Disconnecting neighbors")
+            for uid in self.my_neighbors:
+                self.communication.send(uid, {"BYE": self.uid})
+            self.sent_disconnections = True
+            while len(self.barrier):
+                sender, recv = self.receive()
+                if "BYE" in recv:
+                    logging.debug("Received {} from {}".format("BYE", sender))
+                    self.barrier.remove(sender)
+                else:
+                    logging.critical(
+                        "Received unexpected {} from {}".format(recv, sender)
+                    )
+                    raise RuntimeError(
+                        "Received a message when expecting BYE from {}".format(sender)
+                    )
 
     def init_log(self, log_dir, rank, log_level, force=True):
         """
@@ -68,7 +129,7 @@ class Node:
             filename=log_file,
             format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s",
             level=log_level,
-            force=True,
+            force=force,
         )
 
     def cache_fields(
@@ -79,12 +140,6 @@ class Node:
         graph,
         iterations,
         log_dir,
-        weights_store_dir,
-        test_after,
-        train_evaluate_after,
-        reset_optimizer,
-        centralized_train_eval,
-        centralized_test_eval,
     ):
         """
         Instantiate object field with arguments.
@@ -103,18 +158,6 @@ class Node:
             Number of iterations (communication steps) for which the model should be trained
         log_dir : str
             Logging directory
-        weights_store_dir : str
-            Directory in which to store model weights
-        test_after : int
-            Number of iterations after which the test loss and accuracy arecalculated
-        train_evaluate_after : int
-            Number of iterations after which the train loss is calculated
-        reset_optimizer : int
-            1 if optimizer should be reset every communication round, else 0
-        centralized_train_eval : bool
-            If set the train set evaluation happens at the node with uid 0
-        centralized_test_eval : bool
-            If set the train set evaluation happens at the node with uid 0
         """
         self.rank = rank
         self.machine_id = machine_id
@@ -122,19 +165,12 @@ class Node:
         self.mapping = mapping
         self.uid = self.mapping.get_uid(rank, machine_id)
         self.log_dir = log_dir
-        self.weights_store_dir = weights_store_dir
         self.iterations = iterations
-        self.test_after = test_after
-        self.train_evaluate_after = train_evaluate_after
-        self.reset_optimizer = reset_optimizer
-        self.centralized_train_eval = centralized_train_eval
-        self.centralized_test_eval = centralized_test_eval
+        self.sent_disconnections = False
 
-        logging.debug("Rank: %d", self.rank)
-        logging.debug("type(graph): %s", str(type(self.rank)))
-        logging.debug("type(mapping): %s", str(type(self.mapping)))
-
-        self.star = Star(self.mapping.get_n_procs())
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
 
     def init_dataset_model(self, dataset_configs):
         """
@@ -243,17 +279,6 @@ class Node:
         comm_class = getattr(comm_module, comm_configs["comm_class"])
         comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
         self.addresses_filepath = comm_params.get("addresses_filepath", None)
-        if self.centralized_test_eval:
-            self.testing_comm = TCP(
-                self.rank,
-                self.machine_id,
-                self.mapping,
-                self.star.n_procs,
-                self.addresses_filepath,
-                offset=self.star.n_procs,
-            )
-            self.testing_comm.connect_neighbors(self.star.neighbors(self.uid))
-
         self.communication = comm_class(
             self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
         )
@@ -294,13 +319,7 @@ class Node:
         config,
         iterations=1,
         log_dir=".",
-        weights_store_dir=".",
         log_level=logging.INFO,
-        test_after=5,
-        train_evaluate_after=1,
-        reset_optimizer=1,
-        centralized_train_eval=False,
-        centralized_test_eval=True,
         *args
     ):
         """
@@ -322,26 +341,16 @@ class Node:
             Number of iterations (communication steps) for which the model should be trained
         log_dir : str
             Logging directory
-        weights_store_dir : str
-            Directory in which to store model weights
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
-        test_after : int
-            Number of iterations after which the test loss and accuracy arecalculated
-        train_evaluate_after : int
-            Number of iterations after which the train loss is calculated
-        reset_optimizer : int
-            1 if optimizer should be reset every communication round, else 0
-        centralized_train_eval : bool
-            If set the train set evaluation happens at the node with uid 0
-        centralized_test_eval : bool
-            If set the train set evaluation happens at the node with uid 0
         args : optional
             Other arguments
 
         """
         logging.info("Started process.")
 
+        self.init_log(log_dir, rank, log_level)
+
         self.cache_fields(
             rank,
             machine_id,
@@ -349,18 +358,16 @@ class Node:
             graph,
             iterations,
             log_dir,
-            weights_store_dir,
-            test_after,
-            train_evaluate_after,
-            reset_optimizer,
-            centralized_train_eval,
-            centralized_test_eval,
         )
-        self.init_log(log_dir, rank, log_level)
         self.init_dataset_model(config["DATASET"])
         self.init_optimizer(config["OPTIMIZER_PARAMS"])
         self.init_trainer(config["TRAIN_PARAMS"])
         self.init_comm(config["COMMUNICATION"])
+
+        self.message_queue = deque()
+        self.barrier = set()
+        self.my_neighbors = self.graph.neighbors(self.uid)
+
         self.init_sharing(config["SHARING"])
 
     def run(self):
@@ -368,146 +375,7 @@ class Node:
         Start the decentralized learning
 
         """
-        self.testset = self.dataset.get_testset()
-        self.communication.connect_neighbors(self.graph.neighbors(self.uid))
-        rounds_to_test = self.test_after
-        rounds_to_train_evaluate = self.train_evaluate_after
-        global_epoch = 1
-        change = 1
-        if self.uid == 0:
-            dataset = self.dataset
-            if self.centralized_train_eval:
-                dataset_params_copy = self.dataset_params.copy()
-                if "sizes" in dataset_params_copy:
-                    del dataset_params_copy["sizes"]
-                self.whole_dataset = self.dataset_class(
-                    self.rank,
-                    self.machine_id,
-                    self.mapping,
-                    sizes=[1.0],
-                    **dataset_params_copy
-                )
-                dataset = self.whole_dataset
-            if self.centralized_test_eval:
-                tthelper = TrainTestHelper(
-                    dataset,  # self.whole_dataset,
-                    # self.model_test, # todo: this only works if eval_train is set to false
-                    self.model,
-                    self.loss,
-                    self.weights_store_dir,
-                    self.mapping.get_n_procs(),
-                    self.trainer,
-                    self.testing_comm,
-                    self.star,
-                    self.threads_per_proc,
-                    eval_train=self.centralized_train_eval,
-                )
-
-        for iteration in range(self.iterations):
-            logging.info("Starting training iteration: %d", iteration)
-            self.trainer.train(self.dataset)
-
-            self.sharing.step()
-
-            if self.reset_optimizer:
-                self.optimizer = self.optimizer_class(
-                    self.model.parameters(), **self.optimizer_params
-                )  # Reset optimizer state
-                self.trainer.reset_optimizer(self.optimizer)
-
-            if iteration:
-                with open(
-                    os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
-                    "r",
-                ) as inf:
-                    results_dict = json.load(inf)
-            else:
-                results_dict = {
-                    "train_loss": {},
-                    "test_loss": {},
-                    "test_acc": {},
-                    "total_bytes": {},
-                    "total_meta": {},
-                    "total_data_per_n": {},
-                    "grad_mean": {},
-                    "grad_std": {},
-                }
-
-            results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
-
-            if hasattr(self.communication, "total_meta"):
-                results_dict["total_meta"][
-                    iteration + 1
-                ] = self.communication.total_meta
-            if hasattr(self.communication, "total_data"):
-                results_dict["total_data_per_n"][
-                    iteration + 1
-                ] = self.communication.total_data
-            if hasattr(self.sharing, "mean"):
-                results_dict["grad_mean"][iteration + 1] = self.sharing.mean
-            if hasattr(self.sharing, "std"):
-                results_dict["grad_std"][iteration + 1] = self.sharing.std
-
-            rounds_to_train_evaluate -= 1
-
-            if rounds_to_train_evaluate == 0 and not self.centralized_train_eval:
-                logging.info("Evaluating on train set.")
-                rounds_to_train_evaluate = self.train_evaluate_after * change
-                loss_after_sharing = self.trainer.eval_loss(self.dataset)
-                results_dict["train_loss"][iteration + 1] = loss_after_sharing
-                self.save_plot(
-                    results_dict["train_loss"],
-                    "train_loss",
-                    "Training Loss",
-                    "Communication Rounds",
-                    os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
-                )
-
-            rounds_to_test -= 1
-
-            if self.dataset.__testing__ and rounds_to_test == 0:
-                rounds_to_test = self.test_after * change
-                # ta, tl = self.dataset.test(self.model, self.loss)
-                # self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
-                if self.centralized_test_eval:
-                    if self.uid == 0:
-                        ta, tl, trl = tthelper.train_test_evaluation(iteration)
-                        results_dict["test_acc"][iteration + 1] = ta
-                        results_dict["test_loss"][iteration + 1] = tl
-                        if trl is not None:
-                            results_dict["train_loss"][iteration + 1] = trl
-                    else:
-                        self.testing_comm.send(0, self.model.get_weights())
-                        sender, data = self.testing_comm.receive()
-                        assert sender == 0 and data == "finished"
-                else:
-                    logging.info("Evaluating on test set.")
-                    ta, tl = self.dataset.test(self.model, self.loss)
-                    results_dict["test_acc"][iteration + 1] = ta
-                    results_dict["test_loss"][iteration + 1] = tl
-
-                if global_epoch == 49:
-                    change *= 2
-
-                global_epoch += change
-
-            with open(
-                os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
-            ) as of:
-                json.dump(results_dict, of)
-        if self.model.shared_parameters_counter is not None:
-            logging.info("Saving the shared parameter counts")
-            with open(
-                os.path.join(
-                    self.log_dir, "{}_shared_parameters.json".format(self.rank)
-                ),
-                "w",
-            ) as of:
-                json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
-        self.communication.disconnect_neighbors()
-        logging.info("Storing final weight")
-        self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
-        logging.info("All neighbors disconnected. Process complete!")
+        raise NotImplementedError
 
     def __init__(
         self,
@@ -518,13 +386,7 @@ class Node:
         config,
         iterations=1,
         log_dir=".",
-        weights_store_dir=".",
         log_level=logging.INFO,
-        test_after=5,
-        train_evaluate_after=1,
-        reset_optimizer=1,
-        centralized_train_eval=0,
-        centralized_test_eval=1,
         *args
     ):
         """
@@ -559,28 +421,12 @@ class Node:
         log_dir : str
             Logging directory
         weights_store_dir : str
-            Directory in which to store model weights
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
-        test_after : int
-            Number of iterations after which the test loss and accuracy arecalculated
-        train_evaluate_after : int
-            Number of iterations after which the train loss is calculated
-        reset_optimizer : int
-            1 if optimizer should be reset every communication round, else 0
-        centralized_train_eval : int
-            If set then the train set evaluation happens at the node with uid 0.
-            Note: If it is True then centralized_test_eval needs to be true as well!
-        centralized_test_eval : int
-            If set then the trainset evaluation happens at the node with uid 0
         args : optional
             Other arguments
 
         """
-        centralized_train_eval = centralized_train_eval == 1
-        centralized_test_eval = centralized_test_eval == 1
-        # If centralized_train_eval is True then centralized_test_eval needs to be true as well!
-        assert not centralized_train_eval or centralized_test_eval
 
         total_threads = os.cpu_count()
         self.threads_per_proc = max(
@@ -588,25 +434,17 @@ class Node:
         )
         torch.set_num_threads(self.threads_per_proc)
         torch.set_num_interop_threads(1)
-        self.instantiate(
-            rank,
-            machine_id,
-            mapping,
-            graph,
-            config,
-            iterations,
-            log_dir,
-            weights_store_dir,
-            log_level,
-            test_after,
-            train_evaluate_after,
-            reset_optimizer,
-            centralized_train_eval == 1,
-            centralized_test_eval == 1,
-            *args
-        )
+        # self.instantiate(
+        #     rank,
+        #     machine_id,
+        #     mapping,
+        #     graph,
+        #     config,
+        #     iterations,
+        #     log_dir,
+        #     log_level,
+        #     *args
+        # )
         logging.info(
             "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
         )
-
-        self.run()
diff --git a/src/decentralizepy/node/PeerSampler.py b/src/decentralizepy/node/PeerSampler.py
new file mode 100644
index 0000000..8f8db6f
--- /dev/null
+++ b/src/decentralizepy/node/PeerSampler.py
@@ -0,0 +1,221 @@
+import importlib
+import logging
+from collections import deque
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+
+
+class PeerSampler(Node):
+    """
+    This class defines the peer sampling service
+
+    """
+
+    def cache_fields(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+
+        """
+        self.rank = rank
+        self.machine_id = machine_id
+        self.graph = graph
+        self.mapping = mapping
+        self.uid = self.mapping.get_uid(rank, machine_id)
+        self.log_dir = log_dir
+        self.iterations = iterations
+        self.sent_disconnections = False
+
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
+
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
+        comm_module = importlib.import_module(comm_configs["comm_package"])
+        comm_class = getattr(comm_module, comm_configs["comm_class"])
+        comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
+        self.communication = comm_class(
+            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
+        )
+
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        log_level=logging.INFO,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.init_log(log_dir, log_level)
+
+        self.cache_fields(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+        )
+
+        self.message_queue = deque()
+        self.barrier = set()
+
+        self.init_comm(config["COMMUNICATION"])
+        self.my_neighbors = self.graph.get_all_nodes()
+        self.connect_neighbours()
+
+    def run(self):
+        """
+        Start the peer-sampling service.
+
+        """
+        while len(self.barrier) > 0:
+            sender, data = self.receive()
+            if "BYE" in data:
+                logging.debug("Received {} from {}".format("BYE", sender))
+                self.barrier.remove(sender)
+            else:
+                logging.debug("Received {} from {}".format("Request", sender))
+                resp = {"neighbors": self.get_neighbors(sender)}
+                self.communication.send(sender, resp)
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        log_level=logging.INFO,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        args : optional
+            Other arguments
+
+        """
+        super().__init__(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            log_level,
+            *args
+        )
+
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            log_level,
+            *args
+        )
+
+        self.run()
diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
index 17650c1..60c60ec 100644
--- a/src/decentralizepy/sharing/FFT.py
+++ b/src/decentralizepy/sharing/FFT.py
@@ -1,8 +1,6 @@
 import json
 import logging
 import os
-from pathlib import Path
-from time import time
 
 import numpy as np
 import torch
@@ -53,6 +51,9 @@ class FFT(PartialModel):
         save_accumulated="",
         accumulation=True,
         accumulate_averaging_changes=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -111,6 +112,9 @@ class FFT(PartialModel):
             save_accumulated,
             change_transformer_fft,
             accumulate_averaging_changes,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.change_based_selection = change_based_selection
 
@@ -163,7 +167,7 @@ class FFT(PartialModel):
                 self.model.accumulated_changes = torch.zeros_like(
                     self.model.accumulated_changes
                 )
-            return m
+            return self.compress_data(m)
 
         with torch.no_grad():
             topk, indices = self.apply_fft()
@@ -199,7 +203,7 @@ class FFT(PartialModel):
             m["indices"] = indices.numpy().astype(np.int32)
             m["send_partial"] = True
 
-        return m
+        return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -216,6 +220,8 @@ class FFT(PartialModel):
             state_dict of received
 
         """
+        m = self.decompress_data(m)
+
         ret = dict()
         if "send_partial" not in m:
             params = m["params"]
@@ -237,7 +243,7 @@ class FFT(PartialModel):
             ret["send_partial"] = True
         return ret
 
-    def _averaging(self):
+    def _averaging(self, peer_deques):
         """
         Averages the received model with the local model
 
@@ -251,8 +257,11 @@ class FFT(PartialModel):
             pre_share_model = torch.cat(tensors_to_cat, dim=0)
             flat_fft = self.change_transformer(pre_share_model)
 
-            for i, n in enumerate(self.peer_deques):
-                degree, iteration, data = self.peer_deques[n].popleft()
+            for i, n in enumerate(peer_deques):
+                data = peer_deques[n].popleft()
+                degree, iteration = data["degree"], data["iteration"]
+                del data["degree"]
+                del data["iteration"]
                 logging.debug(
                     "Averaging model from neighbor {} of iteration {}".format(
                         n, iteration
@@ -268,7 +277,7 @@ class FFT(PartialModel):
                 else:
                     topkf = params
 
-                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                weight = 1 / (max(len(peer_deques), degree) + 1)  # Metro-Hastings
                 weight_total += weight
                 if total is None:
                     total = weight * topkf
@@ -289,3 +298,5 @@ class FFT(PartialModel):
                 start_index = end_index
 
         self.model.load_state_dict(std_dict)
+        self._post_step()
+        self.communication_round += 1
diff --git a/src/decentralizepy/sharing/GrowingAlpha.py b/src/decentralizepy/sharing/GrowingAlpha.py
index 7fe7bf5..a13a869 100644
--- a/src/decentralizepy/sharing/GrowingAlpha.py
+++ b/src/decentralizepy/sharing/GrowingAlpha.py
@@ -1,3 +1,4 @@
+# Deprecated
 import logging
 
 from decentralizepy.sharing.PartialModel import PartialModel
@@ -25,6 +26,9 @@ class GrowingAlpha(PartialModel):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -74,12 +78,15 @@ class GrowingAlpha(PartialModel):
             dict_ordered,
             save_shared,
             metadata_cap,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.init_alpha = init_alpha
         self.max_alpha = max_alpha
         self.k = k
 
-    def step(self):
+    def get_data_to_send(self):
         """
         Perform a sharing step. Implements D-PSGD with alpha increasing as a linear function.
 
@@ -93,4 +100,4 @@ class GrowingAlpha(PartialModel):
             self.communication_round += 1
             return
 
-        super().step()
+        return super().get_data_to_send()
diff --git a/src/decentralizepy/sharing/LowerBoundTopK.py b/src/decentralizepy/sharing/LowerBoundTopK.py
index 6ac5329..86b9c3b 100644
--- a/src/decentralizepy/sharing/LowerBoundTopK.py
+++ b/src/decentralizepy/sharing/LowerBoundTopK.py
@@ -24,6 +24,9 @@ class LowerBoundTopK(PartialModel):
         log_dir,
         lower_bound=0.1,
         metro_hastings=True,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
         **kwargs,
     ):
         """
@@ -81,7 +84,9 @@ class LowerBoundTopK(PartialModel):
             model,
             dataset,
             log_dir,
-            **kwargs,
+            compress,
+            compression_package,
+            compression_class**kwargs,
         )
         self.lower_bound = lower_bound
         self.metro_hastings = metro_hastings
@@ -154,6 +159,8 @@ class LowerBoundTopK(PartialModel):
         if "send_partial" not in m:
             return super().deserialized_model(m)
 
+        m = self.decompress_data(m)
+
         with torch.no_grad():
             state_dict = self.model.state_dict()
 
@@ -169,7 +176,7 @@ class LowerBoundTopK(PartialModel):
 
             return T, index_tensor
 
-    def _averaging(self):
+    def _averaging(self, peer_deques):
         """
         Averages the received model with the local model
 
@@ -187,8 +194,11 @@ class LowerBoundTopK(PartialModel):
                 weight_total = 0
                 weight_vector = torch.ones_like(self.init_model)
                 datas = []
-                for i, n in enumerate(self.peer_deques):
-                    degree, iteration, data = self.peer_deques[n].popleft()
+                for i, n in enumerate(peer_deques):
+                    data = peer_deques[n].popleft()
+                    degree, iteration = data["degree"], data["iteration"]
+                    del data["degree"]
+                    del data["iteration"]
                     logging.debug(
                         "Averaging model from neighbor {} of iteration {}".format(
                             n, iteration
@@ -215,3 +225,5 @@ class LowerBoundTopK(PartialModel):
 
             logging.info("new averaging")
             self.model.load_state_dict(total)
+            self._post_step()
+            self.communication_round += 1
diff --git a/src/decentralizepy/sharing/ManualAdapt.py b/src/decentralizepy/sharing/ManualAdapt.py
index dcb94cf..9a54eb7 100644
--- a/src/decentralizepy/sharing/ManualAdapt.py
+++ b/src/decentralizepy/sharing/ManualAdapt.py
@@ -1,3 +1,4 @@
+# Deprecated
 import logging
 
 from decentralizepy.sharing.PartialModel import PartialModel
@@ -24,6 +25,9 @@ class ManualAdapt(PartialModel):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -81,11 +85,14 @@ class ManualAdapt(PartialModel):
             dict_ordered,
             save_shared,
             metadata_cap,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.change_alpha = change_alpha[1:]
         self.change_rounds = change_rounds
 
-    def step(self):
+    def get_data_to_send(self):
         """
         Perform a sharing step. Implements D-PSGD with alpha manually given.
 
@@ -101,6 +108,6 @@ class ManualAdapt(PartialModel):
         if self.alpha == 0.0:
             logging.info("Not sending/receiving data (alpha=0.0)")
             self.communication_round += 1
-            return
+            return dict()
 
-        super().step()
+        return super().get_data_to_send()
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 3111e82..b302f5d 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -34,6 +34,9 @@ class PartialModel(Sharing):
         save_accumulated="",
         change_transformer=identity,
         accumulate_averaging_changes=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -76,7 +79,17 @@ class PartialModel(Sharing):
 
         """
         super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+            rank,
+            machine_id,
+            communication,
+            mapping,
+            graph,
+            model,
+            dataset,
+            log_dir,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha = alpha
         self.dict_ordered = dict_ordered
@@ -129,6 +142,23 @@ class PartialModel(Sharing):
             self.change_transformer(self.init_model).shape[0], dtype=torch.int32
         )
 
+    def compress_data(self, data):
+        result = dict(data)
+        if self.compress:
+            if "indices" in result:
+                result["indices"] = self.compressor.compress(result["indices"])
+            if "params" in result:
+                result["params"] = self.compressor.compress_float(result["params"])
+        return result
+
+    def decompress_data(self, data):
+        if self.compress:
+            if "indices" in data:
+                data["indices"] = self.compressor.decompress(data["indices"])
+            if "params" in data:
+                data["params"] = self.compressor.decompress_float(data["params"])
+        return data
+
     def extract_top_gradients(self):
         """
         Extract the indices and values of the topK gradients.
@@ -220,7 +250,7 @@ class PartialModel(Sharing):
 
             logging.info("Converted dictionary to pickle")
 
-            return m
+            return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -241,6 +271,8 @@ class PartialModel(Sharing):
             return super().deserialized_model(m)
 
         with torch.no_grad():
+            m = self.decompress_data(m)
+
             state_dict = self.model.state_dict()
 
             if not self.dict_ordered:
diff --git a/src/decentralizepy/sharing/RandomAlpha.py b/src/decentralizepy/sharing/RandomAlpha.py
index 1956c29..3bac634 100644
--- a/src/decentralizepy/sharing/RandomAlpha.py
+++ b/src/decentralizepy/sharing/RandomAlpha.py
@@ -28,6 +28,9 @@ class RandomAlpha(PartialModel):
         save_accumulated="",
         change_transformer=identity,
         accumulate_averaging_changes=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -75,14 +78,17 @@ class RandomAlpha(PartialModel):
             save_accumulated,
             change_transformer,
             accumulate_averaging_changes,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha_list = eval(alpha_list)
         random.seed(self.mapping.get_uid(self.rank, self.machine_id))
 
-    def step(self):
+    def get_data_to_send(self):
         """
         Perform a sharing step. Implements D-PSGD with alpha randomly chosen.
 
         """
         self.alpha = random.choice(self.alpha_list)
-        super().step()
+        return super().get_data_to_send()
diff --git a/src/decentralizepy/sharing/RandomAlphaIncremental.py b/src/decentralizepy/sharing/RandomAlphaIncremental.py
index c3b7c0d..96ead3d 100644
--- a/src/decentralizepy/sharing/RandomAlphaIncremental.py
+++ b/src/decentralizepy/sharing/RandomAlphaIncremental.py
@@ -1,3 +1,4 @@
+# Deprecated
 import random
 
 from decentralizepy.sharing.PartialModel import PartialModel
@@ -24,6 +25,9 @@ class RandomAlphaIncremental(PartialModel):
         metadata_cap=1.0,
         range_start=0.1,
         range_end=0.2,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -67,16 +71,19 @@ class RandomAlphaIncremental(PartialModel):
             dict_ordered,
             save_shared,
             metadata_cap,
+            compress,
+            compression_package,
+            compression_class,
         )
         random.seed(self.mapping.get_uid(self.rank, self.machine_id))
         self.range_start = range_start
         self.range_end = range_end
 
-    def step(self):
+    def get_data_to_send(self):
         """
         Perform a sharing step. Implements D-PSGD with alpha randomly chosen from an increasing range.
 
         """
         self.alpha = round(random.uniform(self.range_start, self.range_end), 2)
         self.range_end = min(1.0, self.range_end + round(random.uniform(0.0, 0.1), 2))
-        super().step()
+        return super().get_data_to_send()
diff --git a/src/decentralizepy/sharing/RandomAlphaWavelet.py b/src/decentralizepy/sharing/RandomAlphaWavelet.py
index 44ea336..de2a5e6 100644
--- a/src/decentralizepy/sharing/RandomAlphaWavelet.py
+++ b/src/decentralizepy/sharing/RandomAlphaWavelet.py
@@ -29,6 +29,9 @@ class RandomAlpha(Wavelet):
         save_accumulated="",
         accumulation=False,
         accumulate_averaging_changes=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -78,14 +81,17 @@ class RandomAlpha(Wavelet):
             save_accumulated,
             accumulation,
             accumulate_averaging_changes,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha_list = eval(alpha_list)
         random.seed(self.mapping.get_uid(self.rank, self.machine_id))
 
-    def step(self):
+    def get_data_to_send(self):
         """
         Perform a sharing step. Implements D-PSGD with alpha randomly chosen.
 
         """
         self.alpha = random.choice(self.alpha_list)
-        super().step()
+        return super().get_data_to_send()
diff --git a/src/decentralizepy/sharing/RoundRobinPartial.py b/src/decentralizepy/sharing/RoundRobinPartial.py
index c5288a5..fbe0179 100644
--- a/src/decentralizepy/sharing/RoundRobinPartial.py
+++ b/src/decentralizepy/sharing/RoundRobinPartial.py
@@ -25,6 +25,9 @@ class RoundRobinPartial(Sharing):
         dataset,
         log_dir,
         alpha=1.0,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -52,7 +55,17 @@ class RoundRobinPartial(Sharing):
 
         """
         super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+            rank,
+            machine_id,
+            communication,
+            mapping,
+            graph,
+            model,
+            dataset,
+            log_dir,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha = alpha
         random.seed(self.mapping.get_uid(rank, machine_id))
@@ -104,7 +117,7 @@ class RoundRobinPartial(Sharing):
 
             logging.info("Converted dictionary to json")
             self.total_data += len(self.communication.encrypt(m["params"]))
-            return m
+            return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -121,9 +134,9 @@ class RoundRobinPartial(Sharing):
             state_dict of received
 
         """
+        m = self.decompress_data(m)
         with torch.no_grad():
             state_dict = self.model.state_dict()
-
             shapes = []
             lens = []
             tensors_to_cat = []
diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py
index 0ad3927..7ef18cc 100644
--- a/src/decentralizepy/sharing/Sharing.py
+++ b/src/decentralizepy/sharing/Sharing.py
@@ -1,5 +1,5 @@
+import importlib
 import logging
-from collections import deque
 
 import torch
 
@@ -11,7 +11,18 @@ class Sharing:
     """
 
     def __init__(
-        self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -47,11 +58,6 @@ class Sharing:
         self.communication_round = 0
         self.log_dir = log_dir
 
-        self.peer_deques = dict()
-        self.my_neighbors = self.graph.neighbors(self.uid)
-        for n in self.my_neighbors:
-            self.peer_deques[n] = deque()
-
         self.shapes = []
         self.lens = []
         with torch.no_grad():
@@ -60,38 +66,28 @@ class Sharing:
                 t = v.flatten().numpy()
                 self.lens.append(t.shape[0])
 
-    def received_from_all(self):
-        """
-        Check if all neighbors have sent the current iteration
-
-        Returns
-        -------
-        bool
-            True if required data has been received, False otherwise
-
-        """
-        for _, i in self.peer_deques.items():
-            if len(i) == 0:
-                return False
-        return True
-
-    def get_neighbors(self, neighbors):
-        """
-        Choose which neighbors to share with
-
-        Parameters
-        ----------
-        neighbors : list(int)
-            List of all neighbors
-
-        Returns
-        -------
-        list(int)
-            Neighbors to share with
-
-        """
-        # modify neighbors here
-        return neighbors
+        self.compress = compress
+
+        if compression_package and compression_class:
+            compressor_module = importlib.import_module(compression_package)
+            compressor_class = getattr(compressor_module, compression_class)
+            self.compressor = compressor_class()
+            logging.info(f"Using the {compressor_class} to compress the data")
+        else:
+            assert not self.compress
+
+    def compress_data(self, data):
+        result = dict(data)
+        if self.compress:
+            if "params" in result:
+                result["params"] = self.compressor.compress_float(result["params"])
+        return result
+
+    def decompress_data(self, data):
+        if self.compress:
+            if "params" in data:
+                data["params"] = self.compressor.decompress_float(data["params"])
+        return data
 
     def serialized_model(self):
         """
@@ -111,7 +107,7 @@ class Sharing:
         flat = torch.cat(to_cat)
         data = dict()
         data["params"] = flat.numpy()
-        return data
+        return self.compress_data(data)
 
     def deserialized_model(self, m):
         """
@@ -129,11 +125,14 @@ class Sharing:
 
         """
         state_dict = dict()
+        m = self.decompress_data(m)
         T = m["params"]
         start_index = 0
         for i, key in enumerate(self.model.state_dict()):
             end_index = start_index + self.lens[i]
-            state_dict[key] = torch.from_numpy(T[start_index:end_index].reshape(self.shapes[i]))
+            state_dict[key] = torch.from_numpy(
+                T[start_index:end_index].reshape(self.shapes[i])
+            )
             start_index = end_index
 
         return state_dict
@@ -152,7 +151,7 @@ class Sharing:
         """
         pass
 
-    def _averaging(self):
+    def _averaging(self, peer_deques):
         """
         Averages the received model with the local model
 
@@ -160,15 +159,18 @@ class Sharing:
         with torch.no_grad():
             total = dict()
             weight_total = 0
-            for i, n in enumerate(self.peer_deques):
-                degree, iteration, data = self.peer_deques[n].popleft()
+            for i, n in enumerate(peer_deques):
+                data = peer_deques[n].popleft()
+                degree, iteration = data["degree"], data["iteration"]
+                del data["degree"]
+                del data["iteration"]
                 logging.debug(
                     "Averaging model from neighbor {} of iteration {}".format(
                         n, iteration
                     )
                 )
                 data = self.deserialized_model(data)
-                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                weight = 1 / (max(len(peer_deques), degree) + 1)  # Metro-Hastings
                 weight_total += weight
                 for key, value in data.items():
                     if key in total:
@@ -180,41 +182,14 @@ class Sharing:
                 total[key] += (1 - weight_total) * value  # Metro-Hastings
 
         self.model.load_state_dict(total)
+        self._post_step()
+        self.communication_round += 1
 
-    def step(self):
-        """
-        Perform a sharing step. Implements D-PSGD.
-
-        """
+    def get_data_to_send(self):
         self._pre_step()
         data = self.serialized_model()
         my_uid = self.mapping.get_uid(self.rank, self.machine_id)
         all_neighbors = self.graph.neighbors(my_uid)
-        iter_neighbors = self.get_neighbors(all_neighbors)
         data["degree"] = len(all_neighbors)
         data["iteration"] = self.communication_round
-        encrypted = self.communication.encrypt(data)
-        for neighbor in iter_neighbors:
-            self.communication.send(neighbor, encrypted, encrypt=False)
-
-        logging.info("Waiting for messages from neighbors")
-        while not self.received_from_all():
-            sender, data = self.communication.receive()
-            logging.debug("Received model from {}".format(sender))
-            degree = data["degree"]
-            iteration = data["iteration"]
-            del data["degree"]
-            del data["iteration"]
-            self.peer_deques[sender].append((degree, iteration, data))
-            logging.info(
-                "Deserialized received model from {} of iteration {}".format(
-                    sender, iteration
-                )
-            )
-
-        logging.info("Starting model averaging after receiving from all neighbors")
-        self._averaging()
-        logging.info("Model averaging complete")
-
-        self.communication_round += 1
-        self._post_step()
+        return data
diff --git a/src/decentralizepy/sharing/SharingCentrality.py b/src/decentralizepy/sharing/SharingCentrality.py
index f933a0e..05986ac 100644
--- a/src/decentralizepy/sharing/SharingCentrality.py
+++ b/src/decentralizepy/sharing/SharingCentrality.py
@@ -1,3 +1,4 @@
+# Deprecated
 import logging
 from collections import deque
 
diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py
index b51cb07..7201d33 100644
--- a/src/decentralizepy/sharing/SubSampling.py
+++ b/src/decentralizepy/sharing/SubSampling.py
@@ -31,6 +31,9 @@ class SubSampling(Sharing):
         metadata_cap=1.0,
         pickle=True,
         layerwise=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -66,7 +69,17 @@ class SubSampling(Sharing):
 
         """
         super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+            rank,
+            machine_id,
+            communication,
+            mapping,
+            graph,
+            model,
+            dataset,
+            log_dir,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha = alpha
         self.dict_ordered = dict_ordered
@@ -215,7 +228,7 @@ class SubSampling(Sharing):
             m["alpha"] = alpha
             m["params"] = subsample.numpy()
 
-            return m
+            return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -235,6 +248,8 @@ class SubSampling(Sharing):
         if self.alpha > self.metadata_cap:  # Share fully
             return super().deserialized_model(m)
 
+        m = self.decompress_data(m)
+
         with torch.no_grad():
             state_dict = self.model.state_dict()
 
diff --git a/src/decentralizepy/sharing/Synchronous.py b/src/decentralizepy/sharing/Synchronous.py
index 2c2d5e7..7fc1c35 100644
--- a/src/decentralizepy/sharing/Synchronous.py
+++ b/src/decentralizepy/sharing/Synchronous.py
@@ -1,3 +1,4 @@
+# Deprecated
 import logging
 from collections import deque
 
diff --git a/src/decentralizepy/sharing/TopKNormalized.py b/src/decentralizepy/sharing/TopKNormalized.py
index 15a3caf..b281294 100644
--- a/src/decentralizepy/sharing/TopKNormalized.py
+++ b/src/decentralizepy/sharing/TopKNormalized.py
@@ -31,6 +31,9 @@ class TopKNormalized(PartialModel):
         change_transformer=identity,
         accumulate_averaging_changes=False,
         epsilon=0.01,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -91,6 +94,9 @@ class TopKNormalized(PartialModel):
             save_accumulated,
             change_transformer,
             accumulate_averaging_changes,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.epsilon = epsilon
 
diff --git a/src/decentralizepy/sharing/TopKParams.py b/src/decentralizepy/sharing/TopKParams.py
index f188179..c2b0e3f 100644
--- a/src/decentralizepy/sharing/TopKParams.py
+++ b/src/decentralizepy/sharing/TopKParams.py
@@ -29,6 +29,9 @@ class TopKParams(Sharing):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -62,7 +65,17 @@ class TopKParams(Sharing):
 
         """
         super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+            rank,
+            machine_id,
+            communication,
+            mapping,
+            graph,
+            model,
+            dataset,
+            log_dir,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha = alpha
         self.dict_ordered = dict_ordered
@@ -171,7 +184,7 @@ class TopKParams(Sharing):
 
             logging.info("Converted dictionary to json")
 
-            return m
+            return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -191,6 +204,8 @@ class TopKParams(Sharing):
         if self.alpha > self.metadata_cap:  # Share fully
             return super().deserialized_model(m)
 
+        m = self.decompress_data(m)
+
         with torch.no_grad():
             state_dict = self.model.state_dict()
 
diff --git a/src/decentralizepy/sharing/TopKPlusRandom.py b/src/decentralizepy/sharing/TopKPlusRandom.py
index 728d5bf..8962933 100644
--- a/src/decentralizepy/sharing/TopKPlusRandom.py
+++ b/src/decentralizepy/sharing/TopKPlusRandom.py
@@ -26,6 +26,9 @@ class TopKPlusRandom(PartialModel):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -71,6 +74,9 @@ class TopKPlusRandom(PartialModel):
             dict_ordered,
             save_shared,
             metadata_cap,
+            compress,
+            compression_package,
+            compression_class,
         )
 
     def extract_top_gradients(self):
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index 91c97d0..ffc2558 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -1,8 +1,6 @@
 import json
 import logging
 import os
-from pathlib import Path
-from time import time
 
 import numpy as np
 import pywt
@@ -61,6 +59,9 @@ class Wavelet(PartialModel):
         save_accumulated="",
         accumulation=False,
         accumulate_averaging_changes=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -125,6 +126,9 @@ class Wavelet(PartialModel):
             save_accumulated,
             lambda x: change_transformer_wavelet(x, wavelet, level),
             accumulate_averaging_changes,
+            compress,
+            compression_package,
+            compression_class,
         )
 
         self.change_based_selection = change_based_selection
@@ -185,7 +189,7 @@ class Wavelet(PartialModel):
                 self.model.accumulated_changes = torch.zeros_like(
                     self.model.accumulated_changes
                 )
-            return m
+            return self.compress_data(m)
 
         with torch.no_grad():
             topk, indices = self.apply_wavelet()
@@ -223,7 +227,7 @@ class Wavelet(PartialModel):
 
             m["send_partial"] = True
 
-            return m
+            return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -240,6 +244,7 @@ class Wavelet(PartialModel):
             state_dict of received
 
         """
+        m = self.decompress_data(m)
         ret = dict()
         if "send_partial" not in m:
             params = m["params"]
@@ -260,7 +265,7 @@ class Wavelet(PartialModel):
             ret["send_partial"] = True
         return ret
 
-    def _averaging(self):
+    def _averaging(self, peer_deques):
         """
         Averages the received model with the local model
 
@@ -269,8 +274,11 @@ class Wavelet(PartialModel):
             total = None
             weight_total = 0
             wt_params = self.pre_share_model_transformed
-            for i, n in enumerate(self.peer_deques):
-                degree, iteration, data = self.peer_deques[n].popleft()
+            for i, n in enumerate(peer_deques):
+                data = peer_deques[n].popleft()
+                degree, iteration = data["degree"], data["iteration"]
+                del data["degree"]
+                del data["iteration"]
                 logging.debug(
                     "Averaging model from neighbor {} of iteration {}".format(
                         n, iteration
@@ -287,7 +295,7 @@ class Wavelet(PartialModel):
                 else:
                     topkwf = params.reshape(self.wt_shape)
 
-                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                weight = 1 / (max(len(peer_deques), degree) + 1)  # Metro-Hastings
                 weight_total += weight
                 if total is None:
                     total = weight * topkwf
@@ -314,3 +322,5 @@ class Wavelet(PartialModel):
                 start_index = end_index
 
         self.model.load_state_dict(std_dict)
+        self._post_step()
+        self.communication_round += 1
diff --git a/src/decentralizepy/train_test_evaluation.py b/src/decentralizepy/train_test_evaluation.py
index 319d308..95f407c 100644
--- a/src/decentralizepy/train_test_evaluation.py
+++ b/src/decentralizepy/train_test_evaluation.py
@@ -1,6 +1,5 @@
 import logging
 import os
-import pickle
 from pathlib import Path
 
 import numpy as np
-- 
GitLab