From 6496eb69d412d56ae6864f2e216779fdd5b75937 Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Fri, 4 Mar 2022 16:22:08 +0100
Subject: [PATCH] removing not needed to list; set_num_threads to 2;

---
 src/decentralizepy/node/Node.py            | 3 +++
 src/decentralizepy/sharing/PartialModel.py | 4 ++--
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 0ef54cd..e612eca 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -3,6 +3,7 @@ import json
 import logging
 import os
 
+import torch
 from matplotlib import pyplot as plt
 
 from decentralizepy import utils
@@ -420,6 +421,8 @@ class Node:
             Other arguments
 
         """
+        torch.set_num_threads(2)
+        torch.set_num_interop_threads(1)
         self.instantiate(
             rank,
             machine_id,
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 28a790e..a87b868 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -155,9 +155,9 @@ class PartialModel(Sharing):
             if not self.dict_ordered:
                 raise NotImplementedError
 
-            m["indices"] = G_topk.numpy().tolist()
+            m["indices"] = G_topk.numpy()
 
-            m["params"] = T_topk.numpy().tolist()
+            m["params"] = T_topk.numpy()
 
             assert len(m["indices"]) == len(m["params"])
             logging.info("Elements sending: {}".format(len(m["indices"])))
-- 
GitLab