From 3e0be95958fc0a2a2cec0b390af155672adf5aca Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Thu, 10 Feb 2022 11:49:07 +0100
Subject: [PATCH] Add grad_mean and grad_std plot

---
 eval/plot.py | 22 ++++++++++++++++++----
 1 file changed, 18 insertions(+), 4 deletions(-)

diff --git a/eval/plot.py b/eval/plot.py
index fa8be52..d15db46 100644
--- a/eval/plot.py
+++ b/eval/plot.py
@@ -69,6 +69,13 @@ def plot_results(path):
         plt.figure(3)
         means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
         plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
+        plt.figure(6)
+        means, stdevs, mins, maxs = get_stats([x["grad_std"] for x in results])
+        plot(means, stdevs, mins, maxs, "Gradient Variation over Nodes", folder, "upper right")
+        # Plot Testing loss
+        plt.figure(7)
+        means, stdevs, mins, maxs = get_stats([x["grad_mean"] for x in results])
+        plot(means, stdevs, mins, maxs, "Gradient Magnitude Mean", folder, "upper right")
         # Collect total_bytes shared
         bytes_list = []
         for x in results:
@@ -80,8 +87,11 @@ def plot_results(path):
 
         meta_list = []
         for x in results:
-            max_key = str(max(list(map(int, x["total_meta"].keys()))))
-            meta_list.append({max_key: x["total_meta"][max_key]})
+            if x["total_meta"]:
+                max_key = str(max(list(map(int, x["total_meta"].keys()))))
+                meta_list.append({max_key: x["total_meta"][max_key]})
+            else:
+                meta_list.append({max_key: 0})
         means, stdevs, mins, maxs = get_stats(meta_list)
         meta_means[folder] = list(means.values())[0]
         meta_stdevs[folder] = list(stdevs.values())[0]
@@ -100,6 +110,10 @@ def plot_results(path):
     plt.savefig(os.path.join(path, "test_loss.png"))
     plt.figure(3)
     plt.savefig(os.path.join(path, "test_acc.png"))
+    plt.figure(6)
+    plt.savefig(os.path.join(path, "grad_std.png"))
+    plt.figure(7)
+    plt.savefig(os.path.join(path, "grad_mean.png"))
     # Plot total_bytes
     plt.figure(4)
     plt.title("Data Shared")
@@ -116,9 +130,9 @@ def plot_results(path):
     plt.savefig(os.path.join(path, "data_shared.png"))
 
     # Plot stacked_bytes
-    plt.figure(4)
+    plt.figure(5)
     plt.title("Data Shared per Neighbor")
-    x_pos = np.arange(len(meta_means.keys()))
+    x_pos = np.arange(len(bytes_means.keys()))
     plt.bar(
         x_pos,
         np.array(list(data_means.values())) // (1024 * 1024),
-- 
GitLab