Skip to content
Snippets Groups Projects
Commit 5ab284a8 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Merge branch 'shared_param_counter_pr' into 'main'

Shared parameter counter

See merge request !7
parents 4b2edf5e b8731d2c
No related branches found
No related tags found
1 merge request!7Shared parameter counter
......@@ -5,6 +5,7 @@ import sys
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
def get_stats(l):
......@@ -152,15 +153,15 @@ def plot_results(path):
data_stdevs[folder] = list(stdevs.values())[0]
plt.figure(1)
plt.savefig(os.path.join(path, "train_loss.png"))
plt.savefig(os.path.join(path, "train_loss.png"), dpi=300)
plt.figure(2)
plt.savefig(os.path.join(path, "test_loss.png"))
plt.savefig(os.path.join(path, "test_loss.png"), dpi=300)
plt.figure(3)
plt.savefig(os.path.join(path, "test_acc.png"))
plt.savefig(os.path.join(path, "test_acc.png"), dpi=300)
plt.figure(6)
plt.savefig(os.path.join(path, "grad_std.png"))
plt.savefig(os.path.join(path, "grad_std.png"), dpi=300)
plt.figure(7)
plt.savefig(os.path.join(path, "grad_mean.png"))
plt.savefig(os.path.join(path, "grad_mean.png"), dpi=300)
# Plot total_bytes
plt.figure(4)
plt.title("Data Shared")
......@@ -174,7 +175,7 @@ def plot_results(path):
plt.ylabel("Total data shared in MBs")
plt.xlabel("Fraction of Model Shared")
plt.xticks(x_pos, list(bytes_means.keys()))
plt.savefig(os.path.join(path, "data_shared.png"))
plt.savefig(os.path.join(path, "data_shared.png"), dpi=300)
# Plot stacked_bytes
plt.figure(5)
......@@ -198,7 +199,7 @@ def plot_results(path):
plt.ylabel("Data shared in MBs")
plt.xlabel("Fraction of Model Shared")
plt.xticks(x_pos, list(meta_means.keys()))
plt.savefig(os.path.join(path, "parameters_metadata.png"))
plt.savefig(os.path.join(path, "parameters_metadata.png"), dpi=300)
def plot_parameters(path):
......@@ -237,4 +238,4 @@ def plot_parameters(path):
if __name__ == "__main__":
assert len(sys.argv) == 2
plot_results(sys.argv[1])
# plot_parameters(sys.argv[1])
# plot_parameters(sys.argv[1])
\ No newline at end of file
import json
import os
import sys
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
def get_stats(l):
assert len(l) > 0
mean_dict, stdev_dict, min_dict, max_dict = {}, {}, {}, {}
for key in l[0].keys():
all_nodes = [i[key] for i in l]
all_nodes = np.array(all_nodes)
mean = np.mean(all_nodes)
std = np.std(all_nodes)
min = np.min(all_nodes)
max = np.max(all_nodes)
mean_dict[int(key)] = mean
stdev_dict[int(key)] = std
min_dict[int(key)] = min
max_dict[int(key)] = max
return mean_dict, stdev_dict, min_dict, max_dict
def plot(means, stdevs, mins, maxs, title, label, loc):
plt.title(title)
plt.xlabel("communication rounds")
x_axis = list(means.keys())
y_axis = list(means.values())
err = list(stdevs.values())
plt.errorbar(x_axis, y_axis, yerr=err, label=label)
plt.legend(loc=loc)
def plot_results(path):
"""
plots the percentiles
Based on plot.py
Parameters
----------
path
path to the folders from which to create the percentiles plots
"""
folders = os.listdir(path)
folders.sort()
print("Reading folders from: ", path)
print("Folders: ", folders)
for folder in folders:
folder_path = os.path.join(path, folder)
if not os.path.isdir(folder_path):
continue
results = []
all_shared_params = []
machine_folders = os.listdir(folder_path)
for machine_folder in machine_folders:
mf_path = os.path.join(folder_path, machine_folder)
if not os.path.isdir(mf_path):
continue
files = os.listdir(mf_path)
shared_params = [f for f in files if f.endswith("_shared_parameters.json")]
files = [f for f in files if f.endswith("_results.json")]
for f in files:
filepath = os.path.join(mf_path, f)
with open(filepath, "r") as inf:
results.append(json.load(inf))
for sp in shared_params:
filepath = os.path.join(mf_path, sp)
with open(filepath, "r") as spf:
all_shared_params.append(np.array(json.load(spf), dtype = np.int32))
# Plot Training loss
plt.figure(1)
# Average of the shared parameters
mean = np.mean(all_shared_params, axis=0)
std = np.std(all_shared_params, axis=0)
with open(
os.path.join(path, "shared_params_avg_"+folder+".json"), "w"
) as mf:
json.dump(mean.tolist(), mf)
with open(
os.path.join(path, "shared_params_std_"+folder+".json"), "w"
) as sf:
json.dump(std.tolist(), sf)
# copy jupyter notebook code
percentile = np.percentile(mean, np.arange(0, 100, 1))
plt.plot(np.arange(0, 100, 1), percentile, label=folder)
plt.title('Shared parameters Percentiles')
# plt.ylabel("Absolute frequency value")
plt.xlabel("Percentiles")
plt.xticks(np.arange(0, 110, 10))
plt.legend(loc="lower right")
plt.figure(2)
sort = torch.sort(torch.tensor(mean)).values
print(sort)
length = sort.shape[0]
length = int(length / 20)
bins = [torch.sum(sort[length * i: length * (i + 1)]).item() for i in range(20)]
total = np.sum(bins)
perc = bins / total #np.divide(bins, total)
print(perc)
plt.bar(np.arange(0, 97.5, 5), perc, width=5, align='edge',
label=folder)
plt.title('Shared parameters Percentiles')
# plt.ylabel("Absolute frequency value")
plt.xlabel("Percentiles")
plt.legend(loc="lower right")
plt.savefig(os.path.join(path, f"percentiles_histogram_{folder}.png"), dpi=300)
plt.clf()
plt.cla()
plt.figure(1)
plt.savefig(os.path.join(path, "percentiles.png"), dpi=300)
if __name__ == "__main__":
assert len(sys.argv) == 2
plot_results(sys.argv[1])
\ No newline at end of file
......@@ -18,6 +18,7 @@ class Model(nn.Module):
self._param_count_ot = None
self._param_count_total = None
self.accumulated_changes = None
self.shared_parameters_counter = None
def count_params(self, only_trainable=False):
"""
......
......@@ -419,7 +419,12 @@ class Node:
os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
) as of:
json.dump(results_dict, of)
if self.model.shared_parameters_counter is not None:
logging.info("Saving the shared parameter counts")
with open(
os.path.join(self.log_dir, "{}_shared_parameters.json".format(self.rank)), "w"
) as of:
json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
self.communication.disconnect_neighbors()
logging.info("All neighbors disconnected. Process complete!")
......
......@@ -163,7 +163,7 @@ class FFT(PartialModel):
with torch.no_grad():
topk, indices = self.apply_fft()
self.model.shared_parameters_counter[indices] += 1
self.model.rewind_accumulation(indices)
if self.save_shared:
......
......@@ -126,6 +126,8 @@ class PartialModel(Sharing):
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
self.model.shared_parameters_counter = torch.zeros(self.change_transformer(self.init_model).shape[0], dtype = torch.int32)
def extract_top_gradients(self):
"""
Extract the indices and values of the topK gradients.
......@@ -162,6 +164,7 @@ class PartialModel(Sharing):
with torch.no_grad():
_, G_topk = self.extract_top_gradients()
self.model.shared_parameters_counter[G_topk] += 1
if self.accumulation:
self.model.rewind_accumulation(G_topk)
if self.save_shared:
......
......@@ -84,6 +84,7 @@ class RoundRobinPartial(Sharing):
block_end = min(T.shape[0], (self.current_block + 1) * self.block_size)
self.current_block = (self.current_block + 1) % self.num_blocks
T_send = T[block_start:block_end]
self.model.shared_parameters_counter[block_start:block_end] += 1
logging.info("Range sending: {}-{}".format(block_start, block_end))
logging.info("Generating dictionary to send")
......
......@@ -131,6 +131,7 @@ class SubSampling(Sharing):
<= self.alpha
)
subsample = concated[binary_mask]
self.model.shared_parameters_counter[binary_mask] += 1
# logging.debug("Subsampling vector is of size: " + str(subsample.size(dim = 0)))
return (subsample, curr_seed, self.alpha)
else:
......@@ -147,6 +148,7 @@ class SubSampling(Sharing):
)
<= self.alpha
)
# TODO: support shared_parameters_counter
selected = flat[binary_mask]
values_list.append(selected)
off += selected.size(dim=0)
......
......@@ -128,7 +128,7 @@ class TopKParams(Sharing):
with torch.no_grad():
values, index, offsets = self.extract_top_params()
self.model.shared_parameters_counter[index] += 1
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
......
......@@ -184,7 +184,7 @@ class Wavelet(PartialModel):
with torch.no_grad():
topk, indices = self.apply_wavelet()
self.model.shared_parameters_counter[indices] += 1
self.model.rewind_accumulation(indices)
if self.save_shared:
shared_params = dict()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment