Skip to content
Snippets Groups Projects
Commit 3c3007a0 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Fix plot x_axis

parent 40ccc424
No related branches found
No related tags found
No related merge requests found
import json
import os
import sys
import numpy as np
from matplotlib import pyplot as plt
from numpy.core.numeric import indices
def get_stats(l):
assert len(l) > 0
mean_dict, stdev_dict, min_dict, max_dict = {}, {}, {}, {}
for key in l[0].keys():
all_nodes = [i[key] for i in l]
all_nodes = np.array(all_nodes)
mean = np.mean(all_nodes)
std = np.std(all_nodes)
min = np.min(all_nodes)
max = np.max(all_nodes)
mean_dict[int(key)] = mean
stdev_dict[int(key)] = std
min_dict[int(key)] = min
max_dict[int(key)] = max
return mean_dict, stdev_dict, min_dict, max_dict
def plot(means, stdevs, mins, maxs, title, label, loc):
plt.title(title)
plt.xlabel("communication rounds")
x_axis = list(means.keys())
y_axis = list(means.values())
err = list(stdevs.values())
plt.errorbar(x_axis, y_axis, yerr=err, label=label)
plt.legend(loc=loc)
def plot_results(path):
folders = os.listdir(path)
print("Reading folders from: ", path)
print("Folders: ", folders)
for folder in folders:
folder_path = os.path.join(path, folder)
if not os.path.isdir(folder_path):
continue
results = []
files = os.listdir(folder_path)
files = [f for f in files if f.endswith("_results.json")]
for f in files:
filepath = os.path.join(folder_path, f)
with open(filepath, "r") as inf:
results.append(json.load(inf))
plt.figure(1)
means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
plt.figure(2)
means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
plt.figure(3)
means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
plt.figure(1)
plt.savefig(os.path.join(path, "train_loss.png"))
plt.figure(2)
plt.savefig(os.path.join(path, "test_loss.png"))
plt.figure(3)
plt.savefig(os.path.join(path, "test_acc.png"))
def plot_parameters(path):
plt.figure(4)
folders = os.listdir(path)
for folder in folders:
folder_path = os.path.join(path, folder)
if not os.path.isdir(folder_path):
continue
files = os.listdir(folder_path)
files = [f for f in files if f.endswith("_shared_params.json")]
for f in files:
filepath = os.path.join(folder_path, f)
print("Working with ", filepath)
with open(filepath, "r") as inf:
loaded_dict = json.load(inf)
del loaded_dict["order"]
del loaded_dict["shapes"]
assert len(loaded_dict["0"]) > 0
assert "0" in loaded_dict.keys()
counts = np.zeros(len(loaded_dict["0"]))
for key in loaded_dict.keys():
indices = np.array(loaded_dict[key])
counts = np.pad(
counts,
max(np.max(indices) - counts.shape[0], 0),
"constant",
constant_values=0,
)
counts[indices] += 1
plt.plot(np.arange(0, counts.shape[0]), counts, ".")
print("Saving scatterplot")
plt.savefig(os.path.join(folder_path, "shared_params.png"))
if __name__ == "__main__":
assert len(sys.argv) == 2
plot_results(sys.argv[1])
plot_parameters(sys.argv[1])
......@@ -18,8 +18,8 @@ class Node:
def save_plot(self, l, label, title, xlabel, filename):
plt.clf()
x_axis = l.keys()
y_axis = [l[key] for key in x_axis]
y_axis = [l[key] for key in l.keys()]
x_axis = list(map(int, l.keys()))
plt.plot(x_axis, y_axis, label=label)
plt.xlabel(xlabel)
plt.title(title)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment