From d0656117c4b79f53c554af3051f513e6df77dc63 Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Thu, 7 Apr 2022 10:39:40 +0200 Subject: [PATCH] subsampling fix --- src/decentralizepy/sharing/SubSampling.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py index 6221714..f8c8f50 100644 --- a/src/decentralizepy/sharing/SubSampling.py +++ b/src/decentralizepy/sharing/SubSampling.py @@ -101,6 +101,17 @@ class SubSampling(Sharing): ) Path(self.folder_path).mkdir(parents=True, exist_ok=True) + with torch.no_grad(): + tensors_to_cat = [] + for _, v in self.model.state_dict().items(): + t = v.flatten() + tensors_to_cat.append(t) + self.init_model = torch.cat(tensors_to_cat, dim=0) + + self.model.shared_parameters_counter = torch.zeros( + self.init_model.shape[0], dtype=torch.int32 + ) + def apply_subsampling(self): """ Creates a random binary mask that is used to subsample the parameters that will be shared -- GitLab