Newer
Older
import json
import os
import sys
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt
def plot(x, y, label, *args):
plt.plot(x, y, *args, label=label)
plt.legend()
def plot_shared(path, title):
model_path = os.path.join(path, "plots")
Path(model_path).mkdir(parents=True, exist_ok=True)
files = [f for f in os.listdir(path) if f.endswith("json")]
assert len(files) > 0
for i, file in enumerate(files):
filepath = os.path.join(path, file)
with open(filepath, "r") as inf:
model_vec = json.load(inf)
del model_vec["order"]
if i == 0:
total_params = 0
for l in model_vec["shapes"].values():
current_params = 1
for v in l:
current_params *= v
total_params += current_params
print("Total Params: ", str(total_params))
shared_count = np.zeros(total_params, dtype=int)
del model_vec["shapes"]
model_vec = np.array(model_vec[list(model_vec.keys())[0]])
shared_count[model_vec] += 1
print("sum: ", np.sum(shared_count))
num_elements = shared_count.shape[0]
x_axis = np.arange(1, num_elements + 1)
plt.clf()
plt.title(title)
plot(x_axis, shared_count, "unsorted", ".")
shared_count = np.sort(shared_count)
plot(x_axis, shared_count, "sorted")
plt.savefig(os.path.join(model_path, "shared_plot.png"))
if __name__ == "__main__":
assert len(sys.argv) == 2
plot_shared(sys.argv[1], "Shared Parameters")