From da846d2a1f038e136112325b7490d3712c23834a Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Thu, 28 Apr 2022 22:30:40 +0200
Subject: [PATCH] Update random alpha

---
 src/decentralizepy/sharing/PartialModel.py    |  2 +
 src/decentralizepy/sharing/RandomAlpha.py     | 19 +++-
 .../sharing/RandomAlphaWavelet.py             | 93 +++++++++++++++++++
 src/decentralizepy/sharing/Wavelet.py         |  6 +-
 4 files changed, 114 insertions(+), 6 deletions(-)
 create mode 100644 src/decentralizepy/sharing/RandomAlphaWavelet.py

diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 7d8e7fc..cec6685 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -201,6 +201,8 @@ class PartialModel(Sharing):
             if not self.dict_ordered:
                 raise NotImplementedError
 
+            m["alpha"] = self.alpha
+
             m["indices"] = G_topk.numpy().astype(np.int32)
 
             m["params"] = T_topk.numpy()
diff --git a/src/decentralizepy/sharing/RandomAlpha.py b/src/decentralizepy/sharing/RandomAlpha.py
index c81e233..a91ba1d 100644
--- a/src/decentralizepy/sharing/RandomAlpha.py
+++ b/src/decentralizepy/sharing/RandomAlpha.py
@@ -1,6 +1,7 @@
 import random
 
 from decentralizepy.sharing.PartialModel import PartialModel
+from decentralizepy.utils import identity
 
 
 class RandomAlpha(PartialModel):
@@ -19,9 +20,14 @@ class RandomAlpha(PartialModel):
         model,
         dataset,
         log_dir,
+        alpha_list=[0.1,0.2,0.3,0.4,1.0],
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
+        accumulation=False,
+        save_accumulated="",
+        change_transformer=identity,
+        accumulate_averaging_changes=False,
     ):
         """
         Constructor
@@ -65,6 +71,14 @@ class RandomAlpha(PartialModel):
             dict_ordered,
             save_shared,
             metadata_cap,
+            accumulation,
+            save_accumulated,
+            change_transformer,
+            accumulate_averaging_changes
+        )
+        self.alpha_list = eval(alpha_list)
+        random.seed(
+            self.mapping.get_uid(self.rank, self.machine_id)
         )
 
     def step(self):
@@ -72,8 +86,5 @@ class RandomAlpha(PartialModel):
         Perform a sharing step. Implements D-PSGD with alpha randomly chosen.
 
         """
-        random.seed(
-            self.mapping.get_uid(self.rank, self.machine_id) + self.communication_round
-        )
-        self.alpha = random.randint(1, 7) / 10.0
+        self.alpha = random.choice(self.alpha_list)
         super().step()
diff --git a/src/decentralizepy/sharing/RandomAlphaWavelet.py b/src/decentralizepy/sharing/RandomAlphaWavelet.py
new file mode 100644
index 0000000..61c17bb
--- /dev/null
+++ b/src/decentralizepy/sharing/RandomAlphaWavelet.py
@@ -0,0 +1,93 @@
+import random
+
+from decentralizepy.sharing.Wavelet import Wavelet
+
+
+class RandomAlpha(Wavelet):
+    """
+    This class implements the partial model sharing with a random alpha each iteration.
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha_list=[0.1,0.2,0.3,0.4,1.0],
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+        wavelet="haar",
+        level=4,
+        change_based_selection=True,
+        save_accumulated="",
+        accumulation=False,
+        accumulate_averaging_changes=False,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+
+        """
+        super().__init__(
+            rank,
+            machine_id,
+            communication,
+            mapping,
+            graph,
+            model,
+            dataset,
+            log_dir,
+            1.0,
+            dict_ordered,
+            save_shared,
+            metadata_cap,
+            wavelet,
+            level,
+            change_based_selection,
+            save_accumulated,
+            accumulation,
+            accumulate_averaging_changes,
+        )
+        self.alpha_list = eval(alpha_list)
+        random.seed(
+            self.mapping.get_uid(self.rank, self.machine_id)
+        )
+
+    def step(self):
+        """
+        Perform a sharing step. Implements D-PSGD with alpha randomly chosen.
+
+        """
+        self.alpha = random.choice(self.alpha_list)
+        super().step()
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index 9a02a64..cf67d98 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -179,7 +179,7 @@ class Wavelet(PartialModel):
             Model converted to json dict
 
         """
-        if self.alpha > self.metadata_cap:  # Share fully
+        if self.alpha >= self.metadata_cap:  # Share fully
             return super().serialized_model()
 
         with torch.no_grad():
@@ -218,6 +218,8 @@ class Wavelet(PartialModel):
 
             m["indices"] = indices.numpy().astype(np.int32)
 
+            m["send_partial"] = True
+
             self.total_data += len(self.communication.encrypt(m["params"]))
             self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
                 self.communication.encrypt(m["alpha"])
@@ -240,7 +242,7 @@ class Wavelet(PartialModel):
             state_dict of received
 
         """
-        if self.alpha > self.metadata_cap:  # Share fully
+        if "send_partial" not in m:
             return super().deserialized_model(m)
 
         with torch.no_grad():
-- 
GitLab