diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index 4a6440484a6c925624a1289730be388a745d724e..49c6abe43b8beab005da11b5bdc8aca75d0cbca1 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -173,9 +173,6 @@ class Node: self.testset = self.dataset.get_testset() rounds_to_test = test_after - self.train_loss = dict() - self.test_loss = dict() - self.test_acc = dict() for iteration in range(iterations): logging.info("Starting training iteration: %d", iteration) @@ -188,7 +185,25 @@ class Node: self.trainer.reset_optimizer(self.optimizer) loss_after_sharing = self.trainer.eval_loss(self.dataset) - self.train_loss[iteration + 1] = loss_after_sharing + + if iteration: + with open( + os.path.join(self.log_dir, "{}_results.json".format(self.rank)), + "r", + ) as inf: + results_dict = json.load(inf) + else: + results_dict = {"train_loss": {}, "test_loss": {}, "test_acc": {}} + + results_dict["train_loss"][iteration + 1] = loss_after_sharing + + self.save_plot( + results_dict["train_loss"], + "train_loss", + "Training Loss", + "Communication Rounds", + os.path.join(log_dir, "{}_train_loss.png".format(self.rank)), + ) rounds_to_test -= 1 @@ -196,42 +211,27 @@ class Node: logging.info("Evaluating on test set.") rounds_to_test = test_after ta, tl = self.dataset.test(self.model, self.loss) - self.test_acc[iteration + 1] = ta - self.test_loss[iteration + 1] = tl + results_dict["test_acc"][iteration + 1] = ta + results_dict["test_loss"][iteration + 1] = tl self.save_plot( - self.train_loss, - "train_loss", - "Training Loss", - "Communication Rounds", - os.path.join(log_dir, "{}_train_loss.png".format(self.rank)), - ) - self.save_plot( - self.test_loss, + results_dict["test_loss"], "test_loss", "Testing Loss", "Communication Rounds", os.path.join(log_dir, "{}_test_loss.png".format(self.rank)), ) self.save_plot( - self.test_acc, + results_dict["test_acc"], "test_acc", "Testing Accuracy", "Communication Rounds", os.path.join(log_dir, "{}_test_acc.png".format(self.rank)), ) - with open( - os.path.join(log_dir, "{}_train_loss.json".format(self.rank)), "w" - ) as of: - json.dump(self.train_loss, of) - with open( - os.path.join(log_dir, "{}_test_loss.json".format(self.rank)), "w" - ) as of: - json.dump(self.test_loss, of) - with open( - os.path.join(log_dir, "{}_test_acc.json".format(self.rank)), "w" - ) as of: - json.dump(self.test_acc, of) + with open( + os.path.join(log_dir, "{}_results.json".format(self.rank)), "w" + ) as of: + json.dump(results_dict, of) self.communication.disconnect_neighbors()