From bee802722b15acebb2d594dc8fea29a26660bd5e Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Tue, 8 Mar 2022 10:09:42 +0100
Subject: [PATCH] dynamically limiting the number of threads per proc

---
 src/decentralizepy/mappings/Linear.py  | 13 +++++++++++++
 src/decentralizepy/mappings/Mapping.py | 13 +++++++++++++
 src/decentralizepy/node/Node.py        |  8 +++++++-
 3 files changed, 33 insertions(+), 1 deletion(-)

diff --git a/src/decentralizepy/mappings/Linear.py b/src/decentralizepy/mappings/Linear.py
index 57ef628..9419fbd 100644
--- a/src/decentralizepy/mappings/Linear.py
+++ b/src/decentralizepy/mappings/Linear.py
@@ -59,3 +59,16 @@ class Linear(Mapping):
 
         """
         return (uid % self.procs_per_machine), (uid // self.procs_per_machine)
+
+    def get_local_procs_count(self):
+        """
+        Gives number of processes that run on the node
+
+        Returns
+        -------
+        int
+            the number of local processes
+
+        """
+
+        return self.procs_per_machine
diff --git a/src/decentralizepy/mappings/Mapping.py b/src/decentralizepy/mappings/Mapping.py
index 9f764ed..bb8523f 100644
--- a/src/decentralizepy/mappings/Mapping.py
+++ b/src/decentralizepy/mappings/Mapping.py
@@ -67,3 +67,16 @@ class Mapping:
         """
 
         raise NotImplementedError
+
+    def get_local_procs_count(self):
+        """
+        Gives number of processes that run on the node
+
+        Returns
+        -------
+        int
+            the number of local processes
+
+        """
+
+        raise NotImplementedError
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index e612eca..9f1b7da 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -1,6 +1,7 @@
 import importlib
 import json
 import logging
+import math
 import os
 
 import torch
@@ -421,7 +422,9 @@ class Node:
             Other arguments
 
         """
-        torch.set_num_threads(2)
+        total_threads = os.cpu_count()
+        threads_per_proc = max(math.floor(total_threads / mapping.procs_per_machine), 1)
+        torch.set_num_threads(threads_per_proc)
         torch.set_num_interop_threads(1)
         self.instantiate(
             rank,
@@ -435,5 +438,8 @@ class Node:
             test_after,
             *args
         )
+        logging.info(
+            "Each proc uses %d threads out of %d.", threads_per_proc, total_threads
+        )
 
         self.run()
-- 
GitLab