From bee802722b15acebb2d594dc8fea29a26660bd5e Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Tue, 8 Mar 2022 10:09:42 +0100 Subject: [PATCH] dynamically limiting the number of threads per proc --- src/decentralizepy/mappings/Linear.py | 13 +++++++++++++ src/decentralizepy/mappings/Mapping.py | 13 +++++++++++++ src/decentralizepy/node/Node.py | 8 +++++++- 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/decentralizepy/mappings/Linear.py b/src/decentralizepy/mappings/Linear.py index 57ef628..9419fbd 100644 --- a/src/decentralizepy/mappings/Linear.py +++ b/src/decentralizepy/mappings/Linear.py @@ -59,3 +59,16 @@ class Linear(Mapping): """ return (uid % self.procs_per_machine), (uid // self.procs_per_machine) + + def get_local_procs_count(self): + """ + Gives number of processes that run on the node + + Returns + ------- + int + the number of local processes + + """ + + return self.procs_per_machine diff --git a/src/decentralizepy/mappings/Mapping.py b/src/decentralizepy/mappings/Mapping.py index 9f764ed..bb8523f 100644 --- a/src/decentralizepy/mappings/Mapping.py +++ b/src/decentralizepy/mappings/Mapping.py @@ -67,3 +67,16 @@ class Mapping: """ raise NotImplementedError + + def get_local_procs_count(self): + """ + Gives number of processes that run on the node + + Returns + ------- + int + the number of local processes + + """ + + raise NotImplementedError diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index e612eca..9f1b7da 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -1,6 +1,7 @@ import importlib import json import logging +import math import os import torch @@ -421,7 +422,9 @@ class Node: Other arguments """ - torch.set_num_threads(2) + total_threads = os.cpu_count() + threads_per_proc = max(math.floor(total_threads / mapping.procs_per_machine), 1) + torch.set_num_threads(threads_per_proc) torch.set_num_interop_threads(1) self.instantiate( rank, @@ -435,5 +438,8 @@ class Node: test_after, *args ) + logging.info( + "Each proc uses %d threads out of %d.", threads_per_proc, total_threads + ) self.run() -- GitLab