diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index 0ef54cdee09002171d062f3b3f8ea2cf91e537f0..e612eca9c07cf5cd2f58882f1358fe9d809717a1 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -3,6 +3,7 @@ import json import logging import os +import torch from matplotlib import pyplot as plt from decentralizepy import utils @@ -420,6 +421,8 @@ class Node: Other arguments """ + torch.set_num_threads(2) + torch.set_num_interop_threads(1) self.instantiate( rank, machine_id, diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 28a790e994620425108d2b9a0e827e00060863db..a87b86847b74b3b63f0a76757c1c8ab1852245bc 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -155,9 +155,9 @@ class PartialModel(Sharing): if not self.dict_ordered: raise NotImplementedError - m["indices"] = G_topk.numpy().tolist() + m["indices"] = G_topk.numpy() - m["params"] = T_topk.numpy().tolist() + m["params"] = T_topk.numpy() assert len(m["indices"]) == len(m["params"]) logging.info("Elements sending: {}".format(len(m["indices"])))