From 2efaec4a2261f339a279d7ee8ed19f14bcae91be Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Sat, 19 Mar 2022 13:23:11 +0100
Subject: [PATCH] Changing the accumulation implementation

---
 src/decentralizepy/sharing/FFT.py          |  6 +++++-
 src/decentralizepy/sharing/PartialModel.py | 21 ++++++++++++++++-----
 src/decentralizepy/sharing/Wavelet.py      |  6 +++++-
 3 files changed, 26 insertions(+), 7 deletions(-)

diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
index 1bc7e0e..6af75e3 100644
--- a/src/decentralizepy/sharing/FFT.py
+++ b/src/decentralizepy/sharing/FFT.py
@@ -51,6 +51,7 @@ class FFT(PartialModel):
         change_based_selection=True,
         save_accumulated="",
         accumulation=True,
+        accumulate_averaging_changes=False
     ):
         """
         Constructor
@@ -88,10 +89,13 @@ class FFT(PartialModel):
             the accumulated change is stored.
         accumulation : bool
             True if the the indices to share should be selected based on accumulated frequency change
+        accumulate_averaging_changes: bool
+            True if the accumulation should account the model change due to averaging
+
         """
         super().__init__(
             rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared,
-            metadata_cap, accumulation, save_accumulated, change_transformer_fft
+            metadata_cap, accumulation, save_accumulated, change_transformer_fft, accumulate_averaging_changes
         )
         self.change_based_selection = change_based_selection
 
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index c961c43..935f302 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -32,7 +32,8 @@ class PartialModel(Sharing):
         metadata_cap=1.0,
         accumulation = False,
         save_accumulated="",
-        change_transformer = identity
+        change_transformer = identity,
+        accumulate_averaging_changes = False
     ):
         """
         Constructor
@@ -70,6 +71,8 @@ class PartialModel(Sharing):
             is stored. If a change_transformer is used then the transformed change is stored.
         change_transformer : (x: Tensor) -> Tensor
             A function that transforms the model change into other domains. Default: identity function
+        accumulate_averaging_changes: bool
+            True if the accumulation should account the model change due to averaging
 
         """
         super().__init__(
@@ -83,6 +86,7 @@ class PartialModel(Sharing):
         self.accumulation = accumulation
         self.save_accumulated = conditional_value(save_accumulated, "", False)
         self.change_transformer = change_transformer
+        self.accumulate_averaging_changes = accumulate_averaging_changes
 
         # getting the initial model
         self.shapes = []
@@ -266,7 +270,14 @@ class PartialModel(Sharing):
             pre_share_model = torch.cat(tensors_to_cat, dim=0)
             change = self.change_transformer(pre_share_model - self.init_model)
             if self.accumulation:
-                change += self.model.accumulated_changes
+                if not self.accumulate_averaging_changes:
+                    # Need to accumulate in _pre_step as the accumulation gets rewind during the step
+                    self.model.accumulated_changes += change
+                    change = self.model.accumulated_changes.clone().detach()
+                else:
+                    # For the legacy implementation, we will only rewind currently accumulated values
+                    # and add the model change due to averaging in the end
+                    change += self.model.accumulated_changes
             # stores change of the model due to training, change due to averaging is not accounted
             self.model.model_change = change
 
@@ -277,16 +288,16 @@ class PartialModel(Sharing):
         """
         logging.info("PartialModel _post_step")
         with torch.no_grad():
-            self.model.model_change = None
             tensors_to_cat = [
                 v.data.flatten() for _, v in self.model.state_dict().items()
             ]
             post_share_model = torch.cat(tensors_to_cat, dim=0)
             self.init_model = post_share_model
             if self.accumulation:
-                self.model.accumulated_changes += self.change_transformer(self.init_model - self.prev)
+                if self.accumulate_averaging_changes:
+                    self.model.accumulated_changes += self.change_transformer(self.init_model - self.prev)
                 self.prev = self.init_model
-
+            self.model.model_change = None
         if self.save_accumulated:
             self.save_change()
 
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index 1b73b29..cd039f8 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -58,6 +58,7 @@ class Wavelet(PartialModel):
         change_based_selection=True,
         save_accumulated="",
         accumulation=False,
+        accumulate_averaging_changes = False
     ):
         """
         Constructor
@@ -99,13 +100,16 @@ class Wavelet(PartialModel):
             the accumulated change is stored.
         accumulation : bool
             True if the the indices to share should be selected based on accumulated frequency change
+        accumulate_averaging_changes: bool
+            True if the accumulation should account the model change due to averaging
         """
         self.wavelet = wavelet
         self.level = level
 
         super().__init__(
             rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared,
-            metadata_cap, accumulation, save_accumulated, lambda x : change_transformer_wavelet(x, wavelet, level)
+            metadata_cap, accumulation, save_accumulated, lambda x : change_transformer_wavelet(x, wavelet, level),
+            accumulate_averaging_changes
         )
 
         self.change_based_selection = change_based_selection
-- 
GitLab