diff --git a/src/decentralizepy/mappings/Linear.py b/src/decentralizepy/mappings/Linear.py
index 57ef628c99817d32a8e1cddfefe5b5ff6a22a5cd..9419fbd40a18d2c9ca1a4992854a6971ce937dde 100644
--- a/src/decentralizepy/mappings/Linear.py
+++ b/src/decentralizepy/mappings/Linear.py
@@ -59,3 +59,16 @@ class Linear(Mapping):
 
         """
         return (uid % self.procs_per_machine), (uid // self.procs_per_machine)
+
+    def get_local_procs_count(self):
+        """
+        Gives number of processes that run on the node
+
+        Returns
+        -------
+        int
+            the number of local processes
+
+        """
+
+        return self.procs_per_machine
diff --git a/src/decentralizepy/mappings/Mapping.py b/src/decentralizepy/mappings/Mapping.py
index 9f764ed2875abd2dbaaf01e769c1c4d88d7d858b..bb8523fe2919f1d79c4fa0834a0d231db4701ac0 100644
--- a/src/decentralizepy/mappings/Mapping.py
+++ b/src/decentralizepy/mappings/Mapping.py
@@ -67,3 +67,16 @@ class Mapping:
         """
 
         raise NotImplementedError
+
+    def get_local_procs_count(self):
+        """
+        Gives number of processes that run on the node
+
+        Returns
+        -------
+        int
+            the number of local processes
+
+        """
+
+        raise NotImplementedError
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index e612eca9c07cf5cd2f58882f1358fe9d809717a1..9f1b7dae9357a988edc1d0e526a461e7b3abcb42 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -1,6 +1,7 @@
 import importlib
 import json
 import logging
+import math
 import os
 
 import torch
@@ -421,7 +422,9 @@ class Node:
             Other arguments
 
         """
-        torch.set_num_threads(2)
+        total_threads = os.cpu_count()
+        threads_per_proc = max(math.floor(total_threads / mapping.procs_per_machine), 1)
+        torch.set_num_threads(threads_per_proc)
         torch.set_num_interop_threads(1)
         self.instantiate(
             rank,
@@ -435,5 +438,8 @@ class Node:
             test_after,
             *args
         )
+        logging.info(
+            "Each proc uses %d threads out of %d.", threads_per_proc, total_threads
+        )
 
         self.run()