Skip to content
Snippets Groups Projects
Commit 2efaec4a authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

Changing the accumulation implementation

parent 9c9efb16
No related branches found
No related tags found
1 merge request!3FFT Wavelets and more
...@@ -51,6 +51,7 @@ class FFT(PartialModel): ...@@ -51,6 +51,7 @@ class FFT(PartialModel):
change_based_selection=True, change_based_selection=True,
save_accumulated="", save_accumulated="",
accumulation=True, accumulation=True,
accumulate_averaging_changes=False
): ):
""" """
Constructor Constructor
...@@ -88,10 +89,13 @@ class FFT(PartialModel): ...@@ -88,10 +89,13 @@ class FFT(PartialModel):
the accumulated change is stored. the accumulated change is stored.
accumulation : bool accumulation : bool
True if the the indices to share should be selected based on accumulated frequency change True if the the indices to share should be selected based on accumulated frequency change
accumulate_averaging_changes: bool
True if the accumulation should account the model change due to averaging
""" """
super().__init__( super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared, rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared,
metadata_cap, accumulation, save_accumulated, change_transformer_fft metadata_cap, accumulation, save_accumulated, change_transformer_fft, accumulate_averaging_changes
) )
self.change_based_selection = change_based_selection self.change_based_selection = change_based_selection
......
...@@ -32,7 +32,8 @@ class PartialModel(Sharing): ...@@ -32,7 +32,8 @@ class PartialModel(Sharing):
metadata_cap=1.0, metadata_cap=1.0,
accumulation = False, accumulation = False,
save_accumulated="", save_accumulated="",
change_transformer = identity change_transformer = identity,
accumulate_averaging_changes = False
): ):
""" """
Constructor Constructor
...@@ -70,6 +71,8 @@ class PartialModel(Sharing): ...@@ -70,6 +71,8 @@ class PartialModel(Sharing):
is stored. If a change_transformer is used then the transformed change is stored. is stored. If a change_transformer is used then the transformed change is stored.
change_transformer : (x: Tensor) -> Tensor change_transformer : (x: Tensor) -> Tensor
A function that transforms the model change into other domains. Default: identity function A function that transforms the model change into other domains. Default: identity function
accumulate_averaging_changes: bool
True if the accumulation should account the model change due to averaging
""" """
super().__init__( super().__init__(
...@@ -83,6 +86,7 @@ class PartialModel(Sharing): ...@@ -83,6 +86,7 @@ class PartialModel(Sharing):
self.accumulation = accumulation self.accumulation = accumulation
self.save_accumulated = conditional_value(save_accumulated, "", False) self.save_accumulated = conditional_value(save_accumulated, "", False)
self.change_transformer = change_transformer self.change_transformer = change_transformer
self.accumulate_averaging_changes = accumulate_averaging_changes
# getting the initial model # getting the initial model
self.shapes = [] self.shapes = []
...@@ -266,7 +270,14 @@ class PartialModel(Sharing): ...@@ -266,7 +270,14 @@ class PartialModel(Sharing):
pre_share_model = torch.cat(tensors_to_cat, dim=0) pre_share_model = torch.cat(tensors_to_cat, dim=0)
change = self.change_transformer(pre_share_model - self.init_model) change = self.change_transformer(pre_share_model - self.init_model)
if self.accumulation: if self.accumulation:
change += self.model.accumulated_changes if not self.accumulate_averaging_changes:
# Need to accumulate in _pre_step as the accumulation gets rewind during the step
self.model.accumulated_changes += change
change = self.model.accumulated_changes.clone().detach()
else:
# For the legacy implementation, we will only rewind currently accumulated values
# and add the model change due to averaging in the end
change += self.model.accumulated_changes
# stores change of the model due to training, change due to averaging is not accounted # stores change of the model due to training, change due to averaging is not accounted
self.model.model_change = change self.model.model_change = change
...@@ -277,16 +288,16 @@ class PartialModel(Sharing): ...@@ -277,16 +288,16 @@ class PartialModel(Sharing):
""" """
logging.info("PartialModel _post_step") logging.info("PartialModel _post_step")
with torch.no_grad(): with torch.no_grad():
self.model.model_change = None
tensors_to_cat = [ tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items() v.data.flatten() for _, v in self.model.state_dict().items()
] ]
post_share_model = torch.cat(tensors_to_cat, dim=0) post_share_model = torch.cat(tensors_to_cat, dim=0)
self.init_model = post_share_model self.init_model = post_share_model
if self.accumulation: if self.accumulation:
self.model.accumulated_changes += self.change_transformer(self.init_model - self.prev) if self.accumulate_averaging_changes:
self.model.accumulated_changes += self.change_transformer(self.init_model - self.prev)
self.prev = self.init_model self.prev = self.init_model
self.model.model_change = None
if self.save_accumulated: if self.save_accumulated:
self.save_change() self.save_change()
......
...@@ -58,6 +58,7 @@ class Wavelet(PartialModel): ...@@ -58,6 +58,7 @@ class Wavelet(PartialModel):
change_based_selection=True, change_based_selection=True,
save_accumulated="", save_accumulated="",
accumulation=False, accumulation=False,
accumulate_averaging_changes = False
): ):
""" """
Constructor Constructor
...@@ -99,13 +100,16 @@ class Wavelet(PartialModel): ...@@ -99,13 +100,16 @@ class Wavelet(PartialModel):
the accumulated change is stored. the accumulated change is stored.
accumulation : bool accumulation : bool
True if the the indices to share should be selected based on accumulated frequency change True if the the indices to share should be selected based on accumulated frequency change
accumulate_averaging_changes: bool
True if the accumulation should account the model change due to averaging
""" """
self.wavelet = wavelet self.wavelet = wavelet
self.level = level self.level = level
super().__init__( super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared, rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared,
metadata_cap, accumulation, save_accumulated, lambda x : change_transformer_wavelet(x, wavelet, level) metadata_cap, accumulation, save_accumulated, lambda x : change_transformer_wavelet(x, wavelet, level),
accumulate_averaging_changes
) )
self.change_based_selection = change_based_selection self.change_based_selection = change_based_selection
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment