From 565957aec90d333d9afa6741142b2b938b84b57f Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Tue, 23 Nov 2021 14:12:56 +0100 Subject: [PATCH] Multi-machine-fix --- config.ini | 2 +- src/decentralizepy/datasets/Femnist.py | 5 ++++- testing.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/config.ini b/config.ini index b0e596f..c97844e 100644 --- a/config.ini +++ b/config.ini @@ -16,7 +16,7 @@ lr = 0.01 [TRAIN_PARAMS] training_package = decentralizepy.training.Training training_class = Training -epochs_per_round = 10 +epochs_per_round = 4 batch_size = 1024 shuffle = True loss_package = torch.nn diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py index ccb469c..8cc09f9 100644 --- a/src/decentralizepy/datasets/Femnist.py +++ b/src/decentralizepy/datasets/Femnist.py @@ -275,7 +275,10 @@ class Femnist(Dataset): logging.debug("Predicted on the test set") for key, value in enumerate(correct_pred): - accuracy = 100 * float(value) / total_pred[key] + if total_pred[key] != 0: + accuracy = 100 * float(value) / total_pred[key] + else: + accuracy = 100.0 logging.debug("Accuracy for class {} is: {:.1f} %".format(key, accuracy)) accuracy = 100 * float(total_correct) / total_predicted diff --git a/testing.py b/testing.py index 5a50a3b..0d27c9b 100644 --- a/testing.py +++ b/testing.py @@ -42,5 +42,5 @@ if __name__ == "__main__": mp.spawn( fn=Node, nprocs=procs_per_machine, - args=[m_id, l, g, my_config, 20, "results", logging.DEBUG], + args=[m_id, l, g, my_config, 20, "results", logging.INFO], ) -- GitLab