Skip to content
Snippets Groups Projects
Commit d4c2b6c6 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Dynamic Peer sampler added

parent f1d5035f
No related branches found
No related tags found
No related merge requests found
Showing with 526 additions and 64 deletions
import logging
from pathlib import Path
from shutil import copy
from localconfig import LocalConfig
from torch import multiprocessing as mp
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.DPSGDWithPeerSampler import DPSGDWithPeerSampler
from decentralizepy.node.PeerSamplerDynamic import PeerSamplerDynamic
def read_ini(file_path):
config = LocalConfig(file_path)
for section in config:
print("Section: ", section)
for key, value in config.items(section):
print((key, value))
print(dict(config.items("DATASET")))
return config
if __name__ == "__main__":
args = utils.get_args()
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
log_level = {
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
config = read_ini(args.config_file)
my_config = dict()
for section in config:
my_config[section] = dict(config.items(section))
copy(args.config_file, args.log_dir)
copy(args.graph_file, args.log_dir)
utils.write_args(args, args.log_dir)
g = Graph()
g.read_graph_from_file(args.graph_file, args.graph_type)
n_machines = args.machines
procs_per_machine = args.procs_per_machine
l = Linear(n_machines, procs_per_machine)
m_id = args.machine_id
sm = args.server_machine
sr = args.server_rank
processes = []
if sm == m_id:
processes.append(
mp.Process(
target=PeerSamplerDynamic,
args=[
sr,
m_id,
l,
g,
my_config,
args.iterations,
args.log_dir,
log_level[args.log_level],
],
)
)
for r in range(0, procs_per_machine):
processes.append(
mp.Process(
target=DPSGDWithPeerSampler,
args=[
r,
m_id,
l,
g,
my_config,
args.iterations,
args.log_dir,
args.weights_store_dir,
log_level[args.log_level],
args.test_after,
args.train_evaluate_after,
args.reset_optimizer,
args.centralized_train_eval,
args.centralized_test_eval,
],
)
)
for p in processes:
p.start()
for p in processes:
p.join()
...@@ -160,6 +160,16 @@ class TCP(Communication): ...@@ -160,6 +160,16 @@ class TCP(Communication):
req.connect(self.addr(*self.mapping.get_machine_and_rank(neighbor))) req.connect(self.addr(*self.mapping.get_machine_and_rank(neighbor)))
self.peer_sockets[id] = req self.peer_sockets[id] = req
def destroy_connection(self, neighbor, linger=None):
id = str(neighbor).encode()
if self.already_connected(neighbor):
self.peer_sockets[id].close(linger=linger)
del self.peer_sockets[id]
def already_connected(self, neighbor):
id = str(neighbor).encode()
return id in self.peer_sockets
def receive(self): def receive(self):
""" """
Returns ONE message received. Returns ONE message received.
...@@ -177,7 +187,8 @@ class TCP(Communication): ...@@ -177,7 +187,8 @@ class TCP(Communication):
""" """
sender, recv = self.router.recv_multipart() sender, recv = self.router.recv_multipart()
return self.decrypt(sender, recv) s, r = self.decrypt(sender, recv)
return s, r
def send(self, uid, data, encrypt=True): def send(self, uid, data, encrypt=True):
""" """
......
...@@ -49,6 +49,17 @@ class DPSGDNode(Node): ...@@ -49,6 +49,17 @@ class DPSGDNode(Node):
plt.title(title) plt.title(title)
plt.savefig(filename) plt.savefig(filename)
def get_neighbors(self, node=None):
return self.my_neighbors
# def instantiate_peer_deques(self):
# for neighbor in self.my_neighbors:
# if neighbor not in self.peer_deques:
# self.peer_deques[neighbor] = deque()
def receive_DPSGD(self):
return self.receive_channel("DPSGD")
def run(self): def run(self):
""" """
Start the decentralized learning Start the decentralized learning
...@@ -90,28 +101,47 @@ class DPSGDNode(Node): ...@@ -90,28 +101,47 @@ class DPSGDNode(Node):
for iteration in range(self.iterations): for iteration in range(self.iterations):
logging.info("Starting training iteration: %d", iteration) logging.info("Starting training iteration: %d", iteration)
self.iteration = iteration
self.trainer.train(self.dataset) self.trainer.train(self.dataset)
new_neighbors = self.get_neighbors()
# for neighbor in self.my_neighbors:
# if neighbor not in new_neighbors:
# logging.info("Removing neighbor {}".format(neighbor))
# if neighbor in self.peer_deques:
# assert len(self.peer_deques[neighbor]) == 0
# del self.peer_deques[neighbor]
# self.communication.destroy_connection(neighbor, linger = 10000)
# self.barrier.remove(neighbor)
self.my_neighbors = new_neighbors
self.connect_neighbors()
logging.info("Connected to all neighbors")
# self.instantiate_peer_deques()
to_send = self.sharing.get_data_to_send() to_send = self.sharing.get_data_to_send()
to_send["CHANNEL"] = "DPSGD"
for neighbor in self.my_neighbors: for neighbor in self.my_neighbors:
self.communication.send(neighbor, to_send) self.communication.send(neighbor, to_send)
while not self.received_from_all(): while not self.received_from_all():
sender, data = self.receive() sender, data = self.receive_DPSGD()
logging.info(
if "HELLO" in data: "Received Model from {} of iteration {}".format(
logging.critical( sender, data["iteration"]
"Received unexpected {} from {}".format("HELLO", sender)
) )
raise RuntimeError("A neighbour wants to connect during training!") )
elif "BYE" in data: if sender not in self.peer_deques:
logging.debug("Received {} from {}".format("BYE", sender)) self.peer_deques[sender] = deque()
self.barrier.remove(sender) self.peer_deques[sender].append(data)
else:
logging.debug("Received message from {}".format(sender))
self.peer_deques[sender].append(data)
self.sharing._averaging(self.peer_deques) averaging_deque = dict()
for neighbor in self.my_neighbors:
averaging_deque[neighbor] = self.peer_deques[neighbor]
self.sharing._averaging(averaging_deque)
if self.reset_optimizer: if self.reset_optimizer:
self.optimizer = self.optimizer_class( self.optimizer = self.optimizer_class(
...@@ -385,16 +415,15 @@ class DPSGDNode(Node): ...@@ -385,16 +415,15 @@ class DPSGDNode(Node):
self.init_trainer(config["TRAIN_PARAMS"]) self.init_trainer(config["TRAIN_PARAMS"])
self.init_comm(config["COMMUNICATION"]) self.init_comm(config["COMMUNICATION"])
self.message_queue = deque() self.message_queue = dict()
self.barrier = set() self.barrier = set()
self.my_neighbors = self.graph.neighbors(self.uid) self.my_neighbors = self.graph.neighbors(self.uid)
self.init_sharing(config["SHARING"]) self.init_sharing(config["SHARING"])
self.peer_deques = dict() self.peer_deques = dict()
for n in self.my_neighbors:
self.peer_deques[n] = deque()
self.connect_neighbors() self.connect_neighbors()
# self.instantiate_peer_deques()
def received_from_all(self): def received_from_all(self):
""" """
...@@ -407,7 +436,7 @@ class DPSGDNode(Node): ...@@ -407,7 +436,7 @@ class DPSGDNode(Node):
""" """
for k in self.my_neighbors: for k in self.my_neighbors:
if len(self.peer_deques[k]) == 0: if (k not in self.peer_deques) or len(self.peer_deques[k]) == 0:
return False return False
return True return True
......
import logging
import math
import os
from collections import deque
import torch
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.node.DPSGDNode import DPSGDNode
class DPSGDWithPeerSampler(DPSGDNode):
"""
This class defines the node for DPSGD
"""
def receive_neighbors(self):
return self.receive_channel("PEERS")[1]["NEIGHBORS"]
def get_neighbors(self, node=None):
logging.info("Requesting neighbors from the peer sampler.")
self.communication.send(
self.peer_sampler_uid,
{
"REQUEST_NEIGHBORS": self.uid,
"iteration": self.iteration,
"CHANNEL": "SERVER_REQUEST",
},
)
my_neighbors = self.receive_neighbors()
logging.info("Neighbors this round: {}".format(my_neighbors))
return my_neighbors
def __init__(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
weights_store_dir=".",
log_level=logging.INFO,
test_after=5,
train_evaluate_after=1,
reset_optimizer=1,
centralized_train_eval=0,
centralized_test_eval=1,
peer_sampler_uid=-1,
*args
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations. Must contain the following:
[DATASET]
dataset_package
dataset_class
model_class
[OPTIMIZER_PARAMS]
optimizer_package
optimizer_class
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
epochs_per_round = 25
batch_size = 64
iterations : int
Number of iterations (communication steps) for which the model should be trained
log_dir : str
Logging directory
weights_store_dir : str
Directory in which to store model weights
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
test_after : int
Number of iterations after which the test loss and accuracy arecalculated
train_evaluate_after : int
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
centralized_train_eval : int
If set then the train set evaluation happens at the node with uid 0.
Note: If it is True then centralized_test_eval needs to be true as well!
centralized_test_eval : int
If set then the trainset evaluation happens at the node with uid 0
args : optional
Other arguments
"""
centralized_train_eval = centralized_train_eval == 1
centralized_test_eval = centralized_test_eval == 1
# If centralized_train_eval is True then centralized_test_eval needs to be true as well!
assert not centralized_train_eval or centralized_test_eval
total_threads = os.cpu_count()
self.threads_per_proc = max(
math.floor(total_threads / mapping.procs_per_machine), 1
)
torch.set_num_threads(self.threads_per_proc)
torch.set_num_interop_threads(1)
self.instantiate(
rank,
machine_id,
mapping,
graph,
config,
iterations,
log_dir,
weights_store_dir,
log_level,
test_after,
train_evaluate_after,
reset_optimizer,
centralized_train_eval == 1,
centralized_test_eval == 1,
*args
)
logging.info(
"Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
)
self.message_queue["PEERS"] = deque()
self.peer_sampler_uid = peer_sampler_uid
self.connect_neighbor(self.peer_sampler_uid)
self.wait_for_hello(self.peer_sampler_uid)
self.run()
def disconnect_neighbors(self):
"""
Disconnects all neighbors.
Raises
------
RuntimeError
If received another message while waiting for BYEs
"""
if not self.sent_disconnections:
logging.info("Disconnecting neighbors")
for uid in self.my_neighbors:
self.communication.send(uid, {"BYE": self.uid, "CHANNEL": "DISCONNECT"})
self.communication.send(
self.peer_sampler_uid, {"BYE": self.uid, "CHANNEL": "SERVER_REQUEST"}
)
self.sent_disconnections = True
self.barrier.remove(self.peer_sampler_uid)
while len(self.barrier):
sender, _ = self.receive_disconnect()
self.barrier.remove(sender)
...@@ -24,7 +24,36 @@ class Node: ...@@ -24,7 +24,36 @@ class Node:
""" """
logging.info("Sending connection request to {}".format(neighbor)) logging.info("Sending connection request to {}".format(neighbor))
self.communication.init_connection(neighbor) self.communication.init_connection(neighbor)
self.communication.send(neighbor, {"HELLO": self.uid}) self.communication.send(neighbor, {"HELLO": self.uid, "CHANNEL": "CONNECT"})
def receive_channel(self, channel):
if channel not in self.message_queue:
self.message_queue[channel] = deque()
if len(self.message_queue[channel]) > 0:
return self.message_queue[channel].popleft()
else:
sender, recv = self.communication.receive()
logging.info(
"Received some message from {} with CHANNEL: {}".format(
sender, recv["CHANNEL"]
)
)
assert "CHANNEL" in recv
while recv["CHANNEL"] != channel:
if recv["CHANNEL"] not in self.message_queue:
self.message_queue[recv["CHANNEL"]] = deque()
self.message_queue[recv["CHANNEL"]].append((sender, recv))
sender, recv = self.communication.receive()
logging.info(
"Received some message from {} with CHANNEL: {}".format(
sender, recv["CHANNEL"]
)
)
return (sender, recv)
def receive_hello(self):
return self.receive_channel("CONNECT")
def wait_for_hello(self, neighbor): def wait_for_hello(self, neighbor):
""" """
...@@ -37,30 +66,11 @@ class Node: ...@@ -37,30 +66,11 @@ class Node:
If received BYE while waiting for HELLO If received BYE while waiting for HELLO
""" """
while neighbor not in self.barrier: while neighbor not in self.barrier:
sender, recv = self.communication.receive() logging.info("Waiting HELLO from {}".format(neighbor))
sender, _ = self.receive_hello()
if "HELLO" in recv: logging.info("Received HELLO from {}".format(sender))
logging.debug("Received {} from {}".format("HELLO", sender)) self.barrier.add(sender)
self.barrier.add(sender)
elif "BYE" in recv:
logging.debug("Received {} from {}".format("BYE", sender))
raise RuntimeError(
"A neighbour wants to disconnect before training started!"
)
else:
logging.debug(
"Received message from {} @ connect_neighbors".format(sender)
)
self.message_queue.append((sender, recv))
def receive(self):
if len(self.message_queue) > 0:
resp = self.message_queue.popleft()
else:
resp = self.communication.receive()
return resp
def connect_neighbors(self): def connect_neighbors(self):
""" """
...@@ -74,12 +84,18 @@ class Node: ...@@ -74,12 +84,18 @@ class Node:
""" """
logging.info("Sending connection request to all neighbors") logging.info("Sending connection request to all neighbors")
wait_acknowledgements = []
for neighbor in self.my_neighbors: for neighbor in self.my_neighbors:
self.connect_neighbor(neighbor) if not self.communication.already_connected(neighbor):
self.connect_neighbor(neighbor)
wait_acknowledgements.append(neighbor)
for neighbor in self.my_neighbors: for neighbor in wait_acknowledgements:
self.wait_for_hello(neighbor) self.wait_for_hello(neighbor)
def receive_disconnect(self):
return self.receive_channel("DISCONNECT")
def disconnect_neighbors(self): def disconnect_neighbors(self):
""" """
Disconnects all neighbors. Disconnects all neighbors.
...@@ -93,20 +109,11 @@ class Node: ...@@ -93,20 +109,11 @@ class Node:
if not self.sent_disconnections: if not self.sent_disconnections:
logging.info("Disconnecting neighbors") logging.info("Disconnecting neighbors")
for uid in self.my_neighbors: for uid in self.my_neighbors:
self.communication.send(uid, {"BYE": self.uid}) self.communication.send(uid, {"BYE": self.uid, "CHANNEL": "DISCONNECT"})
self.sent_disconnections = True self.sent_disconnections = True
while len(self.barrier): while len(self.barrier):
sender, recv = self.receive() sender, _ = self.receive_disconnect()
if "BYE" in recv: self.barrier.remove(sender)
logging.debug("Received {} from {}".format("BYE", sender))
self.barrier.remove(sender)
else:
logging.critical(
"Received unexpected {} from {}".format(recv, sender)
)
raise RuntimeError(
"Received a message when expecting BYE from {}".format(sender)
)
def init_log(self, log_dir, rank, log_level, force=True): def init_log(self, log_dir, rank, log_level, force=True):
""" """
...@@ -364,7 +371,8 @@ class Node: ...@@ -364,7 +371,8 @@ class Node:
self.init_trainer(config["TRAIN_PARAMS"]) self.init_trainer(config["TRAIN_PARAMS"])
self.init_comm(config["COMMUNICATION"]) self.init_comm(config["COMMUNICATION"])
self.message_queue = deque() self.message_queue = dict()
self.barrier = set() self.barrier = set()
self.my_neighbors = self.graph.neighbors(self.uid) self.my_neighbors = self.graph.neighbors(self.uid)
......
import importlib import importlib
import logging import logging
import os
from collections import deque from collections import deque
from decentralizepy import utils from decentralizepy import utils
...@@ -14,6 +15,30 @@ class PeerSampler(Node): ...@@ -14,6 +15,30 @@ class PeerSampler(Node):
""" """
def init_log(self, log_dir, log_level, force=True):
"""
Instantiate Logging.
Parameters
----------
log_dir : str
Logging directory
rank : rank : int
Rank of process local to the machine
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
force : bool
Argument to logging.basicConfig()
"""
log_file = os.path.join(log_dir, "PeerSampler.log")
logging.basicConfig(
filename=log_file,
format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s",
level=log_level,
force=force,
)
def cache_fields( def cache_fields(
self, self,
rank, rank,
...@@ -123,12 +148,19 @@ class PeerSampler(Node): ...@@ -123,12 +148,19 @@ class PeerSampler(Node):
log_dir, log_dir,
) )
self.message_queue = deque() self.message_queue = dict()
self.barrier = set() self.barrier = set()
self.init_comm(config["COMMUNICATION"]) self.init_comm(config["COMMUNICATION"])
self.my_neighbors = self.graph.get_all_nodes() self.my_neighbors = self.graph.get_all_nodes()
self.connect_neighbours() self.connect_neighbors()
def get_neighbors(self, node, iteration=None):
return self.graph.neighbors(node)
def receive_server_request(self):
return self.receive_channel("SERVER_REQUEST")
def run(self): def run(self):
""" """
...@@ -136,13 +168,20 @@ class PeerSampler(Node): ...@@ -136,13 +168,20 @@ class PeerSampler(Node):
""" """
while len(self.barrier) > 0: while len(self.barrier) > 0:
sender, data = self.receive() sender, data = self.receive_server_request()
if "BYE" in data: if "BYE" in data:
logging.debug("Received {} from {}".format("BYE", sender)) logging.debug("Received {} from {}".format("BYE", sender))
self.barrier.remove(sender) self.barrier.remove(sender)
else:
elif "REQUEST_NEIGHBORS" in data:
logging.debug("Received {} from {}".format("Request", sender)) logging.debug("Received {} from {}".format("Request", sender))
resp = {"neighbors": self.get_neighbors(sender)} if "iteration" in data:
resp = {
"NEIGHBORS": self.get_neighbors(sender, data["iteration"]),
"CHANNEL": "PEERS",
}
else:
resp = {"NEIGHBORS": self.get_neighbors(sender), "CHANNEL": "PEERS"}
self.communication.send(sender, resp) self.communication.send(sender, resp)
def __init__( def __init__(
......
import logging
from collections import deque
from decentralizepy.graphs.Graph import Graph
from decentralizepy.graphs.Regular import Regular
from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.node.PeerSampler import PeerSampler
class PeerSamplerDynamic(PeerSampler):
"""
This class defines the peer sampling service
"""
def get_neighbors(self, node, iteration=None):
if iteration != None:
if iteration > self.iteration:
logging.info(
"iteration, self.iteration: {}, {}".format(
iteration, self.iteration
)
)
assert iteration == self.iteration + 1
self.iteration = iteration
self.graphs.append(Regular(self.graph.n_procs, self.graph_degree))
return self.graphs[iteration].neighbors(node)
else:
return self.graph.neighbors(node)
def __init__(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
log_level=logging.INFO,
graph_degree=4,
*args
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations. Must contain the following:
[DATASET]
dataset_package
dataset_class
model_class
[OPTIMIZER_PARAMS]
optimizer_package
optimizer_class
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
epochs_per_round = 25
batch_size = 64
iterations : int
Number of iterations (communication steps) for which the model should be trained
log_dir : str
Logging directory
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
args : optional
Other arguments
"""
self.iteration = -1
self.graphs = []
self.graph_degree = graph_degree
self.instantiate(
rank,
machine_id,
mapping,
graph,
config,
iterations,
log_dir,
log_level,
*args
)
self.run()
...@@ -262,6 +262,7 @@ class FFT(PartialModel): ...@@ -262,6 +262,7 @@ class FFT(PartialModel):
degree, iteration = data["degree"], data["iteration"] degree, iteration = data["degree"], data["iteration"]
del data["degree"] del data["degree"]
del data["iteration"] del data["iteration"]
del data["CHANNEL"]
logging.debug( logging.debug(
"Averaging model from neighbor {} of iteration {}".format( "Averaging model from neighbor {} of iteration {}".format(
n, iteration n, iteration
......
...@@ -199,6 +199,7 @@ class LowerBoundTopK(PartialModel): ...@@ -199,6 +199,7 @@ class LowerBoundTopK(PartialModel):
degree, iteration = data["degree"], data["iteration"] degree, iteration = data["degree"], data["iteration"]
del data["degree"] del data["degree"]
del data["iteration"] del data["iteration"]
del data["CHANNEL"]
logging.debug( logging.debug(
"Averaging model from neighbor {} of iteration {}".format( "Averaging model from neighbor {} of iteration {}".format(
n, iteration n, iteration
......
...@@ -164,6 +164,7 @@ class Sharing: ...@@ -164,6 +164,7 @@ class Sharing:
degree, iteration = data["degree"], data["iteration"] degree, iteration = data["degree"], data["iteration"]
del data["degree"] del data["degree"]
del data["iteration"] del data["iteration"]
del data["CHANNEL"]
logging.debug( logging.debug(
"Averaging model from neighbor {} of iteration {}".format( "Averaging model from neighbor {} of iteration {}".format(
n, iteration n, iteration
......
...@@ -189,6 +189,7 @@ class Sharing: ...@@ -189,6 +189,7 @@ class Sharing:
iteration = data["iteration"] iteration = data["iteration"]
del data["degree"] del data["degree"]
del data["iteration"] del data["iteration"]
del data["CHANNEL"]
self.peer_deques[sender].append((degree, iteration, data)) self.peer_deques[sender].append((degree, iteration, data))
logging.info( logging.info(
"Deserialized received model from {} of iteration {}".format( "Deserialized received model from {} of iteration {}".format(
......
...@@ -279,6 +279,7 @@ class Wavelet(PartialModel): ...@@ -279,6 +279,7 @@ class Wavelet(PartialModel):
degree, iteration = data["degree"], data["iteration"] degree, iteration = data["degree"], data["iteration"]
del data["degree"] del data["degree"]
del data["iteration"] del data["iteration"]
del data["CHANNEL"]
logging.debug( logging.debug(
"Averaging model from neighbor {} of iteration {}".format( "Averaging model from neighbor {} of iteration {}".format(
n, iteration n, iteration
......
...@@ -84,7 +84,9 @@ def get_args(): ...@@ -84,7 +84,9 @@ def get_args():
parser.add_argument("-tea", "--train_evaluate_after", type=int, default=1) parser.add_argument("-tea", "--train_evaluate_after", type=int, default=1)
parser.add_argument("-ro", "--reset_optimizer", type=int, default=1) parser.add_argument("-ro", "--reset_optimizer", type=int, default=1)
parser.add_argument("-ctr", "--centralized_train_eval", type=int, default=0) parser.add_argument("-ctr", "--centralized_train_eval", type=int, default=0)
parser.add_argument("-cte", "--centralized_test_eval", type=int, default=1) parser.add_argument("-cte", "--centralized_test_eval", type=int, default=0)
parser.add_argument("-sm", "--server_machine", type=int, default=0)
parser.add_argument("-sr", "--server_rank", type=int, default=-1)
args = parser.parse_args() args = parser.parse_args()
return args return args
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment