Skip to content
Snippets Groups Projects
Commit 4f14c17f authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

global epoch plotting and option for non centralized plotting

parent d8b2fe11
No related branches found
No related tags found
No related merge requests found
...@@ -36,8 +36,14 @@ def plot(means, stdevs, mins, maxs, title, label, loc): ...@@ -36,8 +36,14 @@ def plot(means, stdevs, mins, maxs, title, label, loc):
plt.legend(loc=loc) plt.legend(loc=loc)
def plot_results(path, data_machine="machine0", data_node=0): def plot_results(path, centralized, data_machine="machine0", data_node=0):
folders = os.listdir(path) folders = os.listdir(path)
if centralized.lower() in ['true', '1', 't', 'y', 'yes']:
centralized = True
print("Centralized")
else:
centralized = False
folders.sort() folders.sort()
print("Reading folders from: ", path) print("Reading folders from: ", path)
print("Folders: ", folders) print("Folders: ", folders)
...@@ -82,7 +88,10 @@ def plot_results(path, data_machine="machine0", data_node=0): ...@@ -82,7 +88,10 @@ def plot_results(path, data_machine="machine0", data_node=0):
) )
# Plot Testing loss # Plot Testing loss
plt.figure(2) plt.figure(2)
means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in main_data]) if centralized:
means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in main_data])
else:
means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right") plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
df = pd.DataFrame( df = pd.DataFrame(
{ {
...@@ -98,7 +107,10 @@ def plot_results(path, data_machine="machine0", data_node=0): ...@@ -98,7 +107,10 @@ def plot_results(path, data_machine="machine0", data_node=0):
) )
# Plot Testing Accuracy # Plot Testing Accuracy
plt.figure(3) plt.figure(3)
means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in main_data]) if centralized:
means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in main_data])
else:
means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right") plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
df = pd.DataFrame( df = pd.DataFrame(
{ {
...@@ -241,6 +253,9 @@ def plot_parameters(path): ...@@ -241,6 +253,9 @@ def plot_parameters(path):
if __name__ == "__main__": if __name__ == "__main__":
assert len(sys.argv) == 2 assert len(sys.argv) == 3
plot_results(sys.argv[1]) # The args are:
# 1: the folder with the data
# 2: True/False: If True then the evaluation on the test set was centralized
plot_results(sys.argv[1], sys.argv[2])
# plot_parameters(sys.argv[1]) # plot_parameters(sys.argv[1])
import distutils
import json
import os
import sys
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel):
cmap = plt.get_cmap("gist_rainbow")
plt.title(title)
plt.xlabel(xlabel)
y_axis = list(means)
err = list(stdevs)
print("label:", label)
print("color: ", cmap(1 / nb_plots * pos))
plt.errorbar(
list(x_axis), y_axis, yerr=err, label=label, color=cmap(1 / nb_plots * pos)
)
plt.legend(loc=loc)
def plot_results(path, epochs, global_epochs="True"):
if global_epochs.lower() in ['true', '1', 't', 'y', 'yes']:
global_epochs = True
else:
global_epochs = False
epochs = int(epochs)
# rounds = int(rounds)
folders = os.listdir(path)
folders.sort()
print("Reading folders from: ", path)
print("Folders: ", folders)
bytes_means, bytes_stdevs = {}, {}
meta_means, meta_stdevs = {}, {}
data_means, data_stdevs = {}, {}
files = os.listdir(path)
files = [f for f in files if f.endswith(".csv")]
train_loss = sorted([f for f in files if f.startswith("train_loss")])
test_acc = sorted([f for f in files if f.startswith("test_acc")])
test_loss = sorted([f for f in files if f.startswith("test_loss")])
min_losses = []
for i, f in enumerate(train_loss):
filepath = os.path.join(path, f)
with open(filepath, "r") as inf:
results_csv = pd.read_csv(inf)
# Plot Training loss
plt.figure(1)
if global_epochs:
rounds = results_csv["rounds"].iloc[0]
print("Rounds: ", rounds)
results_cr = results_csv[results_csv.rounds <= epochs*rounds]
means = results_cr["mean"].to_numpy()
stdevs = results_cr["std"].to_numpy()
x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
x_label = "global epochs"
else:
results_cr = results_csv[results_csv.rounds <= epochs]
means = results_cr["mean"].to_numpy()
stdevs = results_cr["std"].to_numpy()
x_axis = results_cr["rounds"].to_numpy()
x_label = "communication rounds"
min_losses.append(np.min(means))
plot(
x_axis,
means,
stdevs,
i,
len(train_loss),
"Training Loss",
f[len("train_loss") + 1 : -len(":2022-03-24T17:54.csv")],
"upper right",
x_label,
)
min_tlosses = []
for i, f in enumerate(test_loss):
filepath = os.path.join(path, f)
with open(filepath, "r") as inf:
results_csv = pd.read_csv(inf)
if global_epochs:
rounds = results_csv["rounds"].iloc[0]
print("Rounds: ", rounds)
results_cr = results_csv[results_csv.rounds <= epochs*rounds]
means = results_cr["mean"].to_numpy()
stdevs = results_cr["std"].to_numpy()
x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
x_label = "global epochs"
else:
results_cr = results_csv[results_csv.rounds <= epochs]
means = results_cr["mean"].to_numpy()
stdevs = results_cr["std"].to_numpy()
x_axis = results_cr["rounds"].to_numpy()
x_label = "communication rounds"
print("x axis:", x_axis)
min_tlosses.append(np.min(means))
# Plot Testing loss
plt.figure(2)
plot(
x_axis,
means,
stdevs,
i,
len(test_loss),
"Testing Loss",
f[len("test_loss") + 1 : -len(":2022-03-24T17:54.csv")],
"upper right",
x_label,
)
max_taccs = []
for i, f in enumerate(test_acc):
filepath = os.path.join(path, f)
with open(filepath, "r") as inf:
results_csv = pd.read_csv(inf)
if global_epochs:
rounds = results_csv["rounds"].iloc[0]
print("Rounds: ", rounds)
results_cr = results_csv[results_csv.rounds <= epochs*rounds]
means = results_cr["mean"].to_numpy()
stdevs = results_cr["std"].to_numpy()
x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
x_label = "global epochs"
else:
results_cr = results_csv[results_csv.rounds <= epochs]
means = results_cr["mean"].to_numpy()
stdevs = results_cr["std"].to_numpy()
x_axis = results_cr["rounds"].to_numpy()
x_label = "communication rounds"
max_taccs.append(np.max(means))
# Plot Testing Accuracy
plt.figure(3)
plot(
x_axis,
means,
stdevs,
i,
len(test_acc),
"Testing Accuracy",
f[len("test_acc") + 1 : -len(":2022-03-24T17:54.csv")],
"lower right",
x_label,
)
names_loss = [
f[len("train_loss") + 1 : -len(":2022-03-24T17:54.csv")] for f in train_loss
]
names_acc = [
f[len("test_acc") + 1 : -len(":2022-03-24T17:54.csv")] for f in test_acc
]
print(names_loss)
print(names_acc)
pf = pd.DataFrame(
{
"test_accuracy": max_taccs,
"test_losses": min_tlosses,
"train_losses": min_losses,
},
names_loss,
)
pf = pf.sort_values(["test_accuracy"], 0, ascending=False)
pf.to_csv(os.path.join(path, "best_results.csv"))
plt.figure(1)
plt.savefig(os.path.join(path, "ge_train_loss.png"), dpi=300)
plt.figure(2)
plt.savefig(os.path.join(path, "ge_test_loss.png"), dpi=300)
plt.figure(3)
plt.savefig(os.path.join(path, "ge_test_acc.png"), dpi=300)
if __name__ == "__main__":
assert len(sys.argv) == 4
# The args are:
# 1: the folder with the csv files,
# 2: the number of epochs / comm rounds to plot for,
# 3: True/False with True meaning plot global epochs and False plot communication rounds
print(sys.argv[1], sys.argv[2], sys.argv[3])
plot_results(sys.argv[1], sys.argv[2], sys.argv[3])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment