From 786dcb98dfa7ad5d080561c4c4fab77447f4bd4f Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Mon, 9 May 2022 16:17:50 +0200 Subject: [PATCH] removing total_data and total_meta from sharing classes Fixing subsampling with compression --- src/decentralizepy/communication/TCP.py | 2 ++ src/decentralizepy/sharing/FFT.py | 6 ------ src/decentralizepy/sharing/PartialModel.py | 1 - src/decentralizepy/sharing/Sharing.py | 8 ++++---- src/decentralizepy/sharing/SharingCentrality.py | 2 -- src/decentralizepy/sharing/SubSampling.py | 7 ------- src/decentralizepy/sharing/Synchronous.py | 3 --- src/decentralizepy/sharing/TopKParams.py | 5 ----- src/decentralizepy/sharing/Wavelet.py | 1 - 9 files changed, 6 insertions(+), 29 deletions(-) diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py index c609699..58cc55d 100644 --- a/src/decentralizepy/communication/TCP.py +++ b/src/decentralizepy/communication/TCP.py @@ -134,6 +134,8 @@ class TCP(Communication): meta_len = len( pickle.dumps(data["indices"]) ) # ONLY necessary for the statistics + else: + meta_len = 0 if "params" in data: data["params"] = self.compressor.compress_float(data["params"]) output = pickle.dumps(data) diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py index ba7b841..17650c1 100644 --- a/src/decentralizepy/sharing/FFT.py +++ b/src/decentralizepy/sharing/FFT.py @@ -159,7 +159,6 @@ class FFT(PartialModel): if self.alpha >= self.metadata_cap: # Share fully data = self.pre_share_model_transformed m["params"] = data.numpy() - self.total_data += len(self.communication.encrypt(m["params"])) if self.model.accumulated_changes is not None: self.model.accumulated_changes = torch.zeros_like( self.model.accumulated_changes @@ -200,11 +199,6 @@ class FFT(PartialModel): m["indices"] = indices.numpy().astype(np.int32) m["send_partial"] = True - self.total_data += len(self.communication.encrypt(m["params"])) - self.total_meta += len(self.communication.encrypt(m["indices"])) + len( - self.communication.encrypt(m["alpha"]) - ) - return m def deserialized_model(self, m): diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 1f1feca..3111e82 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -82,7 +82,6 @@ class PartialModel(Sharing): self.dict_ordered = dict_ordered self.save_shared = save_shared self.metadata_cap = metadata_cap - self.total_meta = 0 self.accumulation = accumulation self.save_accumulated = conditional_value(save_accumulated, "", False) self.change_transformer = change_transformer diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index 22b4de9..7dc8852 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -46,7 +46,6 @@ class Sharing: self.dataset = dataset self.communication_round = 0 self.log_dir = log_dir - self.total_data = 0 self.peer_deques = dict() self.my_neighbors = self.graph.neighbors(self.uid) @@ -99,8 +98,9 @@ class Sharing: m = dict() for key, val in self.model.state_dict().items(): m[key] = val.numpy() - self.total_data += len(self.communication.encrypt(m[key])) - return m + data = dict() + data["params"] = m + return data def deserialized_model(self, m): """ @@ -118,7 +118,7 @@ class Sharing: """ state_dict = dict() - for key, value in m.items(): + for key, value in m["params"].items(): state_dict[key] = torch.from_numpy(value) return state_dict diff --git a/src/decentralizepy/sharing/SharingCentrality.py b/src/decentralizepy/sharing/SharingCentrality.py index 580ce2a..f933a0e 100644 --- a/src/decentralizepy/sharing/SharingCentrality.py +++ b/src/decentralizepy/sharing/SharingCentrality.py @@ -46,7 +46,6 @@ class Sharing: self.dataset = dataset self.communication_round = 0 self.log_dir = log_dir - self.total_data = 0 self.peer_deques = dict() my_neighbors = self.graph.neighbors(self.uid) @@ -101,7 +100,6 @@ class Sharing: m = dict() for key, val in self.model.state_dict().items(): m[key] = val.numpy() - self.total_data += len(self.communication.encrypt(m[key])) return m def deserialized_model(self, m): diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py index f8c8f50..b51cb07 100644 --- a/src/decentralizepy/sharing/SubSampling.py +++ b/src/decentralizepy/sharing/SubSampling.py @@ -72,7 +72,6 @@ class SubSampling(Sharing): self.dict_ordered = dict_ordered self.save_shared = save_shared self.metadata_cap = metadata_cap - self.total_meta = 0 # self.random_seed_generator = torch.Generator() # # Will use the random device if supported by CPU, else uses the system time @@ -216,12 +215,6 @@ class SubSampling(Sharing): m["alpha"] = alpha m["params"] = subsample.numpy() - # logging.info("Converted dictionary to json") - self.total_data += len(self.communication.encrypt(m["params"])) - self.total_meta += len(self.communication.encrypt(m["seed"])) + len( - self.communication.encrypt(m["alpha"]) - ) - return m def deserialized_model(self, m): diff --git a/src/decentralizepy/sharing/Synchronous.py b/src/decentralizepy/sharing/Synchronous.py index 29d7f62..2c2d5e7 100644 --- a/src/decentralizepy/sharing/Synchronous.py +++ b/src/decentralizepy/sharing/Synchronous.py @@ -46,7 +46,6 @@ class Synchronous: self.dataset = dataset self.communication_round = 0 self.log_dir = log_dir - self.total_data = 0 self.peer_deques = dict() self.my_neighbors = self.graph.neighbors(self.uid) @@ -104,7 +103,6 @@ class Synchronous: m = dict() for key, val in self.model.state_dict().items(): m[key] = val - self.init_model[key] # this is -lr*gradient - self.total_data += len(self.communication.encrypt(m)) return m def serialized_model(self): @@ -120,7 +118,6 @@ class Synchronous: m = dict() for key, val in self.model.state_dict().items(): m[key] = val.clone().detach() - self.total_data += len(self.communication.encrypt(m)) return m def deserialized_model(self, m): diff --git a/src/decentralizepy/sharing/TopKParams.py b/src/decentralizepy/sharing/TopKParams.py index 02531f1..f188179 100644 --- a/src/decentralizepy/sharing/TopKParams.py +++ b/src/decentralizepy/sharing/TopKParams.py @@ -68,7 +68,6 @@ class TopKParams(Sharing): self.dict_ordered = dict_ordered self.save_shared = save_shared self.metadata_cap = metadata_cap - self.total_meta = 0 if self.save_shared: # Only save for 2 procs: Save space @@ -171,10 +170,6 @@ class TopKParams(Sharing): # m[key] = json.dumps(m[key]) logging.info("Converted dictionary to json") - self.total_data += len(self.communication.encrypt(m["params"])) - self.total_meta += len(self.communication.encrypt(m["indices"])) + len( - self.communication.encrypt(m["offsets"]) - ) return m diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py index b864f1f..91c97d0 100644 --- a/src/decentralizepy/sharing/Wavelet.py +++ b/src/decentralizepy/sharing/Wavelet.py @@ -181,7 +181,6 @@ class Wavelet(PartialModel): if self.alpha >= self.metadata_cap: # Share fully data = self.pre_share_model_transformed m["params"] = data.numpy() - self.total_data += len(self.communication.encrypt(m["params"])) if self.model.accumulated_changes is not None: self.model.accumulated_changes = torch.zeros_like( self.model.accumulated_changes -- GitLab