diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py index c6096994b4fdc937649515b75d1479ef6cc97d89..6d72e1cfc61e6f9888cac3d6cdb52c81c84d24e4 100644 --- a/src/decentralizepy/communication/TCP.py +++ b/src/decentralizepy/communication/TCP.py @@ -131,25 +131,23 @@ class TCP(Communication): if self.compress: if "indices" in data: data["indices"] = self.compressor.compress(data["indices"]) - meta_len = len( - pickle.dumps(data["indices"]) - ) # ONLY necessary for the statistics - if "params" in data: - data["params"] = self.compressor.compress_float(data["params"]) + + assert "params" in data + data["params"] = self.compressor.compress_float(data["params"]) + data_len = len(pickle.dumps(data["params"])) output = pickle.dumps(data) + # the compressed meta data gets only a few bytes smaller after pickling - self.total_meta += meta_len - self.total_data += len(output) - meta_len + self.total_meta += len(output) - data_len + self.total_data += data_len else: output = pickle.dumps(data) # centralized testing uses its own instance if type(data) == dict: - if "indices" in data: - meta_len = len(pickle.dumps(data["indices"])) - else: - meta_len = 0 - self.total_meta += meta_len - self.total_data += len(output) - meta_len + assert "params" in data + data_len = len(pickle.dumps(data["params"])) + self.total_meta += len(output) - data_len + self.total_data += data_len return output def decrypt(self, sender, data): diff --git a/src/decentralizepy/compression/EliasFpzip.py b/src/decentralizepy/compression/EliasFpzip.py index dc1413a22f98f2378401ecf173bc0404a51bc4a5..0c82560aae28efb64bdba2a9cf0abf81d7fcdda7 100644 --- a/src/decentralizepy/compression/EliasFpzip.py +++ b/src/decentralizepy/compression/EliasFpzip.py @@ -49,4 +49,4 @@ class EliasFpzip(Elias): decompressed data as array """ - return fpzip.decompress(bytes, order="C") + return fpzip.decompress(bytes, order="C").squeeze() diff --git a/src/decentralizepy/compression/EliasFpzipLossy.py b/src/decentralizepy/compression/EliasFpzipLossy.py index 30e0111be49451386a760dfc379c0f2f2e1a7c7e..617a78b2b27ff88bd57db29a3a65d71ad1e0a843 100644 --- a/src/decentralizepy/compression/EliasFpzipLossy.py +++ b/src/decentralizepy/compression/EliasFpzipLossy.py @@ -49,4 +49,4 @@ class EliasFpzipLossy(Elias): decompressed data as array """ - return fpzip.decompress(bytes, order="C") + return fpzip.decompress(bytes, order="C").squeeze() diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py index ba7b8418603ef1c8e119402dcfe982a2462a6a82..17650c1dd49aa220eeeeb9ff2526555b1797cf14 100644 --- a/src/decentralizepy/sharing/FFT.py +++ b/src/decentralizepy/sharing/FFT.py @@ -159,7 +159,6 @@ class FFT(PartialModel): if self.alpha >= self.metadata_cap: # Share fully data = self.pre_share_model_transformed m["params"] = data.numpy() - self.total_data += len(self.communication.encrypt(m["params"])) if self.model.accumulated_changes is not None: self.model.accumulated_changes = torch.zeros_like( self.model.accumulated_changes @@ -200,11 +199,6 @@ class FFT(PartialModel): m["indices"] = indices.numpy().astype(np.int32) m["send_partial"] = True - self.total_data += len(self.communication.encrypt(m["params"])) - self.total_meta += len(self.communication.encrypt(m["indices"])) + len( - self.communication.encrypt(m["alpha"]) - ) - return m def deserialized_model(self, m): diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 1f1fecaf7736f344e0fc770e0940724f4bccc907..3111e82ce9af9ad8c6aba27311219c823df6e135 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -82,7 +82,6 @@ class PartialModel(Sharing): self.dict_ordered = dict_ordered self.save_shared = save_shared self.metadata_cap = metadata_cap - self.total_meta = 0 self.accumulation = accumulation self.save_accumulated = conditional_value(save_accumulated, "", False) self.change_transformer = change_transformer diff --git a/src/decentralizepy/sharing/RandomAlphaWavelet.py b/src/decentralizepy/sharing/RandomAlphaWavelet.py index 62bc51f4513012eb08925b98b582ecca38b12e94..44ea3364bc913042931583d22bbcba78f0fef5be 100644 --- a/src/decentralizepy/sharing/RandomAlphaWavelet.py +++ b/src/decentralizepy/sharing/RandomAlphaWavelet.py @@ -19,7 +19,7 @@ class RandomAlpha(Wavelet): model, dataset, log_dir, - alpha_list=[0.1, 0.2, 0.3, 0.4, 1.0], + alpha_list="[0.1, 0.2, 0.3, 0.4, 1.0]", dict_ordered=True, save_shared=False, metadata_cap=1.0, diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index 22b4de9025c8ebdbb034a29484db94a9bb621e2d..0ad3927a6d7fb80acfa01e120fde1ef8db4a8d77 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -46,13 +46,20 @@ class Sharing: self.dataset = dataset self.communication_round = 0 self.log_dir = log_dir - self.total_data = 0 self.peer_deques = dict() self.my_neighbors = self.graph.neighbors(self.uid) for n in self.my_neighbors: self.peer_deques[n] = deque() + self.shapes = [] + self.lens = [] + with torch.no_grad(): + for _, v in self.model.state_dict().items(): + self.shapes.append(v.shape) + t = v.flatten().numpy() + self.lens.append(t.shape[0]) + def received_from_all(self): """ Check if all neighbors have sent the current iteration @@ -96,11 +103,15 @@ class Sharing: Model converted to dict """ - m = dict() - for key, val in self.model.state_dict().items(): - m[key] = val.numpy() - self.total_data += len(self.communication.encrypt(m[key])) - return m + to_cat = [] + with torch.no_grad(): + for _, v in self.model.state_dict().items(): + t = v.flatten() + to_cat.append(t) + flat = torch.cat(to_cat) + data = dict() + data["params"] = flat.numpy() + return data def deserialized_model(self, m): """ @@ -118,8 +129,13 @@ class Sharing: """ state_dict = dict() - for key, value in m.items(): - state_dict[key] = torch.from_numpy(value) + T = m["params"] + start_index = 0 + for i, key in enumerate(self.model.state_dict()): + end_index = start_index + self.lens[i] + state_dict[key] = torch.from_numpy(T[start_index:end_index].reshape(self.shapes[i])) + start_index = end_index + return state_dict def _pre_step(self): diff --git a/src/decentralizepy/sharing/SharingCentrality.py b/src/decentralizepy/sharing/SharingCentrality.py index 580ce2aacc6505d7c713cfab763540c7484cd609..f933a0e6e002b7064eccbaa88f92280bdff3f488 100644 --- a/src/decentralizepy/sharing/SharingCentrality.py +++ b/src/decentralizepy/sharing/SharingCentrality.py @@ -46,7 +46,6 @@ class Sharing: self.dataset = dataset self.communication_round = 0 self.log_dir = log_dir - self.total_data = 0 self.peer_deques = dict() my_neighbors = self.graph.neighbors(self.uid) @@ -101,7 +100,6 @@ class Sharing: m = dict() for key, val in self.model.state_dict().items(): m[key] = val.numpy() - self.total_data += len(self.communication.encrypt(m[key])) return m def deserialized_model(self, m): diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py index f8c8f50e7ac24fa3eca7d9a89199deb864c47594..b51cb07ce0345ee339be0fe2e338ffd9ab61b63e 100644 --- a/src/decentralizepy/sharing/SubSampling.py +++ b/src/decentralizepy/sharing/SubSampling.py @@ -72,7 +72,6 @@ class SubSampling(Sharing): self.dict_ordered = dict_ordered self.save_shared = save_shared self.metadata_cap = metadata_cap - self.total_meta = 0 # self.random_seed_generator = torch.Generator() # # Will use the random device if supported by CPU, else uses the system time @@ -216,12 +215,6 @@ class SubSampling(Sharing): m["alpha"] = alpha m["params"] = subsample.numpy() - # logging.info("Converted dictionary to json") - self.total_data += len(self.communication.encrypt(m["params"])) - self.total_meta += len(self.communication.encrypt(m["seed"])) + len( - self.communication.encrypt(m["alpha"]) - ) - return m def deserialized_model(self, m): diff --git a/src/decentralizepy/sharing/Synchronous.py b/src/decentralizepy/sharing/Synchronous.py index 29d7f62a7ea0872d4c3c0b5b8bdf3bf121a977b3..2c2d5e76cfa328260b14fcb9cbf2614e7101c751 100644 --- a/src/decentralizepy/sharing/Synchronous.py +++ b/src/decentralizepy/sharing/Synchronous.py @@ -46,7 +46,6 @@ class Synchronous: self.dataset = dataset self.communication_round = 0 self.log_dir = log_dir - self.total_data = 0 self.peer_deques = dict() self.my_neighbors = self.graph.neighbors(self.uid) @@ -104,7 +103,6 @@ class Synchronous: m = dict() for key, val in self.model.state_dict().items(): m[key] = val - self.init_model[key] # this is -lr*gradient - self.total_data += len(self.communication.encrypt(m)) return m def serialized_model(self): @@ -120,7 +118,6 @@ class Synchronous: m = dict() for key, val in self.model.state_dict().items(): m[key] = val.clone().detach() - self.total_data += len(self.communication.encrypt(m)) return m def deserialized_model(self, m): diff --git a/src/decentralizepy/sharing/TopKParams.py b/src/decentralizepy/sharing/TopKParams.py index 02531f164d37eb49158b839aed230be3beb17761..f1881798e91ff7cacb114f4071acc97ba81530e7 100644 --- a/src/decentralizepy/sharing/TopKParams.py +++ b/src/decentralizepy/sharing/TopKParams.py @@ -68,7 +68,6 @@ class TopKParams(Sharing): self.dict_ordered = dict_ordered self.save_shared = save_shared self.metadata_cap = metadata_cap - self.total_meta = 0 if self.save_shared: # Only save for 2 procs: Save space @@ -171,10 +170,6 @@ class TopKParams(Sharing): # m[key] = json.dumps(m[key]) logging.info("Converted dictionary to json") - self.total_data += len(self.communication.encrypt(m["params"])) - self.total_meta += len(self.communication.encrypt(m["indices"])) + len( - self.communication.encrypt(m["offsets"]) - ) return m diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py index b864f1ff875268a2a6867b78ad774d38cf68a7f2..91c97d0a5d71f20c9eac79405b066f97f087f0bc 100644 --- a/src/decentralizepy/sharing/Wavelet.py +++ b/src/decentralizepy/sharing/Wavelet.py @@ -181,7 +181,6 @@ class Wavelet(PartialModel): if self.alpha >= self.metadata_cap: # Share fully data = self.pre_share_model_transformed m["params"] = data.numpy() - self.total_data += len(self.communication.encrypt(m["params"])) if self.model.accumulated_changes is not None: self.model.accumulated_changes = torch.zeros_like( self.model.accumulated_changes