Skip to content
Snippets Groups Projects
Commit 9d09cccf authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Add Bytes to CSVs, accuracy vs bytes plots

parent 10a6c713
No related branches found
No related tags found
No related merge requests found
......@@ -26,9 +26,9 @@ def get_stats(l):
return mean_dict, stdev_dict, min_dict, max_dict
def plot(means, stdevs, mins, maxs, title, label, loc):
def plot(means, stdevs, mins, maxs, title, label, loc, xlabel="communication rounds"):
plt.title(title)
plt.xlabel("communication rounds")
plt.xlabel(xlabel)
x_axis = np.array(list(means.keys()))
y_axis = np.array(list(means.values()))
err = np.array(list(stdevs.values()))
......@@ -37,6 +37,13 @@ def plot(means, stdevs, mins, maxs, title, label, loc):
plt.legend(loc=loc)
def replace_dict_key(d_org: dict, d_other: dict):
result = {}
for x, y in d_org.items():
result[d_other[x]] = y
return result
def plot_results(path, centralized, data_machine="machine0", data_node=0):
folders = os.listdir(path)
if centralized.lower() in ["true", "1", "t", "y", "yes"]:
......@@ -74,19 +81,54 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
with open(folder_path / data_machine / f"{data_node}_results.json", "r") as f:
main_data = json.load(f)
main_data = [main_data]
# Plotting bytes over time
plt.figure(10)
b_means, stdevs, mins, maxs = get_stats([x["total_bytes"] for x in results])
plot(b_means, stdevs, mins, maxs, "Total Bytes", folder, "lower right")
df = pd.DataFrame(
{
"mean": list(b_means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(b_means),
},
list(b_means.keys()),
columns=["mean", "std", "nr_nodes"],
)
df.to_csv(
os.path.join(path, "total_bytes_" + folder + ".csv"), index_label="rounds"
)
# Plot Training loss
plt.figure(1)
means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
correct_bytes = [b_means[x] for x in means]
df = pd.DataFrame(
{
"mean": list(means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means),
"total_bytes": correct_bytes,
},
list(means.keys()),
columns=["mean", "std", "nr_nodes"],
columns=["mean", "std", "nr_nodes", "total_bytes"],
)
plt.figure(11)
means = replace_dict_key(means, b_means)
plot(
means,
stdevs,
mins,
maxs,
"Training Loss",
folder,
"upper right",
"Total Bytes per node",
)
df.to_csv(
os.path.join(path, "train_loss_" + folder + ".csv"), index_label="rounds"
)
......@@ -102,10 +144,24 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
"mean": list(means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means),
"total_bytes": correct_bytes,
},
list(means.keys()),
columns=["mean", "std", "nr_nodes"],
columns=["mean", "std", "nr_nodes", "total_bytes"],
)
plt.figure(12)
means = replace_dict_key(means, b_means)
plot(
means,
stdevs,
mins,
maxs,
"Testing Loss",
folder,
"upper right",
"Total Bytes per node",
)
df.to_csv(
os.path.join(path, "test_loss_" + folder + ".csv"), index_label="rounds"
)
......@@ -121,9 +177,22 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
"mean": list(means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means),
"total_bytes": correct_bytes,
},
list(means.keys()),
columns=["mean", "std", "nr_nodes"],
columns=["mean", "std", "nr_nodes", "total_bytes"],
)
plt.figure(13)
means = replace_dict_key(means, b_means)
plot(
means,
stdevs,
mins,
maxs,
"Testing Accuracy",
folder,
"lower right",
"Total Bytes per node",
)
df.to_csv(
os.path.join(path, "test_acc_" + folder + ".csv"), index_label="rounds"
......@@ -157,6 +226,15 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
data_means[folder] = list(means.values())[0]
data_stdevs[folder] = list(stdevs.values())[0]
plt.figure(10)
plt.savefig(os.path.join(path, "total_bytes.png"), dpi=300)
plt.figure(11)
plt.savefig(os.path.join(path, "bytes_train_loss.png"), dpi=300)
plt.figure(12)
plt.savefig(os.path.join(path, "bytes_test_loss.png"), dpi=300)
plt.figure(13)
plt.savefig(os.path.join(path, "bytes_test_acc.png"), dpi=300)
plt.figure(1)
plt.savefig(os.path.join(path, "train_loss.png"), dpi=300)
plt.figure(2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment