From 3e0be95958fc0a2a2cec0b390af155672adf5aca Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Thu, 10 Feb 2022 11:49:07 +0100 Subject: [PATCH] Add grad_mean and grad_std plot --- eval/plot.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/eval/plot.py b/eval/plot.py index fa8be52..d15db46 100644 --- a/eval/plot.py +++ b/eval/plot.py @@ -69,6 +69,13 @@ def plot_results(path): plt.figure(3) means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results]) plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right") + plt.figure(6) + means, stdevs, mins, maxs = get_stats([x["grad_std"] for x in results]) + plot(means, stdevs, mins, maxs, "Gradient Variation over Nodes", folder, "upper right") + # Plot Testing loss + plt.figure(7) + means, stdevs, mins, maxs = get_stats([x["grad_mean"] for x in results]) + plot(means, stdevs, mins, maxs, "Gradient Magnitude Mean", folder, "upper right") # Collect total_bytes shared bytes_list = [] for x in results: @@ -80,8 +87,11 @@ def plot_results(path): meta_list = [] for x in results: - max_key = str(max(list(map(int, x["total_meta"].keys())))) - meta_list.append({max_key: x["total_meta"][max_key]}) + if x["total_meta"]: + max_key = str(max(list(map(int, x["total_meta"].keys())))) + meta_list.append({max_key: x["total_meta"][max_key]}) + else: + meta_list.append({max_key: 0}) means, stdevs, mins, maxs = get_stats(meta_list) meta_means[folder] = list(means.values())[0] meta_stdevs[folder] = list(stdevs.values())[0] @@ -100,6 +110,10 @@ def plot_results(path): plt.savefig(os.path.join(path, "test_loss.png")) plt.figure(3) plt.savefig(os.path.join(path, "test_acc.png")) + plt.figure(6) + plt.savefig(os.path.join(path, "grad_std.png")) + plt.figure(7) + plt.savefig(os.path.join(path, "grad_mean.png")) # Plot total_bytes plt.figure(4) plt.title("Data Shared") @@ -116,9 +130,9 @@ def plot_results(path): plt.savefig(os.path.join(path, "data_shared.png")) # Plot stacked_bytes - plt.figure(4) + plt.figure(5) plt.title("Data Shared per Neighbor") - x_pos = np.arange(len(meta_means.keys())) + x_pos = np.arange(len(bytes_means.keys())) plt.bar( x_pos, np.array(list(data_means.values())) // (1024 * 1024), -- GitLab