Skip to content
Snippets Groups Projects
plot_percentile.py 4.01 KiB
Newer Older
import json
import os
import sys

import numpy as np
import pandas as pd
import torch
Rishi Sharma's avatar
Rishi Sharma committed
from matplotlib import pyplot as plt


def get_stats(l):
    assert len(l) > 0
    mean_dict, stdev_dict, min_dict, max_dict = {}, {}, {}, {}
    for key in l[0].keys():
        all_nodes = [i[key] for i in l]
        all_nodes = np.array(all_nodes)
        mean = np.mean(all_nodes)
        std = np.std(all_nodes)
        min = np.min(all_nodes)
        max = np.max(all_nodes)
        mean_dict[int(key)] = mean
        stdev_dict[int(key)] = std
        min_dict[int(key)] = min
        max_dict[int(key)] = max
    return mean_dict, stdev_dict, min_dict, max_dict


def plot(means, stdevs, mins, maxs, title, label, loc):
    plt.title(title)
    plt.xlabel("communication rounds")
    x_axis = list(means.keys())
    y_axis = list(means.values())
    err = list(stdevs.values())
    plt.errorbar(x_axis, y_axis, yerr=err, label=label)
    plt.legend(loc=loc)


def plot_results(path):
    """
    plots the percentiles
    Based on plot.py
    Parameters
    ----------
    path
        path to the folders from which to create the percentiles plots

    """
    folders = os.listdir(path)
    folders.sort()
    print("Reading folders from: ", path)
    print("Folders: ", folders)
    for folder in folders:
        folder_path = os.path.join(path, folder)
        if not os.path.isdir(folder_path):
            continue
        results = []
        all_shared_params = []
        machine_folders = os.listdir(folder_path)
        for machine_folder in machine_folders:
            mf_path = os.path.join(folder_path, machine_folder)
            if not os.path.isdir(mf_path):
                continue
            files = os.listdir(mf_path)
            shared_params = [f for f in files if f.endswith("_shared_parameters.json")]
            files = [f for f in files if f.endswith("_results.json")]
            for f in files:
                filepath = os.path.join(mf_path, f)
                with open(filepath, "r") as inf:
                    results.append(json.load(inf))
            for sp in shared_params:
                filepath = os.path.join(mf_path, sp)
                with open(filepath, "r") as spf:
Rishi Sharma's avatar
Rishi Sharma committed
                    all_shared_params.append(np.array(json.load(spf), dtype=np.int32))

        # Plot Training loss
        plt.figure(1)
        # Average of the shared parameters
        mean = np.mean(all_shared_params, axis=0)
        std = np.std(all_shared_params, axis=0)
        with open(
Rishi Sharma's avatar
Rishi Sharma committed
            os.path.join(path, "shared_params_avg_" + folder + ".json"), "w"
        ) as mf:
            json.dump(mean.tolist(), mf)

        with open(
Rishi Sharma's avatar
Rishi Sharma committed
            os.path.join(path, "shared_params_std_" + folder + ".json"), "w"
        ) as sf:
            json.dump(std.tolist(), sf)

        # copy jupyter notebook code
        percentile = np.percentile(mean, np.arange(0, 100, 1))
        plt.plot(np.arange(0, 100, 1), percentile, label=folder)
Rishi Sharma's avatar
Rishi Sharma committed
        plt.title("Shared parameters Percentiles")
        # plt.ylabel("Absolute frequency value")
        plt.xlabel("Percentiles")
        plt.xticks(np.arange(0, 110, 10))
        plt.legend(loc="lower right")

        plt.figure(2)
        sort = torch.sort(torch.tensor(mean)).values
        print(sort)
        length = sort.shape[0]
        length = int(length / 20)
Rishi Sharma's avatar
Rishi Sharma committed
        bins = [
            torch.sum(sort[length * i : length * (i + 1)]).item() for i in range(20)
        ]
        total = np.sum(bins)
Rishi Sharma's avatar
Rishi Sharma committed
        perc = bins / total  # np.divide(bins, total)
        print(perc)
Rishi Sharma's avatar
Rishi Sharma committed
        plt.bar(np.arange(0, 97.5, 5), perc, width=5, align="edge", label=folder)
Rishi Sharma's avatar
Rishi Sharma committed
        plt.title("Shared parameters Percentiles")
        # plt.ylabel("Absolute frequency value")
        plt.xlabel("Percentiles")
        plt.legend(loc="lower right")
        plt.savefig(os.path.join(path, f"percentiles_histogram_{folder}.png"), dpi=300)
        plt.clf()
        plt.cla()

    plt.figure(1)
    plt.savefig(os.path.join(path, "percentiles.png"), dpi=300)


if __name__ == "__main__":
    assert len(sys.argv) == 2
Rishi Sharma's avatar
Rishi Sharma committed
    plot_results(sys.argv[1])