diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py index 18423a86bdaf8934634ea78bc1779ff41585d983..e23a579f8876cbf05b73e33022def3dbec2977fc 100644 --- a/src/decentralizepy/communication/TCP.py +++ b/src/decentralizepy/communication/TCP.py @@ -98,7 +98,7 @@ class TCP(Communication): elif recv == BYE: logging.info("Recieved {} from {}".format(BYE, sender)) self.barrier.remove(sender) - self.disconnect_neighbors() + return self.receive() else: logging.debug("Recieved message from {}".format(sender)) return self.decrypt(sender, recv) @@ -114,3 +114,15 @@ class TCP(Communication): for sock in self.peer_sockets.values(): sock.send(BYE) self.sent_disconnections = True + while len(self.barrier): + sender, recv = self.router.recv_multipart() + if recv == BYE: + logging.info("Recieved {} from {}".format(BYE, sender)) + self.barrier.remove(sender) + else: + logging.critical( + "Recieved unexpected {} from {}".format(recv, sender) + ) + raise RuntimeError( + "Received a message when expecting BYE from {}".format(sender) + ) diff --git a/testing.py b/testing.py index 0d27c9b7071bea6e0620d836ee3f78ec12a9de2f..0a32f59c34de4273389d652c2c4cfedc14df0aa3 100644 --- a/testing.py +++ b/testing.py @@ -1,5 +1,7 @@ import argparse +import datetime import logging +from pathlib import Path from localconfig import LocalConfig from torch import multiprocessing as mp @@ -20,20 +22,31 @@ def read_ini(file_path): if __name__ == "__main__": - config = read_ini("config.ini") - my_config = dict() - for section in config: - my_config[section] = dict(config.items(section)) parser = argparse.ArgumentParser() parser.add_argument("-mid", "--machine_id", type=int, default=0) parser.add_argument("-ps", "--procs_per_machine", type=int, default=1) parser.add_argument("-ms", "--machines", type=int, default=1) + parser.add_argument( + "-ld", "--log_dir", type=str, default="./{}".format(datetime.datetime.now()) + ) + parser.add_argument("-is", "--iterations", type=int, default=1) + parser.add_argument("-cf", "--config_file", type=str, default="config.ini") + parser.add_argument("-ll", "--log_level", type=str, default="INFO") + parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges") + parser.add_argument("-gt", "--graph_type", type=str, default="edges") args = parser.parse_args() + Path(args.log_dir).mkdir(parents=True, exist_ok=True) + + config = read_ini(args.config_file) + my_config = dict() + for section in config: + my_config[section] = dict(config.items(section)) + g = Graph() - g.read_graph_from_file("36_nodes.edges", "edges") + g.read_graph_from_file(args.graph_file, args.graph_type) n_machines = args.machines procs_per_machine = args.procs_per_machine l = Linear(n_machines, procs_per_machine) @@ -42,5 +55,5 @@ if __name__ == "__main__": mp.spawn( fn=Node, nprocs=procs_per_machine, - args=[m_id, l, g, my_config, 20, "results", logging.INFO], + args=[m_id, l, g, my_config, args.iterations, args.log_dir, args.log_level], )