diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index ae2aff9ade98e7623e29f7268c4f202c189fce2f..90140e802d56c0a20323f815f158ad80f8c5acf0 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -219,11 +219,17 @@ class Node: os.path.join(log_dir, "{}_test_acc.png".format(self.rank)), ) - with open(os.path.join(log_dir, "{}_train_loss.json"), "w") as of: + with open( + os.path.join(log_dir, "{}_train_loss.json".format(self.rank)), "w" + ) as of: json.dump(self.train_loss, of) - with open(os.path.join(log_dir, "{}_test_loss.json"), "w") as of: + with open( + os.path.join(log_dir, "{}_test_loss.json".format(self.rank)), "w" + ) as of: json.dump(self.test_loss, of) - with open(os.path.join(log_dir, "{}_test_acc.json"), "w") as of: + with open( + os.path.join(log_dir, "{}_test_acc.json".format(self.rank)), "w" + ) as of: json.dump(self.test_acc, of) self.communication.disconnect_neighbors() diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 0b2f60675314ba3a20703f510287456bd1fa11d1..ac175e337e7c1d9c98dbf35775adec494ba2c55f 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -1,5 +1,6 @@ import json import logging +import os import numpy import torch @@ -17,14 +18,16 @@ class PartialModel(Sharing): graph, model, dataset, + log_dir, alpha=1.0, dict_ordered=True, ): super().__init__( - rank, machine_id, communication, mapping, graph, model, dataset + rank, machine_id, communication, mapping, graph, model, dataset, log_dir ) self.alpha = alpha self.dict_ordered = dict_ordered + self.communication_round = 0 def extract_top_gradients(self): logging.info("Summing up gradients") @@ -44,6 +47,31 @@ class PartialModel(Sharing): def serialized_model(self): with torch.no_grad(): _, G_topk = self.extract_top_gradients() + + if self.communication_round: + with open( + os.path.join( + self.log_dir, "{}_shared_params.json".format(self.rank) + ), + "r", + ) as inf: + shared_params = json.load(inf) + else: + shared_params = dict() + shared_params["order"] = self.model.state_dict().keys() + shapes = dict() + for k, v in self.model.state_dict.items(): + shapes[k] = v.shape.tolist() + shared_params["shapes"] = shapes + + shared_params[self.communication_round] = G_topk.tolist() + + with open( + os.path.join(self.log_dir, "{}_shared_params.json".format(self.rank)), + "w", + ) as of: + json.dump(shared_params, of) + logging.info("Extracting topk params") tensors_to_cat = [v.data.flatten() for v in self.model.parameters()] diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index 186c2d2f69567204a465837a636a9b43749ed11a..03e80fd52a8e4b14c381dd3b463a926e94741041 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -11,7 +11,9 @@ class Sharing: API defining who to share with and what, and what to do on receiving """ - def __init__(self, rank, machine_id, communication, mapping, graph, model, dataset): + def __init__( + self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir + ): self.rank = rank self.machine_id = machine_id self.uid = mapping.get_uid(rank, machine_id) @@ -20,6 +22,7 @@ class Sharing: self.graph = graph self.model = model self.dataset = dataset + self.log_dir = log_dir self.peer_deques = dict() my_neighbors = self.graph.neighbors(self.uid)