From 1070c082d2de2983cc8f1c496efcee015e448ced Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Thu, 20 Oct 2022 22:19:28 +0200 Subject: [PATCH] Testing Peer Sampler Dynamic --- eval/testingPeerSamplerDynamic.py | 100 ++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 eval/testingPeerSamplerDynamic.py diff --git a/eval/testingPeerSamplerDynamic.py b/eval/testingPeerSamplerDynamic.py new file mode 100644 index 0000000..7045085 --- /dev/null +++ b/eval/testingPeerSamplerDynamic.py @@ -0,0 +1,100 @@ +import logging +from pathlib import Path +from shutil import copy + +from localconfig import LocalConfig +from torch import multiprocessing as mp + +from decentralizepy import utils +from decentralizepy.graphs.Graph import Graph +from decentralizepy.mappings.Linear import Linear +from decentralizepy.node.DPSGDWithPeerSampler import DPSGDWithPeerSampler +from decentralizepy.node.PeerSamplerDynamic import PeerSamplerDynamic + + +def read_ini(file_path): + config = LocalConfig(file_path) + for section in config: + print("Section: ", section) + for key, value in config.items(section): + print((key, value)) + print(dict(config.items("DATASET"))) + return config + + +if __name__ == "__main__": + args = utils.get_args() + + Path(args.log_dir).mkdir(parents=True, exist_ok=True) + + log_level = { + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + + config = read_ini(args.config_file) + my_config = dict() + for section in config: + my_config[section] = dict(config.items(section)) + + copy(args.config_file, args.log_dir) + copy(args.graph_file, args.log_dir) + utils.write_args(args, args.log_dir) + + g = Graph() + g.read_graph_from_file(args.graph_file, args.graph_type) + n_machines = args.machines + procs_per_machine = args.procs_per_machine + l = Linear(n_machines, procs_per_machine) + m_id = args.machine_id + + sm = args.server_machine + sr = args.server_rank + + processes = [] + if sm == m_id: + processes.append( + mp.Process( + target=PeerSamplerDynamic, + args=[ + sr, + m_id, + l, + g, + my_config, + args.iterations, + args.log_dir, + log_level[args.log_level], + ], + ) + ) + + for r in range(0, procs_per_machine): + processes.append( + mp.Process( + target=DPSGDWithPeerSampler, + args=[ + r, + m_id, + l, + g, + my_config, + args.iterations, + args.log_dir, + args.weights_store_dir, + log_level[args.log_level], + args.test_after, + args.train_evaluate_after, + args.reset_optimizer, + ], + ) + ) + + for p in processes: + p.start() + + for p in processes: + p.join() -- GitLab