diff --git a/eval/step_configs/config_celeba_wavelet.ini b/eval/step_configs/config_celeba_wavelet.ini index 70e9f155d303d397d42e4a48dff75ab5477a912e..1c97eb9f7e4c5220fd65f231e4423b9b49edc723 100644 --- a/eval/step_configs/config_celeba_wavelet.ini +++ b/eval/step_configs/config_celeba_wavelet.ini @@ -34,5 +34,5 @@ sharing_class = Wavelet change_based_selection = True alpha = 0.1 wavelet=sym2 -level= None +level= 4 accumulation = True diff --git a/eval/step_configs/config_femnist.ini b/eval/step_configs/config_femnist.ini index 8063181b132ba8a862041f38aa66e8dd99b33fbb..de4f1cef82d2a90665228ef403b02a32fff3188d 100644 --- a/eval/step_configs/config_femnist.ini +++ b/eval/step_configs/config_femnist.ini @@ -31,3 +31,4 @@ addresses_filepath = ip_addr_6Machines.json [SHARING] sharing_package = decentralizepy.sharing.PartialModel sharing_class = PartialModel +alpha=0.1 diff --git a/eval/step_configs/config_femnist_wavelet.ini b/eval/step_configs/config_femnist_wavelet.ini index 68704a3f3c08f3e0293434a303ce04685fb30d92..b6ff27856b3b7f43fd9ef1bd2c321003bda80a38 100644 --- a/eval/step_configs/config_femnist_wavelet.ini +++ b/eval/step_configs/config_femnist_wavelet.ini @@ -35,5 +35,5 @@ sharing_class = Wavelet change_based_selection = True alpha = 0.1 wavelet=sym2 -level= None +level= 4 accumulation = True diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py index 6af75e390489ceef91faf4d99919c7c546b40aa5..e0e67fd0db45fc8f77212b8494f11e38b10d8093 100644 --- a/src/decentralizepy/sharing/FFT.py +++ b/src/decentralizepy/sharing/FFT.py @@ -224,8 +224,11 @@ class FFT(PartialModel): with torch.no_grad(): total = None weight_total = 0 - - flat_fft = self.change_transformer(self.init_model) + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.state_dict().items() + ] + pre_share_model = torch.cat(tensors_to_cat, dim=0) + flat_fft = self.change_transformer(pre_share_model) for i, n in enumerate(self.peer_deques): degree, iteration, data = self.peer_deques[n].popleft() diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 935f302143433a4eb2d99e738a0dd2f1738b2153..dca5c75a95f2e5332fbe0534b6e949568df9b6a0 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -155,7 +155,7 @@ class PartialModel(Sharing): Model converted to a dict """ - if self.alpha > self.metadata_cap: # Share fully + if self.alpha >= self.metadata_cap: # Share fully return super().serialized_model() with torch.no_grad(): diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py index cd039f8223ba7f58bbbd24404e669139e0a664b8..e41bd24ab0f40076ddb091806b2d81aa382c7766 100644 --- a/src/decentralizepy/sharing/Wavelet.py +++ b/src/decentralizepy/sharing/Wavelet.py @@ -257,7 +257,11 @@ class Wavelet(PartialModel): with torch.no_grad(): total = None weight_total = 0 - wt_params = self.change_transformer(self.init_model) + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.state_dict().items() + ] + pre_share_model = torch.cat(tensors_to_cat, dim=0) + wt_params = self.change_transformer(pre_share_model) for i, n in enumerate(self.peer_deques): degree, iteration, data = self.peer_deques[n].popleft() logging.debug(