From 2efaec4a2261f339a279d7ee8ed19f14bcae91be Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Sat, 19 Mar 2022 13:23:11 +0100 Subject: [PATCH] Changing the accumulation implementation --- src/decentralizepy/sharing/FFT.py | 6 +++++- src/decentralizepy/sharing/PartialModel.py | 21 ++++++++++++++++----- src/decentralizepy/sharing/Wavelet.py | 6 +++++- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py index 1bc7e0e..6af75e3 100644 --- a/src/decentralizepy/sharing/FFT.py +++ b/src/decentralizepy/sharing/FFT.py @@ -51,6 +51,7 @@ class FFT(PartialModel): change_based_selection=True, save_accumulated="", accumulation=True, + accumulate_averaging_changes=False ): """ Constructor @@ -88,10 +89,13 @@ class FFT(PartialModel): the accumulated change is stored. accumulation : bool True if the the indices to share should be selected based on accumulated frequency change + accumulate_averaging_changes: bool + True if the accumulation should account the model change due to averaging + """ super().__init__( rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared, - metadata_cap, accumulation, save_accumulated, change_transformer_fft + metadata_cap, accumulation, save_accumulated, change_transformer_fft, accumulate_averaging_changes ) self.change_based_selection = change_based_selection diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index c961c43..935f302 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -32,7 +32,8 @@ class PartialModel(Sharing): metadata_cap=1.0, accumulation = False, save_accumulated="", - change_transformer = identity + change_transformer = identity, + accumulate_averaging_changes = False ): """ Constructor @@ -70,6 +71,8 @@ class PartialModel(Sharing): is stored. If a change_transformer is used then the transformed change is stored. change_transformer : (x: Tensor) -> Tensor A function that transforms the model change into other domains. Default: identity function + accumulate_averaging_changes: bool + True if the accumulation should account the model change due to averaging """ super().__init__( @@ -83,6 +86,7 @@ class PartialModel(Sharing): self.accumulation = accumulation self.save_accumulated = conditional_value(save_accumulated, "", False) self.change_transformer = change_transformer + self.accumulate_averaging_changes = accumulate_averaging_changes # getting the initial model self.shapes = [] @@ -266,7 +270,14 @@ class PartialModel(Sharing): pre_share_model = torch.cat(tensors_to_cat, dim=0) change = self.change_transformer(pre_share_model - self.init_model) if self.accumulation: - change += self.model.accumulated_changes + if not self.accumulate_averaging_changes: + # Need to accumulate in _pre_step as the accumulation gets rewind during the step + self.model.accumulated_changes += change + change = self.model.accumulated_changes.clone().detach() + else: + # For the legacy implementation, we will only rewind currently accumulated values + # and add the model change due to averaging in the end + change += self.model.accumulated_changes # stores change of the model due to training, change due to averaging is not accounted self.model.model_change = change @@ -277,16 +288,16 @@ class PartialModel(Sharing): """ logging.info("PartialModel _post_step") with torch.no_grad(): - self.model.model_change = None tensors_to_cat = [ v.data.flatten() for _, v in self.model.state_dict().items() ] post_share_model = torch.cat(tensors_to_cat, dim=0) self.init_model = post_share_model if self.accumulation: - self.model.accumulated_changes += self.change_transformer(self.init_model - self.prev) + if self.accumulate_averaging_changes: + self.model.accumulated_changes += self.change_transformer(self.init_model - self.prev) self.prev = self.init_model - + self.model.model_change = None if self.save_accumulated: self.save_change() diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py index 1b73b29..cd039f8 100644 --- a/src/decentralizepy/sharing/Wavelet.py +++ b/src/decentralizepy/sharing/Wavelet.py @@ -58,6 +58,7 @@ class Wavelet(PartialModel): change_based_selection=True, save_accumulated="", accumulation=False, + accumulate_averaging_changes = False ): """ Constructor @@ -99,13 +100,16 @@ class Wavelet(PartialModel): the accumulated change is stored. accumulation : bool True if the the indices to share should be selected based on accumulated frequency change + accumulate_averaging_changes: bool + True if the accumulation should account the model change due to averaging """ self.wavelet = wavelet self.level = level super().__init__( rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared, - metadata_cap, accumulation, save_accumulated, lambda x : change_transformer_wavelet(x, wavelet, level) + metadata_cap, accumulation, save_accumulated, lambda x : change_transformer_wavelet(x, wavelet, level), + accumulate_averaging_changes ) self.change_based_selection = change_based_selection -- GitLab