Skip to content
Snippets Groups Projects
Commit b8731d2c authored by Jeffrey Wigger's avatar Jeffrey Wigger Committed by Rishi Sharma
Browse files

Shared parameter counter

parent 4b2edf5e
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@ import sys
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
def get_stats(l):
......@@ -152,15 +153,15 @@ def plot_results(path):
data_stdevs[folder] = list(stdevs.values())[0]
plt.figure(1)
plt.savefig(os.path.join(path, "train_loss.png"))
plt.savefig(os.path.join(path, "train_loss.png"), dpi=300)
plt.figure(2)
plt.savefig(os.path.join(path, "test_loss.png"))
plt.savefig(os.path.join(path, "test_loss.png"), dpi=300)
plt.figure(3)
plt.savefig(os.path.join(path, "test_acc.png"))
plt.savefig(os.path.join(path, "test_acc.png"), dpi=300)
plt.figure(6)
plt.savefig(os.path.join(path, "grad_std.png"))
plt.savefig(os.path.join(path, "grad_std.png"), dpi=300)
plt.figure(7)
plt.savefig(os.path.join(path, "grad_mean.png"))
plt.savefig(os.path.join(path, "grad_mean.png"), dpi=300)
# Plot total_bytes
plt.figure(4)
plt.title("Data Shared")
......@@ -174,7 +175,7 @@ def plot_results(path):
plt.ylabel("Total data shared in MBs")
plt.xlabel("Fraction of Model Shared")
plt.xticks(x_pos, list(bytes_means.keys()))
plt.savefig(os.path.join(path, "data_shared.png"))
plt.savefig(os.path.join(path, "data_shared.png"), dpi=300)
# Plot stacked_bytes
plt.figure(5)
......@@ -198,7 +199,7 @@ def plot_results(path):
plt.ylabel("Data shared in MBs")
plt.xlabel("Fraction of Model Shared")
plt.xticks(x_pos, list(meta_means.keys()))
plt.savefig(os.path.join(path, "parameters_metadata.png"))
plt.savefig(os.path.join(path, "parameters_metadata.png"), dpi=300)
def plot_parameters(path):
......@@ -237,4 +238,4 @@ def plot_parameters(path):
if __name__ == "__main__":
assert len(sys.argv) == 2
plot_results(sys.argv[1])
# plot_parameters(sys.argv[1])
# plot_parameters(sys.argv[1])
\ No newline at end of file
import json
import os
import sys
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
def get_stats(l):
assert len(l) > 0
mean_dict, stdev_dict, min_dict, max_dict = {}, {}, {}, {}
for key in l[0].keys():
all_nodes = [i[key] for i in l]
all_nodes = np.array(all_nodes)
mean = np.mean(all_nodes)
std = np.std(all_nodes)
min = np.min(all_nodes)
max = np.max(all_nodes)
mean_dict[int(key)] = mean
stdev_dict[int(key)] = std
min_dict[int(key)] = min
max_dict[int(key)] = max
return mean_dict, stdev_dict, min_dict, max_dict
def plot(means, stdevs, mins, maxs, title, label, loc):
plt.title(title)
plt.xlabel("communication rounds")
x_axis = list(means.keys())
y_axis = list(means.values())
err = list(stdevs.values())
plt.errorbar(x_axis, y_axis, yerr=err, label=label)
plt.legend(loc=loc)
def plot_results(path):
"""
plots the percentiles
Based on plot.py
Parameters
----------
path
path to the folders from which to create the percentiles plots
"""
folders = os.listdir(path)
folders.sort()
print("Reading folders from: ", path)
print("Folders: ", folders)
for folder in folders:
folder_path = os.path.join(path, folder)
if not os.path.isdir(folder_path):
continue
results = []
all_shared_params = []
machine_folders = os.listdir(folder_path)
for machine_folder in machine_folders:
mf_path = os.path.join(folder_path, machine_folder)
if not os.path.isdir(mf_path):
continue
files = os.listdir(mf_path)
shared_params = [f for f in files if f.endswith("_shared_parameters.json")]
files = [f for f in files if f.endswith("_results.json")]
for f in files:
filepath = os.path.join(mf_path, f)
with open(filepath, "r") as inf:
results.append(json.load(inf))
for sp in shared_params:
filepath = os.path.join(mf_path, sp)
with open(filepath, "r") as spf:
all_shared_params.append(np.array(json.load(spf), dtype = np.int32))
# Plot Training loss
plt.figure(1)
# Average of the shared parameters
mean = np.mean(all_shared_params, axis=0)
std = np.std(all_shared_params, axis=0)
with open(
os.path.join(path, "shared_params_avg_"+folder+".json"), "w"
) as mf:
json.dump(mean.tolist(), mf)
with open(
os.path.join(path, "shared_params_std_"+folder+".json"), "w"
) as sf:
json.dump(std.tolist(), sf)
# copy jupyter notebook code
percentile = np.percentile(mean, np.arange(0, 100, 1))
plt.plot(np.arange(0, 100, 1), percentile, label=folder)
plt.title('Shared parameters Percentiles')
# plt.ylabel("Absolute frequency value")
plt.xlabel("Percentiles")
plt.xticks(np.arange(0, 110, 10))
plt.legend(loc="lower right")
plt.figure(2)
sort = torch.sort(torch.tensor(mean)).values
print(sort)
length = sort.shape[0]
length = int(length / 20)
bins = [torch.sum(sort[length * i: length * (i + 1)]).item() for i in range(20)]
total = np.sum(bins)
perc = bins / total #np.divide(bins, total)
print(perc)
plt.bar(np.arange(0, 97.5, 5), perc, width=5, align='edge',
label=folder)
plt.title('Shared parameters Percentiles')
# plt.ylabel("Absolute frequency value")
plt.xlabel("Percentiles")
plt.legend(loc="lower right")
plt.savefig(os.path.join(path, f"percentiles_histogram_{folder}.png"), dpi=300)
plt.clf()
plt.cla()
plt.figure(1)
plt.savefig(os.path.join(path, "percentiles.png"), dpi=300)
if __name__ == "__main__":
assert len(sys.argv) == 2
plot_results(sys.argv[1])
\ No newline at end of file
......@@ -18,6 +18,7 @@ class Model(nn.Module):
self._param_count_ot = None
self._param_count_total = None
self.accumulated_changes = None
self.shared_parameters_counter = None
def count_params(self, only_trainable=False):
"""
......
......@@ -419,7 +419,12 @@ class Node:
os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
) as of:
json.dump(results_dict, of)
if self.model.shared_parameters_counter is not None:
logging.info("Saving the shared parameter counts")
with open(
os.path.join(self.log_dir, "{}_shared_parameters.json".format(self.rank)), "w"
) as of:
json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
self.communication.disconnect_neighbors()
logging.info("All neighbors disconnected. Process complete!")
......
......@@ -163,7 +163,7 @@ class FFT(PartialModel):
with torch.no_grad():
topk, indices = self.apply_fft()
self.model.shared_parameters_counter[indices] += 1
self.model.rewind_accumulation(indices)
if self.save_shared:
......
......@@ -126,6 +126,8 @@ class PartialModel(Sharing):
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
self.model.shared_parameters_counter = torch.zeros(self.change_transformer(self.init_model).shape[0], dtype = torch.int32)
def extract_top_gradients(self):
"""
Extract the indices and values of the topK gradients.
......@@ -162,6 +164,7 @@ class PartialModel(Sharing):
with torch.no_grad():
_, G_topk = self.extract_top_gradients()
self.model.shared_parameters_counter[G_topk] += 1
if self.accumulation:
self.model.rewind_accumulation(G_topk)
if self.save_shared:
......
......@@ -84,6 +84,7 @@ class RoundRobinPartial(Sharing):
block_end = min(T.shape[0], (self.current_block + 1) * self.block_size)
self.current_block = (self.current_block + 1) % self.num_blocks
T_send = T[block_start:block_end]
self.model.shared_parameters_counter[block_start:block_end] += 1
logging.info("Range sending: {}-{}".format(block_start, block_end))
logging.info("Generating dictionary to send")
......
......@@ -131,6 +131,7 @@ class SubSampling(Sharing):
<= self.alpha
)
subsample = concated[binary_mask]
self.model.shared_parameters_counter[binary_mask] += 1
# logging.debug("Subsampling vector is of size: " + str(subsample.size(dim = 0)))
return (subsample, curr_seed, self.alpha)
else:
......@@ -147,6 +148,7 @@ class SubSampling(Sharing):
)
<= self.alpha
)
# TODO: support shared_parameters_counter
selected = flat[binary_mask]
values_list.append(selected)
off += selected.size(dim=0)
......
......@@ -128,7 +128,7 @@ class TopKParams(Sharing):
with torch.no_grad():
values, index, offsets = self.extract_top_params()
self.model.shared_parameters_counter[index] += 1
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
......
......@@ -184,7 +184,7 @@ class Wavelet(PartialModel):
with torch.no_grad():
topk, indices = self.apply_wavelet()
self.model.shared_parameters_counter[indices] += 1
self.model.rewind_accumulation(indices)
if self.save_shared:
shared_params = dict()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment