diff --git a/src/decentralizepy/node/DPSGDNodeFederated.py b/src/decentralizepy/node/DPSGDNodeFederated.py index d77e35d409d86fa66ffc210e6e0ff0bd62f9f865..7d8f5b1efdd4977413610cf114f27fa06cd83bab 100644 --- a/src/decentralizepy/node/DPSGDNodeFederated.py +++ b/src/decentralizepy/node/DPSGDNodeFederated.py @@ -35,10 +35,6 @@ class DPSGDNodeFederated(Node): del data["iteration"] del data["CHANNEL"] - if iteration == 0: - del data["degree"] - data = self.sharing.deserialized_model(data) - self.model.load_state_dict(data) self.sharing._post_step() self.sharing.communication_round += 1 diff --git a/src/decentralizepy/node/FederatedParameterServer.py b/src/decentralizepy/node/FederatedParameterServer.py index 312fa6d79187e7cec8b1433d46fb84fc7d296554..ee919b73abf34934029b9cfdac874ee183def6ab 100644 --- a/src/decentralizepy/node/FederatedParameterServer.py +++ b/src/decentralizepy/node/FederatedParameterServer.py @@ -285,9 +285,7 @@ class FederatedParameterServer(Node): self.current_workers = self.get_working_nodes() # Params to send to workers - # if this is the first iteration, use the init parameters, else use averaged params from last iteration - if iteration == 0: - to_send = self.sharing.get_data_to_send() + to_send = self.model.state_dict() to_send["CHANNEL"] = "WORKER_REQUEST" to_send["iteration"] = iteration @@ -309,28 +307,11 @@ class FederatedParameterServer(Node): # Average received updates averaging_deque = dict() - total = dict() for worker in self.current_workers: averaging_deque[worker] = self.peer_deques[worker] - for i, n in enumerate(averaging_deque): - data = averaging_deque[n].popleft() - del data["degree"] - del data["iteration"] - del data["CHANNEL"] - data = self.sharing.deserialized_model(data) - for key, value in data.items(): - if key in total: - total[key] += value - else: - total[key] = value - - for key, value in total.items(): - total[key] = total[key] / len(averaging_deque) - - self.model.load_state_dict(total) - - to_send = total + self.sharing._pre_step() + self.sharing._averaging_server(averaging_deque) if iteration: with open( diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index a41de33203d72cafd4aecca10a8b14b347625d72..49419214a36260574d8c617d187273dc0501fd25 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -80,13 +80,15 @@ class Sharing: result = dict(data) if self.compress: if "params" in result: - result["params"] = self.compressor.compress_float(result["params"]) + result["params"] = self.compressor.compress_float( + result["params"]) return result def decompress_data(self, data): if self.compress: if "params" in data: - data["params"] = self.compressor.decompress_float(data["params"]) + data["params"] = self.compressor.decompress_float( + data["params"]) return data def serialized_model(self): @@ -171,7 +173,8 @@ class Sharing: ) ) data = self.deserialized_model(data) - weight = 1 / (max(len(peer_deques), degree) + 1) # Metro-Hastings + # Metro-Hastings + weight = 1 / (max(len(peer_deques), degree) + 1) weight_total += weight for key, value in data.items(): if key in total: @@ -194,3 +197,34 @@ class Sharing: data["degree"] = len(all_neighbors) data["iteration"] = self.communication_round return data + + def _averaging_server(self, peer_deques): + """ + Averages the received models of all working nodes + + """ + with torch.no_grad(): + total = dict() + for i, n in enumerate(peer_deques): + data = peer_deques[n].popleft() + degree, iteration = data["degree"], data["iteration"] + del data["degree"] + del data["iteration"] + del data["CHANNEL"] + logging.debug( + "Averaging model from neighbor {} of iteration {}".format( + n, iteration + ) + ) + data = self.deserialized_model(data) + weight = 1 / len(peer_deques) + for key, value in data.items(): + if key in total: + total[key] += weight * value + else: + total[key] = weight * value + + self.model.load_state_dict(total) + self._post_step() + self.communication_round += 1 + return total diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py index 24db77afea84b7a8424303f2239121423f5f054c..b0a4796fa909cd7cfc3d333bfb4bab45236851a9 100644 --- a/src/decentralizepy/sharing/Wavelet.py +++ b/src/decentralizepy/sharing/Wavelet.py @@ -134,7 +134,8 @@ class Wavelet(PartialModel): self.change_based_selection = change_based_selection # Do a dummy transform to get the shape and coefficents slices - coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level) + coeff = pywt.wavedec(self.init_model.numpy(), + self.wavelet, level=self.level) data, coeff_slices = pywt.coeffs_to_array(coeff) self.wt_shape = data.shape self.coeff_slices = coeff_slices @@ -203,14 +204,16 @@ class Wavelet(PartialModel): shapes[k] = list(v.shape) shared_params["shapes"] = shapes - shared_params[self.communication_round] = indices.tolist() # is slow + # is slow + shared_params[self.communication_round] = indices.tolist() shared_params["alpha"] = self.alpha with open( os.path.join( self.folder_path, - "{}_shared_params.json".format(self.communication_round + 1), + "{}_shared_params.json".format( + self.communication_round + 1), ), "w", ) as of: @@ -296,7 +299,8 @@ class Wavelet(PartialModel): else: topkwf = params.reshape(self.wt_shape) - weight = 1 / (max(len(peer_deques), degree) + 1) # Metro-Hastings + # Metro-Hastings + weight = 1 / (max(len(peer_deques), degree) + 1) weight_total += weight if total is None: total = weight * topkwf @@ -325,3 +329,59 @@ class Wavelet(PartialModel): self.model.load_state_dict(std_dict) self._post_step() self.communication_round += 1 + + def _averaging_server(self, peer_deques): + """ + Averages the received models of all working nodes + + """ + with torch.no_grad(): + total = None + wt_params = self.pre_share_model_transformed + for i, n in enumerate(peer_deques): + data = peer_deques[n].popleft() + degree, iteration = data["degree"], data["iteration"] + del data["degree"] + del data["iteration"] + del data["CHANNEL"] + logging.debug( + "Averaging model from neighbor {} of iteration {}".format( + n, iteration + ) + ) + data = self.deserialized_model(data) + params = data["params"] + if "indices" in data: + indices = data["indices"] + # use local data to complement + topkwf = wt_params.clone().detach() + topkwf[indices] = params + topkwf = topkwf.reshape(self.wt_shape) + else: + topkwf = params.reshape(self.wt_shape) + + weight = 1 / len(peer_deques) + if total is None: + total = weight * topkwf + else: + total += weight * topkwf + + avg_wf_params = pywt.array_to_coeffs( + total.numpy(), self.coeff_slices, output_format="wavedec" + ) + reverse_total = torch.from_numpy( + pywt.waverec(avg_wf_params, wavelet=self.wavelet) + ) + + start_index = 0 + std_dict = {} + for i, key in enumerate(self.model.state_dict()): + end_index = start_index + self.lens[i] + std_dict[key] = reverse_total[start_index:end_index].reshape( + self.shapes[i] + ) + start_index = end_index + + self.model.load_state_dict(std_dict) + self._post_step() + self.communication_round += 1