diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 7f27d4a10f568a6ad4d59e48d9ec025a7044d79d..49083aaf03957aefe4cfd83c822800dba6758d87 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -61,7 +61,7 @@ class PartialModel(Sharing): shared_params["order"] = self.model.state_dict().keys() shapes = dict() for k, v in self.model.state_dict().items(): - shapes[k] = v.shape.tolist() + shapes[k] = list(v.shape) shared_params["shapes"] = shapes shared_params[self.communication_round] = G_topk.tolist()