Skip to content
Snippets Groups Projects
Commit 7575bd98 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Separate files shared_params

parent 3c3007a0
No related branches found
No related tags found
No related merge requests found
{
"0": "10.90.41.130",
"1": "10.90.41.131",
"2": "10.90.41.132",
"3": "10.90.41.133"
}
\ No newline at end of file
{
"0": "10.90.41.129",
"1": "10.90.41.130",
"2": "10.90.41.131",
"3": "10.90.41.132",
"4": "10.90.41.133"
}
\ No newline at end of file
import json
import logging
import os
from pathlib import Path
import numpy
import torch
......@@ -28,6 +29,10 @@ class PartialModel(Sharing):
self.alpha = alpha
self.dict_ordered = dict_ordered
self.communication_round = 0
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
def extract_top_gradients(self):
logging.info("Summing up gradients")
......@@ -48,26 +53,20 @@ class PartialModel(Sharing):
with torch.no_grad():
_, G_topk = self.extract_top_gradients()
if self.communication_round:
with open(
os.path.join(
self.log_dir, "{}_shared_params.json".format(self.rank)
),
"r",
) as inf:
shared_params = json.load(inf)
else:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = G_topk.tolist()
with open(
os.path.join(self.log_dir, "{}_shared_params.json".format(self.rank)),
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment