From 9d09cccff134eeddd0d5ea86dda2a9d0b4d91da8 Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Wed, 11 Jan 2023 14:19:20 +0100
Subject: [PATCH] Add Bytes to CSVs, accuracy vs bytes plots

---
 eval/plot.py | 88 +++++++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 83 insertions(+), 5 deletions(-)

diff --git a/eval/plot.py b/eval/plot.py
index a4374d3..a37e760 100644
--- a/eval/plot.py
+++ b/eval/plot.py
@@ -26,9 +26,9 @@ def get_stats(l):
     return mean_dict, stdev_dict, min_dict, max_dict
 
 
-def plot(means, stdevs, mins, maxs, title, label, loc):
+def plot(means, stdevs, mins, maxs, title, label, loc, xlabel="communication rounds"):
     plt.title(title)
-    plt.xlabel("communication rounds")
+    plt.xlabel(xlabel)
     x_axis = np.array(list(means.keys()))
     y_axis = np.array(list(means.values()))
     err = np.array(list(stdevs.values()))
@@ -37,6 +37,13 @@ def plot(means, stdevs, mins, maxs, title, label, loc):
     plt.legend(loc=loc)
 
 
+def replace_dict_key(d_org: dict, d_other: dict):
+    result = {}
+    for x, y in d_org.items():
+        result[d_other[x]] = y
+    return result
+
+
 def plot_results(path, centralized, data_machine="machine0", data_node=0):
     folders = os.listdir(path)
     if centralized.lower() in ["true", "1", "t", "y", "yes"]:
@@ -74,19 +81,54 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
         with open(folder_path / data_machine / f"{data_node}_results.json", "r") as f:
             main_data = json.load(f)
         main_data = [main_data]
+
+        # Plotting bytes over time
+        plt.figure(10)
+        b_means, stdevs, mins, maxs = get_stats([x["total_bytes"] for x in results])
+        plot(b_means, stdevs, mins, maxs, "Total Bytes", folder, "lower right")
+        df = pd.DataFrame(
+            {
+                "mean": list(b_means.values()),
+                "std": list(stdevs.values()),
+                "nr_nodes": [len(results)] * len(b_means),
+            },
+            list(b_means.keys()),
+            columns=["mean", "std", "nr_nodes"],
+        )
+        df.to_csv(
+            os.path.join(path, "total_bytes_" + folder + ".csv"), index_label="rounds"
+        )
+
         # Plot Training loss
         plt.figure(1)
         means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
         plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
+
+        correct_bytes = [b_means[x] for x in means]
+
         df = pd.DataFrame(
             {
                 "mean": list(means.values()),
                 "std": list(stdevs.values()),
                 "nr_nodes": [len(results)] * len(means),
+                "total_bytes": correct_bytes,
             },
             list(means.keys()),
-            columns=["mean", "std", "nr_nodes"],
+            columns=["mean", "std", "nr_nodes", "total_bytes"],
         )
+        plt.figure(11)
+        means = replace_dict_key(means, b_means)
+        plot(
+            means,
+            stdevs,
+            mins,
+            maxs,
+            "Training Loss",
+            folder,
+            "upper right",
+            "Total Bytes per node",
+        )
+
         df.to_csv(
             os.path.join(path, "train_loss_" + folder + ".csv"), index_label="rounds"
         )
@@ -102,10 +144,24 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
                 "mean": list(means.values()),
                 "std": list(stdevs.values()),
                 "nr_nodes": [len(results)] * len(means),
+                "total_bytes": correct_bytes,
             },
             list(means.keys()),
-            columns=["mean", "std", "nr_nodes"],
+            columns=["mean", "std", "nr_nodes", "total_bytes"],
         )
+        plt.figure(12)
+        means = replace_dict_key(means, b_means)
+        plot(
+            means,
+            stdevs,
+            mins,
+            maxs,
+            "Testing Loss",
+            folder,
+            "upper right",
+            "Total Bytes per node",
+        )
+
         df.to_csv(
             os.path.join(path, "test_loss_" + folder + ".csv"), index_label="rounds"
         )
@@ -121,9 +177,22 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
                 "mean": list(means.values()),
                 "std": list(stdevs.values()),
                 "nr_nodes": [len(results)] * len(means),
+                "total_bytes": correct_bytes,
             },
             list(means.keys()),
-            columns=["mean", "std", "nr_nodes"],
+            columns=["mean", "std", "nr_nodes", "total_bytes"],
+        )
+        plt.figure(13)
+        means = replace_dict_key(means, b_means)
+        plot(
+            means,
+            stdevs,
+            mins,
+            maxs,
+            "Testing Accuracy",
+            folder,
+            "lower right",
+            "Total Bytes per node",
         )
         df.to_csv(
             os.path.join(path, "test_acc_" + folder + ".csv"), index_label="rounds"
@@ -157,6 +226,15 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
         data_means[folder] = list(means.values())[0]
         data_stdevs[folder] = list(stdevs.values())[0]
 
+    plt.figure(10)
+    plt.savefig(os.path.join(path, "total_bytes.png"), dpi=300)
+    plt.figure(11)
+    plt.savefig(os.path.join(path, "bytes_train_loss.png"), dpi=300)
+    plt.figure(12)
+    plt.savefig(os.path.join(path, "bytes_test_loss.png"), dpi=300)
+    plt.figure(13)
+    plt.savefig(os.path.join(path, "bytes_test_acc.png"), dpi=300)
+
     plt.figure(1)
     plt.savefig(os.path.join(path, "train_loss.png"), dpi=300)
     plt.figure(2)
-- 
GitLab