diff --git a/src/decentralizepy/mappings/Linear.py b/src/decentralizepy/mappings/Linear.py index 57ef628c99817d32a8e1cddfefe5b5ff6a22a5cd..9419fbd40a18d2c9ca1a4992854a6971ce937dde 100644 --- a/src/decentralizepy/mappings/Linear.py +++ b/src/decentralizepy/mappings/Linear.py @@ -59,3 +59,16 @@ class Linear(Mapping): """ return (uid % self.procs_per_machine), (uid // self.procs_per_machine) + + def get_local_procs_count(self): + """ + Gives number of processes that run on the node + + Returns + ------- + int + the number of local processes + + """ + + return self.procs_per_machine diff --git a/src/decentralizepy/mappings/Mapping.py b/src/decentralizepy/mappings/Mapping.py index 9f764ed2875abd2dbaaf01e769c1c4d88d7d858b..bb8523fe2919f1d79c4fa0834a0d231db4701ac0 100644 --- a/src/decentralizepy/mappings/Mapping.py +++ b/src/decentralizepy/mappings/Mapping.py @@ -67,3 +67,16 @@ class Mapping: """ raise NotImplementedError + + def get_local_procs_count(self): + """ + Gives number of processes that run on the node + + Returns + ------- + int + the number of local processes + + """ + + raise NotImplementedError diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index e612eca9c07cf5cd2f58882f1358fe9d809717a1..9f1b7dae9357a988edc1d0e526a461e7b3abcb42 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -1,6 +1,7 @@ import importlib import json import logging +import math import os import torch @@ -421,7 +422,9 @@ class Node: Other arguments """ - torch.set_num_threads(2) + total_threads = os.cpu_count() + threads_per_proc = max(math.floor(total_threads / mapping.procs_per_machine), 1) + torch.set_num_threads(threads_per_proc) torch.set_num_interop_threads(1) self.instantiate( rank, @@ -435,5 +438,8 @@ class Node: test_after, *args ) + logging.info( + "Each proc uses %d threads out of %d.", threads_per_proc, total_threads + ) self.run()