Skip to content
Snippets Groups Projects
Commit 8be9d956 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

DEBUG:Divide by zero

parent 0bc45f9b
No related branches found
No related tags found
No related merge requests found
...@@ -16,7 +16,7 @@ lr = 0.01 ...@@ -16,7 +16,7 @@ lr = 0.01
[TRAIN_PARAMS] [TRAIN_PARAMS]
training_package = decentralizepy.training.Training training_package = decentralizepy.training.Training
training_class = Training training_class = Training
epochs_per_round = 4 epochs_per_round = 1
batch_size = 1024 batch_size = 1024
shuffle = True shuffle = True
loss_package = torch.nn loss_package = torch.nn
......
...@@ -272,7 +272,7 @@ class Femnist(Dataset): ...@@ -272,7 +272,7 @@ class Femnist(Dataset):
total_pred[label] += 1 total_pred[label] += 1
total_predicted += 1 total_predicted += 1
logging.debug("Predicted on the test set") logging.info("Predicted on the test set")
for key, value in enumerate(correct_pred): for key, value in enumerate(correct_pred):
if total_pred[key] != 0: if total_pred[key] != 0:
...@@ -283,7 +283,7 @@ class Femnist(Dataset): ...@@ -283,7 +283,7 @@ class Femnist(Dataset):
accuracy = 100 * float(total_correct) / total_predicted accuracy = 100 * float(total_correct) / total_predicted
logging.info("Overall accuracy is: {:.1f} %".format(accuracy)) logging.info("Overall accuracy is: {:.1f} %".format(accuracy))
logging.debug("Evaluating complete.") logging.info("Evaluating complete.")
class LogisticRegression(nn.Module): class LogisticRegression(nn.Module):
......
...@@ -159,7 +159,9 @@ class Node: ...@@ -159,7 +159,9 @@ class Node:
self.trainer.train(self.dataset) self.trainer.train(self.dataset)
self.sharing.step() self.sharing.step()
self.optimizer = optimizer_class(self.model.parameters(), **optimizer_params) # Reset optimizer state self.optimizer = optimizer_class(
self.model.parameters(), **optimizer_params
) # Reset optimizer state
self.trainer.reset_optimizer(self.optimizer) self.trainer.reset_optimizer(self.optimizer)
rounds_to_test -= 1 rounds_to_test -= 1
......
...@@ -45,7 +45,6 @@ class Training: ...@@ -45,7 +45,6 @@ class Training:
def reset_optimizer(self, optimizer): def reset_optimizer(self, optimizer):
self.optimizer = optimizer self.optimizer = optimizer
def train(self, dataset): def train(self, dataset):
""" """
......
...@@ -42,5 +42,5 @@ if __name__ == "__main__": ...@@ -42,5 +42,5 @@ if __name__ == "__main__":
mp.spawn( mp.spawn(
fn=Node, fn=Node,
nprocs=procs_per_machine, nprocs=procs_per_machine,
args=[m_id, l, g, my_config, 20, "results", logging.INFO], args=[m_id, l, g, my_config, 20, "results", logging.DEBUG],
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment