Skip to content
Snippets Groups Projects
plot.py 11.1 KiB
Newer Older
Rishi Sharma's avatar
Rishi Sharma committed
import json
import os
import sys
Jeffrey Wigger's avatar
Jeffrey Wigger committed
from pathlib import Path
Rishi Sharma's avatar
Rishi Sharma committed

import numpy as np
import pandas as pd
import torch
Rishi Sharma's avatar
Rishi Sharma committed
from matplotlib import pyplot as plt
Rishi Sharma's avatar
Rishi Sharma committed


def get_stats(l):
    assert len(l) > 0
    mean_dict, stdev_dict, min_dict, max_dict = {}, {}, {}, {}
    for key in l[0].keys():
        all_nodes = [i[key] for i in l]
        all_nodes = np.array(all_nodes)
        mean = np.mean(all_nodes)
        std = np.std(all_nodes)
        min = np.min(all_nodes)
        max = np.max(all_nodes)
        mean_dict[int(key)] = mean
        stdev_dict[int(key)] = std
        min_dict[int(key)] = min
        max_dict[int(key)] = max
    return mean_dict, stdev_dict, min_dict, max_dict


def plot(means, stdevs, mins, maxs, title, label, loc, xlabel="communication rounds"):
Rishi Sharma's avatar
Rishi Sharma committed
    plt.title(title)
    plt.xlabel(xlabel)
Rishi Sharma's avatar
Rishi Sharma committed
    x_axis = np.array(list(means.keys()))
    y_axis = np.array(list(means.values()))
    err = np.array(list(stdevs.values()))
    plt.plot(x_axis, y_axis, label=label)
    plt.fill_between(x_axis, y_axis - err, y_axis + err, alpha=0.4)
Rishi Sharma's avatar
Rishi Sharma committed
    plt.legend(loc=loc)


def replace_dict_key(d_org: dict, d_other: dict):
    result = {}
    for x, y in d_org.items():
        result[d_other[x]] = y
    return result


def plot_results(path, centralized, data_machine="machine0", data_node=0):
Rishi Sharma's avatar
Rishi Sharma committed
    folders = os.listdir(path)
    if centralized.lower() in ["true", "1", "t", "y", "yes"]:
        centralized = True
        print("Centralized")
    else:
        centralized = False

Rishi Sharma's avatar
Rishi Sharma committed
    folders.sort()
Rishi Sharma's avatar
Rishi Sharma committed
    print("Reading folders from: ", path)
    print("Folders: ", folders)
Rishi Sharma's avatar
Rishi Sharma committed
    bytes_means, bytes_stdevs = {}, {}
    meta_means, meta_stdevs = {}, {}
    data_means, data_stdevs = {}, {}
Rishi Sharma's avatar
Rishi Sharma committed
    for folder in folders:
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        folder_path = Path(os.path.join(path, folder))
        if not folder_path.is_dir() or "weights" == folder_path.name:
Rishi Sharma's avatar
Rishi Sharma committed
            continue
        results = []
Rishi Sharma's avatar
Rishi Sharma committed
        machine_folders = os.listdir(folder_path)
        for machine_folder in machine_folders:
            mf_path = os.path.join(folder_path, machine_folder)
            if not os.path.isdir(mf_path):
                continue
            files = os.listdir(mf_path)
            files = [f for f in files if f.endswith("_results.json")]
            for f in files:
                filepath = os.path.join(mf_path, f)
                with open(filepath, "r") as inf:
                    results.append(json.load(inf))
kirsten's avatar
kirsten committed
        if folder.startswith("FL") or folder.startswith("Parameter Server"):
            data_node = -1
        else:
            data_node = 0
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        with open(folder_path / data_machine / f"{data_node}_results.json", "r") as f:
            main_data = json.load(f)
        main_data = [main_data]

        # Plotting bytes over time
        plt.figure(10)
        b_means, stdevs, mins, maxs = get_stats([x["total_bytes"] for x in results])
        plot(b_means, stdevs, mins, maxs, "Total Bytes", folder, "lower right")
        df = pd.DataFrame(
            {
                "mean": list(b_means.values()),
                "std": list(stdevs.values()),
                "nr_nodes": [len(results)] * len(b_means),
            },
            list(b_means.keys()),
            columns=["mean", "std", "nr_nodes"],
        )
        df.to_csv(
            os.path.join(path, "total_bytes_" + folder + ".csv"), index_label="rounds"
        )

Rishi Sharma's avatar
Rishi Sharma committed
        # Plot Training loss
Rishi Sharma's avatar
Rishi Sharma committed
        plt.figure(1)
        means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
Rishi Sharma's avatar
Rishi Sharma committed
        plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")

        correct_bytes = [b_means[x] for x in means]

Jeffrey Wigger's avatar
Jeffrey Wigger committed
        df = pd.DataFrame(
            {
                "mean": list(means.values()),
                "std": list(stdevs.values()),
                "nr_nodes": [len(results)] * len(means),
                "total_bytes": correct_bytes,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            },
            list(means.keys()),
            columns=["mean", "std", "nr_nodes", "total_bytes"],
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        )
        plt.figure(11)
        means = replace_dict_key(means, b_means)
        plot(
            means,
            stdevs,
            mins,
            maxs,
            "Training Loss",
            folder,
            "upper right",
            "Total Bytes per node",
        )

Jeffrey Wigger's avatar
Jeffrey Wigger committed
        df.to_csv(
            os.path.join(path, "train_loss_" + folder + ".csv"), index_label="rounds"
        )
Rishi Sharma's avatar
Rishi Sharma committed
        # Plot Testing loss
Rishi Sharma's avatar
Rishi Sharma committed
        plt.figure(2)
            means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in main_data])
            means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
Rishi Sharma's avatar
Rishi Sharma committed
        plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        df = pd.DataFrame(
            {
                "mean": list(means.values()),
                "std": list(stdevs.values()),
                "nr_nodes": [len(results)] * len(means),
                "total_bytes": correct_bytes,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            },
            list(means.keys()),
            columns=["mean", "std", "nr_nodes", "total_bytes"],
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        )
        plt.figure(12)
        means = replace_dict_key(means, b_means)
        plot(
            means,
            stdevs,
            mins,
            maxs,
            "Testing Loss",
            folder,
            "upper right",
            "Total Bytes per node",
        )

Jeffrey Wigger's avatar
Jeffrey Wigger committed
        df.to_csv(
            os.path.join(path, "test_loss_" + folder + ".csv"), index_label="rounds"
        )
Rishi Sharma's avatar
Rishi Sharma committed
        # Plot Testing Accuracy
Rishi Sharma's avatar
Rishi Sharma committed
        plt.figure(3)
            means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in main_data])
            means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
Rishi Sharma's avatar
Rishi Sharma committed
        plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        df = pd.DataFrame(
            {
                "mean": list(means.values()),
                "std": list(stdevs.values()),
                "nr_nodes": [len(results)] * len(means),
                "total_bytes": correct_bytes,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            },
            list(means.keys()),
            columns=["mean", "std", "nr_nodes", "total_bytes"],
        )
        plt.figure(13)
        means = replace_dict_key(means, b_means)
        plot(
            means,
            stdevs,
            mins,
            maxs,
            "Testing Accuracy",
            folder,
            "lower right",
            "Total Bytes per node",
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        )
        df.to_csv(
            os.path.join(path, "test_acc_" + folder + ".csv"), index_label="rounds"
        )
Rishi Sharma's avatar
Rishi Sharma committed
        # Collect total_bytes shared
        bytes_list = []
        for x in results:
            max_key = str(max(list(map(int, x["total_bytes"].keys()))))
            bytes_list.append({max_key: x["total_bytes"][max_key]})
        means, stdevs, mins, maxs = get_stats(bytes_list)
        bytes_means[folder] = list(means.values())[0]
        bytes_stdevs[folder] = list(stdevs.values())[0]
Rishi Sharma's avatar
Rishi Sharma committed

        meta_list = []
        for x in results:
            if x["total_meta"]:
                max_key = str(max(list(map(int, x["total_meta"].keys()))))
                meta_list.append({max_key: x["total_meta"][max_key]})
            else:
                meta_list.append({max_key: 0})
        means, stdevs, mins, maxs = get_stats(meta_list)
        meta_means[folder] = list(means.values())[0]
        meta_stdevs[folder] = list(stdevs.values())[0]

        data_list = []
        for x in results:
            max_key = str(max(list(map(int, x["total_data_per_n"].keys()))))
            data_list.append({max_key: x["total_data_per_n"][max_key]})
        means, stdevs, mins, maxs = get_stats(data_list)
        data_means[folder] = list(means.values())[0]
        data_stdevs[folder] = list(stdevs.values())[0]

    plt.figure(10)
    plt.savefig(os.path.join(path, "total_bytes.png"), dpi=300)
    plt.figure(11)
    plt.savefig(os.path.join(path, "bytes_train_loss.png"), dpi=300)
    plt.figure(12)
    plt.savefig(os.path.join(path, "bytes_test_loss.png"), dpi=300)
    plt.figure(13)
    plt.savefig(os.path.join(path, "bytes_test_acc.png"), dpi=300)

Rishi Sharma's avatar
Rishi Sharma committed
    plt.figure(1)
    plt.savefig(os.path.join(path, "train_loss.png"), dpi=300)
Rishi Sharma's avatar
Rishi Sharma committed
    plt.figure(2)
    plt.savefig(os.path.join(path, "test_loss.png"), dpi=300)
Rishi Sharma's avatar
Rishi Sharma committed
    plt.figure(3)
    plt.savefig(os.path.join(path, "test_acc.png"), dpi=300)
Rishi Sharma's avatar
Rishi Sharma committed
    # Plot total_bytes
    plt.figure(4)
    plt.title("Data Shared")
    x_pos = np.arange(len(bytes_means.keys()))
    plt.bar(
        x_pos,
        np.array(list(bytes_means.values())) // (1024 * 1024),
        yerr=np.array(list(bytes_stdevs.values())) // (1024 * 1024),
        align="center",
    )
    plt.ylabel("Total data shared in MBs")
    plt.xlabel("Fraction of Model Shared")
    plt.xticks(x_pos, list(bytes_means.keys()))
    plt.savefig(os.path.join(path, "data_shared.png"), dpi=300)
Rishi Sharma's avatar
Rishi Sharma committed

    # Plot stacked_bytes
    plt.figure(5)
    plt.title("Data Shared per Neighbor")
    x_pos = np.arange(len(bytes_means.keys()))
    plt.bar(
        x_pos,
        np.array(list(data_means.values())) // (1024 * 1024),
        yerr=np.array(list(data_stdevs.values())) // (1024 * 1024),
        align="center",
        label="Parameters",
    )
    plt.bar(
        x_pos,
        np.array(list(meta_means.values())) // (1024 * 1024),
        bottom=np.array(list(data_means.values())) // (1024 * 1024),
        yerr=np.array(list(meta_stdevs.values())) // (1024 * 1024),
        align="center",
        label="Metadata",
    )
    plt.ylabel("Data shared in MBs")
    plt.xlabel("Fraction of Model Shared")
    plt.xticks(x_pos, list(meta_means.keys()))
    plt.savefig(os.path.join(path, "parameters_metadata.png"), dpi=300)
Rishi Sharma's avatar
Rishi Sharma committed

def plot_parameters(path):
    plt.figure(4)
    folders = os.listdir(path)
    for folder in folders:
        folder_path = os.path.join(path, folder)
        if not os.path.isdir(folder_path):
            continue
        files = os.listdir(folder_path)
        files = [f for f in files if f.endswith("_shared_params.json")]
        for f in files:
            filepath = os.path.join(folder_path, f)
            print("Working with ", filepath)
            with open(filepath, "r") as inf:
                loaded_dict = json.load(inf)
                del loaded_dict["order"]
                del loaded_dict["shapes"]
            assert len(loaded_dict["0"]) > 0
            assert "0" in loaded_dict.keys()
            counts = np.zeros(len(loaded_dict["0"]))
            for key in loaded_dict.keys():
                indices = np.array(loaded_dict[key])
                counts = np.pad(
                    counts,
                    max(np.max(indices) - counts.shape[0], 0),
                    "constant",
                    constant_values=0,
                )
                counts[indices] += 1
            plt.plot(np.arange(0, counts.shape[0]), counts, ".")
        print("Saving scatterplot")
        plt.savefig(os.path.join(folder_path, "shared_params.png"))


if __name__ == "__main__":
    assert len(sys.argv) == 3
    # The args are:
    # 1: the folder with the data
    # 2: True/False: If True then the evaluation on the test set was centralized
    # for federated learning folder name must start with "FL"!
    plot_results(sys.argv[1], sys.argv[2])
Rishi Sharma's avatar
Rishi Sharma committed
    # plot_parameters(sys.argv[1])