Skip to content
Snippets Groups Projects
Commit 3e0be959 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Add grad_mean and grad_std plot

parent 806a5638
No related branches found
No related tags found
No related merge requests found
...@@ -69,6 +69,13 @@ def plot_results(path): ...@@ -69,6 +69,13 @@ def plot_results(path):
plt.figure(3) plt.figure(3)
means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results]) means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right") plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
plt.figure(6)
means, stdevs, mins, maxs = get_stats([x["grad_std"] for x in results])
plot(means, stdevs, mins, maxs, "Gradient Variation over Nodes", folder, "upper right")
# Plot Testing loss
plt.figure(7)
means, stdevs, mins, maxs = get_stats([x["grad_mean"] for x in results])
plot(means, stdevs, mins, maxs, "Gradient Magnitude Mean", folder, "upper right")
# Collect total_bytes shared # Collect total_bytes shared
bytes_list = [] bytes_list = []
for x in results: for x in results:
...@@ -80,8 +87,11 @@ def plot_results(path): ...@@ -80,8 +87,11 @@ def plot_results(path):
meta_list = [] meta_list = []
for x in results: for x in results:
max_key = str(max(list(map(int, x["total_meta"].keys())))) if x["total_meta"]:
meta_list.append({max_key: x["total_meta"][max_key]}) max_key = str(max(list(map(int, x["total_meta"].keys()))))
meta_list.append({max_key: x["total_meta"][max_key]})
else:
meta_list.append({max_key: 0})
means, stdevs, mins, maxs = get_stats(meta_list) means, stdevs, mins, maxs = get_stats(meta_list)
meta_means[folder] = list(means.values())[0] meta_means[folder] = list(means.values())[0]
meta_stdevs[folder] = list(stdevs.values())[0] meta_stdevs[folder] = list(stdevs.values())[0]
...@@ -100,6 +110,10 @@ def plot_results(path): ...@@ -100,6 +110,10 @@ def plot_results(path):
plt.savefig(os.path.join(path, "test_loss.png")) plt.savefig(os.path.join(path, "test_loss.png"))
plt.figure(3) plt.figure(3)
plt.savefig(os.path.join(path, "test_acc.png")) plt.savefig(os.path.join(path, "test_acc.png"))
plt.figure(6)
plt.savefig(os.path.join(path, "grad_std.png"))
plt.figure(7)
plt.savefig(os.path.join(path, "grad_mean.png"))
# Plot total_bytes # Plot total_bytes
plt.figure(4) plt.figure(4)
plt.title("Data Shared") plt.title("Data Shared")
...@@ -116,9 +130,9 @@ def plot_results(path): ...@@ -116,9 +130,9 @@ def plot_results(path):
plt.savefig(os.path.join(path, "data_shared.png")) plt.savefig(os.path.join(path, "data_shared.png"))
# Plot stacked_bytes # Plot stacked_bytes
plt.figure(4) plt.figure(5)
plt.title("Data Shared per Neighbor") plt.title("Data Shared per Neighbor")
x_pos = np.arange(len(meta_means.keys())) x_pos = np.arange(len(bytes_means.keys()))
plt.bar( plt.bar(
x_pos, x_pos,
np.array(list(data_means.values())) // (1024 * 1024), np.array(list(data_means.values())) // (1024 * 1024),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment