From 573e433159fa7c0718fa934d597330a650a43db8 Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Tue, 15 Mar 2022 22:28:13 +0100
Subject: [PATCH] fft and wavelet moving everything to sharing

---
 eval/step_configs/config_femnist_fft.ini      |   8 +-
 eval/step_configs/config_femnist_wavelet.ini  |   8 +-
 src/decentralizepy/sharing/FFT.py             |  59 +++++++--
 src/decentralizepy/sharing/Wavelet.py         |  74 ++++++++---
 .../training/FrequencyAccumulator.py          | 116 ----------------
 .../training/FrequencyWaveletAccumulator.py   | 125 ------------------
 6 files changed, 111 insertions(+), 279 deletions(-)
 delete mode 100644 src/decentralizepy/training/FrequencyAccumulator.py
 delete mode 100644 src/decentralizepy/training/FrequencyWaveletAccumulator.py

diff --git a/eval/step_configs/config_femnist_fft.ini b/eval/step_configs/config_femnist_fft.ini
index 13a769c..afac1f4 100644
--- a/eval/step_configs/config_femnist_fft.ini
+++ b/eval/step_configs/config_femnist_fft.ini
@@ -15,15 +15,14 @@ lr = 0.001
 
 # There are 734463 femnist samples
 [TRAIN_PARAMS]
-training_package = decentralizepy.training.FrequencyAccumulator
-training_class = FrequencyAccumulator
+training_package = decentralizepy.training.Training
+training_class = Training
 rounds = 47
 full_epochs = False
 batch_size = 16
 shuffle = True
 loss_package = torch.nn
 loss_class = CrossEntropyLoss
-accumulation = True
 
 [COMMUNICATION]
 comm_package = decentralizepy.communication.TCP
@@ -34,4 +33,5 @@ addresses_filepath = ip_addr_6Machines.json
 sharing_package = decentralizepy.sharing.FFT
 sharing_class = FFT
 alpha = 0.1
-change_based_selection = True
\ No newline at end of file
+change_based_selection = True
+accumulation = True
diff --git a/eval/step_configs/config_femnist_wavelet.ini b/eval/step_configs/config_femnist_wavelet.ini
index ac3bac2..68704a3 100644
--- a/eval/step_configs/config_femnist_wavelet.ini
+++ b/eval/step_configs/config_femnist_wavelet.ini
@@ -15,17 +15,14 @@ lr = 0.001
 
 # There are 734463 femnist samples
 [TRAIN_PARAMS]
-training_package = decentralizepy.training.FrequencyWaveletAccumulator
-training_class = FrequencyWaveletAccumulator
+training_package = decentralizepy.training.Training
+training_class = Training
 rounds = 47
 full_epochs = False
 batch_size = 16
 shuffle = True
 loss_package = torch.nn
 loss_class = CrossEntropyLoss
-wavelet=sym2
-level= None
-accumulation = True
 
 [COMMUNICATION]
 comm_package = decentralizepy.communication.TCP
@@ -39,3 +36,4 @@ change_based_selection = True
 alpha = 0.1
 wavelet=sym2
 level= None
+accumulation = True
diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
index a4c3b59..80b5a5d 100644
--- a/src/decentralizepy/sharing/FFT.py
+++ b/src/decentralizepy/sharing/FFT.py
@@ -99,6 +99,23 @@ class FFT(Sharing):
         self.change_based_selection = change_based_selection
         self.accumulation = accumulation
 
+        # getting the initial model
+        with torch.no_grad():
+            self.model.accumulated_gradients = []
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            concated = torch.cat(tensors_to_cat, dim=0)
+            self.init_model = fft.rfft(concated)
+            self.prev = None
+            if self.accumulation:
+                if self.model.accumulated_changes is None:
+                    self.model.accumulated_changes = torch.zeros_like(self.init_model)
+                    self.prev = self.init_model
+                else:
+                    self.model.accumulated_changes += self.init_model - self.prev
+                    self.prev = self.init_model
+
     def apply_fft(self):
         """
         Does fft transformation of the model parameters and selects topK (alpha) of them in the frequency domain
@@ -225,6 +242,25 @@ class FFT(Sharing):
 
         """
         t_start = time()
+        shapes = []
+        lens = []
+        end_model = None
+        change = 0
+        self.model.accumulated_gradients = []
+        with torch.no_grad():
+            # FFT of this model
+            tensors_to_cat = []
+            for _, v in self.model.state_dict().items():
+                shapes.append(v.shape)
+                t = v.flatten()
+                lens.append(t.shape[0])
+                tensors_to_cat.append(t)
+            concated = torch.cat(tensors_to_cat, dim=0)
+            end_model = fft.rfft(concated)
+            change = end_model - self.init_model
+            if self.accumulation:
+                change += self.model.accumulated_changes
+            self.model.accumulated_gradients.append(change)
         data = self.serialized_model()
         t_post_serialize = time()
         my_uid = self.mapping.get_uid(self.rank, self.machine_id)
@@ -255,17 +291,7 @@ class FFT(Sharing):
         total = None
         weight_total = 0
 
-        # FFT of this model
-        shapes = []
-        lens = []
-        tensors_to_cat = []
-        for _, v in self.model.state_dict().items():
-            shapes.append(v.shape)
-            t = v.flatten()
-            lens.append(t.shape[0])
-            tensors_to_cat.append(t)
-        concated = torch.cat(tensors_to_cat, dim=0)
-        flat_fft = fft.rfft(concated)
+        flat_fft = end_model
 
         for i, n in enumerate(self.peer_deques):
             degree, iteration, data = self.peer_deques[n].popleft()
@@ -303,6 +329,17 @@ class FFT(Sharing):
 
         self.communication_round += 1
 
+        with torch.no_grad():
+            self.model.accumulated_gradients = []
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            concated = torch.cat(tensors_to_cat, dim=0)
+            self.init_model = fft.rfft(concated)
+            if self.accumulation:
+                self.model.accumulated_changes += self.init_model - self.prev
+                self.prev = self.init_model
+
         t_end = time()
 
         logging.info(
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index 2ec700a..2d651b0 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -106,6 +106,26 @@ class Wavelet(Sharing):
             Path(self.folder_path).mkdir(parents=True, exist_ok=True)
 
         self.change_based_selection = change_based_selection
+        self.accumulation = accumulation
+
+        # getting the initial model
+        with torch.no_grad():
+            self.model.accumulated_gradients = []
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            concated = torch.cat(tensors_to_cat, dim=0)
+            coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
+            data, coeff_slices = pywt.coeffs_to_array(coeff)
+            self.init_model = torch.from_numpy(data.ravel())
+            self.prev = None
+            if self.accumulation:
+                if self.model.accumulated_changes is None:
+                    self.model.accumulated_changes = torch.zeros_like(self.init_model)
+                    self.prev = self.init_model
+                else:
+                    self.model.accumulated_changes += self.init_model - self.prev
+                    self.prev = self.init_model
 
     def apply_wavelet(self):
         """
@@ -257,6 +277,29 @@ class Wavelet(Sharing):
 
         """
         t_start = time()
+        shapes = []
+        lens = []
+        end_model = None
+        change = 0
+        self.model.accumulated_gradients = []
+        with torch.no_grad():
+            # FFT of this model
+            tensors_to_cat = []
+            for _, v in self.model.state_dict().items():
+                shapes.append(v.shape)
+                t = v.flatten()
+                lens.append(t.shape[0])
+                tensors_to_cat.append(t)
+            concated = torch.cat(tensors_to_cat, dim=0)
+            coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
+            data, coeff_slices = pywt.coeffs_to_array(coeff)
+            shape = data.shape
+            wt_params = data.ravel()
+            end_model = torch.from_numpy(wt_params)
+            change = end_model - self.init_model
+            if self.accumulation:
+                change += self.model.accumulated_changes
+            self.model.accumulated_gradients.append(change)
         data = self.serialized_model()
         t_post_serialize = time()
         my_uid = self.mapping.get_uid(self.rank, self.machine_id)
@@ -287,24 +330,6 @@ class Wavelet(Sharing):
         total = None
         weight_total = 0
 
-        # FFT of this model
-        shapes = []
-        lens = []
-        tensors_to_cat = []
-        # TODO: should we detach
-        for _, v in self.model.state_dict().items():
-            shapes.append(v.shape)
-            t = v.flatten()
-            lens.append(t.shape[0])
-            tensors_to_cat.append(t)
-        concated = torch.cat(tensors_to_cat, dim=0)
-        coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
-        wt_params, coeff_slices = pywt.coeffs_to_array(
-            coeff
-        )  # coeff_slices will be reproduced on the receiver
-        shape = wt_params.shape
-        wt_params = wt_params.ravel()
-
         for i, n in enumerate(self.peer_deques):
             degree, iteration, data = self.peer_deques[n].popleft()
             logging.debug(
@@ -348,6 +373,19 @@ class Wavelet(Sharing):
 
         self.communication_round += 1
 
+        with torch.no_grad():
+            self.model.accumulated_gradients = []
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            concated = torch.cat(tensors_to_cat, dim=0)
+            coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
+            data, coeff_slices = pywt.coeffs_to_array(coeff)
+            self.init_model = torch.from_numpy(data.ravel())
+            if self.accumulation:
+                self.model.accumulated_changes += self.init_model - self.prev
+                self.prev = self.init_model
+
         t_end = time()
 
         logging.info(
diff --git a/src/decentralizepy/training/FrequencyAccumulator.py b/src/decentralizepy/training/FrequencyAccumulator.py
deleted file mode 100644
index 91e74b3..0000000
--- a/src/decentralizepy/training/FrequencyAccumulator.py
+++ /dev/null
@@ -1,116 +0,0 @@
-import logging
-
-import torch
-from torch import fft
-
-from decentralizepy.training.Training import Training
-
-
-class FrequencyAccumulator(Training):
-    """
-    This class implements the training module which also accumulates the fft frequency at the beginning of steps a communication round.
-
-    """
-
-    def __init__(
-        self,
-        rank,
-        machine_id,
-        mapping,
-        model,
-        optimizer,
-        loss,
-        log_dir,
-        rounds="",
-        full_epochs="",
-        batch_size="",
-        shuffle="",
-        accumulation=True,
-    ):
-        """
-        Constructor
-
-        Parameters
-        ----------
-        rank : int
-            Rank of process local to the machine
-        machine_id : int
-            Machine ID on which the process in running
-        mapping : decentralizepy.mappings
-            The object containing the mapping rank <--> uid
-        model : torch.nn.Module
-            Neural Network for training
-        optimizer : torch.optim
-            Optimizer to learn parameters
-        loss : function
-            Loss function
-        log_dir : str
-            Directory to log the model change.
-        rounds : int, optional
-            Number of steps/epochs per training call
-        full_epochs: bool, optional
-            True if 1 round = 1 epoch. False if 1 round = 1 minibatch
-        batch_size : int, optional
-            Number of items to learn over, in one batch
-        shuffle : bool
-            True if the dataset should be shuffled before training.
-        accumulation : bool
-            True if the model change should be accumulated across communication steps
-        """
-        super().__init__(
-            rank,
-            machine_id,
-            mapping,
-            model,
-            optimizer,
-            loss,
-            log_dir,
-            rounds,
-            full_epochs,
-            batch_size,
-            shuffle,
-        )
-        self.accumulation = accumulation
-        self.init_model = None
-        self.prev = None
-
-    def train(self, dataset):
-        """
-        Does one training iteration.
-        If self.accumulation is True then it accumulates model fft frequency changes in model.accumulated_frequency.
-        Otherwise it stores the current fft frequency representation of the model in model.accumulated_frequency.
-
-        Parameters
-        ----------
-        dataset : decentralizepy.datasets.Dataset
-            The training dataset. Should implement get_trainset(batch_size, shuffle)
-
-        """
-        with torch.no_grad():
-            self.model.accumulated_gradients = []
-            tensors_to_cat = [
-                v.data.flatten() for _, v in self.model.state_dict().items()
-            ]
-            concated = torch.cat(tensors_to_cat, dim=0)
-            self.init_model = fft.rfft(concated)
-            if self.accumulation:
-                if self.model.accumulated_changes is None:
-                    self.model.accumulated_changes = torch.zeros_like(self.init_model)
-                    self.prev = self.init_model
-                else:
-                    self.model.accumulated_changes += self.init_model - self.prev
-                    self.prev = self.init_model
-
-        super().train(dataset)
-
-        with torch.no_grad():
-            tensors_to_cat = [
-                v.data.flatten() for _, v in self.model.state_dict().items()
-            ]
-            concated = torch.cat(tensors_to_cat, dim=0)
-            end_model = fft.rfft(concated)
-            change = end_model - self.init_model
-            if self.accumulation:
-                change += self.model.accumulated_changes
-
-            self.model.accumulated_gradients.append(change)
diff --git a/src/decentralizepy/training/FrequencyWaveletAccumulator.py b/src/decentralizepy/training/FrequencyWaveletAccumulator.py
deleted file mode 100644
index 54238ab..0000000
--- a/src/decentralizepy/training/FrequencyWaveletAccumulator.py
+++ /dev/null
@@ -1,125 +0,0 @@
-import logging
-
-import numpy as np
-import pywt
-import torch
-
-from decentralizepy.training.Training import Training
-
-
-class FrequencyWaveletAccumulator(Training):
-    """
-    This class implements the training module which also accumulates the wavelet frequency at the beginning of steps a communication round.
-
-    """
-
-    def __init__(
-        self,
-        rank,
-        machine_id,
-        mapping,
-        model,
-        optimizer,
-        loss,
-        log_dir,
-        rounds="",
-        full_epochs="",
-        batch_size="",
-        shuffle="",
-        wavelet="haar",
-        level=4,
-        accumulation=True,
-    ):
-        """
-        Constructor
-
-        Parameters
-        ----------
-        rank : int
-            Rank of process local to the machine
-        machine_id : int
-            Machine ID on which the process in running
-        mapping : decentralizepy.mappings
-            The object containing the mapping rank <--> uid
-        model : torch.nn.Module
-            Neural Network for training
-        optimizer : torch.optim
-            Optimizer to learn parameters
-        loss : function
-            Loss function
-        log_dir : str
-            Directory to log the model change.
-        rounds : int, optional
-            Number of steps/epochs per training call
-        full_epochs: bool, optional
-            True if 1 round = 1 epoch. False if 1 round = 1 minibatch
-        batch_size : int, optional
-            Number of items to learn over, in one batch
-        shuffle : bool
-            True if the dataset should be shuffled before training.
-        accumulation : bool
-            True if the model change should be accumulated across communication steps
-        """
-        super().__init__(
-            rank,
-            machine_id,
-            mapping,
-            model,
-            optimizer,
-            loss,
-            log_dir,
-            rounds,
-            full_epochs,
-            batch_size,
-            shuffle,
-        )
-        self.wavelet = wavelet
-        self.level = level
-        self.accumulation = accumulation
-
-    def train(self, dataset):
-        """
-        Does one training iteration.
-        If self.accumulation is True then it accumulates model wavelet frequency changes in model.accumulated_frequency.
-        Otherwise it stores the current wavelet frequency representation of the model in model.accumulated_frequency.
-
-        Parameters
-        ----------
-        dataset : decentralizepy.datasets.Dataset
-            The training dataset. Should implement get_trainset(batch_size, shuffle)
-
-        """
-
-        # this looks at the change from the last round averaging of the frequencies
-        with torch.no_grad():
-            self.model.accumulated_gradients = []
-            tensors_to_cat = [
-                v.data.flatten() for _, v in self.model.state_dict().items()
-            ]
-            concated = torch.cat(tensors_to_cat, dim=0)
-            coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
-            data, coeff_slices = pywt.coeffs_to_array(coeff)
-            self.init_model = torch.from_numpy(data.ravel())
-            if self.accumulation:
-                if self.model.accumulated_changes is None:
-                    self.model.accumulated_changes = torch.zeros_like(self.init_model)
-                    self.prev = self.init_model
-                else:
-                    self.model.accumulated_changes += self.init_model - self.prev
-                    self.prev = self.init_model
-
-        super().train(dataset)
-
-        with torch.no_grad():
-            tensors_to_cat = [
-                v.data.flatten() for _, v in self.model.state_dict().items()
-            ]
-            concated = torch.cat(tensors_to_cat, dim=0)
-            coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
-            data, coeff_slices = pywt.coeffs_to_array(coeff)
-            end_model = torch.from_numpy(data.ravel())
-            change = end_model - self.init_model
-            if self.accumulation:
-                change += self.model.accumulated_changes
-
-            self.model.accumulated_gradients.append(change)
-- 
GitLab