Skip to content
Snippets Groups Projects
Commit a6db3f59 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

plot data_shared

parent ae89f9ad
No related branches found
No related tags found
No related merge requests found
...@@ -39,26 +39,43 @@ def plot_results(path): ...@@ -39,26 +39,43 @@ def plot_results(path):
folders.sort() folders.sort()
print("Reading folders from: ", path) print("Reading folders from: ", path)
print("Folders: ", folders) print("Folders: ", folders)
bytes_means, bytes_stdevs = {}, {}
for folder in folders: for folder in folders:
folder_path = os.path.join(path, folder) folder_path = os.path.join(path, folder)
if not os.path.isdir(folder_path): if not os.path.isdir(folder_path):
continue continue
results = [] results = []
files = os.listdir(folder_path) machine_folders = os.listdir(folder_path)
files = [f for f in files if f.endswith("_results.json")] for machine_folder in machine_folders:
for f in files: mf_path = os.path.join(folder_path, machine_folder)
filepath = os.path.join(folder_path, f) if not os.path.isdir(mf_path):
with open(filepath, "r") as inf: continue
results.append(json.load(inf)) files = os.listdir(mf_path)
files = [f for f in files if f.endswith("_results.json")]
for f in files:
filepath = os.path.join(mf_path, f)
with open(filepath, "r") as inf:
results.append(json.load(inf))
# Plot Training loss
plt.figure(1) plt.figure(1)
means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results]) means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right") plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
# Plot Testing loss
plt.figure(2) plt.figure(2)
means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results]) means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right") plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
# Plot Testing Accuracy
plt.figure(3) plt.figure(3)
means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results]) means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right") plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
# Collect total_bytes shared
bytes_list = []
for x in results:
max_key = str(max(list(map(int, x["total_bytes"].keys()))))
bytes_list.append({max_key: x["total_bytes"][max_key]})
means, stdevs, mins, maxs = get_stats(bytes_list)
bytes_means[folder] = list(means.values())[0]
bytes_stdevs[folder] = list(stdevs.values())[0]
plt.figure(1) plt.figure(1)
plt.savefig(os.path.join(path, "train_loss.png")) plt.savefig(os.path.join(path, "train_loss.png"))
...@@ -66,6 +83,20 @@ def plot_results(path): ...@@ -66,6 +83,20 @@ def plot_results(path):
plt.savefig(os.path.join(path, "test_loss.png")) plt.savefig(os.path.join(path, "test_loss.png"))
plt.figure(3) plt.figure(3)
plt.savefig(os.path.join(path, "test_acc.png")) plt.savefig(os.path.join(path, "test_acc.png"))
# Plot total_bytes
plt.figure(4)
plt.title("Data Shared")
x_pos = np.arange(len(bytes_means.keys()))
plt.bar(
x_pos,
np.array(list(bytes_means.values())) // (1024 * 1024),
yerr=np.array(list(bytes_stdevs.values())) // (1024 * 1024),
align="center",
)
plt.ylabel("Total data shared in MBs")
plt.xlabel("Fraction of Model Shared")
plt.xticks(x_pos, list(bytes_means.keys()))
plt.savefig(os.path.join(path, "data_shared.png"))
def plot_parameters(path): def plot_parameters(path):
......
echo "[Cloning leaf repository]"
git clone https://github.com/TalwalkarLab/leaf.git
echo "[Installing unzip]"
sudo apt-get install unzip
cd leaf/data/shakespeare
echo "[Generating non-iid data]"
./preprocess.sh -s niid --sf 1.0 -k 0 -t sample -tf 0.8 --smplseed 10 --spltseed 10
echo "[Data Generated]"
\ No newline at end of file
...@@ -22,6 +22,7 @@ class PartialModel(Sharing): ...@@ -22,6 +22,7 @@ class PartialModel(Sharing):
log_dir, log_dir,
alpha=1.0, alpha=1.0,
dict_ordered=True, dict_ordered=True,
save_shared=False,
): ):
super().__init__( super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir rank, machine_id, communication, mapping, graph, model, dataset, log_dir
...@@ -29,10 +30,17 @@ class PartialModel(Sharing): ...@@ -29,10 +30,17 @@ class PartialModel(Sharing):
self.alpha = alpha self.alpha = alpha
self.dict_ordered = dict_ordered self.dict_ordered = dict_ordered
self.communication_round = 0 self.communication_round = 0
self.folder_path = os.path.join( self.save_shared = save_shared
self.log_dir, "shared_params/{}".format(self.rank)
) # Only save for 2 procs
Path(self.folder_path).mkdir(parents=True, exist_ok=True) if rank == 0 or rank == 1:
self.save_shared = True
if self.save_shared:
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
def extract_top_gradients(self): def extract_top_gradients(self):
logging.info("Summing up gradients") logging.info("Summing up gradients")
...@@ -53,23 +61,24 @@ class PartialModel(Sharing): ...@@ -53,23 +61,24 @@ class PartialModel(Sharing):
with torch.no_grad(): with torch.no_grad():
_, G_topk = self.extract_top_gradients() _, G_topk = self.extract_top_gradients()
shared_params = dict() if self.save_shared:
shared_params["order"] = list(self.model.state_dict().keys()) shared_params = dict()
shapes = dict() shared_params["order"] = list(self.model.state_dict().keys())
for k, v in self.model.state_dict().items(): shapes = dict()
shapes[k] = list(v.shape) for k, v in self.model.state_dict().items():
shared_params["shapes"] = shapes shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = G_topk.tolist()
shared_params[self.communication_round] = G_topk.tolist()
with open(
os.path.join( with open(
self.folder_path, os.path.join(
"{}_shared_params.json".format(self.communication_round + 1), self.folder_path,
), "{}_shared_params.json".format(self.communication_round + 1),
"w", ),
) as of: "w",
json.dump(shared_params, of) ) as of:
json.dump(shared_params, of)
logging.info("Extracting topk params") logging.info("Extracting topk params")
......
...@@ -21,7 +21,10 @@ def get_args(): ...@@ -21,7 +21,10 @@ def get_args():
parser.add_argument("-ps", "--procs_per_machine", type=int, default=1) parser.add_argument("-ps", "--procs_per_machine", type=int, default=1)
parser.add_argument("-ms", "--machines", type=int, default=1) parser.add_argument("-ms", "--machines", type=int, default=1)
parser.add_argument( parser.add_argument(
"-ld", "--log_dir", type=str, default="./{}".format(datetime.datetime.now().isoformat(timespec='minutes')) "-ld",
"--log_dir",
type=str,
default="./{}".format(datetime.datetime.now().isoformat(timespec="minutes")),
) )
parser.add_argument("-is", "--iterations", type=int, default=1) parser.add_argument("-is", "--iterations", type=int, default=1)
parser.add_argument("-cf", "--config_file", type=str, default="config.ini") parser.add_argument("-cf", "--config_file", type=str, default="config.ini")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment