From 327519cce88383b30c75191263f16054c5dbb3ad Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Sun, 9 Jan 2022 17:53:18 +0100 Subject: [PATCH] Debug TCP --- src/decentralizepy/sharing/GrowingAlpha.py | 23 ++++------------------ src/decentralizepy/sharing/PartialModel.py | 8 ++++++++ 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/src/decentralizepy/sharing/GrowingAlpha.py b/src/decentralizepy/sharing/GrowingAlpha.py index f587cf5..1fb8410 100644 --- a/src/decentralizepy/sharing/GrowingAlpha.py +++ b/src/decentralizepy/sharing/GrowingAlpha.py @@ -1,7 +1,6 @@ import logging from decentralizepy.sharing.PartialModel import PartialModel -from decentralizepy.sharing.Sharing import Sharing class GrowingAlpha(PartialModel): @@ -18,9 +17,9 @@ class GrowingAlpha(PartialModel): init_alpha=0.0, max_alpha=1.0, k=10, - metadata_cap=0.6, dict_ordered=True, save_shared=False, + metadata_cap=1.0, ): super().__init__( rank, @@ -34,34 +33,20 @@ class GrowingAlpha(PartialModel): init_alpha, dict_ordered, save_shared, + metadata_cap, ) self.init_alpha = init_alpha self.max_alpha = max_alpha self.k = k - self.metadata_cap = metadata_cap - self.base = None def step(self): if (self.communication_round + 1) % self.k == 0: self.alpha += (self.max_alpha - self.init_alpha) / self.k + self.alpha = min(self.alpha, 1.00) if self.alpha == 0.0: logging.info("Not sending/receiving data (alpha=0.0)") self.communication_round += 1 return - if self.alpha > self.metadata_cap: - if self.base == None: - self.base = Sharing( - self.rank, - self.machine_id, - self.communication, - self.mapping, - self.graph, - self.model, - self.dataset, - ) - self.base.communication_round = self.communication_round - self.base.step() - else: - super().step() + super().step() diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 1df11c4..0836d78 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -22,6 +22,7 @@ class PartialModel(Sharing): alpha=1.0, dict_ordered=True, save_shared=False, + metadata_cap=1.0, ): """ Constructor @@ -51,6 +52,7 @@ class PartialModel(Sharing): self.alpha = alpha self.dict_ordered = dict_ordered self.save_shared = save_shared + self.metadata_cap = metadata_cap # Only save for 2 procs if rank == 0 or rank == 1: @@ -78,6 +80,9 @@ class PartialModel(Sharing): ) def serialized_model(self): + if self.alpha > self.metadata_cap: # Share fully + return super().serialized_model() + with torch.no_grad(): _, G_topk = self.extract_top_gradients() @@ -129,6 +134,9 @@ class PartialModel(Sharing): return m def deserialized_model(self, m): + if self.alpha > self.metadata_cap: # Share fully + return super().deserialized_model(m) + with torch.no_grad(): state_dict = self.model.state_dict() -- GitLab