From 573e433159fa7c0718fa934d597330a650a43db8 Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Tue, 15 Mar 2022 22:28:13 +0100 Subject: [PATCH] fft and wavelet moving everything to sharing --- eval/step_configs/config_femnist_fft.ini | 8 +- eval/step_configs/config_femnist_wavelet.ini | 8 +- src/decentralizepy/sharing/FFT.py | 59 +++++++-- src/decentralizepy/sharing/Wavelet.py | 74 ++++++++--- .../training/FrequencyAccumulator.py | 116 ---------------- .../training/FrequencyWaveletAccumulator.py | 125 ------------------ 6 files changed, 111 insertions(+), 279 deletions(-) delete mode 100644 src/decentralizepy/training/FrequencyAccumulator.py delete mode 100644 src/decentralizepy/training/FrequencyWaveletAccumulator.py diff --git a/eval/step_configs/config_femnist_fft.ini b/eval/step_configs/config_femnist_fft.ini index 13a769c..afac1f4 100644 --- a/eval/step_configs/config_femnist_fft.ini +++ b/eval/step_configs/config_femnist_fft.ini @@ -15,15 +15,14 @@ lr = 0.001 # There are 734463 femnist samples [TRAIN_PARAMS] -training_package = decentralizepy.training.FrequencyAccumulator -training_class = FrequencyAccumulator +training_package = decentralizepy.training.Training +training_class = Training rounds = 47 full_epochs = False batch_size = 16 shuffle = True loss_package = torch.nn loss_class = CrossEntropyLoss -accumulation = True [COMMUNICATION] comm_package = decentralizepy.communication.TCP @@ -34,4 +33,5 @@ addresses_filepath = ip_addr_6Machines.json sharing_package = decentralizepy.sharing.FFT sharing_class = FFT alpha = 0.1 -change_based_selection = True \ No newline at end of file +change_based_selection = True +accumulation = True diff --git a/eval/step_configs/config_femnist_wavelet.ini b/eval/step_configs/config_femnist_wavelet.ini index ac3bac2..68704a3 100644 --- a/eval/step_configs/config_femnist_wavelet.ini +++ b/eval/step_configs/config_femnist_wavelet.ini @@ -15,17 +15,14 @@ lr = 0.001 # There are 734463 femnist samples [TRAIN_PARAMS] -training_package = decentralizepy.training.FrequencyWaveletAccumulator -training_class = FrequencyWaveletAccumulator +training_package = decentralizepy.training.Training +training_class = Training rounds = 47 full_epochs = False batch_size = 16 shuffle = True loss_package = torch.nn loss_class = CrossEntropyLoss -wavelet=sym2 -level= None -accumulation = True [COMMUNICATION] comm_package = decentralizepy.communication.TCP @@ -39,3 +36,4 @@ change_based_selection = True alpha = 0.1 wavelet=sym2 level= None +accumulation = True diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py index a4c3b59..80b5a5d 100644 --- a/src/decentralizepy/sharing/FFT.py +++ b/src/decentralizepy/sharing/FFT.py @@ -99,6 +99,23 @@ class FFT(Sharing): self.change_based_selection = change_based_selection self.accumulation = accumulation + # getting the initial model + with torch.no_grad(): + self.model.accumulated_gradients = [] + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.state_dict().items() + ] + concated = torch.cat(tensors_to_cat, dim=0) + self.init_model = fft.rfft(concated) + self.prev = None + if self.accumulation: + if self.model.accumulated_changes is None: + self.model.accumulated_changes = torch.zeros_like(self.init_model) + self.prev = self.init_model + else: + self.model.accumulated_changes += self.init_model - self.prev + self.prev = self.init_model + def apply_fft(self): """ Does fft transformation of the model parameters and selects topK (alpha) of them in the frequency domain @@ -225,6 +242,25 @@ class FFT(Sharing): """ t_start = time() + shapes = [] + lens = [] + end_model = None + change = 0 + self.model.accumulated_gradients = [] + with torch.no_grad(): + # FFT of this model + tensors_to_cat = [] + for _, v in self.model.state_dict().items(): + shapes.append(v.shape) + t = v.flatten() + lens.append(t.shape[0]) + tensors_to_cat.append(t) + concated = torch.cat(tensors_to_cat, dim=0) + end_model = fft.rfft(concated) + change = end_model - self.init_model + if self.accumulation: + change += self.model.accumulated_changes + self.model.accumulated_gradients.append(change) data = self.serialized_model() t_post_serialize = time() my_uid = self.mapping.get_uid(self.rank, self.machine_id) @@ -255,17 +291,7 @@ class FFT(Sharing): total = None weight_total = 0 - # FFT of this model - shapes = [] - lens = [] - tensors_to_cat = [] - for _, v in self.model.state_dict().items(): - shapes.append(v.shape) - t = v.flatten() - lens.append(t.shape[0]) - tensors_to_cat.append(t) - concated = torch.cat(tensors_to_cat, dim=0) - flat_fft = fft.rfft(concated) + flat_fft = end_model for i, n in enumerate(self.peer_deques): degree, iteration, data = self.peer_deques[n].popleft() @@ -303,6 +329,17 @@ class FFT(Sharing): self.communication_round += 1 + with torch.no_grad(): + self.model.accumulated_gradients = [] + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.state_dict().items() + ] + concated = torch.cat(tensors_to_cat, dim=0) + self.init_model = fft.rfft(concated) + if self.accumulation: + self.model.accumulated_changes += self.init_model - self.prev + self.prev = self.init_model + t_end = time() logging.info( diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py index 2ec700a..2d651b0 100644 --- a/src/decentralizepy/sharing/Wavelet.py +++ b/src/decentralizepy/sharing/Wavelet.py @@ -106,6 +106,26 @@ class Wavelet(Sharing): Path(self.folder_path).mkdir(parents=True, exist_ok=True) self.change_based_selection = change_based_selection + self.accumulation = accumulation + + # getting the initial model + with torch.no_grad(): + self.model.accumulated_gradients = [] + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.state_dict().items() + ] + concated = torch.cat(tensors_to_cat, dim=0) + coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level) + data, coeff_slices = pywt.coeffs_to_array(coeff) + self.init_model = torch.from_numpy(data.ravel()) + self.prev = None + if self.accumulation: + if self.model.accumulated_changes is None: + self.model.accumulated_changes = torch.zeros_like(self.init_model) + self.prev = self.init_model + else: + self.model.accumulated_changes += self.init_model - self.prev + self.prev = self.init_model def apply_wavelet(self): """ @@ -257,6 +277,29 @@ class Wavelet(Sharing): """ t_start = time() + shapes = [] + lens = [] + end_model = None + change = 0 + self.model.accumulated_gradients = [] + with torch.no_grad(): + # FFT of this model + tensors_to_cat = [] + for _, v in self.model.state_dict().items(): + shapes.append(v.shape) + t = v.flatten() + lens.append(t.shape[0]) + tensors_to_cat.append(t) + concated = torch.cat(tensors_to_cat, dim=0) + coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level) + data, coeff_slices = pywt.coeffs_to_array(coeff) + shape = data.shape + wt_params = data.ravel() + end_model = torch.from_numpy(wt_params) + change = end_model - self.init_model + if self.accumulation: + change += self.model.accumulated_changes + self.model.accumulated_gradients.append(change) data = self.serialized_model() t_post_serialize = time() my_uid = self.mapping.get_uid(self.rank, self.machine_id) @@ -287,24 +330,6 @@ class Wavelet(Sharing): total = None weight_total = 0 - # FFT of this model - shapes = [] - lens = [] - tensors_to_cat = [] - # TODO: should we detach - for _, v in self.model.state_dict().items(): - shapes.append(v.shape) - t = v.flatten() - lens.append(t.shape[0]) - tensors_to_cat.append(t) - concated = torch.cat(tensors_to_cat, dim=0) - coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level) - wt_params, coeff_slices = pywt.coeffs_to_array( - coeff - ) # coeff_slices will be reproduced on the receiver - shape = wt_params.shape - wt_params = wt_params.ravel() - for i, n in enumerate(self.peer_deques): degree, iteration, data = self.peer_deques[n].popleft() logging.debug( @@ -348,6 +373,19 @@ class Wavelet(Sharing): self.communication_round += 1 + with torch.no_grad(): + self.model.accumulated_gradients = [] + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.state_dict().items() + ] + concated = torch.cat(tensors_to_cat, dim=0) + coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level) + data, coeff_slices = pywt.coeffs_to_array(coeff) + self.init_model = torch.from_numpy(data.ravel()) + if self.accumulation: + self.model.accumulated_changes += self.init_model - self.prev + self.prev = self.init_model + t_end = time() logging.info( diff --git a/src/decentralizepy/training/FrequencyAccumulator.py b/src/decentralizepy/training/FrequencyAccumulator.py deleted file mode 100644 index 91e74b3..0000000 --- a/src/decentralizepy/training/FrequencyAccumulator.py +++ /dev/null @@ -1,116 +0,0 @@ -import logging - -import torch -from torch import fft - -from decentralizepy.training.Training import Training - - -class FrequencyAccumulator(Training): - """ - This class implements the training module which also accumulates the fft frequency at the beginning of steps a communication round. - - """ - - def __init__( - self, - rank, - machine_id, - mapping, - model, - optimizer, - loss, - log_dir, - rounds="", - full_epochs="", - batch_size="", - shuffle="", - accumulation=True, - ): - """ - Constructor - - Parameters - ---------- - rank : int - Rank of process local to the machine - machine_id : int - Machine ID on which the process in running - mapping : decentralizepy.mappings - The object containing the mapping rank <--> uid - model : torch.nn.Module - Neural Network for training - optimizer : torch.optim - Optimizer to learn parameters - loss : function - Loss function - log_dir : str - Directory to log the model change. - rounds : int, optional - Number of steps/epochs per training call - full_epochs: bool, optional - True if 1 round = 1 epoch. False if 1 round = 1 minibatch - batch_size : int, optional - Number of items to learn over, in one batch - shuffle : bool - True if the dataset should be shuffled before training. - accumulation : bool - True if the model change should be accumulated across communication steps - """ - super().__init__( - rank, - machine_id, - mapping, - model, - optimizer, - loss, - log_dir, - rounds, - full_epochs, - batch_size, - shuffle, - ) - self.accumulation = accumulation - self.init_model = None - self.prev = None - - def train(self, dataset): - """ - Does one training iteration. - If self.accumulation is True then it accumulates model fft frequency changes in model.accumulated_frequency. - Otherwise it stores the current fft frequency representation of the model in model.accumulated_frequency. - - Parameters - ---------- - dataset : decentralizepy.datasets.Dataset - The training dataset. Should implement get_trainset(batch_size, shuffle) - - """ - with torch.no_grad(): - self.model.accumulated_gradients = [] - tensors_to_cat = [ - v.data.flatten() for _, v in self.model.state_dict().items() - ] - concated = torch.cat(tensors_to_cat, dim=0) - self.init_model = fft.rfft(concated) - if self.accumulation: - if self.model.accumulated_changes is None: - self.model.accumulated_changes = torch.zeros_like(self.init_model) - self.prev = self.init_model - else: - self.model.accumulated_changes += self.init_model - self.prev - self.prev = self.init_model - - super().train(dataset) - - with torch.no_grad(): - tensors_to_cat = [ - v.data.flatten() for _, v in self.model.state_dict().items() - ] - concated = torch.cat(tensors_to_cat, dim=0) - end_model = fft.rfft(concated) - change = end_model - self.init_model - if self.accumulation: - change += self.model.accumulated_changes - - self.model.accumulated_gradients.append(change) diff --git a/src/decentralizepy/training/FrequencyWaveletAccumulator.py b/src/decentralizepy/training/FrequencyWaveletAccumulator.py deleted file mode 100644 index 54238ab..0000000 --- a/src/decentralizepy/training/FrequencyWaveletAccumulator.py +++ /dev/null @@ -1,125 +0,0 @@ -import logging - -import numpy as np -import pywt -import torch - -from decentralizepy.training.Training import Training - - -class FrequencyWaveletAccumulator(Training): - """ - This class implements the training module which also accumulates the wavelet frequency at the beginning of steps a communication round. - - """ - - def __init__( - self, - rank, - machine_id, - mapping, - model, - optimizer, - loss, - log_dir, - rounds="", - full_epochs="", - batch_size="", - shuffle="", - wavelet="haar", - level=4, - accumulation=True, - ): - """ - Constructor - - Parameters - ---------- - rank : int - Rank of process local to the machine - machine_id : int - Machine ID on which the process in running - mapping : decentralizepy.mappings - The object containing the mapping rank <--> uid - model : torch.nn.Module - Neural Network for training - optimizer : torch.optim - Optimizer to learn parameters - loss : function - Loss function - log_dir : str - Directory to log the model change. - rounds : int, optional - Number of steps/epochs per training call - full_epochs: bool, optional - True if 1 round = 1 epoch. False if 1 round = 1 minibatch - batch_size : int, optional - Number of items to learn over, in one batch - shuffle : bool - True if the dataset should be shuffled before training. - accumulation : bool - True if the model change should be accumulated across communication steps - """ - super().__init__( - rank, - machine_id, - mapping, - model, - optimizer, - loss, - log_dir, - rounds, - full_epochs, - batch_size, - shuffle, - ) - self.wavelet = wavelet - self.level = level - self.accumulation = accumulation - - def train(self, dataset): - """ - Does one training iteration. - If self.accumulation is True then it accumulates model wavelet frequency changes in model.accumulated_frequency. - Otherwise it stores the current wavelet frequency representation of the model in model.accumulated_frequency. - - Parameters - ---------- - dataset : decentralizepy.datasets.Dataset - The training dataset. Should implement get_trainset(batch_size, shuffle) - - """ - - # this looks at the change from the last round averaging of the frequencies - with torch.no_grad(): - self.model.accumulated_gradients = [] - tensors_to_cat = [ - v.data.flatten() for _, v in self.model.state_dict().items() - ] - concated = torch.cat(tensors_to_cat, dim=0) - coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level) - data, coeff_slices = pywt.coeffs_to_array(coeff) - self.init_model = torch.from_numpy(data.ravel()) - if self.accumulation: - if self.model.accumulated_changes is None: - self.model.accumulated_changes = torch.zeros_like(self.init_model) - self.prev = self.init_model - else: - self.model.accumulated_changes += self.init_model - self.prev - self.prev = self.init_model - - super().train(dataset) - - with torch.no_grad(): - tensors_to_cat = [ - v.data.flatten() for _, v in self.model.state_dict().items() - ] - concated = torch.cat(tensors_to_cat, dim=0) - coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level) - data, coeff_slices = pywt.coeffs_to_array(coeff) - end_model = torch.from_numpy(data.ravel()) - change = end_model - self.init_model - if self.accumulation: - change += self.model.accumulated_changes - - self.model.accumulated_gradients.append(change) -- GitLab