Skip to content
Snippets Groups Projects
Commit aec49a05 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Fix termination bug TCP

parent 3849b18b
No related branches found
No related tags found
No related merge requests found
......@@ -98,7 +98,7 @@ class TCP(Communication):
elif recv == BYE:
logging.info("Recieved {} from {}".format(BYE, sender))
self.barrier.remove(sender)
self.disconnect_neighbors()
return self.receive()
else:
logging.debug("Recieved message from {}".format(sender))
return self.decrypt(sender, recv)
......@@ -114,3 +114,15 @@ class TCP(Communication):
for sock in self.peer_sockets.values():
sock.send(BYE)
self.sent_disconnections = True
while len(self.barrier):
sender, recv = self.router.recv_multipart()
if recv == BYE:
logging.info("Recieved {} from {}".format(BYE, sender))
self.barrier.remove(sender)
else:
logging.critical(
"Recieved unexpected {} from {}".format(recv, sender)
)
raise RuntimeError(
"Received a message when expecting BYE from {}".format(sender)
)
import argparse
import datetime
import logging
from pathlib import Path
from localconfig import LocalConfig
from torch import multiprocessing as mp
......@@ -20,20 +22,31 @@ def read_ini(file_path):
if __name__ == "__main__":
config = read_ini("config.ini")
my_config = dict()
for section in config:
my_config[section] = dict(config.items(section))
parser = argparse.ArgumentParser()
parser.add_argument("-mid", "--machine_id", type=int, default=0)
parser.add_argument("-ps", "--procs_per_machine", type=int, default=1)
parser.add_argument("-ms", "--machines", type=int, default=1)
parser.add_argument(
"-ld", "--log_dir", type=str, default="./{}".format(datetime.datetime.now())
)
parser.add_argument("-is", "--iterations", type=int, default=1)
parser.add_argument("-cf", "--config_file", type=str, default="config.ini")
parser.add_argument("-ll", "--log_level", type=str, default="INFO")
parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges")
parser.add_argument("-gt", "--graph_type", type=str, default="edges")
args = parser.parse_args()
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
config = read_ini(args.config_file)
my_config = dict()
for section in config:
my_config[section] = dict(config.items(section))
g = Graph()
g.read_graph_from_file("36_nodes.edges", "edges")
g.read_graph_from_file(args.graph_file, args.graph_type)
n_machines = args.machines
procs_per_machine = args.procs_per_machine
l = Linear(n_machines, procs_per_machine)
......@@ -42,5 +55,5 @@ if __name__ == "__main__":
mp.spawn(
fn=Node,
nprocs=procs_per_machine,
args=[m_id, l, g, my_config, 20, "results", logging.INFO],
args=[m_id, l, g, my_config, args.iterations, args.log_dir, args.log_level],
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment