Skip to content
Snippets Groups Projects
Commit 89dfaa00 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Fix disconnect

parent d4c2b6c6
No related branches found
No related tags found
1 merge request!15Refactor and add federated + parameter server + central peer sampling
......@@ -29,10 +29,11 @@ def get_stats(l):
def plot(means, stdevs, mins, maxs, title, label, loc):
plt.title(title)
plt.xlabel("communication rounds")
x_axis = list(means.keys())
y_axis = list(means.values())
err = list(stdevs.values())
plt.errorbar(x_axis, y_axis, yerr=err, label=label)
x_axis = np.array(list(means.keys()))
y_axis = np.array(list(means.values()))
err = np.array(list(stdevs.values()))
plt.plot(x_axis, y_axis, label=label)
plt.fill_between(x_axis, y_axis - err, y_axis + err, alpha=0.4)
plt.legend(loc=loc)
......
......@@ -154,7 +154,7 @@ class DPSGDWithPeerSampler(DPSGDNode):
"""
if not self.sent_disconnections:
logging.info("Disconnecting neighbors")
for uid in self.my_neighbors:
for uid in self.barrier:
self.communication.send(uid, {"BYE": self.uid, "CHANNEL": "DISCONNECT"})
self.communication.send(
self.peer_sampler_uid, {"BYE": self.uid, "CHANNEL": "SERVER_REQUEST"}
......
......@@ -108,7 +108,7 @@ class Node:
"""
if not self.sent_disconnections:
logging.info("Disconnecting neighbors")
for uid in self.my_neighbors:
for uid in self.barrier:
self.communication.send(uid, {"BYE": self.uid, "CHANNEL": "DISCONNECT"})
self.sent_disconnections = True
while len(self.barrier):
......
......@@ -258,3 +258,5 @@ class PeerSampler(Node):
)
self.run()
logging.info("Peer Sampler exiting")
......@@ -96,3 +96,5 @@ class PeerSamplerDynamic(PeerSampler):
)
self.run()
logging.info("Peer Sampler exiting")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment