From 327519cce88383b30c75191263f16054c5dbb3ad Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Sun, 9 Jan 2022 17:53:18 +0100
Subject: [PATCH] Debug TCP

---
 src/decentralizepy/sharing/GrowingAlpha.py | 23 ++++------------------
 src/decentralizepy/sharing/PartialModel.py |  8 ++++++++
 2 files changed, 12 insertions(+), 19 deletions(-)

diff --git a/src/decentralizepy/sharing/GrowingAlpha.py b/src/decentralizepy/sharing/GrowingAlpha.py
index f587cf5..1fb8410 100644
--- a/src/decentralizepy/sharing/GrowingAlpha.py
+++ b/src/decentralizepy/sharing/GrowingAlpha.py
@@ -1,7 +1,6 @@
 import logging
 
 from decentralizepy.sharing.PartialModel import PartialModel
-from decentralizepy.sharing.Sharing import Sharing
 
 
 class GrowingAlpha(PartialModel):
@@ -18,9 +17,9 @@ class GrowingAlpha(PartialModel):
         init_alpha=0.0,
         max_alpha=1.0,
         k=10,
-        metadata_cap=0.6,
         dict_ordered=True,
         save_shared=False,
+        metadata_cap=1.0,
     ):
         super().__init__(
             rank,
@@ -34,34 +33,20 @@ class GrowingAlpha(PartialModel):
             init_alpha,
             dict_ordered,
             save_shared,
+            metadata_cap,
         )
         self.init_alpha = init_alpha
         self.max_alpha = max_alpha
         self.k = k
-        self.metadata_cap = metadata_cap
-        self.base = None
 
     def step(self):
         if (self.communication_round + 1) % self.k == 0:
             self.alpha += (self.max_alpha - self.init_alpha) / self.k
+            self.alpha = min(self.alpha, 1.00)
 
         if self.alpha == 0.0:
             logging.info("Not sending/receiving data (alpha=0.0)")
             self.communication_round += 1
             return
 
-        if self.alpha > self.metadata_cap:
-            if self.base == None:
-                self.base = Sharing(
-                    self.rank,
-                    self.machine_id,
-                    self.communication,
-                    self.mapping,
-                    self.graph,
-                    self.model,
-                    self.dataset,
-                )
-                self.base.communication_round = self.communication_round
-            self.base.step()
-        else:
-            super().step()
+        super().step()
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 1df11c4..0836d78 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -22,6 +22,7 @@ class PartialModel(Sharing):
         alpha=1.0,
         dict_ordered=True,
         save_shared=False,
+        metadata_cap=1.0,
     ):
         """
         Constructor
@@ -51,6 +52,7 @@ class PartialModel(Sharing):
         self.alpha = alpha
         self.dict_ordered = dict_ordered
         self.save_shared = save_shared
+        self.metadata_cap = metadata_cap
 
         # Only save for 2 procs
         if rank == 0 or rank == 1:
@@ -78,6 +80,9 @@ class PartialModel(Sharing):
         )
 
     def serialized_model(self):
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().serialized_model()
+
         with torch.no_grad():
             _, G_topk = self.extract_top_gradients()
 
@@ -129,6 +134,9 @@ class PartialModel(Sharing):
             return m
 
     def deserialized_model(self, m):
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().deserialized_model(m)
+
         with torch.no_grad():
             state_dict = self.model.state_dict()
 
-- 
GitLab