Skip to content
Snippets Groups Projects
Commit a6db3f59 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

plot data_shared

parent ae89f9ad
No related branches found
No related tags found
No related merge requests found
......@@ -39,26 +39,43 @@ def plot_results(path):
folders.sort()
print("Reading folders from: ", path)
print("Folders: ", folders)
bytes_means, bytes_stdevs = {}, {}
for folder in folders:
folder_path = os.path.join(path, folder)
if not os.path.isdir(folder_path):
continue
results = []
files = os.listdir(folder_path)
files = [f for f in files if f.endswith("_results.json")]
for f in files:
filepath = os.path.join(folder_path, f)
with open(filepath, "r") as inf:
results.append(json.load(inf))
machine_folders = os.listdir(folder_path)
for machine_folder in machine_folders:
mf_path = os.path.join(folder_path, machine_folder)
if not os.path.isdir(mf_path):
continue
files = os.listdir(mf_path)
files = [f for f in files if f.endswith("_results.json")]
for f in files:
filepath = os.path.join(mf_path, f)
with open(filepath, "r") as inf:
results.append(json.load(inf))
# Plot Training loss
plt.figure(1)
means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
# Plot Testing loss
plt.figure(2)
means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
# Plot Testing Accuracy
plt.figure(3)
means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
# Collect total_bytes shared
bytes_list = []
for x in results:
max_key = str(max(list(map(int, x["total_bytes"].keys()))))
bytes_list.append({max_key: x["total_bytes"][max_key]})
means, stdevs, mins, maxs = get_stats(bytes_list)
bytes_means[folder] = list(means.values())[0]
bytes_stdevs[folder] = list(stdevs.values())[0]
plt.figure(1)
plt.savefig(os.path.join(path, "train_loss.png"))
......@@ -66,6 +83,20 @@ def plot_results(path):
plt.savefig(os.path.join(path, "test_loss.png"))
plt.figure(3)
plt.savefig(os.path.join(path, "test_acc.png"))
# Plot total_bytes
plt.figure(4)
plt.title("Data Shared")
x_pos = np.arange(len(bytes_means.keys()))
plt.bar(
x_pos,
np.array(list(bytes_means.values())) // (1024 * 1024),
yerr=np.array(list(bytes_stdevs.values())) // (1024 * 1024),
align="center",
)
plt.ylabel("Total data shared in MBs")
plt.xlabel("Fraction of Model Shared")
plt.xticks(x_pos, list(bytes_means.keys()))
plt.savefig(os.path.join(path, "data_shared.png"))
def plot_parameters(path):
......
echo "[Cloning leaf repository]"
git clone https://github.com/TalwalkarLab/leaf.git
echo "[Installing unzip]"
sudo apt-get install unzip
cd leaf/data/shakespeare
echo "[Generating non-iid data]"
./preprocess.sh -s niid --sf 1.0 -k 0 -t sample -tf 0.8 --smplseed 10 --spltseed 10
echo "[Data Generated]"
\ No newline at end of file
......@@ -22,6 +22,7 @@ class PartialModel(Sharing):
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
):
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir
......@@ -29,10 +30,17 @@ class PartialModel(Sharing):
self.alpha = alpha
self.dict_ordered = dict_ordered
self.communication_round = 0
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
self.save_shared = save_shared
# Only save for 2 procs
if rank == 0 or rank == 1:
self.save_shared = True
if self.save_shared:
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
def extract_top_gradients(self):
logging.info("Summing up gradients")
......@@ -53,23 +61,24 @@ class PartialModel(Sharing):
with torch.no_grad():
_, G_topk = self.extract_top_gradients()
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = G_topk.tolist()
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = G_topk.tolist()
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
logging.info("Extracting topk params")
......
......@@ -21,7 +21,10 @@ def get_args():
parser.add_argument("-ps", "--procs_per_machine", type=int, default=1)
parser.add_argument("-ms", "--machines", type=int, default=1)
parser.add_argument(
"-ld", "--log_dir", type=str, default="./{}".format(datetime.datetime.now().isoformat(timespec='minutes'))
"-ld",
"--log_dir",
type=str,
default="./{}".format(datetime.datetime.now().isoformat(timespec="minutes")),
)
parser.add_argument("-is", "--iterations", type=int, default=1)
parser.add_argument("-cf", "--config_file", type=str, default="config.ini")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment