Skip to content
Snippets Groups Projects
Commit dd797778 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Fix writes of acc, loss; Log which params are shared

parent 5985bb90
No related branches found
No related tags found
No related merge requests found
...@@ -219,11 +219,17 @@ class Node: ...@@ -219,11 +219,17 @@ class Node:
os.path.join(log_dir, "{}_test_acc.png".format(self.rank)), os.path.join(log_dir, "{}_test_acc.png".format(self.rank)),
) )
with open(os.path.join(log_dir, "{}_train_loss.json"), "w") as of: with open(
os.path.join(log_dir, "{}_train_loss.json".format(self.rank)), "w"
) as of:
json.dump(self.train_loss, of) json.dump(self.train_loss, of)
with open(os.path.join(log_dir, "{}_test_loss.json"), "w") as of: with open(
os.path.join(log_dir, "{}_test_loss.json".format(self.rank)), "w"
) as of:
json.dump(self.test_loss, of) json.dump(self.test_loss, of)
with open(os.path.join(log_dir, "{}_test_acc.json"), "w") as of: with open(
os.path.join(log_dir, "{}_test_acc.json".format(self.rank)), "w"
) as of:
json.dump(self.test_acc, of) json.dump(self.test_acc, of)
self.communication.disconnect_neighbors() self.communication.disconnect_neighbors()
import json import json
import logging import logging
import os
import numpy import numpy
import torch import torch
...@@ -17,14 +18,16 @@ class PartialModel(Sharing): ...@@ -17,14 +18,16 @@ class PartialModel(Sharing):
graph, graph,
model, model,
dataset, dataset,
log_dir,
alpha=1.0, alpha=1.0,
dict_ordered=True, dict_ordered=True,
): ):
super().__init__( super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset rank, machine_id, communication, mapping, graph, model, dataset, log_dir
) )
self.alpha = alpha self.alpha = alpha
self.dict_ordered = dict_ordered self.dict_ordered = dict_ordered
self.communication_round = 0
def extract_top_gradients(self): def extract_top_gradients(self):
logging.info("Summing up gradients") logging.info("Summing up gradients")
...@@ -44,6 +47,31 @@ class PartialModel(Sharing): ...@@ -44,6 +47,31 @@ class PartialModel(Sharing):
def serialized_model(self): def serialized_model(self):
with torch.no_grad(): with torch.no_grad():
_, G_topk = self.extract_top_gradients() _, G_topk = self.extract_top_gradients()
if self.communication_round:
with open(
os.path.join(
self.log_dir, "{}_shared_params.json".format(self.rank)
),
"r",
) as inf:
shared_params = json.load(inf)
else:
shared_params = dict()
shared_params["order"] = self.model.state_dict().keys()
shapes = dict()
for k, v in self.model.state_dict.items():
shapes[k] = v.shape.tolist()
shared_params["shapes"] = shapes
shared_params[self.communication_round] = G_topk.tolist()
with open(
os.path.join(self.log_dir, "{}_shared_params.json".format(self.rank)),
"w",
) as of:
json.dump(shared_params, of)
logging.info("Extracting topk params") logging.info("Extracting topk params")
tensors_to_cat = [v.data.flatten() for v in self.model.parameters()] tensors_to_cat = [v.data.flatten() for v in self.model.parameters()]
......
...@@ -11,7 +11,9 @@ class Sharing: ...@@ -11,7 +11,9 @@ class Sharing:
API defining who to share with and what, and what to do on receiving API defining who to share with and what, and what to do on receiving
""" """
def __init__(self, rank, machine_id, communication, mapping, graph, model, dataset): def __init__(
self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir
):
self.rank = rank self.rank = rank
self.machine_id = machine_id self.machine_id = machine_id
self.uid = mapping.get_uid(rank, machine_id) self.uid = mapping.get_uid(rank, machine_id)
...@@ -20,6 +22,7 @@ class Sharing: ...@@ -20,6 +22,7 @@ class Sharing:
self.graph = graph self.graph = graph
self.model = model self.model = model
self.dataset = dataset self.dataset = dataset
self.log_dir = log_dir
self.peer_deques = dict() self.peer_deques = dict()
my_neighbors = self.graph.neighbors(self.uid) my_neighbors = self.graph.neighbors(self.uid)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment