import json
import os
import sys

import numpy as np
from matplotlib import pyplot as plt
from numpy.core.numeric import indices


def get_stats(l):
    assert len(l) > 0
    mean_dict, stdev_dict, min_dict, max_dict = {}, {}, {}, {}
    for key in l[0].keys():
        all_nodes = [i[key] for i in l]
        all_nodes = np.array(all_nodes)
        mean = np.mean(all_nodes)
        std = np.std(all_nodes)
        min = np.min(all_nodes)
        max = np.max(all_nodes)
        mean_dict[int(key)] = mean
        stdev_dict[int(key)] = std
        min_dict[int(key)] = min
        max_dict[int(key)] = max
    return mean_dict, stdev_dict, min_dict, max_dict


def plot(means, stdevs, mins, maxs, title, label, loc):
    plt.title(title)
    plt.xlabel("communication rounds")
    x_axis = list(means.keys())
    y_axis = list(means.values())
    err = list(stdevs.values())
    plt.errorbar(x_axis, y_axis, yerr=err, label=label)
    plt.legend(loc=loc)


def plot_results(path):
    folders = os.listdir(path)
    folders.sort()
    print("Reading folders from: ", path)
    print("Folders: ", folders)
    for folder in folders:
        folder_path = os.path.join(path, folder)
        if not os.path.isdir(folder_path):
            continue
        results = []
        files = os.listdir(folder_path)
        files = [f for f in files if f.endswith("_results.json")]
        for f in files:
            filepath = os.path.join(folder_path, f)
            with open(filepath, "r") as inf:
                results.append(json.load(inf))
        plt.figure(1)
        means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
        plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
        plt.figure(2)
        means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
        plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
        plt.figure(3)
        means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
        plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")

    plt.figure(1)
    plt.savefig(os.path.join(path, "train_loss.png"))
    plt.figure(2)
    plt.savefig(os.path.join(path, "test_loss.png"))
    plt.figure(3)
    plt.savefig(os.path.join(path, "test_acc.png"))


def plot_parameters(path):
    plt.figure(4)
    folders = os.listdir(path)
    for folder in folders:
        folder_path = os.path.join(path, folder)
        if not os.path.isdir(folder_path):
            continue
        files = os.listdir(folder_path)
        files = [f for f in files if f.endswith("_shared_params.json")]
        for f in files:
            filepath = os.path.join(folder_path, f)
            print("Working with ", filepath)
            with open(filepath, "r") as inf:
                loaded_dict = json.load(inf)
                del loaded_dict["order"]
                del loaded_dict["shapes"]
            assert len(loaded_dict["0"]) > 0
            assert "0" in loaded_dict.keys()
            counts = np.zeros(len(loaded_dict["0"]))
            for key in loaded_dict.keys():
                indices = np.array(loaded_dict[key])
                counts = np.pad(
                    counts,
                    max(np.max(indices) - counts.shape[0], 0),
                    "constant",
                    constant_values=0,
                )
                counts[indices] += 1
            plt.plot(np.arange(0, counts.shape[0]), counts, ".")
        print("Saving scatterplot")
        plt.savefig(os.path.join(folder_path, "shared_params.png"))


if __name__ == "__main__":
    assert len(sys.argv) == 2
    plot_results(sys.argv[1])
    # plot_parameters(sys.argv[1])