From 1b7936b126a572677ab459143bdb6a5e545f94aa Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Sat, 19 Mar 2022 18:18:50 +0100 Subject: [PATCH] wavelet and fft fix --- eval/step_configs/config_celeba_wavelet.ini | 2 +- eval/step_configs/config_femnist.ini | 1 + eval/step_configs/config_femnist_wavelet.ini | 2 +- src/decentralizepy/sharing/FFT.py | 7 +++++-- src/decentralizepy/sharing/PartialModel.py | 2 +- src/decentralizepy/sharing/Wavelet.py | 6 +++++- 6 files changed, 14 insertions(+), 6 deletions(-) diff --git a/eval/step_configs/config_celeba_wavelet.ini b/eval/step_configs/config_celeba_wavelet.ini index 70e9f15..1c97eb9 100644 --- a/eval/step_configs/config_celeba_wavelet.ini +++ b/eval/step_configs/config_celeba_wavelet.ini @@ -34,5 +34,5 @@ sharing_class = Wavelet change_based_selection = True alpha = 0.1 wavelet=sym2 -level= None +level= 4 accumulation = True diff --git a/eval/step_configs/config_femnist.ini b/eval/step_configs/config_femnist.ini index 8063181..de4f1ce 100644 --- a/eval/step_configs/config_femnist.ini +++ b/eval/step_configs/config_femnist.ini @@ -31,3 +31,4 @@ addresses_filepath = ip_addr_6Machines.json [SHARING] sharing_package = decentralizepy.sharing.PartialModel sharing_class = PartialModel +alpha=0.1 diff --git a/eval/step_configs/config_femnist_wavelet.ini b/eval/step_configs/config_femnist_wavelet.ini index 68704a3..b6ff278 100644 --- a/eval/step_configs/config_femnist_wavelet.ini +++ b/eval/step_configs/config_femnist_wavelet.ini @@ -35,5 +35,5 @@ sharing_class = Wavelet change_based_selection = True alpha = 0.1 wavelet=sym2 -level= None +level= 4 accumulation = True diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py index 6af75e3..e0e67fd 100644 --- a/src/decentralizepy/sharing/FFT.py +++ b/src/decentralizepy/sharing/FFT.py @@ -224,8 +224,11 @@ class FFT(PartialModel): with torch.no_grad(): total = None weight_total = 0 - - flat_fft = self.change_transformer(self.init_model) + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.state_dict().items() + ] + pre_share_model = torch.cat(tensors_to_cat, dim=0) + flat_fft = self.change_transformer(pre_share_model) for i, n in enumerate(self.peer_deques): degree, iteration, data = self.peer_deques[n].popleft() diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 935f302..dca5c75 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -155,7 +155,7 @@ class PartialModel(Sharing): Model converted to a dict """ - if self.alpha > self.metadata_cap: # Share fully + if self.alpha >= self.metadata_cap: # Share fully return super().serialized_model() with torch.no_grad(): diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py index cd039f8..e41bd24 100644 --- a/src/decentralizepy/sharing/Wavelet.py +++ b/src/decentralizepy/sharing/Wavelet.py @@ -257,7 +257,11 @@ class Wavelet(PartialModel): with torch.no_grad(): total = None weight_total = 0 - wt_params = self.change_transformer(self.init_model) + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.state_dict().items() + ] + pre_share_model = torch.cat(tensors_to_cat, dim=0) + wt_params = self.change_transformer(pre_share_model) for i, n in enumerate(self.peer_deques): degree, iteration, data = self.peer_deques[n].popleft() logging.debug( -- GitLab