diff --git a/eval/testingKNN.py b/eval/testingKNN.py
new file mode 100644
index 0000000000000000000000000000000000000000..a706afd99eafda36aca051538a15fd6446fe85de
--- /dev/null
+++ b/eval/testingKNN.py
@@ -0,0 +1,79 @@
+import logging
+from pathlib import Path
+from shutil import copy
+
+from localconfig import LocalConfig
+from torch import multiprocessing as mp
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Linear import Linear
+from decentralizepy.node.KNN import KNN
+
+
+def read_ini(file_path):
+    config = LocalConfig(file_path)
+    for section in config:
+        print("Section: ", section)
+        for key, value in config.items(section):
+            print((key, value))
+    print(dict(config.items("DATASET")))
+    return config
+
+
+if __name__ == "__main__":
+    args = utils.get_args()
+
+    Path(args.log_dir).mkdir(parents=True, exist_ok=True)
+
+    log_level = {
+        "INFO": logging.INFO,
+        "DEBUG": logging.DEBUG,
+        "WARNING": logging.WARNING,
+        "ERROR": logging.ERROR,
+        "CRITICAL": logging.CRITICAL,
+    }
+
+    config = read_ini(args.config_file)
+    my_config = dict()
+    for section in config:
+        my_config[section] = dict(config.items(section))
+
+    copy(args.config_file, args.log_dir)
+    copy(args.graph_file, args.log_dir)
+    utils.write_args(args, args.log_dir)
+
+    g = Graph()
+    g.read_graph_from_file(args.graph_file, args.graph_type)
+    n_machines = args.machines
+    procs_per_machine = args.procs_per_machine
+    l = Linear(n_machines, procs_per_machine)
+    m_id = args.machine_id
+
+    processes = []
+    for r in range(procs_per_machine):
+        processes.append(
+            mp.Process(
+                target=KNN,
+                args=[
+                    r,
+                    m_id,
+                    l,
+                    g,
+                    my_config,
+                    args.iterations,
+                    args.log_dir,
+                    args.weights_store_dir,
+                    log_level[args.log_level],
+                    args.test_after,
+                    args.train_evaluate_after,
+                    args.reset_optimizer,
+                ],
+            )
+        )
+
+    for p in processes:
+        p.start()
+
+    for p in processes:
+        p.join()
diff --git a/src/decentralizepy/node/KNN.py b/src/decentralizepy/node/KNN.py
new file mode 100644
index 0000000000000000000000000000000000000000..6480aefd20e593de83fa7dac7f0f1972a80115ad
--- /dev/null
+++ b/src/decentralizepy/node/KNN.py
@@ -0,0 +1,407 @@
+import logging
+import math
+import os
+import queue
+from random import Random
+from threading import Lock, Thread
+
+import numpy as np
+import torch
+from numpy.linalg import norm
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.OverlayNode import OverlayNode
+
+
+class KNN(OverlayNode):
+    """
+    This class defines the node for KNN Learning Node
+
+    """
+
+    def similarityMetric(self, candidate):
+        logging.debug("A: {}".format(self.othersInfo[self.uid]))
+        logging.debug("B: {}".format(self.othersInfo[candidate]))
+        A = np.array(self.othersInfo[self.uid])
+        B = np.array(self.othersInfo[candidate])
+        return np.dot(A, B) / (norm(A) * norm(B))
+
+    def get_most_similar(self, candidates, to_keep=4):
+        if len(candidates) <= to_keep:
+            return candidates
+
+        cur_candidates = dict()
+        for i in candidates:
+            simil = round(self.similarityMetric(i), 3)
+            if simil not in cur_candidates:
+                cur_candidates[simil] = []
+            cur_candidates[simil].append(i)
+
+        similarity_scores = list(cur_candidates.keys())
+        similarity_scores.sort()
+
+        left_to_keep = to_keep
+        return_result = set()
+        for i in similarity_scores:
+            if left_to_keep >= len(cur_candidates[i]):
+                return_result.update(cur_candidates[i])
+                left_to_keep -= len(cur_candidates[i])
+            elif left_to_keep > 0:
+                return_result.update(
+                    list(self.rng.sample(cur_candidates[i], left_to_keep))
+                )
+                left_to_keep = 0
+                break
+            else:
+                break
+
+        return return_result
+
+    def create_message_to_send(
+        self,
+        channel="KNNConstr",
+        boolean_flags=[],
+        add_my_info=False,
+        add_neighbor_info=False,
+    ):
+        message = {"CHANNEL": channel, "KNNRound": self.knn_round}
+        for x in boolean_flags:
+            message[x] = True
+        if add_my_info:
+            message[self.uid] = self.othersInfo[self.uid]
+        if add_neighbor_info:
+            for neighbors in self.out_edges:
+                if neighbors in self.othersInfo:
+                    message[neighbors] = self.othersInfo[neighbors]
+        return message
+
+    def receive_KNN_message(self):
+        return self.receive_channel("KNNConstr", block=False)
+
+    def process_init_receive(self, message):
+        self.mutex.acquire()
+        if "RESPONSE" in message[1]:
+            self.num_initializations += 1
+        else:
+            self.communication.send(
+                message[0],
+                self.create_message_to_send(
+                    boolean_flags=["INIT", "RESPONSE"], add_my_info=True
+                ),
+            )
+        x = (
+            message[0],
+            utils.remove_keys(message[1], ["RESPONSE", "INIT", "CHANNEL", "KNNRound"]),
+        )
+        self.othersInfo.update(x[1])
+        self.mutex.release()
+
+    def remove_meta_from_message(self, message):
+        return (
+            message[0],
+            utils.remove_keys(message[1], ["RESPONSE", "INIT", "CHANNEL", "KNNRound"]),
+        )
+
+    def process_candidates_without_lock(self, current_candidates, message):
+        if not self.exit_receiver:
+            message = (
+                message[0],
+                utils.remove_keys(
+                    message[1], ["CHANNEL", "RESPONSE", "INIT", "KNNRound"]
+                ),
+            )
+            self.othersInfo.update(message[1])
+            new_candidates = set(message[1].keys())
+            current_candidates = current_candidates.union(new_candidates)
+            if self.uid in current_candidates:
+                current_candidates.remove(self.uid)
+            self.out_edges = self.get_most_similar(current_candidates)
+
+    def send_response(self, message, add_neighbor_info=False, process_candidates=False):
+        self.mutex.acquire()
+        logging.debug("Responding to {}".format(message[0]))
+        self.communication.send(
+            message[0],
+            self.create_message_to_send(
+                boolean_flags=["RESPONSE"],
+                add_my_info=True,
+                add_neighbor_info=add_neighbor_info,
+            ),
+        )
+        if process_candidates:
+            self.process_candidates_without_lock(set(self.out_edges), message)
+        self.mutex.release()
+
+    def receiver_thread(self):
+        knnBYEs = set()
+        self.num_initializations = 0
+        waiting_queue = queue.Queue()
+        while True:
+            if len(knnBYEs) == self.mapping.get_n_procs() - 1:
+                self.mutex.acquire()
+                if self.exit_receiver:
+                    self.mutex.release()
+                    logging.debug("Exiting thread")
+                    return
+                self.mutex.release()
+
+            if self.num_initializations < self.initial_neighbors:
+                x = self.receive_KNN_message()
+                if x == None:
+                    continue
+                elif "INIT" in x[1]:
+                    self.process_init_receive(x)
+                else:
+                    waiting_queue.put(x)
+            else:
+                logging.debug("Waiting for messages")
+                if waiting_queue.empty():
+                    x = self.receive_KNN_message()
+                    if x == None:
+                        continue
+                else:
+                    x = waiting_queue.get()
+
+                if "INIT" in x[1]:
+                    logging.debug("A past INIT Message received from {}".format(x[0]))
+                    self.process_init_receive(x)
+                elif "RESPONSE" in x[1]:
+                    logging.debug(
+                        "A response message received from {} from KNNRound {}".format(
+                            x[0], x[1]["KNNRound"]
+                        )
+                    )
+                    x = self.remove_meta_from_message(x)
+                    self.responseQueue.put(x)
+                elif "RANDOM_DISCOVERY" in x[1]:
+                    logging.debug(
+                        "A Random Discovery message received from {} from KNNRound {}".format(
+                            x[0], x[1]["KNNRound"]
+                        )
+                    )
+                    self.send_response(
+                        x, add_neighbor_info=False, process_candidates=False
+                    )
+                elif "KNNBYE" in x[1]:
+                    self.mutex.acquire()
+                    knnBYEs.add(x[0])
+                    logging.debug("{} KNN Byes received".format(knnBYEs))
+                    if self.uid in x[1]["CLOSE"]:
+                        self.in_edges.add(x[0])
+                    self.mutex.release()
+                else:
+                    logging.debug(
+                        "A KNN sharing message received from {} from KNNRound {}".format(
+                            x[0], x[1]["KNNRound"]
+                        )
+                    )
+                    self.send_response(
+                        x, add_neighbor_info=True, process_candidates=True
+                    )
+
+    def build_topology(self, rounds=30, random_nodes=4):
+        self.knn_round = 0
+        self.exit_receiver = False
+        t = Thread(target=self.receiver_thread)
+
+        t.start()
+
+        # Initializations : Send my dataset info to others
+
+        self.mutex.acquire()
+        initial_KNN_message = self.create_message_to_send(
+            boolean_flags=["INIT"], add_my_info=True
+        )
+        for x in self.out_edges:
+            self.communication.send(x, initial_KNN_message)
+        self.mutex.release()
+
+        for round in range(rounds):
+            self.knn_round = round
+            logging.info("Starting KNN Round {}".format(round))
+            self.mutex.acquire()
+            rand_neighbor = self.rng.choice(list(self.out_edges))
+            logging.debug("Random neighbor: {}".format(rand_neighbor))
+            self.communication.send(
+                rand_neighbor,
+                self.create_message_to_send(add_my_info=True, add_neighbor_info=True),
+            )
+            self.mutex.release()
+
+            logging.debug("Waiting for knn response from {}".format(rand_neighbor))
+
+            response = self.responseQueue.get(block=True)
+
+            logging.debug("Got response from random neighbor")
+
+            self.mutex.acquire()
+            random_candidates = set(
+                self.rng.sample(list(range(self.mapping.get_n_procs())), random_nodes)
+            )
+
+            req_responses = 0
+            for rc in random_candidates:
+                logging.debug("Current random discovery: {}".format(rc))
+                if rc not in self.othersInfo and rc != self.uid:
+                    logging.debug("Sending discovery request to {}".format(rc))
+                    self.communication.send(
+                        rc,
+                        self.create_message_to_send(boolean_flags=["RANDOM_DISCOVERY"]),
+                    )
+                    req_responses += 1
+            self.mutex.release()
+
+            while req_responses > 0:
+                logging.debug(
+                    "Waiting for {} random discovery responses.".format(req_responses)
+                )
+                req_responses -= 1
+                random_discovery_response = self.responseQueue.get(block=True)
+                logging.debug(
+                    "Received discovery response from {}".format(
+                        random_discovery_response[0]
+                    )
+                )
+                self.mutex.acquire()
+                self.othersInfo.update(random_discovery_response[1])
+                self.mutex.release()
+
+            self.mutex.acquire()
+            self.process_candidates_without_lock(
+                random_candidates.union(self.out_edges), response
+            )
+            self.mutex.release()
+
+            logging.info("Completed KNN Round {}".format(round))
+
+            logging.debug("OutNodes: {}".format(self.out_edges))
+
+        # Send out_edges and BYE to all
+
+        to_send = self.create_message_to_send(boolean_flags=["KNNBYE"])
+        logging.info("Sending KNNByes")
+        self.mutex.acquire()
+        self.exit_receiver = True
+        to_send["CLOSE"] = list(self.out_edges)  # Optimize to only send Yes/No
+        for receiver in range(self.mapping.get_n_procs()):
+            if receiver != self.uid:
+                self.communication.send(receiver, to_send)
+        self.mutex.release()
+        logging.info("KNNByes Sent")
+        t.join()
+        logging.info("Receiver Thread Returned")
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        initial_neighbors=4,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        args : optional
+            Other arguments
+
+        """
+
+        total_threads = os.cpu_count()
+        self.threads_per_proc = max(
+            math.floor(total_threads / mapping.procs_per_machine), 1
+        )
+        torch.set_num_threads(self.threads_per_proc)
+        torch.set_num_interop_threads(1)
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            log_level,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+            *args
+        )
+
+        self.rng = Random()
+        self.rng.seed(self.uid + 100)
+
+        self.initial_neighbors = initial_neighbors
+        self.in_edges = set()
+        self.out_edges = set(
+            self.rng.sample(
+                list(self.graph.neighbors(self.uid)), self.initial_neighbors
+            )
+        )
+        self.responseQueue = queue.Queue()
+        self.mutex = Lock()
+        self.othersInfo = {self.uid: list(self.dataset.get_label_distribution())}
+        # ld = self.dataset.get_label_distribution()
+        # ld_keys = sorted(list(ld.keys()))
+        # self.othersInfo = {self.uid: []}
+        # for key in range(max(ld_keys) + 1):
+        #     if key in ld:
+        #         self.othersInfo[self.uid].append(ld[key])
+        #     else:
+        #         self.othersInfo[self.uid].append(0)
+        logging.info("Label Distributions: {}".format(self.othersInfo))
+
+        logging.info(
+            "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
+        )
+        self.run()
diff --git a/src/decentralizepy/node/OverlayNode.py b/src/decentralizepy/node/OverlayNode.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a0768154cd33617d6bc5d462494636e2419a4b0
--- /dev/null
+++ b/src/decentralizepy/node/OverlayNode.py
@@ -0,0 +1,454 @@
+import importlib
+import json
+import logging
+import math
+import os
+from collections import deque
+
+import torch
+from matplotlib import pyplot as plt
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+
+
+class OverlayNode(Node):
+    """
+    This class defines the node on overlay graph
+
+    """
+
+    def save_plot(self, l, label, title, xlabel, filename):
+        """
+        Save Matplotlib plot. Clears previous plots.
+
+        Parameters
+        ----------
+        l : dict
+            dict of x -> y. `x` must be castable to int.
+        label : str
+            label of the plot. Used for legend.
+        title : str
+            Header
+        xlabel : str
+            x-axis label
+        filename : str
+            Name of file to save the plot as.
+
+        """
+        plt.clf()
+        y_axis = [l[key] for key in l.keys()]
+        x_axis = list(map(int, l.keys()))
+        plt.plot(x_axis, y_axis, label=label)
+        plt.xlabel(xlabel)
+        plt.title(title)
+        plt.savefig(filename)
+
+    def get_neighbors(self, node=None):
+        return self.my_neighbors
+
+    def receive_DPSGD(self):
+        return self.receive_channel("DPSGD")
+
+    def build_topology(self):
+        pass
+
+    def run(self):
+        """
+        Start the decentralized learning
+
+        """
+        self.testset = self.dataset.get_testset()
+        rounds_to_test = self.test_after
+        rounds_to_train_evaluate = self.train_evaluate_after
+        global_epoch = 1
+        change = 1
+
+        self.connect_neighbors()
+        logging.info("Connected to all neighbors")
+
+        self.build_topology()
+
+        logging.info("OutNodes: {}".format(self.out_edges))
+
+        logging.info("InNodes: {}".format(self.in_edges))
+
+        logging.info("Unifying edges")
+
+        self.out_edges = self.out_edges.union(self.in_edges)
+        self.my_neighbors = self.in_edges = set(self.out_edges)
+
+        logging.info("Total number of neighbor: {}".format(len(self.my_neighbors)))
+
+        for iteration in range(self.iterations):
+            logging.info("Starting training iteration: %d", iteration)
+            rounds_to_train_evaluate -= 1
+            rounds_to_test -= 1
+
+            self.iteration = iteration
+            self.trainer.train(self.dataset)
+
+            to_send = self.sharing.get_data_to_send()
+            to_send["CHANNEL"] = "DPSGD"
+            to_send["degree"] = len(self.in_edges)
+
+            assert len(self.out_edges) != 0
+            assert len(self.in_edges) != 0
+
+            for neighbor in self.out_edges:
+                self.communication.send(neighbor, to_send)
+
+            while not self.received_from_all():
+                sender, data = self.receive_DPSGD()
+                logging.info(
+                    "Received Model from {} of iteration {}".format(
+                        sender, data["iteration"]
+                    )
+                )
+                if sender not in self.peer_deques:
+                    self.peer_deques[sender] = deque()
+                self.peer_deques[sender].append(data)
+
+            averaging_deque = dict()
+            for neighbor in self.in_edges:
+                averaging_deque[neighbor] = self.peer_deques[neighbor]
+
+            self.sharing._averaging(averaging_deque)
+
+            if self.reset_optimizer:
+                self.optimizer = self.optimizer_class(
+                    self.model.parameters(), **self.optimizer_params
+                )  # Reset optimizer state
+                self.trainer.reset_optimizer(self.optimizer)
+
+            if iteration:
+                with open(
+                    os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
+                    "r",
+                ) as inf:
+                    results_dict = json.load(inf)
+            else:
+                results_dict = {
+                    "train_loss": {},
+                    "test_loss": {},
+                    "test_acc": {},
+                    "total_bytes": {},
+                    "total_meta": {},
+                    "total_data_per_n": {},
+                }
+
+            results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
+
+            if hasattr(self.communication, "total_meta"):
+                results_dict["total_meta"][
+                    iteration + 1
+                ] = self.communication.total_meta
+            if hasattr(self.communication, "total_data"):
+                results_dict["total_data_per_n"][
+                    iteration + 1
+                ] = self.communication.total_data
+
+            if rounds_to_train_evaluate == 0:
+                logging.info("Evaluating on train set.")
+                rounds_to_train_evaluate = self.train_evaluate_after * change
+                loss_after_sharing = self.trainer.eval_loss(self.dataset)
+                results_dict["train_loss"][iteration + 1] = loss_after_sharing
+                self.save_plot(
+                    results_dict["train_loss"],
+                    "train_loss",
+                    "Training Loss",
+                    "Communication Rounds",
+                    os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
+                )
+
+            if self.dataset.__testing__ and rounds_to_test == 0:
+                rounds_to_test = self.test_after * change
+                logging.info("Evaluating on test set.")
+                ta, tl = self.dataset.test(self.model, self.loss)
+                results_dict["test_acc"][iteration + 1] = ta
+                results_dict["test_loss"][iteration + 1] = tl
+
+                if global_epoch == 49:
+                    change *= 2
+
+                global_epoch += change
+
+            with open(
+                os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
+            ) as of:
+                json.dump(results_dict, of)
+        # if self.model.shared_parameters_counter is not None:
+        #     logging.info("Saving the shared parameter counts")
+        #     with open(
+        #         os.path.join(
+        #             self.log_dir, "{}_shared_parameters.json".format(self.rank)
+        #         ),
+        #         "w",
+        #     ) as of:
+        #         json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
+        self.disconnect_neighbors()
+        # logging.info("Storing final weight")
+        # self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
+        logging.info("All neighbors disconnected. Process complete!")
+
+    def cache_fields(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+        weights_store_dir,
+        test_after,
+        train_evaluate_after,
+        reset_optimizer,
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        """
+        self.rank = rank
+        self.machine_id = machine_id
+        self.graph = graph
+        self.mapping = mapping
+        self.uid = self.mapping.get_uid(rank, machine_id)
+        self.log_dir = log_dir
+        self.weights_store_dir = weights_store_dir
+        self.iterations = iterations
+        self.test_after = test_after
+        self.train_evaluate_after = train_evaluate_after
+        self.reset_optimizer = reset_optimizer
+        self.sent_disconnections = False
+
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
+
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
+        comm_module = importlib.import_module(comm_configs["comm_package"])
+        comm_class = getattr(comm_module, comm_configs["comm_class"])
+        comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
+        self.communication = comm_class(
+            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
+        )
+
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.init_log(log_dir, rank, log_level)
+
+        self.cache_fields(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+        )
+        self.init_dataset_model(config["DATASET"])
+        self.init_optimizer(config["OPTIMIZER_PARAMS"])
+        self.init_trainer(config["TRAIN_PARAMS"])
+        self.init_comm(config["COMMUNICATION"])
+
+        self.message_queue = dict()
+
+        self.barrier = set()
+        self.my_neighbors = self.graph.neighbors(self.uid)
+
+        self.init_sharing(config["SHARING"])
+        self.peer_deques = dict()
+        self.connect_neighbors()
+
+    def received_from_all(self):
+        """
+        Check if all neighbors have sent the current iteration
+
+        Returns
+        -------
+        bool
+            True if required data has been received, False otherwise
+
+        """
+        for k in self.in_edges:
+            if (k not in self.peer_deques) or len(self.peer_deques[k]) == 0:
+                return False
+        return True
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        args : optional
+            Other arguments
+
+        """
+
+        total_threads = os.cpu_count()
+        self.threads_per_proc = max(
+            math.floor(total_threads / mapping.procs_per_machine), 1
+        )
+        torch.set_num_threads(self.threads_per_proc)
+        torch.set_num_interop_threads(1)
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            log_level,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+            *args
+        )
+
+        self.in_edges = set()
+        self.out_edges = set()
+
+        logging.info(
+            "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
+        )
+        self.run()