diff --git a/config.ini b/config.ini index b0e596f0141cb39cd480bf420a49a40619d29795..c97844e1b92fe50a2c5f8b6aabb2274f36b9f257 100644 --- a/config.ini +++ b/config.ini @@ -16,7 +16,7 @@ lr = 0.01 [TRAIN_PARAMS] training_package = decentralizepy.training.Training training_class = Training -epochs_per_round = 10 +epochs_per_round = 4 batch_size = 1024 shuffle = True loss_package = torch.nn diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py index ccb469c0332c10602029edbf079ba8088d14ce61..8cc09f9687176c6f10a311bc559489b57666efee 100644 --- a/src/decentralizepy/datasets/Femnist.py +++ b/src/decentralizepy/datasets/Femnist.py @@ -275,7 +275,10 @@ class Femnist(Dataset): logging.debug("Predicted on the test set") for key, value in enumerate(correct_pred): - accuracy = 100 * float(value) / total_pred[key] + if total_pred[key] != 0: + accuracy = 100 * float(value) / total_pred[key] + else: + accuracy = 100.0 logging.debug("Accuracy for class {} is: {:.1f} %".format(key, accuracy)) accuracy = 100 * float(total_correct) / total_predicted diff --git a/testing.py b/testing.py index 5a50a3b21450c45dcbf70347525f414604752d01..0d27c9b7071bea6e0620d836ee3f78ec12a9de2f 100644 --- a/testing.py +++ b/testing.py @@ -42,5 +42,5 @@ if __name__ == "__main__": mp.spawn( fn=Node, nprocs=procs_per_machine, - args=[m_id, l, g, my_config, 20, "results", logging.DEBUG], + args=[m_id, l, g, my_config, 20, "results", logging.INFO], )