From 1b7936b126a572677ab459143bdb6a5e545f94aa Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Sat, 19 Mar 2022 18:18:50 +0100
Subject: [PATCH] wavelet and fft fix

---
 eval/step_configs/config_celeba_wavelet.ini  | 2 +-
 eval/step_configs/config_femnist.ini         | 1 +
 eval/step_configs/config_femnist_wavelet.ini | 2 +-
 src/decentralizepy/sharing/FFT.py            | 7 +++++--
 src/decentralizepy/sharing/PartialModel.py   | 2 +-
 src/decentralizepy/sharing/Wavelet.py        | 6 +++++-
 6 files changed, 14 insertions(+), 6 deletions(-)

diff --git a/eval/step_configs/config_celeba_wavelet.ini b/eval/step_configs/config_celeba_wavelet.ini
index 70e9f15..1c97eb9 100644
--- a/eval/step_configs/config_celeba_wavelet.ini
+++ b/eval/step_configs/config_celeba_wavelet.ini
@@ -34,5 +34,5 @@ sharing_class = Wavelet
 change_based_selection = True
 alpha = 0.1
 wavelet=sym2
-level= None
+level= 4
 accumulation = True
diff --git a/eval/step_configs/config_femnist.ini b/eval/step_configs/config_femnist.ini
index 8063181..de4f1ce 100644
--- a/eval/step_configs/config_femnist.ini
+++ b/eval/step_configs/config_femnist.ini
@@ -31,3 +31,4 @@ addresses_filepath = ip_addr_6Machines.json
 [SHARING]
 sharing_package = decentralizepy.sharing.PartialModel
 sharing_class = PartialModel
+alpha=0.1
diff --git a/eval/step_configs/config_femnist_wavelet.ini b/eval/step_configs/config_femnist_wavelet.ini
index 68704a3..b6ff278 100644
--- a/eval/step_configs/config_femnist_wavelet.ini
+++ b/eval/step_configs/config_femnist_wavelet.ini
@@ -35,5 +35,5 @@ sharing_class = Wavelet
 change_based_selection = True
 alpha = 0.1
 wavelet=sym2
-level= None
+level= 4
 accumulation = True
diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
index 6af75e3..e0e67fd 100644
--- a/src/decentralizepy/sharing/FFT.py
+++ b/src/decentralizepy/sharing/FFT.py
@@ -224,8 +224,11 @@ class FFT(PartialModel):
         with torch.no_grad():
             total = None
             weight_total = 0
-
-            flat_fft = self.change_transformer(self.init_model)
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            pre_share_model = torch.cat(tensors_to_cat, dim=0)
+            flat_fft = self.change_transformer(pre_share_model)
 
             for i, n in enumerate(self.peer_deques):
                 degree, iteration, data = self.peer_deques[n].popleft()
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 935f302..dca5c75 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -155,7 +155,7 @@ class PartialModel(Sharing):
             Model converted to a dict
 
         """
-        if self.alpha > self.metadata_cap:  # Share fully
+        if self.alpha >= self.metadata_cap:  # Share fully
             return super().serialized_model()
 
         with torch.no_grad():
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index cd039f8..e41bd24 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -257,7 +257,11 @@ class Wavelet(PartialModel):
         with torch.no_grad():
             total = None
             weight_total = 0
-            wt_params = self.change_transformer(self.init_model)
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            pre_share_model = torch.cat(tensors_to_cat, dim=0)
+            wt_params = self.change_transformer(pre_share_model)
             for i, n in enumerate(self.peer_deques):
                 degree, iteration, data = self.peer_deques[n].popleft()
                 logging.debug(
-- 
GitLab