diff --git a/eval/testingPeerSampler.py b/eval/testingPeerSampler.py
index 1e0b39a838254bd078f0ae709f68b967f97d3caa..d0b7c3b74a149a90177a504efbff86641492ecee 100644
--- a/eval/testingPeerSampler.py
+++ b/eval/testingPeerSampler.py
@@ -10,7 +10,6 @@ from decentralizepy.graphs.Graph import Graph
 from decentralizepy.mappings.Linear import Linear
 from decentralizepy.node.DPSGDWithPeerSampler import DPSGDWithPeerSampler
 from decentralizepy.node.PeerSampler import PeerSampler
-# from decentralizepy.node.PeerSamplerDynamic import PeerSamplerDynamic
 
 
 def read_ini(file_path):
diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py
index 16de517dc2c1c88dc951ed544bece2799fd426a9..c5c7e92f8ccd4fe213b91c9e805f1deef798a7f8 100644
--- a/src/decentralizepy/communication/TCP.py
+++ b/src/decentralizepy/communication/TCP.py
@@ -47,6 +47,7 @@ class TCP(Communication):
         total_procs,
         addresses_filepath,
         offset=9000,
+        recv_timeout=50,
     ):
         """
         Constructor
@@ -79,11 +80,14 @@ class TCP(Communication):
         self.machine_id = machine_id
         self.mapping = mapping
         self.offset = offset
+        self.recv_timeout = recv_timeout
         self.uid = mapping.get_uid(rank, machine_id)
         self.identity = str(self.uid).encode()
         self.context = zmq.Context()
         self.router = self.context.socket(zmq.ROUTER)
         self.router.setsockopt(zmq.IDENTITY, self.identity)
+        self.router.setsockopt(zmq.RCVTIMEO, self.recv_timeout)
+        self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
         self.router.bind(self.addr(rank, machine_id))
 
         self.total_data = 0
@@ -170,7 +174,7 @@ class TCP(Communication):
         id = str(neighbor).encode()
         return id in self.peer_sockets
 
-    def receive(self):
+    def receive(self, block=True):
         """
         Returns ONE message received.
 
@@ -185,10 +189,19 @@ class TCP(Communication):
             If received HELLO
 
         """
-
-        sender, recv = self.router.recv_multipart()
-        s, r = self.decrypt(sender, recv)
-        return s, r
+        while True:
+            try:
+                sender, recv = self.router.recv_multipart()
+                s, r = self.decrypt(sender, recv)
+                return s, r
+            except zmq.ZMQError as exc:
+                if exc.errno == zmq.EAGAIN:
+                    if not block:
+                        return None
+                    else:
+                        continue
+                else:
+                    raise
 
     def send(self, uid, data, encrypt=True):
         """
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 145b36297909314c663cc3402e9cd4c0205c1dc5..ede7c372b9b7fc3f8fc7f0fb8b4f907624d7e392 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -26,14 +26,19 @@ class Node:
         self.communication.init_connection(neighbor)
         self.communication.send(neighbor, {"HELLO": self.uid, "CHANNEL": "CONNECT"})
 
-    def receive_channel(self, channel):
+    def receive_channel(self, channel, block=True):
         if channel not in self.message_queue:
             self.message_queue[channel] = deque()
 
         if len(self.message_queue[channel]) > 0:
             return self.message_queue[channel].popleft()
         else:
-            sender, recv = self.communication.receive()
+            x = self.communication.receive(block=block)
+            if x == None:
+                assert not block
+                return None
+            sender, recv = x
+
             logging.info(
                 "Received some message from {} with CHANNEL: {}".format(
                     sender, recv["CHANNEL"]
@@ -44,7 +49,11 @@ class Node:
                 if recv["CHANNEL"] not in self.message_queue:
                     self.message_queue[recv["CHANNEL"]] = deque()
                 self.message_queue[recv["CHANNEL"]].append((sender, recv))
-                sender, recv = self.communication.receive()
+                x = self.communication.receive(block=block)
+                if x == None:
+                    assert not block
+                    return None
+                sender, recv = x
                 logging.info(
                     "Received some message from {} with CHANNEL: {}".format(
                         sender, recv["CHANNEL"]
diff --git a/src/decentralizepy/sharing/Choco.py b/src/decentralizepy/sharing/Choco.py
index 05de209d3cd80e1e0e8d8c2d2bed8f60c74b6d5a..ecea36a1a764f77c834bccde5a522bc838dae5ca 100644
--- a/src/decentralizepy/sharing/Choco.py
+++ b/src/decentralizepy/sharing/Choco.py
@@ -1,13 +1,13 @@
 import logging
+from collections import OrderedDict
 
 import torch
 
-from collections import OrderedDict
-
 from decentralizepy.sharing.Sharing import Sharing
 
+
 def zeros_like_state_dict(state_dict):
-    """ 
+    """
     Creates a new state dictionary such that it has same
     layers (name and size) as the input state dictionary, but all values
     are zero
@@ -22,11 +22,12 @@ def zeros_like_state_dict(state_dict):
         result_dict[tensor_name] = torch.zeros_like(tensor_values)
     return result_dict
 
+
 def get_dict_keys_and_check_matching(dict_1, dict_2):
-    """ 
+    """
     Checks if keys of the two dictionaries match and
     reutrns them if they do, otherwise raises ValueError
-    
+
     Parameters
     ----------
     dict_1: dict
@@ -40,11 +41,12 @@ def get_dict_keys_and_check_matching(dict_1, dict_2):
     """
     keys = dict_1.keys()
     if set(keys).difference(set(dict_2.keys())):
-        raise ValueError('Dictionaries must have matching keys')
+        raise ValueError("Dictionaries must have matching keys")
     return keys
 
+
 def subtract_state_dicts(_1, _2):
-    """ 
+    """
     Subtracts one state dictionary from another
 
     Parameters
@@ -67,12 +69,13 @@ def subtract_state_dicts(_1, _2):
         result_dict[key] = _1[key] - _2[key]
     return result_dict
 
-def self_add_state_dict(_1, _2, constant=1.):
+
+def self_add_state_dict(_1, _2, constant=1.0):
     """
     Scales one state dictionary by a constant and
     adds it directly to another minimizing copies
     created. Equivalent to operation `_1 += constant * _2`
-    
+
     Parameters
     ----------
     _1: dict[str, torch.Tensor]
@@ -93,11 +96,12 @@ def self_add_state_dict(_1, _2, constant=1.):
         # Size checking is done by torch during the subtraction
         _1[key] += constant * _2[key]
 
+
 def flatten_state_dict(state_dict):
     """
     Transforms state dictionary into a flat tensor
-    by flattening and concatenating tensors of the 
-    state dictionary. 
+    by flattening and concatenating tensors of the
+    state dictionary.
 
     Note: changes made to the result won't affect state dictionary
 
@@ -107,10 +111,8 @@ def flatten_state_dict(state_dict):
         A state dictionary to flatten
 
     """
-    return torch.cat([
-        tensor.flatten()\
-        for tensor in state_dict.values()
-    ], axis=0)
+    return torch.cat([tensor.flatten() for tensor in state_dict.values()], axis=0)
+
 
 def unflatten_state_dict(flat_tensor, reference_state_dict):
     """
@@ -138,11 +140,11 @@ def unflatten_state_dict(flat_tensor, reference_state_dict):
     start_index = 0
     for tensor_name, tensor in reference_state_dict.items():
         end_index = start_index + tensor.numel()
-        result[tensor_name] = flat_tensor[start_index:end_index].reshape(
-            tensor.shape)
+        result[tensor_name] = flat_tensor[start_index:end_index].reshape(tensor.shape)
         start_index = end_index
     return result
 
+
 def serialize_sparse_tensor(tensor):
     """
     Serializes sparse tensor by flattening it and
@@ -158,6 +160,7 @@ def serialize_sparse_tensor(tensor):
     values = flat[indices]
     return values, indices
 
+
 def deserialize_sparse_tensor(values, indices, shape):
     """
     Deserializes tensor from its non-zero values and indices
@@ -171,12 +174,12 @@ def deserialize_sparse_tensor(values, indices, shape):
         Respective indices of non-zero entries of flattened original tensor
     shape: torch.Size or tuple[*int]
         Shape of the original tensor
-        
+
     """
     result = torch.zeros(size=shape)
     if len(indices):
-      flat_result = result.flatten()
-      flat_result[indices] = values
+        flat_result = result.flatten()
+        flat_result[indices] = values
     return result
 
 
@@ -203,6 +206,7 @@ def topk_sparsification_tensor(tensor, alpha):
         tensor[tensor_abs < -cutoff_value] = 0
     return tensor
 
+
 def topk_sparsification(state_dict, alpha):
     """
     Performs topk sparsification of a state_dict
@@ -221,17 +225,18 @@ def topk_sparsification(state_dict, alpha):
     """
     flat_tensor = flatten_state_dict(state_dict)
     return unflatten_state_dict(
-        topk_sparsification_tensor(flat_tensor, alpha), 
-        state_dict)
+        topk_sparsification_tensor(flat_tensor, alpha), state_dict
+    )
+
 
 def serialize_sparse_state_dict(state_dict):
     with torch.no_grad():
-        concatted_tensors = torch.cat([
-            tensor.flatten()\
-            for tensor in state_dict.values()
-        ], axis=0)
+        concatted_tensors = torch.cat(
+            [tensor.flatten() for tensor in state_dict.values()], axis=0
+        )
         return serialize_sparse_tensor(concatted_tensors)
 
+
 def deserialize_sparse_state_dict(values, indices, reference_state_dict):
     with torch.no_grad():
         keys = []
@@ -310,16 +315,20 @@ class Choco(Sharing):
             model,
             dataset,
             log_dir,
-            compress=False,
-            compression_package=None,
-            compression_class=None
+            compress,
+            compression_package,
+            compression_class,
         )
         self.step_size = step_size
         self.alpha = alpha
-        logging.info("type(step_size): %s, value: %s", 
-            str(type(self.step_size)), str(self.step_size))
-        logging.info("type(alpha): %s, value: %s", 
-            str(type(self.alpha)), str(self.alpha))
+        logging.info(
+            "type(step_size): %s, value: %s",
+            str(type(self.step_size)),
+            str(self.step_size),
+        )
+        logging.info(
+            "type(alpha): %s, value: %s", str(type(self.alpha)), str(self.alpha)
+        )
         model_state_dict = model.state_dict()
         self.model_hat = zeros_like_state_dict(model_state_dict)
         self.s = zeros_like_state_dict(model_state_dict)
@@ -351,10 +360,10 @@ class Choco(Sharing):
 
         """
         with torch.no_grad():
-            self.my_q = self._compress(subtract_state_dicts(
-                self.model.state_dict(), self.model_hat
-            ))
-    
+            self.my_q = self._compress(
+                subtract_state_dicts(self.model.state_dict(), self.model_hat)
+            )
+
     def serialized_model(self):
         """
         Convert self q to a dictionary. Here we can choose how much to share
@@ -395,15 +404,16 @@ class Choco(Sharing):
             indices = torch.tensor(m["indices"], dtype=torch.long)
             values = torch.tensor(m["params"])
             return deserialize_sparse_state_dict(
-                values, indices, self.model.state_dict())
-            
+                values, indices, self.model.state_dict()
+            )
+
     def _averaging(self, peer_deques):
         """
         Averages the received model with the local model
 
         """
         with torch.no_grad():
-            self_add_state_dict(self.model_hat, self.my_q) # x_hat = q_self + x_hat
+            self_add_state_dict(self.model_hat, self.my_q)  # x_hat = q_self + x_hat
             weight_total = 0
             for i, n in enumerate(peer_deques):
                 data = peer_deques[n].popleft()
@@ -433,7 +443,8 @@ class Choco(Sharing):
             self_add_state_dict(
                 total,
                 subtract_state_dicts(self.s, self.model_hat),
-                constant=self.step_size) # x = x + gamma * (s - x_hat)
+                constant=self.step_size,
+            )  # x = x + gamma * (s - x_hat)
 
         self.model.load_state_dict(total)
         self._post_step()
@@ -444,5 +455,4 @@ class Choco(Sharing):
         Averages the received models of all working nodes
 
         """
-        raise NotImplementedError()        
-    
\ No newline at end of file
+        raise NotImplementedError()