From 4f14c17f9ac0571ffc812da72cca0785e36b124d Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Wed, 4 May 2022 13:36:28 +0200 Subject: [PATCH] global epoch plotting and option for non centralized plotting --- eval/plot.py | 25 ++++-- eval/plotting_from_csv.py | 183 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 203 insertions(+), 5 deletions(-) create mode 100644 eval/plotting_from_csv.py diff --git a/eval/plot.py b/eval/plot.py index 9fdccde..601d8e8 100644 --- a/eval/plot.py +++ b/eval/plot.py @@ -36,8 +36,14 @@ def plot(means, stdevs, mins, maxs, title, label, loc): plt.legend(loc=loc) -def plot_results(path, data_machine="machine0", data_node=0): +def plot_results(path, centralized, data_machine="machine0", data_node=0): folders = os.listdir(path) + if centralized.lower() in ['true', '1', 't', 'y', 'yes']: + centralized = True + print("Centralized") + else: + centralized = False + folders.sort() print("Reading folders from: ", path) print("Folders: ", folders) @@ -82,7 +88,10 @@ def plot_results(path, data_machine="machine0", data_node=0): ) # Plot Testing loss plt.figure(2) - means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in main_data]) + if centralized: + means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in main_data]) + else: + means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results]) plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right") df = pd.DataFrame( { @@ -98,7 +107,10 @@ def plot_results(path, data_machine="machine0", data_node=0): ) # Plot Testing Accuracy plt.figure(3) - means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in main_data]) + if centralized: + means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in main_data]) + else: + means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results]) plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right") df = pd.DataFrame( { @@ -241,6 +253,9 @@ def plot_parameters(path): if __name__ == "__main__": - assert len(sys.argv) == 2 - plot_results(sys.argv[1]) + assert len(sys.argv) == 3 + # The args are: + # 1: the folder with the data + # 2: True/False: If True then the evaluation on the test set was centralized + plot_results(sys.argv[1], sys.argv[2]) # plot_parameters(sys.argv[1]) diff --git a/eval/plotting_from_csv.py b/eval/plotting_from_csv.py new file mode 100644 index 0000000..b8d4320 --- /dev/null +++ b/eval/plotting_from_csv.py @@ -0,0 +1,183 @@ +import distutils +import json +import os +import sys + +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt + + +def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel): + cmap = plt.get_cmap("gist_rainbow") + plt.title(title) + plt.xlabel(xlabel) + y_axis = list(means) + err = list(stdevs) + print("label:", label) + print("color: ", cmap(1 / nb_plots * pos)) + plt.errorbar( + list(x_axis), y_axis, yerr=err, label=label, color=cmap(1 / nb_plots * pos) + ) + plt.legend(loc=loc) + + +def plot_results(path, epochs, global_epochs="True"): + if global_epochs.lower() in ['true', '1', 't', 'y', 'yes']: + global_epochs = True + else: + global_epochs = False + epochs = int(epochs) + # rounds = int(rounds) + folders = os.listdir(path) + folders.sort() + print("Reading folders from: ", path) + print("Folders: ", folders) + bytes_means, bytes_stdevs = {}, {} + meta_means, meta_stdevs = {}, {} + data_means, data_stdevs = {}, {} + + files = os.listdir(path) + files = [f for f in files if f.endswith(".csv")] + train_loss = sorted([f for f in files if f.startswith("train_loss")]) + test_acc = sorted([f for f in files if f.startswith("test_acc")]) + test_loss = sorted([f for f in files if f.startswith("test_loss")]) + min_losses = [] + for i, f in enumerate(train_loss): + filepath = os.path.join(path, f) + with open(filepath, "r") as inf: + results_csv = pd.read_csv(inf) + # Plot Training loss + plt.figure(1) + if global_epochs: + rounds = results_csv["rounds"].iloc[0] + print("Rounds: ", rounds) + results_cr = results_csv[results_csv.rounds <= epochs*rounds] + means = results_cr["mean"].to_numpy() + stdevs = results_cr["std"].to_numpy() + x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1)) + x_label = "global epochs" + else: + results_cr = results_csv[results_csv.rounds <= epochs] + means = results_cr["mean"].to_numpy() + stdevs = results_cr["std"].to_numpy() + x_axis = results_cr["rounds"].to_numpy() + x_label = "communication rounds" + min_losses.append(np.min(means)) + + plot( + x_axis, + means, + stdevs, + i, + len(train_loss), + "Training Loss", + f[len("train_loss") + 1 : -len(":2022-03-24T17:54.csv")], + "upper right", + x_label, + ) + + min_tlosses = [] + for i, f in enumerate(test_loss): + filepath = os.path.join(path, f) + with open(filepath, "r") as inf: + results_csv = pd.read_csv(inf) + if global_epochs: + rounds = results_csv["rounds"].iloc[0] + print("Rounds: ", rounds) + results_cr = results_csv[results_csv.rounds <= epochs*rounds] + means = results_cr["mean"].to_numpy() + stdevs = results_cr["std"].to_numpy() + x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1)) + x_label = "global epochs" + else: + results_cr = results_csv[results_csv.rounds <= epochs] + means = results_cr["mean"].to_numpy() + stdevs = results_cr["std"].to_numpy() + x_axis = results_cr["rounds"].to_numpy() + x_label = "communication rounds" + print("x axis:", x_axis) + min_tlosses.append(np.min(means)) + # Plot Testing loss + plt.figure(2) + plot( + x_axis, + means, + stdevs, + i, + len(test_loss), + "Testing Loss", + f[len("test_loss") + 1 : -len(":2022-03-24T17:54.csv")], + "upper right", + x_label, + ) + + max_taccs = [] + for i, f in enumerate(test_acc): + filepath = os.path.join(path, f) + with open(filepath, "r") as inf: + results_csv = pd.read_csv(inf) + if global_epochs: + rounds = results_csv["rounds"].iloc[0] + print("Rounds: ", rounds) + results_cr = results_csv[results_csv.rounds <= epochs*rounds] + means = results_cr["mean"].to_numpy() + stdevs = results_cr["std"].to_numpy() + x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1)) + x_label = "global epochs" + else: + results_cr = results_csv[results_csv.rounds <= epochs] + means = results_cr["mean"].to_numpy() + stdevs = results_cr["std"].to_numpy() + x_axis = results_cr["rounds"].to_numpy() + x_label = "communication rounds" + max_taccs.append(np.max(means)) + # Plot Testing Accuracy + plt.figure(3) + plot( + x_axis, + means, + stdevs, + i, + len(test_acc), + "Testing Accuracy", + f[len("test_acc") + 1 : -len(":2022-03-24T17:54.csv")], + "lower right", + x_label, + ) + + names_loss = [ + f[len("train_loss") + 1 : -len(":2022-03-24T17:54.csv")] for f in train_loss + ] + names_acc = [ + f[len("test_acc") + 1 : -len(":2022-03-24T17:54.csv")] for f in test_acc + ] + print(names_loss) + print(names_acc) + pf = pd.DataFrame( + { + "test_accuracy": max_taccs, + "test_losses": min_tlosses, + "train_losses": min_losses, + }, + names_loss, + ) + pf = pf.sort_values(["test_accuracy"], 0, ascending=False) + pf.to_csv(os.path.join(path, "best_results.csv")) + + plt.figure(1) + plt.savefig(os.path.join(path, "ge_train_loss.png"), dpi=300) + plt.figure(2) + plt.savefig(os.path.join(path, "ge_test_loss.png"), dpi=300) + plt.figure(3) + plt.savefig(os.path.join(path, "ge_test_acc.png"), dpi=300) + + +if __name__ == "__main__": + assert len(sys.argv) == 4 + # The args are: + # 1: the folder with the csv files, + # 2: the number of epochs / comm rounds to plot for, + # 3: True/False with True meaning plot global epochs and False plot communication rounds + print(sys.argv[1], sys.argv[2], sys.argv[3]) + plot_results(sys.argv[1], sys.argv[2], sys.argv[3]) -- GitLab