diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py index 6221714d5ae7bb4f24ea1a4bf66db5f3e9dffcc8..f8c8f50e7ac24fa3eca7d9a89199deb864c47594 100644 --- a/src/decentralizepy/sharing/SubSampling.py +++ b/src/decentralizepy/sharing/SubSampling.py @@ -101,6 +101,17 @@ class SubSampling(Sharing): ) Path(self.folder_path).mkdir(parents=True, exist_ok=True) + with torch.no_grad(): + tensors_to_cat = [] + for _, v in self.model.state_dict().items(): + t = v.flatten() + tensors_to_cat.append(t) + self.init_model = torch.cat(tensors_to_cat, dim=0) + + self.model.shared_parameters_counter = torch.zeros( + self.init_model.shape[0], dtype=torch.int32 + ) + def apply_subsampling(self): """ Creates a random binary mask that is used to subsample the parameters that will be shared