From b2098c44875d94023469b439fd830bdfdb8ad6c7 Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Wed, 11 Jan 2023 14:22:03 +0100 Subject: [PATCH] Reformatting and ZMQ fix for packet drop and timeout option for receive --- eval/testingPeerSampler.py | 1 - src/decentralizepy/communication/TCP.py | 23 ++++-- src/decentralizepy/node/Node.py | 15 +++- src/decentralizepy/sharing/Choco.py | 96 ++++++++++++++----------- 4 files changed, 83 insertions(+), 52 deletions(-) diff --git a/eval/testingPeerSampler.py b/eval/testingPeerSampler.py index 1e0b39a..d0b7c3b 100644 --- a/eval/testingPeerSampler.py +++ b/eval/testingPeerSampler.py @@ -10,7 +10,6 @@ from decentralizepy.graphs.Graph import Graph from decentralizepy.mappings.Linear import Linear from decentralizepy.node.DPSGDWithPeerSampler import DPSGDWithPeerSampler from decentralizepy.node.PeerSampler import PeerSampler -# from decentralizepy.node.PeerSamplerDynamic import PeerSamplerDynamic def read_ini(file_path): diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py index 16de517..c5c7e92 100644 --- a/src/decentralizepy/communication/TCP.py +++ b/src/decentralizepy/communication/TCP.py @@ -47,6 +47,7 @@ class TCP(Communication): total_procs, addresses_filepath, offset=9000, + recv_timeout=50, ): """ Constructor @@ -79,11 +80,14 @@ class TCP(Communication): self.machine_id = machine_id self.mapping = mapping self.offset = offset + self.recv_timeout = recv_timeout self.uid = mapping.get_uid(rank, machine_id) self.identity = str(self.uid).encode() self.context = zmq.Context() self.router = self.context.socket(zmq.ROUTER) self.router.setsockopt(zmq.IDENTITY, self.identity) + self.router.setsockopt(zmq.RCVTIMEO, self.recv_timeout) + self.router.setsockopt(zmq.ROUTER_MANDATORY, 1) self.router.bind(self.addr(rank, machine_id)) self.total_data = 0 @@ -170,7 +174,7 @@ class TCP(Communication): id = str(neighbor).encode() return id in self.peer_sockets - def receive(self): + def receive(self, block=True): """ Returns ONE message received. @@ -185,10 +189,19 @@ class TCP(Communication): If received HELLO """ - - sender, recv = self.router.recv_multipart() - s, r = self.decrypt(sender, recv) - return s, r + while True: + try: + sender, recv = self.router.recv_multipart() + s, r = self.decrypt(sender, recv) + return s, r + except zmq.ZMQError as exc: + if exc.errno == zmq.EAGAIN: + if not block: + return None + else: + continue + else: + raise def send(self, uid, data, encrypt=True): """ diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index 145b362..ede7c37 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -26,14 +26,19 @@ class Node: self.communication.init_connection(neighbor) self.communication.send(neighbor, {"HELLO": self.uid, "CHANNEL": "CONNECT"}) - def receive_channel(self, channel): + def receive_channel(self, channel, block=True): if channel not in self.message_queue: self.message_queue[channel] = deque() if len(self.message_queue[channel]) > 0: return self.message_queue[channel].popleft() else: - sender, recv = self.communication.receive() + x = self.communication.receive(block=block) + if x == None: + assert not block + return None + sender, recv = x + logging.info( "Received some message from {} with CHANNEL: {}".format( sender, recv["CHANNEL"] @@ -44,7 +49,11 @@ class Node: if recv["CHANNEL"] not in self.message_queue: self.message_queue[recv["CHANNEL"]] = deque() self.message_queue[recv["CHANNEL"]].append((sender, recv)) - sender, recv = self.communication.receive() + x = self.communication.receive(block=block) + if x == None: + assert not block + return None + sender, recv = x logging.info( "Received some message from {} with CHANNEL: {}".format( sender, recv["CHANNEL"] diff --git a/src/decentralizepy/sharing/Choco.py b/src/decentralizepy/sharing/Choco.py index 05de209..ecea36a 100644 --- a/src/decentralizepy/sharing/Choco.py +++ b/src/decentralizepy/sharing/Choco.py @@ -1,13 +1,13 @@ import logging +from collections import OrderedDict import torch -from collections import OrderedDict - from decentralizepy.sharing.Sharing import Sharing + def zeros_like_state_dict(state_dict): - """ + """ Creates a new state dictionary such that it has same layers (name and size) as the input state dictionary, but all values are zero @@ -22,11 +22,12 @@ def zeros_like_state_dict(state_dict): result_dict[tensor_name] = torch.zeros_like(tensor_values) return result_dict + def get_dict_keys_and_check_matching(dict_1, dict_2): - """ + """ Checks if keys of the two dictionaries match and reutrns them if they do, otherwise raises ValueError - + Parameters ---------- dict_1: dict @@ -40,11 +41,12 @@ def get_dict_keys_and_check_matching(dict_1, dict_2): """ keys = dict_1.keys() if set(keys).difference(set(dict_2.keys())): - raise ValueError('Dictionaries must have matching keys') + raise ValueError("Dictionaries must have matching keys") return keys + def subtract_state_dicts(_1, _2): - """ + """ Subtracts one state dictionary from another Parameters @@ -67,12 +69,13 @@ def subtract_state_dicts(_1, _2): result_dict[key] = _1[key] - _2[key] return result_dict -def self_add_state_dict(_1, _2, constant=1.): + +def self_add_state_dict(_1, _2, constant=1.0): """ Scales one state dictionary by a constant and adds it directly to another minimizing copies created. Equivalent to operation `_1 += constant * _2` - + Parameters ---------- _1: dict[str, torch.Tensor] @@ -93,11 +96,12 @@ def self_add_state_dict(_1, _2, constant=1.): # Size checking is done by torch during the subtraction _1[key] += constant * _2[key] + def flatten_state_dict(state_dict): """ Transforms state dictionary into a flat tensor - by flattening and concatenating tensors of the - state dictionary. + by flattening and concatenating tensors of the + state dictionary. Note: changes made to the result won't affect state dictionary @@ -107,10 +111,8 @@ def flatten_state_dict(state_dict): A state dictionary to flatten """ - return torch.cat([ - tensor.flatten()\ - for tensor in state_dict.values() - ], axis=0) + return torch.cat([tensor.flatten() for tensor in state_dict.values()], axis=0) + def unflatten_state_dict(flat_tensor, reference_state_dict): """ @@ -138,11 +140,11 @@ def unflatten_state_dict(flat_tensor, reference_state_dict): start_index = 0 for tensor_name, tensor in reference_state_dict.items(): end_index = start_index + tensor.numel() - result[tensor_name] = flat_tensor[start_index:end_index].reshape( - tensor.shape) + result[tensor_name] = flat_tensor[start_index:end_index].reshape(tensor.shape) start_index = end_index return result + def serialize_sparse_tensor(tensor): """ Serializes sparse tensor by flattening it and @@ -158,6 +160,7 @@ def serialize_sparse_tensor(tensor): values = flat[indices] return values, indices + def deserialize_sparse_tensor(values, indices, shape): """ Deserializes tensor from its non-zero values and indices @@ -171,12 +174,12 @@ def deserialize_sparse_tensor(values, indices, shape): Respective indices of non-zero entries of flattened original tensor shape: torch.Size or tuple[*int] Shape of the original tensor - + """ result = torch.zeros(size=shape) if len(indices): - flat_result = result.flatten() - flat_result[indices] = values + flat_result = result.flatten() + flat_result[indices] = values return result @@ -203,6 +206,7 @@ def topk_sparsification_tensor(tensor, alpha): tensor[tensor_abs < -cutoff_value] = 0 return tensor + def topk_sparsification(state_dict, alpha): """ Performs topk sparsification of a state_dict @@ -221,17 +225,18 @@ def topk_sparsification(state_dict, alpha): """ flat_tensor = flatten_state_dict(state_dict) return unflatten_state_dict( - topk_sparsification_tensor(flat_tensor, alpha), - state_dict) + topk_sparsification_tensor(flat_tensor, alpha), state_dict + ) + def serialize_sparse_state_dict(state_dict): with torch.no_grad(): - concatted_tensors = torch.cat([ - tensor.flatten()\ - for tensor in state_dict.values() - ], axis=0) + concatted_tensors = torch.cat( + [tensor.flatten() for tensor in state_dict.values()], axis=0 + ) return serialize_sparse_tensor(concatted_tensors) + def deserialize_sparse_state_dict(values, indices, reference_state_dict): with torch.no_grad(): keys = [] @@ -310,16 +315,20 @@ class Choco(Sharing): model, dataset, log_dir, - compress=False, - compression_package=None, - compression_class=None + compress, + compression_package, + compression_class, ) self.step_size = step_size self.alpha = alpha - logging.info("type(step_size): %s, value: %s", - str(type(self.step_size)), str(self.step_size)) - logging.info("type(alpha): %s, value: %s", - str(type(self.alpha)), str(self.alpha)) + logging.info( + "type(step_size): %s, value: %s", + str(type(self.step_size)), + str(self.step_size), + ) + logging.info( + "type(alpha): %s, value: %s", str(type(self.alpha)), str(self.alpha) + ) model_state_dict = model.state_dict() self.model_hat = zeros_like_state_dict(model_state_dict) self.s = zeros_like_state_dict(model_state_dict) @@ -351,10 +360,10 @@ class Choco(Sharing): """ with torch.no_grad(): - self.my_q = self._compress(subtract_state_dicts( - self.model.state_dict(), self.model_hat - )) - + self.my_q = self._compress( + subtract_state_dicts(self.model.state_dict(), self.model_hat) + ) + def serialized_model(self): """ Convert self q to a dictionary. Here we can choose how much to share @@ -395,15 +404,16 @@ class Choco(Sharing): indices = torch.tensor(m["indices"], dtype=torch.long) values = torch.tensor(m["params"]) return deserialize_sparse_state_dict( - values, indices, self.model.state_dict()) - + values, indices, self.model.state_dict() + ) + def _averaging(self, peer_deques): """ Averages the received model with the local model """ with torch.no_grad(): - self_add_state_dict(self.model_hat, self.my_q) # x_hat = q_self + x_hat + self_add_state_dict(self.model_hat, self.my_q) # x_hat = q_self + x_hat weight_total = 0 for i, n in enumerate(peer_deques): data = peer_deques[n].popleft() @@ -433,7 +443,8 @@ class Choco(Sharing): self_add_state_dict( total, subtract_state_dicts(self.s, self.model_hat), - constant=self.step_size) # x = x + gamma * (s - x_hat) + constant=self.step_size, + ) # x = x + gamma * (s - x_hat) self.model.load_state_dict(total) self._post_step() @@ -444,5 +455,4 @@ class Choco(Sharing): Averages the received models of all working nodes """ - raise NotImplementedError() - \ No newline at end of file + raise NotImplementedError() -- GitLab