From 3c3007a0b9817e2c284726f1f6607c72dc2e5071 Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Wed, 8 Dec 2021 22:42:27 +0100
Subject: [PATCH] Fix plot x_axis

---
 eval/plot.py                    | 106 ++++++++++++++++++++++++++++++++
 src/decentralizepy/node/Node.py |   4 +-
 2 files changed, 108 insertions(+), 2 deletions(-)
 create mode 100644 eval/plot.py

diff --git a/eval/plot.py b/eval/plot.py
new file mode 100644
index 0000000..af04b0e
--- /dev/null
+++ b/eval/plot.py
@@ -0,0 +1,106 @@
+import json
+import os
+import sys
+
+import numpy as np
+from matplotlib import pyplot as plt
+from numpy.core.numeric import indices
+
+
+def get_stats(l):
+    assert len(l) > 0
+    mean_dict, stdev_dict, min_dict, max_dict = {}, {}, {}, {}
+    for key in l[0].keys():
+        all_nodes = [i[key] for i in l]
+        all_nodes = np.array(all_nodes)
+        mean = np.mean(all_nodes)
+        std = np.std(all_nodes)
+        min = np.min(all_nodes)
+        max = np.max(all_nodes)
+        mean_dict[int(key)] = mean
+        stdev_dict[int(key)] = std
+        min_dict[int(key)] = min
+        max_dict[int(key)] = max
+    return mean_dict, stdev_dict, min_dict, max_dict
+
+
+def plot(means, stdevs, mins, maxs, title, label, loc):
+    plt.title(title)
+    plt.xlabel("communication rounds")
+    x_axis = list(means.keys())
+    y_axis = list(means.values())
+    err = list(stdevs.values())
+    plt.errorbar(x_axis, y_axis, yerr=err, label=label)
+    plt.legend(loc=loc)
+
+
+def plot_results(path):
+    folders = os.listdir(path)
+    print("Reading folders from: ", path)
+    print("Folders: ", folders)
+    for folder in folders:
+        folder_path = os.path.join(path, folder)
+        if not os.path.isdir(folder_path):
+            continue
+        results = []
+        files = os.listdir(folder_path)
+        files = [f for f in files if f.endswith("_results.json")]
+        for f in files:
+            filepath = os.path.join(folder_path, f)
+            with open(filepath, "r") as inf:
+                results.append(json.load(inf))
+        plt.figure(1)
+        means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
+        plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
+        plt.figure(2)
+        means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
+        plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
+        plt.figure(3)
+        means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
+        plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
+
+    plt.figure(1)
+    plt.savefig(os.path.join(path, "train_loss.png"))
+    plt.figure(2)
+    plt.savefig(os.path.join(path, "test_loss.png"))
+    plt.figure(3)
+    plt.savefig(os.path.join(path, "test_acc.png"))
+
+
+def plot_parameters(path):
+    plt.figure(4)
+    folders = os.listdir(path)
+    for folder in folders:
+        folder_path = os.path.join(path, folder)
+        if not os.path.isdir(folder_path):
+            continue
+        files = os.listdir(folder_path)
+        files = [f for f in files if f.endswith("_shared_params.json")]
+        for f in files:
+            filepath = os.path.join(folder_path, f)
+            print("Working with ", filepath)
+            with open(filepath, "r") as inf:
+                loaded_dict = json.load(inf)
+                del loaded_dict["order"]
+                del loaded_dict["shapes"]
+            assert len(loaded_dict["0"]) > 0
+            assert "0" in loaded_dict.keys()
+            counts = np.zeros(len(loaded_dict["0"]))
+            for key in loaded_dict.keys():
+                indices = np.array(loaded_dict[key])
+                counts = np.pad(
+                    counts,
+                    max(np.max(indices) - counts.shape[0], 0),
+                    "constant",
+                    constant_values=0,
+                )
+                counts[indices] += 1
+            plt.plot(np.arange(0, counts.shape[0]), counts, ".")
+        print("Saving scatterplot")
+        plt.savefig(os.path.join(folder_path, "shared_params.png"))
+
+
+if __name__ == "__main__":
+    assert len(sys.argv) == 2
+    plot_results(sys.argv[1])
+    plot_parameters(sys.argv[1])
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 49c6abe..d6897b4 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -18,8 +18,8 @@ class Node:
 
     def save_plot(self, l, label, title, xlabel, filename):
         plt.clf()
-        x_axis = l.keys()
-        y_axis = [l[key] for key in x_axis]
+        y_axis = [l[key] for key in l.keys()]
+        x_axis = list(map(int, l.keys()))
         plt.plot(x_axis, y_axis, label=label)
         plt.xlabel(xlabel)
         plt.title(title)
-- 
GitLab