Skip to content
Snippets Groups Projects
Commit d0656117 authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

subsampling fix

parent a6abfa47
No related branches found
No related tags found
1 merge request!8updated configs and run files
This commit is part of merge request !8. Comments created here will be created in the context of that merge request.
......@@ -101,6 +101,17 @@ class SubSampling(Sharing):
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
with torch.no_grad():
tensors_to_cat = []
for _, v in self.model.state_dict().items():
t = v.flatten()
tensors_to_cat.append(t)
self.init_model = torch.cat(tensors_to_cat, dim=0)
self.model.shared_parameters_counter = torch.zeros(
self.init_model.shape[0], dtype=torch.int32
)
def apply_subsampling(self):
"""
Creates a random binary mask that is used to subsample the parameters that will be shared
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment