Skip to content
Snippets Groups Projects
Commit 2efaec4a authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

Changing the accumulation implementation

parent 9c9efb16
No related branches found
No related tags found
1 merge request!3FFT Wavelets and more
......@@ -51,6 +51,7 @@ class FFT(PartialModel):
change_based_selection=True,
save_accumulated="",
accumulation=True,
accumulate_averaging_changes=False
):
"""
Constructor
......@@ -88,10 +89,13 @@ class FFT(PartialModel):
the accumulated change is stored.
accumulation : bool
True if the the indices to share should be selected based on accumulated frequency change
accumulate_averaging_changes: bool
True if the accumulation should account the model change due to averaging
"""
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared,
metadata_cap, accumulation, save_accumulated, change_transformer_fft
metadata_cap, accumulation, save_accumulated, change_transformer_fft, accumulate_averaging_changes
)
self.change_based_selection = change_based_selection
......
......@@ -32,7 +32,8 @@ class PartialModel(Sharing):
metadata_cap=1.0,
accumulation = False,
save_accumulated="",
change_transformer = identity
change_transformer = identity,
accumulate_averaging_changes = False
):
"""
Constructor
......@@ -70,6 +71,8 @@ class PartialModel(Sharing):
is stored. If a change_transformer is used then the transformed change is stored.
change_transformer : (x: Tensor) -> Tensor
A function that transforms the model change into other domains. Default: identity function
accumulate_averaging_changes: bool
True if the accumulation should account the model change due to averaging
"""
super().__init__(
......@@ -83,6 +86,7 @@ class PartialModel(Sharing):
self.accumulation = accumulation
self.save_accumulated = conditional_value(save_accumulated, "", False)
self.change_transformer = change_transformer
self.accumulate_averaging_changes = accumulate_averaging_changes
# getting the initial model
self.shapes = []
......@@ -266,7 +270,14 @@ class PartialModel(Sharing):
pre_share_model = torch.cat(tensors_to_cat, dim=0)
change = self.change_transformer(pre_share_model - self.init_model)
if self.accumulation:
change += self.model.accumulated_changes
if not self.accumulate_averaging_changes:
# Need to accumulate in _pre_step as the accumulation gets rewind during the step
self.model.accumulated_changes += change
change = self.model.accumulated_changes.clone().detach()
else:
# For the legacy implementation, we will only rewind currently accumulated values
# and add the model change due to averaging in the end
change += self.model.accumulated_changes
# stores change of the model due to training, change due to averaging is not accounted
self.model.model_change = change
......@@ -277,16 +288,16 @@ class PartialModel(Sharing):
"""
logging.info("PartialModel _post_step")
with torch.no_grad():
self.model.model_change = None
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
post_share_model = torch.cat(tensors_to_cat, dim=0)
self.init_model = post_share_model
if self.accumulation:
self.model.accumulated_changes += self.change_transformer(self.init_model - self.prev)
if self.accumulate_averaging_changes:
self.model.accumulated_changes += self.change_transformer(self.init_model - self.prev)
self.prev = self.init_model
self.model.model_change = None
if self.save_accumulated:
self.save_change()
......
......@@ -58,6 +58,7 @@ class Wavelet(PartialModel):
change_based_selection=True,
save_accumulated="",
accumulation=False,
accumulate_averaging_changes = False
):
"""
Constructor
......@@ -99,13 +100,16 @@ class Wavelet(PartialModel):
the accumulated change is stored.
accumulation : bool
True if the the indices to share should be selected based on accumulated frequency change
accumulate_averaging_changes: bool
True if the accumulation should account the model change due to averaging
"""
self.wavelet = wavelet
self.level = level
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared,
metadata_cap, accumulation, save_accumulated, lambda x : change_transformer_wavelet(x, wavelet, level)
metadata_cap, accumulation, save_accumulated, lambda x : change_transformer_wavelet(x, wavelet, level),
accumulate_averaging_changes
)
self.change_based_selection = change_based_selection
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment