From 9d09cccff134eeddd0d5ea86dda2a9d0b4d91da8 Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Wed, 11 Jan 2023 14:19:20 +0100 Subject: [PATCH] Add Bytes to CSVs, accuracy vs bytes plots --- eval/plot.py | 88 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 83 insertions(+), 5 deletions(-) diff --git a/eval/plot.py b/eval/plot.py index a4374d3..a37e760 100644 --- a/eval/plot.py +++ b/eval/plot.py @@ -26,9 +26,9 @@ def get_stats(l): return mean_dict, stdev_dict, min_dict, max_dict -def plot(means, stdevs, mins, maxs, title, label, loc): +def plot(means, stdevs, mins, maxs, title, label, loc, xlabel="communication rounds"): plt.title(title) - plt.xlabel("communication rounds") + plt.xlabel(xlabel) x_axis = np.array(list(means.keys())) y_axis = np.array(list(means.values())) err = np.array(list(stdevs.values())) @@ -37,6 +37,13 @@ def plot(means, stdevs, mins, maxs, title, label, loc): plt.legend(loc=loc) +def replace_dict_key(d_org: dict, d_other: dict): + result = {} + for x, y in d_org.items(): + result[d_other[x]] = y + return result + + def plot_results(path, centralized, data_machine="machine0", data_node=0): folders = os.listdir(path) if centralized.lower() in ["true", "1", "t", "y", "yes"]: @@ -74,19 +81,54 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0): with open(folder_path / data_machine / f"{data_node}_results.json", "r") as f: main_data = json.load(f) main_data = [main_data] + + # Plotting bytes over time + plt.figure(10) + b_means, stdevs, mins, maxs = get_stats([x["total_bytes"] for x in results]) + plot(b_means, stdevs, mins, maxs, "Total Bytes", folder, "lower right") + df = pd.DataFrame( + { + "mean": list(b_means.values()), + "std": list(stdevs.values()), + "nr_nodes": [len(results)] * len(b_means), + }, + list(b_means.keys()), + columns=["mean", "std", "nr_nodes"], + ) + df.to_csv( + os.path.join(path, "total_bytes_" + folder + ".csv"), index_label="rounds" + ) + # Plot Training loss plt.figure(1) means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results]) plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right") + + correct_bytes = [b_means[x] for x in means] + df = pd.DataFrame( { "mean": list(means.values()), "std": list(stdevs.values()), "nr_nodes": [len(results)] * len(means), + "total_bytes": correct_bytes, }, list(means.keys()), - columns=["mean", "std", "nr_nodes"], + columns=["mean", "std", "nr_nodes", "total_bytes"], ) + plt.figure(11) + means = replace_dict_key(means, b_means) + plot( + means, + stdevs, + mins, + maxs, + "Training Loss", + folder, + "upper right", + "Total Bytes per node", + ) + df.to_csv( os.path.join(path, "train_loss_" + folder + ".csv"), index_label="rounds" ) @@ -102,10 +144,24 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0): "mean": list(means.values()), "std": list(stdevs.values()), "nr_nodes": [len(results)] * len(means), + "total_bytes": correct_bytes, }, list(means.keys()), - columns=["mean", "std", "nr_nodes"], + columns=["mean", "std", "nr_nodes", "total_bytes"], ) + plt.figure(12) + means = replace_dict_key(means, b_means) + plot( + means, + stdevs, + mins, + maxs, + "Testing Loss", + folder, + "upper right", + "Total Bytes per node", + ) + df.to_csv( os.path.join(path, "test_loss_" + folder + ".csv"), index_label="rounds" ) @@ -121,9 +177,22 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0): "mean": list(means.values()), "std": list(stdevs.values()), "nr_nodes": [len(results)] * len(means), + "total_bytes": correct_bytes, }, list(means.keys()), - columns=["mean", "std", "nr_nodes"], + columns=["mean", "std", "nr_nodes", "total_bytes"], + ) + plt.figure(13) + means = replace_dict_key(means, b_means) + plot( + means, + stdevs, + mins, + maxs, + "Testing Accuracy", + folder, + "lower right", + "Total Bytes per node", ) df.to_csv( os.path.join(path, "test_acc_" + folder + ".csv"), index_label="rounds" @@ -157,6 +226,15 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0): data_means[folder] = list(means.values())[0] data_stdevs[folder] = list(stdevs.values())[0] + plt.figure(10) + plt.savefig(os.path.join(path, "total_bytes.png"), dpi=300) + plt.figure(11) + plt.savefig(os.path.join(path, "bytes_train_loss.png"), dpi=300) + plt.figure(12) + plt.savefig(os.path.join(path, "bytes_test_loss.png"), dpi=300) + plt.figure(13) + plt.savefig(os.path.join(path, "bytes_test_acc.png"), dpi=300) + plt.figure(1) plt.savefig(os.path.join(path, "train_loss.png"), dpi=300) plt.figure(2) -- GitLab