From 2b7c3661c508fee4431211c58d33be7b08ba6461 Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Wed, 9 Mar 2022 15:09:25 +0100 Subject: [PATCH] Default value for random_seed (backward compatibility) and indices to int32 --- src/decentralizepy/node/Node.py | 5 ++++- src/decentralizepy/sharing/PartialModel.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index 48bb1ac..e5764ae 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -124,7 +124,10 @@ class Node: """ dataset_module = importlib.import_module(dataset_configs["dataset_package"]) self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"]) - torch.manual_seed(dataset_configs["random_seed"]) + random_seed = ( + dataset_configs["random_seed"] if "random_seed" in dataset_configs else 97 + ) + torch.manual_seed(random_seed) self.dataset_params = utils.remove_keys( dataset_configs, ["dataset_package", "dataset_class", "model_class", "random_seed"], diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 529f673..6a8f0cb 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -3,6 +3,7 @@ import logging import os from pathlib import Path +import numpy as np import torch from decentralizepy.sharing.Sharing import Sharing @@ -155,7 +156,7 @@ class PartialModel(Sharing): if not self.dict_ordered: raise NotImplementedError - m["indices"] = G_topk.numpy() + m["indices"] = G_topk.numpy().astype(np.int32) m["params"] = T_topk.numpy() @@ -206,7 +207,7 @@ class PartialModel(Sharing): tensors_to_cat.append(t) T = torch.cat(tensors_to_cat, dim=0) - index_tensor = torch.tensor(m["indices"]) + index_tensor = torch.tensor(m["indices"], dtype=torch.long) logging.debug("Original tensor: {}".format(T[index_tensor])) T[index_tensor] = torch.tensor(m["params"]) logging.debug("Final tensor: {}".format(T[index_tensor])) -- GitLab