From 3c3007a0b9817e2c284726f1f6607c72dc2e5071 Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Wed, 8 Dec 2021 22:42:27 +0100 Subject: [PATCH] Fix plot x_axis --- eval/plot.py | 106 ++++++++++++++++++++++++++++++++ src/decentralizepy/node/Node.py | 4 +- 2 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 eval/plot.py diff --git a/eval/plot.py b/eval/plot.py new file mode 100644 index 0000000..af04b0e --- /dev/null +++ b/eval/plot.py @@ -0,0 +1,106 @@ +import json +import os +import sys + +import numpy as np +from matplotlib import pyplot as plt +from numpy.core.numeric import indices + + +def get_stats(l): + assert len(l) > 0 + mean_dict, stdev_dict, min_dict, max_dict = {}, {}, {}, {} + for key in l[0].keys(): + all_nodes = [i[key] for i in l] + all_nodes = np.array(all_nodes) + mean = np.mean(all_nodes) + std = np.std(all_nodes) + min = np.min(all_nodes) + max = np.max(all_nodes) + mean_dict[int(key)] = mean + stdev_dict[int(key)] = std + min_dict[int(key)] = min + max_dict[int(key)] = max + return mean_dict, stdev_dict, min_dict, max_dict + + +def plot(means, stdevs, mins, maxs, title, label, loc): + plt.title(title) + plt.xlabel("communication rounds") + x_axis = list(means.keys()) + y_axis = list(means.values()) + err = list(stdevs.values()) + plt.errorbar(x_axis, y_axis, yerr=err, label=label) + plt.legend(loc=loc) + + +def plot_results(path): + folders = os.listdir(path) + print("Reading folders from: ", path) + print("Folders: ", folders) + for folder in folders: + folder_path = os.path.join(path, folder) + if not os.path.isdir(folder_path): + continue + results = [] + files = os.listdir(folder_path) + files = [f for f in files if f.endswith("_results.json")] + for f in files: + filepath = os.path.join(folder_path, f) + with open(filepath, "r") as inf: + results.append(json.load(inf)) + plt.figure(1) + means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results]) + plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right") + plt.figure(2) + means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results]) + plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right") + plt.figure(3) + means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results]) + plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right") + + plt.figure(1) + plt.savefig(os.path.join(path, "train_loss.png")) + plt.figure(2) + plt.savefig(os.path.join(path, "test_loss.png")) + plt.figure(3) + plt.savefig(os.path.join(path, "test_acc.png")) + + +def plot_parameters(path): + plt.figure(4) + folders = os.listdir(path) + for folder in folders: + folder_path = os.path.join(path, folder) + if not os.path.isdir(folder_path): + continue + files = os.listdir(folder_path) + files = [f for f in files if f.endswith("_shared_params.json")] + for f in files: + filepath = os.path.join(folder_path, f) + print("Working with ", filepath) + with open(filepath, "r") as inf: + loaded_dict = json.load(inf) + del loaded_dict["order"] + del loaded_dict["shapes"] + assert len(loaded_dict["0"]) > 0 + assert "0" in loaded_dict.keys() + counts = np.zeros(len(loaded_dict["0"])) + for key in loaded_dict.keys(): + indices = np.array(loaded_dict[key]) + counts = np.pad( + counts, + max(np.max(indices) - counts.shape[0], 0), + "constant", + constant_values=0, + ) + counts[indices] += 1 + plt.plot(np.arange(0, counts.shape[0]), counts, ".") + print("Saving scatterplot") + plt.savefig(os.path.join(folder_path, "shared_params.png")) + + +if __name__ == "__main__": + assert len(sys.argv) == 2 + plot_results(sys.argv[1]) + plot_parameters(sys.argv[1]) diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index 49c6abe..d6897b4 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -18,8 +18,8 @@ class Node: def save_plot(self, l, label, title, xlabel, filename): plt.clf() - x_axis = l.keys() - y_axis = [l[key] for key in x_axis] + y_axis = [l[key] for key in l.keys()] + x_axis = list(map(int, l.keys())) plt.plot(x_axis, y_axis, label=label) plt.xlabel(xlabel) plt.title(title) -- GitLab