diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index a87b86847b74b3b63f0a76757c1c8ab1852245bc..107e20d96d2dee01d9ea4ff33ed1c5e09d53d95d 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -159,6 +159,8 @@ class PartialModel(Sharing): m["params"] = T_topk.numpy() + m["send_partial"] = True + assert len(m["indices"]) == len(m["params"]) logging.info("Elements sending: {}".format(len(m["indices"]))) @@ -185,7 +187,7 @@ class PartialModel(Sharing): state_dict of received """ - if self.alpha > self.metadata_cap: # Share fully + if "send_partial" not in m: return super().deserialized_model(m) with torch.no_grad():