Skip to content
Snippets Groups Projects
Commit 2b7c3661 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Default value for random_seed (backward compatibility) and indices to int32

parent e8bfbe5b
No related branches found
No related tags found
No related merge requests found
...@@ -124,7 +124,10 @@ class Node: ...@@ -124,7 +124,10 @@ class Node:
""" """
dataset_module = importlib.import_module(dataset_configs["dataset_package"]) dataset_module = importlib.import_module(dataset_configs["dataset_package"])
self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"]) self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
torch.manual_seed(dataset_configs["random_seed"]) random_seed = (
dataset_configs["random_seed"] if "random_seed" in dataset_configs else 97
)
torch.manual_seed(random_seed)
self.dataset_params = utils.remove_keys( self.dataset_params = utils.remove_keys(
dataset_configs, dataset_configs,
["dataset_package", "dataset_class", "model_class", "random_seed"], ["dataset_package", "dataset_class", "model_class", "random_seed"],
......
...@@ -3,6 +3,7 @@ import logging ...@@ -3,6 +3,7 @@ import logging
import os import os
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
from decentralizepy.sharing.Sharing import Sharing from decentralizepy.sharing.Sharing import Sharing
...@@ -155,7 +156,7 @@ class PartialModel(Sharing): ...@@ -155,7 +156,7 @@ class PartialModel(Sharing):
if not self.dict_ordered: if not self.dict_ordered:
raise NotImplementedError raise NotImplementedError
m["indices"] = G_topk.numpy() m["indices"] = G_topk.numpy().astype(np.int32)
m["params"] = T_topk.numpy() m["params"] = T_topk.numpy()
...@@ -206,7 +207,7 @@ class PartialModel(Sharing): ...@@ -206,7 +207,7 @@ class PartialModel(Sharing):
tensors_to_cat.append(t) tensors_to_cat.append(t)
T = torch.cat(tensors_to_cat, dim=0) T = torch.cat(tensors_to_cat, dim=0)
index_tensor = torch.tensor(m["indices"]) index_tensor = torch.tensor(m["indices"], dtype=torch.long)
logging.debug("Original tensor: {}".format(T[index_tensor])) logging.debug("Original tensor: {}".format(T[index_tensor]))
T[index_tensor] = torch.tensor(m["params"]) T[index_tensor] = torch.tensor(m["params"])
logging.debug("Final tensor: {}".format(T[index_tensor])) logging.debug("Final tensor: {}".format(T[index_tensor]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment