Skip to content
Snippets Groups Projects
Commit 1effbc35 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Partial Model multi-machine

parent 18c957e9
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
dataset_package = decentralizepy.datasets.Femnist
dataset_class = Femnist
model_class = CNN
n_procs = 4
n_procs = 16
train_dir = /home/risharma/Gitlab/decentralizepy/leaf/data/femnist/per_user_data/train
test_dir = /home/risharma/Gitlab/decentralizepy/leaf/data/femnist/data/test
; python list of fractions below
......@@ -16,7 +16,7 @@ lr = 0.01
[TRAIN_PARAMS]
training_package = decentralizepy.training.GradientAccumulator
training_class = GradientAccumulator
epochs_per_round = 2
epochs_per_round = 3
batch_size = 1024
shuffle = True
loss_package = torch.nn
......@@ -30,4 +30,4 @@ addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.PartialModel
sharing_class = PartialModel
alpha = 0.95
\ No newline at end of file
alpha = 1.0
\ No newline at end of file
......@@ -68,16 +68,16 @@ class TCP(Communication):
sender, recv = self.router.recv_multipart()
if recv == HELLO:
logging.info("Recieved {} from {}".format(HELLO, sender))
logging.info("Received {} from {}".format(HELLO, sender))
self.barrier.add(sender)
elif recv == BYE:
logging.info("Recieved {} from {}".format(BYE, sender))
logging.info("Received {} from {}".format(BYE, sender))
raise RuntimeError(
"A neighbour wants to disconnect before training started!"
)
else:
logging.debug(
"Recieved message from {} @ connect_neighbors".format(sender)
"Received message from {} @ connect_neighbors".format(sender)
)
self.peer_deque.append(self.decrypt(sender, recv))
......@@ -91,16 +91,16 @@ class TCP(Communication):
sender, recv = self.router.recv_multipart()
if recv == HELLO:
logging.info("Recieved {} from {}".format(HELLO, sender))
logging.info("Received {} from {}".format(HELLO, sender))
raise RuntimeError(
"A neighbour wants to connect when everyone is connected!"
)
elif recv == BYE:
logging.info("Recieved {} from {}".format(BYE, sender))
logging.info("Received {} from {}".format(BYE, sender))
self.barrier.remove(sender)
return self.receive()
else:
logging.debug("Recieved message from {}".format(sender))
logging.debug("Received message from {}".format(sender))
return self.decrypt(sender, recv)
def send(self, uid, data):
......@@ -117,11 +117,11 @@ class TCP(Communication):
while len(self.barrier):
sender, recv = self.router.recv_multipart()
if recv == BYE:
logging.info("Recieved {} from {}".format(BYE, sender))
logging.info("Received {} from {}".format(BYE, sender))
self.barrier.remove(sender)
else:
logging.critical(
"Recieved unexpected {} from {}".format(recv, sender)
"Received unexpected {} from {}".format(recv, sender)
)
raise RuntimeError(
"Received a message when expecting BYE from {}".format(sender)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment