From d0656117c4b79f53c554af3051f513e6df77dc63 Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Thu, 7 Apr 2022 10:39:40 +0200
Subject: [PATCH] subsampling fix

---
 src/decentralizepy/sharing/SubSampling.py | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py
index 6221714..f8c8f50 100644
--- a/src/decentralizepy/sharing/SubSampling.py
+++ b/src/decentralizepy/sharing/SubSampling.py
@@ -101,6 +101,17 @@ class SubSampling(Sharing):
             )
             Path(self.folder_path).mkdir(parents=True, exist_ok=True)
 
+        with torch.no_grad():
+            tensors_to_cat = []
+            for _, v in self.model.state_dict().items():
+                t = v.flatten()
+                tensors_to_cat.append(t)
+            self.init_model = torch.cat(tensors_to_cat, dim=0)
+
+        self.model.shared_parameters_counter = torch.zeros(
+            self.init_model.shape[0], dtype=torch.int32
+        )
+
     def apply_subsampling(self):
         """
         Creates a random binary mask that is used to subsample the parameters that will be shared
-- 
GitLab