From 2b7c3661c508fee4431211c58d33be7b08ba6461 Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Wed, 9 Mar 2022 15:09:25 +0100
Subject: [PATCH] Default value for random_seed (backward compatibility) and
 indices to int32

---
 src/decentralizepy/node/Node.py            | 5 ++++-
 src/decentralizepy/sharing/PartialModel.py | 5 +++--
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 48bb1ac..e5764ae 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -124,7 +124,10 @@ class Node:
         """
         dataset_module = importlib.import_module(dataset_configs["dataset_package"])
         self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
-        torch.manual_seed(dataset_configs["random_seed"])
+        random_seed = (
+            dataset_configs["random_seed"] if "random_seed" in dataset_configs else 97
+        )
+        torch.manual_seed(random_seed)
         self.dataset_params = utils.remove_keys(
             dataset_configs,
             ["dataset_package", "dataset_class", "model_class", "random_seed"],
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 529f673..6a8f0cb 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -3,6 +3,7 @@ import logging
 import os
 from pathlib import Path
 
+import numpy as np
 import torch
 
 from decentralizepy.sharing.Sharing import Sharing
@@ -155,7 +156,7 @@ class PartialModel(Sharing):
             if not self.dict_ordered:
                 raise NotImplementedError
 
-            m["indices"] = G_topk.numpy()
+            m["indices"] = G_topk.numpy().astype(np.int32)
 
             m["params"] = T_topk.numpy()
 
@@ -206,7 +207,7 @@ class PartialModel(Sharing):
                 tensors_to_cat.append(t)
 
             T = torch.cat(tensors_to_cat, dim=0)
-            index_tensor = torch.tensor(m["indices"])
+            index_tensor = torch.tensor(m["indices"], dtype=torch.long)
             logging.debug("Original tensor: {}".format(T[index_tensor]))
             T[index_tensor] = torch.tensor(m["params"])
             logging.debug("Final tensor: {}".format(T[index_tensor]))
-- 
GitLab