diff --git a/eval/testingPeerSampler.py b/eval/testingPeerSampler.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc26365407227a086f71220d0e7a2661f0ba169 --- /dev/null +++ b/eval/testingPeerSampler.py @@ -0,0 +1,102 @@ +import logging +from pathlib import Path +from shutil import copy + +from localconfig import LocalConfig +from torch import multiprocessing as mp + +from decentralizepy import utils +from decentralizepy.graphs.Graph import Graph +from decentralizepy.mappings.Linear import Linear +from decentralizepy.node.DPSGDWithPeerSampler import DPSGDWithPeerSampler +from decentralizepy.node.PeerSamplerDynamic import PeerSamplerDynamic + + +def read_ini(file_path): + config = LocalConfig(file_path) + for section in config: + print("Section: ", section) + for key, value in config.items(section): + print((key, value)) + print(dict(config.items("DATASET"))) + return config + + +if __name__ == "__main__": + args = utils.get_args() + + Path(args.log_dir).mkdir(parents=True, exist_ok=True) + + log_level = { + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + + config = read_ini(args.config_file) + my_config = dict() + for section in config: + my_config[section] = dict(config.items(section)) + + copy(args.config_file, args.log_dir) + copy(args.graph_file, args.log_dir) + utils.write_args(args, args.log_dir) + + g = Graph() + g.read_graph_from_file(args.graph_file, args.graph_type) + n_machines = args.machines + procs_per_machine = args.procs_per_machine + l = Linear(n_machines, procs_per_machine) + m_id = args.machine_id + + sm = args.server_machine + sr = args.server_rank + + processes = [] + if sm == m_id: + processes.append( + mp.Process( + target=PeerSamplerDynamic, + args=[ + sr, + m_id, + l, + g, + my_config, + args.iterations, + args.log_dir, + log_level[args.log_level], + ], + ) + ) + + for r in range(0, procs_per_machine): + processes.append( + mp.Process( + target=DPSGDWithPeerSampler, + args=[ + r, + m_id, + l, + g, + my_config, + args.iterations, + args.log_dir, + args.weights_store_dir, + log_level[args.log_level], + args.test_after, + args.train_evaluate_after, + args.reset_optimizer, + args.centralized_train_eval, + args.centralized_test_eval, + ], + ) + ) + + for p in processes: + p.start() + + for p in processes: + p.join() diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py index 8c8587b5292092f989eedba7352d1bb2039f6fa4..16de517dc2c1c88dc951ed544bece2799fd426a9 100644 --- a/src/decentralizepy/communication/TCP.py +++ b/src/decentralizepy/communication/TCP.py @@ -160,6 +160,16 @@ class TCP(Communication): req.connect(self.addr(*self.mapping.get_machine_and_rank(neighbor))) self.peer_sockets[id] = req + def destroy_connection(self, neighbor, linger=None): + id = str(neighbor).encode() + if self.already_connected(neighbor): + self.peer_sockets[id].close(linger=linger) + del self.peer_sockets[id] + + def already_connected(self, neighbor): + id = str(neighbor).encode() + return id in self.peer_sockets + def receive(self): """ Returns ONE message received. @@ -177,7 +187,8 @@ class TCP(Communication): """ sender, recv = self.router.recv_multipart() - return self.decrypt(sender, recv) + s, r = self.decrypt(sender, recv) + return s, r def send(self, uid, data, encrypt=True): """ diff --git a/src/decentralizepy/node/DPSGDNode.py b/src/decentralizepy/node/DPSGDNode.py index 964f10361baefb9aa7a6b90dd40631496e19dd64..0c8f0434d74249ff704d833a869907263b4c2c09 100644 --- a/src/decentralizepy/node/DPSGDNode.py +++ b/src/decentralizepy/node/DPSGDNode.py @@ -49,6 +49,17 @@ class DPSGDNode(Node): plt.title(title) plt.savefig(filename) + def get_neighbors(self, node=None): + return self.my_neighbors + + # def instantiate_peer_deques(self): + # for neighbor in self.my_neighbors: + # if neighbor not in self.peer_deques: + # self.peer_deques[neighbor] = deque() + + def receive_DPSGD(self): + return self.receive_channel("DPSGD") + def run(self): """ Start the decentralized learning @@ -90,28 +101,47 @@ class DPSGDNode(Node): for iteration in range(self.iterations): logging.info("Starting training iteration: %d", iteration) + self.iteration = iteration self.trainer.train(self.dataset) + + new_neighbors = self.get_neighbors() + + # for neighbor in self.my_neighbors: + # if neighbor not in new_neighbors: + # logging.info("Removing neighbor {}".format(neighbor)) + # if neighbor in self.peer_deques: + # assert len(self.peer_deques[neighbor]) == 0 + # del self.peer_deques[neighbor] + # self.communication.destroy_connection(neighbor, linger = 10000) + # self.barrier.remove(neighbor) + + self.my_neighbors = new_neighbors + self.connect_neighbors() + logging.info("Connected to all neighbors") + # self.instantiate_peer_deques() + to_send = self.sharing.get_data_to_send() + to_send["CHANNEL"] = "DPSGD" for neighbor in self.my_neighbors: self.communication.send(neighbor, to_send) while not self.received_from_all(): - sender, data = self.receive() - - if "HELLO" in data: - logging.critical( - "Received unexpected {} from {}".format("HELLO", sender) + sender, data = self.receive_DPSGD() + logging.info( + "Received Model from {} of iteration {}".format( + sender, data["iteration"] ) - raise RuntimeError("A neighbour wants to connect during training!") - elif "BYE" in data: - logging.debug("Received {} from {}".format("BYE", sender)) - self.barrier.remove(sender) - else: - logging.debug("Received message from {}".format(sender)) - self.peer_deques[sender].append(data) + ) + if sender not in self.peer_deques: + self.peer_deques[sender] = deque() + self.peer_deques[sender].append(data) - self.sharing._averaging(self.peer_deques) + averaging_deque = dict() + for neighbor in self.my_neighbors: + averaging_deque[neighbor] = self.peer_deques[neighbor] + + self.sharing._averaging(averaging_deque) if self.reset_optimizer: self.optimizer = self.optimizer_class( @@ -385,16 +415,15 @@ class DPSGDNode(Node): self.init_trainer(config["TRAIN_PARAMS"]) self.init_comm(config["COMMUNICATION"]) - self.message_queue = deque() + self.message_queue = dict() + self.barrier = set() self.my_neighbors = self.graph.neighbors(self.uid) self.init_sharing(config["SHARING"]) self.peer_deques = dict() - for n in self.my_neighbors: - self.peer_deques[n] = deque() - self.connect_neighbors() + # self.instantiate_peer_deques() def received_from_all(self): """ @@ -407,7 +436,7 @@ class DPSGDNode(Node): """ for k in self.my_neighbors: - if len(self.peer_deques[k]) == 0: + if (k not in self.peer_deques) or len(self.peer_deques[k]) == 0: return False return True diff --git a/src/decentralizepy/node/DPSGDWithPeerSampler.py b/src/decentralizepy/node/DPSGDWithPeerSampler.py new file mode 100644 index 0000000000000000000000000000000000000000..2508675a5ae0815eadb1bbf32b7e7cf33eeece50 --- /dev/null +++ b/src/decentralizepy/node/DPSGDWithPeerSampler.py @@ -0,0 +1,168 @@ +import logging +import math +import os +from collections import deque + +import torch + +from decentralizepy.graphs.Graph import Graph +from decentralizepy.mappings.Mapping import Mapping +from decentralizepy.node.DPSGDNode import DPSGDNode + + +class DPSGDWithPeerSampler(DPSGDNode): + """ + This class defines the node for DPSGD + + """ + + def receive_neighbors(self): + return self.receive_channel("PEERS")[1]["NEIGHBORS"] + + def get_neighbors(self, node=None): + logging.info("Requesting neighbors from the peer sampler.") + self.communication.send( + self.peer_sampler_uid, + { + "REQUEST_NEIGHBORS": self.uid, + "iteration": self.iteration, + "CHANNEL": "SERVER_REQUEST", + }, + ) + my_neighbors = self.receive_neighbors() + logging.info("Neighbors this round: {}".format(my_neighbors)) + return my_neighbors + + def __init__( + self, + rank: int, + machine_id: int, + mapping: Mapping, + graph: Graph, + config, + iterations=1, + log_dir=".", + weights_store_dir=".", + log_level=logging.INFO, + test_after=5, + train_evaluate_after=1, + reset_optimizer=1, + centralized_train_eval=0, + centralized_test_eval=1, + peer_sampler_uid=-1, + *args + ): + """ + Constructor + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + config : dict + A dictionary of configurations. Must contain the following: + [DATASET] + dataset_package + dataset_class + model_class + [OPTIMIZER_PARAMS] + optimizer_package + optimizer_class + [TRAIN_PARAMS] + training_package = decentralizepy.training.Training + training_class = Training + epochs_per_round = 25 + batch_size = 64 + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + weights_store_dir : str + Directory in which to store model weights + log_level : logging.Level + One of DEBUG, INFO, WARNING, ERROR, CRITICAL + test_after : int + Number of iterations after which the test loss and accuracy arecalculated + train_evaluate_after : int + Number of iterations after which the train loss is calculated + reset_optimizer : int + 1 if optimizer should be reset every communication round, else 0 + centralized_train_eval : int + If set then the train set evaluation happens at the node with uid 0. + Note: If it is True then centralized_test_eval needs to be true as well! + centralized_test_eval : int + If set then the trainset evaluation happens at the node with uid 0 + args : optional + Other arguments + + """ + centralized_train_eval = centralized_train_eval == 1 + centralized_test_eval = centralized_test_eval == 1 + # If centralized_train_eval is True then centralized_test_eval needs to be true as well! + assert not centralized_train_eval or centralized_test_eval + + total_threads = os.cpu_count() + self.threads_per_proc = max( + math.floor(total_threads / mapping.procs_per_machine), 1 + ) + torch.set_num_threads(self.threads_per_proc) + torch.set_num_interop_threads(1) + self.instantiate( + rank, + machine_id, + mapping, + graph, + config, + iterations, + log_dir, + weights_store_dir, + log_level, + test_after, + train_evaluate_after, + reset_optimizer, + centralized_train_eval == 1, + centralized_test_eval == 1, + *args + ) + logging.info( + "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads + ) + + self.message_queue["PEERS"] = deque() + + self.peer_sampler_uid = peer_sampler_uid + self.connect_neighbor(self.peer_sampler_uid) + self.wait_for_hello(self.peer_sampler_uid) + + self.run() + + def disconnect_neighbors(self): + """ + Disconnects all neighbors. + + Raises + ------ + RuntimeError + If received another message while waiting for BYEs + + """ + if not self.sent_disconnections: + logging.info("Disconnecting neighbors") + for uid in self.my_neighbors: + self.communication.send(uid, {"BYE": self.uid, "CHANNEL": "DISCONNECT"}) + self.communication.send( + self.peer_sampler_uid, {"BYE": self.uid, "CHANNEL": "SERVER_REQUEST"} + ) + self.sent_disconnections = True + + self.barrier.remove(self.peer_sampler_uid) + + while len(self.barrier): + sender, _ = self.receive_disconnect() + self.barrier.remove(sender) diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index 67ee659bd2ef42cbe6deea5b63496c96d87cd3ac..305064ec4953697248451d9ed2cf36d5fb62f580 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -24,7 +24,36 @@ class Node: """ logging.info("Sending connection request to {}".format(neighbor)) self.communication.init_connection(neighbor) - self.communication.send(neighbor, {"HELLO": self.uid}) + self.communication.send(neighbor, {"HELLO": self.uid, "CHANNEL": "CONNECT"}) + + def receive_channel(self, channel): + if channel not in self.message_queue: + self.message_queue[channel] = deque() + + if len(self.message_queue[channel]) > 0: + return self.message_queue[channel].popleft() + else: + sender, recv = self.communication.receive() + logging.info( + "Received some message from {} with CHANNEL: {}".format( + sender, recv["CHANNEL"] + ) + ) + assert "CHANNEL" in recv + while recv["CHANNEL"] != channel: + if recv["CHANNEL"] not in self.message_queue: + self.message_queue[recv["CHANNEL"]] = deque() + self.message_queue[recv["CHANNEL"]].append((sender, recv)) + sender, recv = self.communication.receive() + logging.info( + "Received some message from {} with CHANNEL: {}".format( + sender, recv["CHANNEL"] + ) + ) + return (sender, recv) + + def receive_hello(self): + return self.receive_channel("CONNECT") def wait_for_hello(self, neighbor): """ @@ -37,30 +66,11 @@ class Node: If received BYE while waiting for HELLO """ - while neighbor not in self.barrier: - sender, recv = self.communication.receive() - - if "HELLO" in recv: - logging.debug("Received {} from {}".format("HELLO", sender)) - self.barrier.add(sender) - elif "BYE" in recv: - logging.debug("Received {} from {}".format("BYE", sender)) - raise RuntimeError( - "A neighbour wants to disconnect before training started!" - ) - else: - logging.debug( - "Received message from {} @ connect_neighbors".format(sender) - ) - self.message_queue.append((sender, recv)) - - def receive(self): - if len(self.message_queue) > 0: - resp = self.message_queue.popleft() - else: - resp = self.communication.receive() - return resp + logging.info("Waiting HELLO from {}".format(neighbor)) + sender, _ = self.receive_hello() + logging.info("Received HELLO from {}".format(sender)) + self.barrier.add(sender) def connect_neighbors(self): """ @@ -74,12 +84,18 @@ class Node: """ logging.info("Sending connection request to all neighbors") + wait_acknowledgements = [] for neighbor in self.my_neighbors: - self.connect_neighbor(neighbor) + if not self.communication.already_connected(neighbor): + self.connect_neighbor(neighbor) + wait_acknowledgements.append(neighbor) - for neighbor in self.my_neighbors: + for neighbor in wait_acknowledgements: self.wait_for_hello(neighbor) + def receive_disconnect(self): + return self.receive_channel("DISCONNECT") + def disconnect_neighbors(self): """ Disconnects all neighbors. @@ -93,20 +109,11 @@ class Node: if not self.sent_disconnections: logging.info("Disconnecting neighbors") for uid in self.my_neighbors: - self.communication.send(uid, {"BYE": self.uid}) + self.communication.send(uid, {"BYE": self.uid, "CHANNEL": "DISCONNECT"}) self.sent_disconnections = True while len(self.barrier): - sender, recv = self.receive() - if "BYE" in recv: - logging.debug("Received {} from {}".format("BYE", sender)) - self.barrier.remove(sender) - else: - logging.critical( - "Received unexpected {} from {}".format(recv, sender) - ) - raise RuntimeError( - "Received a message when expecting BYE from {}".format(sender) - ) + sender, _ = self.receive_disconnect() + self.barrier.remove(sender) def init_log(self, log_dir, rank, log_level, force=True): """ @@ -364,7 +371,8 @@ class Node: self.init_trainer(config["TRAIN_PARAMS"]) self.init_comm(config["COMMUNICATION"]) - self.message_queue = deque() + self.message_queue = dict() + self.barrier = set() self.my_neighbors = self.graph.neighbors(self.uid) diff --git a/src/decentralizepy/node/PeerSampler.py b/src/decentralizepy/node/PeerSampler.py index 8f8db6fd225119e77840e5620650e6ee355e164c..6c76156e2d58251f1b17388a0dc3f4ed02d8da6a 100644 --- a/src/decentralizepy/node/PeerSampler.py +++ b/src/decentralizepy/node/PeerSampler.py @@ -1,5 +1,6 @@ import importlib import logging +import os from collections import deque from decentralizepy import utils @@ -14,6 +15,30 @@ class PeerSampler(Node): """ + def init_log(self, log_dir, log_level, force=True): + """ + Instantiate Logging. + + Parameters + ---------- + log_dir : str + Logging directory + rank : rank : int + Rank of process local to the machine + log_level : logging.Level + One of DEBUG, INFO, WARNING, ERROR, CRITICAL + force : bool + Argument to logging.basicConfig() + + """ + log_file = os.path.join(log_dir, "PeerSampler.log") + logging.basicConfig( + filename=log_file, + format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s", + level=log_level, + force=force, + ) + def cache_fields( self, rank, @@ -123,12 +148,19 @@ class PeerSampler(Node): log_dir, ) - self.message_queue = deque() + self.message_queue = dict() + self.barrier = set() self.init_comm(config["COMMUNICATION"]) self.my_neighbors = self.graph.get_all_nodes() - self.connect_neighbours() + self.connect_neighbors() + + def get_neighbors(self, node, iteration=None): + return self.graph.neighbors(node) + + def receive_server_request(self): + return self.receive_channel("SERVER_REQUEST") def run(self): """ @@ -136,13 +168,20 @@ class PeerSampler(Node): """ while len(self.barrier) > 0: - sender, data = self.receive() + sender, data = self.receive_server_request() if "BYE" in data: logging.debug("Received {} from {}".format("BYE", sender)) self.barrier.remove(sender) - else: + + elif "REQUEST_NEIGHBORS" in data: logging.debug("Received {} from {}".format("Request", sender)) - resp = {"neighbors": self.get_neighbors(sender)} + if "iteration" in data: + resp = { + "NEIGHBORS": self.get_neighbors(sender, data["iteration"]), + "CHANNEL": "PEERS", + } + else: + resp = {"NEIGHBORS": self.get_neighbors(sender), "CHANNEL": "PEERS"} self.communication.send(sender, resp) def __init__( diff --git a/src/decentralizepy/node/PeerSamplerDynamic.py b/src/decentralizepy/node/PeerSamplerDynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..9c9dc83e363d33e14a746d505e54051660fef752 --- /dev/null +++ b/src/decentralizepy/node/PeerSamplerDynamic.py @@ -0,0 +1,98 @@ +import logging +from collections import deque + +from decentralizepy.graphs.Graph import Graph +from decentralizepy.graphs.Regular import Regular +from decentralizepy.mappings.Mapping import Mapping +from decentralizepy.node.PeerSampler import PeerSampler + + +class PeerSamplerDynamic(PeerSampler): + """ + This class defines the peer sampling service + + """ + + def get_neighbors(self, node, iteration=None): + if iteration != None: + if iteration > self.iteration: + logging.info( + "iteration, self.iteration: {}, {}".format( + iteration, self.iteration + ) + ) + assert iteration == self.iteration + 1 + self.iteration = iteration + self.graphs.append(Regular(self.graph.n_procs, self.graph_degree)) + return self.graphs[iteration].neighbors(node) + else: + return self.graph.neighbors(node) + + def __init__( + self, + rank: int, + machine_id: int, + mapping: Mapping, + graph: Graph, + config, + iterations=1, + log_dir=".", + log_level=logging.INFO, + graph_degree=4, + *args + ): + """ + Constructor + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + config : dict + A dictionary of configurations. Must contain the following: + [DATASET] + dataset_package + dataset_class + model_class + [OPTIMIZER_PARAMS] + optimizer_package + optimizer_class + [TRAIN_PARAMS] + training_package = decentralizepy.training.Training + training_class = Training + epochs_per_round = 25 + batch_size = 64 + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + log_level : logging.Level + One of DEBUG, INFO, WARNING, ERROR, CRITICAL + args : optional + Other arguments + + """ + + self.iteration = -1 + self.graphs = [] + self.graph_degree = graph_degree + + self.instantiate( + rank, + machine_id, + mapping, + graph, + config, + iterations, + log_dir, + log_level, + *args + ) + + self.run() diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py index 60c60ec7e3232ef30231a403c9fb6b5ac16c9b39..4babfc285f771fb681daeafe3a4c11dbdddbde3e 100644 --- a/src/decentralizepy/sharing/FFT.py +++ b/src/decentralizepy/sharing/FFT.py @@ -262,6 +262,7 @@ class FFT(PartialModel): degree, iteration = data["degree"], data["iteration"] del data["degree"] del data["iteration"] + del data["CHANNEL"] logging.debug( "Averaging model from neighbor {} of iteration {}".format( n, iteration diff --git a/src/decentralizepy/sharing/LowerBoundTopK.py b/src/decentralizepy/sharing/LowerBoundTopK.py index 86b9c3bd82239c4782fc20de9a5bc2c61353b311..9d227b22ba106bf1d3d08b2ed40492f083fcf9c7 100644 --- a/src/decentralizepy/sharing/LowerBoundTopK.py +++ b/src/decentralizepy/sharing/LowerBoundTopK.py @@ -199,6 +199,7 @@ class LowerBoundTopK(PartialModel): degree, iteration = data["degree"], data["iteration"] del data["degree"] del data["iteration"] + del data["CHANNEL"] logging.debug( "Averaging model from neighbor {} of iteration {}".format( n, iteration diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index 7ef18cc33c560cd1b10ed8c35c7f5712245c149c..a41de33203d72cafd4aecca10a8b14b347625d72 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -164,6 +164,7 @@ class Sharing: degree, iteration = data["degree"], data["iteration"] del data["degree"] del data["iteration"] + del data["CHANNEL"] logging.debug( "Averaging model from neighbor {} of iteration {}".format( n, iteration diff --git a/src/decentralizepy/sharing/SharingCentrality.py b/src/decentralizepy/sharing/SharingCentrality.py index 05986acfe32c33242c1c59c73cfba674272a28c0..8b10f3cc231af4fea49ca2b8a3dbfc0704783ccf 100644 --- a/src/decentralizepy/sharing/SharingCentrality.py +++ b/src/decentralizepy/sharing/SharingCentrality.py @@ -189,6 +189,7 @@ class Sharing: iteration = data["iteration"] del data["degree"] del data["iteration"] + del data["CHANNEL"] self.peer_deques[sender].append((degree, iteration, data)) logging.info( "Deserialized received model from {} of iteration {}".format( diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py index ffc255867f55a10bf0dc897baaa53078dc32f9eb..24db77afea84b7a8424303f2239121423f5f054c 100644 --- a/src/decentralizepy/sharing/Wavelet.py +++ b/src/decentralizepy/sharing/Wavelet.py @@ -279,6 +279,7 @@ class Wavelet(PartialModel): degree, iteration = data["degree"], data["iteration"] del data["degree"] del data["iteration"] + del data["CHANNEL"] logging.debug( "Averaging model from neighbor {} of iteration {}".format( n, iteration diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py index 3c8c8581a4d14144d00280f7c0369030b7688026..aad8ddaa19db72e7cdeb5293b0f45a63a708bae7 100644 --- a/src/decentralizepy/utils.py +++ b/src/decentralizepy/utils.py @@ -84,7 +84,9 @@ def get_args(): parser.add_argument("-tea", "--train_evaluate_after", type=int, default=1) parser.add_argument("-ro", "--reset_optimizer", type=int, default=1) parser.add_argument("-ctr", "--centralized_train_eval", type=int, default=0) - parser.add_argument("-cte", "--centralized_test_eval", type=int, default=1) + parser.add_argument("-cte", "--centralized_test_eval", type=int, default=0) + parser.add_argument("-sm", "--server_machine", type=int, default=0) + parser.add_argument("-sr", "--server_rank", type=int, default=-1) args = parser.parse_args() return args