From 6496eb69d412d56ae6864f2e216779fdd5b75937 Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Fri, 4 Mar 2022 16:22:08 +0100 Subject: [PATCH] removing not needed to list; set_num_threads to 2; --- src/decentralizepy/node/Node.py | 3 +++ src/decentralizepy/sharing/PartialModel.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index 0ef54cd..e612eca 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -3,6 +3,7 @@ import json import logging import os +import torch from matplotlib import pyplot as plt from decentralizepy import utils @@ -420,6 +421,8 @@ class Node: Other arguments """ + torch.set_num_threads(2) + torch.set_num_interop_threads(1) self.instantiate( rank, machine_id, diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 28a790e..a87b868 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -155,9 +155,9 @@ class PartialModel(Sharing): if not self.dict_ordered: raise NotImplementedError - m["indices"] = G_topk.numpy().tolist() + m["indices"] = G_topk.numpy() - m["params"] = T_topk.numpy().tolist() + m["params"] = T_topk.numpy() assert len(m["indices"]) == len(m["params"]) logging.info("Elements sending: {}".format(len(m["indices"]))) -- GitLab