Skip to content
Snippets Groups Projects
Commit 573e4331 authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

fft and wavelet moving everything to sharing

parent ed4148ea
No related branches found
No related tags found
No related merge requests found
......@@ -15,15 +15,14 @@ lr = 0.001
# There are 734463 femnist samples
[TRAIN_PARAMS]
training_package = decentralizepy.training.FrequencyAccumulator
training_class = FrequencyAccumulator
training_package = decentralizepy.training.Training
training_class = Training
rounds = 47
full_epochs = False
batch_size = 16
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
accumulation = True
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
......@@ -34,4 +33,5 @@ addresses_filepath = ip_addr_6Machines.json
sharing_package = decentralizepy.sharing.FFT
sharing_class = FFT
alpha = 0.1
change_based_selection = True
\ No newline at end of file
change_based_selection = True
accumulation = True
......@@ -15,17 +15,14 @@ lr = 0.001
# There are 734463 femnist samples
[TRAIN_PARAMS]
training_package = decentralizepy.training.FrequencyWaveletAccumulator
training_class = FrequencyWaveletAccumulator
training_package = decentralizepy.training.Training
training_class = Training
rounds = 47
full_epochs = False
batch_size = 16
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
wavelet=sym2
level= None
accumulation = True
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
......@@ -39,3 +36,4 @@ change_based_selection = True
alpha = 0.1
wavelet=sym2
level= None
accumulation = True
......@@ -99,6 +99,23 @@ class FFT(Sharing):
self.change_based_selection = change_based_selection
self.accumulation = accumulation
# getting the initial model
with torch.no_grad():
self.model.accumulated_gradients = []
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0)
self.init_model = fft.rfft(concated)
self.prev = None
if self.accumulation:
if self.model.accumulated_changes is None:
self.model.accumulated_changes = torch.zeros_like(self.init_model)
self.prev = self.init_model
else:
self.model.accumulated_changes += self.init_model - self.prev
self.prev = self.init_model
def apply_fft(self):
"""
Does fft transformation of the model parameters and selects topK (alpha) of them in the frequency domain
......@@ -225,6 +242,25 @@ class FFT(Sharing):
"""
t_start = time()
shapes = []
lens = []
end_model = None
change = 0
self.model.accumulated_gradients = []
with torch.no_grad():
# FFT of this model
tensors_to_cat = []
for _, v in self.model.state_dict().items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
concated = torch.cat(tensors_to_cat, dim=0)
end_model = fft.rfft(concated)
change = end_model - self.init_model
if self.accumulation:
change += self.model.accumulated_changes
self.model.accumulated_gradients.append(change)
data = self.serialized_model()
t_post_serialize = time()
my_uid = self.mapping.get_uid(self.rank, self.machine_id)
......@@ -255,17 +291,7 @@ class FFT(Sharing):
total = None
weight_total = 0
# FFT of this model
shapes = []
lens = []
tensors_to_cat = []
for _, v in self.model.state_dict().items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
concated = torch.cat(tensors_to_cat, dim=0)
flat_fft = fft.rfft(concated)
flat_fft = end_model
for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft()
......@@ -303,6 +329,17 @@ class FFT(Sharing):
self.communication_round += 1
with torch.no_grad():
self.model.accumulated_gradients = []
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0)
self.init_model = fft.rfft(concated)
if self.accumulation:
self.model.accumulated_changes += self.init_model - self.prev
self.prev = self.init_model
t_end = time()
logging.info(
......
......@@ -106,6 +106,26 @@ class Wavelet(Sharing):
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
self.change_based_selection = change_based_selection
self.accumulation = accumulation
# getting the initial model
with torch.no_grad():
self.model.accumulated_gradients = []
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0)
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
self.init_model = torch.from_numpy(data.ravel())
self.prev = None
if self.accumulation:
if self.model.accumulated_changes is None:
self.model.accumulated_changes = torch.zeros_like(self.init_model)
self.prev = self.init_model
else:
self.model.accumulated_changes += self.init_model - self.prev
self.prev = self.init_model
def apply_wavelet(self):
"""
......@@ -257,6 +277,29 @@ class Wavelet(Sharing):
"""
t_start = time()
shapes = []
lens = []
end_model = None
change = 0
self.model.accumulated_gradients = []
with torch.no_grad():
# FFT of this model
tensors_to_cat = []
for _, v in self.model.state_dict().items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
concated = torch.cat(tensors_to_cat, dim=0)
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
shape = data.shape
wt_params = data.ravel()
end_model = torch.from_numpy(wt_params)
change = end_model - self.init_model
if self.accumulation:
change += self.model.accumulated_changes
self.model.accumulated_gradients.append(change)
data = self.serialized_model()
t_post_serialize = time()
my_uid = self.mapping.get_uid(self.rank, self.machine_id)
......@@ -287,24 +330,6 @@ class Wavelet(Sharing):
total = None
weight_total = 0
# FFT of this model
shapes = []
lens = []
tensors_to_cat = []
# TODO: should we detach
for _, v in self.model.state_dict().items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
concated = torch.cat(tensors_to_cat, dim=0)
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
wt_params, coeff_slices = pywt.coeffs_to_array(
coeff
) # coeff_slices will be reproduced on the receiver
shape = wt_params.shape
wt_params = wt_params.ravel()
for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft()
logging.debug(
......@@ -348,6 +373,19 @@ class Wavelet(Sharing):
self.communication_round += 1
with torch.no_grad():
self.model.accumulated_gradients = []
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0)
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
self.init_model = torch.from_numpy(data.ravel())
if self.accumulation:
self.model.accumulated_changes += self.init_model - self.prev
self.prev = self.init_model
t_end = time()
logging.info(
......
import logging
import torch
from torch import fft
from decentralizepy.training.Training import Training
class FrequencyAccumulator(Training):
"""
This class implements the training module which also accumulates the fft frequency at the beginning of steps a communication round.
"""
def __init__(
self,
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds="",
full_epochs="",
batch_size="",
shuffle="",
accumulation=True,
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
model : torch.nn.Module
Neural Network for training
optimizer : torch.optim
Optimizer to learn parameters
loss : function
Loss function
log_dir : str
Directory to log the model change.
rounds : int, optional
Number of steps/epochs per training call
full_epochs: bool, optional
True if 1 round = 1 epoch. False if 1 round = 1 minibatch
batch_size : int, optional
Number of items to learn over, in one batch
shuffle : bool
True if the dataset should be shuffled before training.
accumulation : bool
True if the model change should be accumulated across communication steps
"""
super().__init__(
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds,
full_epochs,
batch_size,
shuffle,
)
self.accumulation = accumulation
self.init_model = None
self.prev = None
def train(self, dataset):
"""
Does one training iteration.
If self.accumulation is True then it accumulates model fft frequency changes in model.accumulated_frequency.
Otherwise it stores the current fft frequency representation of the model in model.accumulated_frequency.
Parameters
----------
dataset : decentralizepy.datasets.Dataset
The training dataset. Should implement get_trainset(batch_size, shuffle)
"""
with torch.no_grad():
self.model.accumulated_gradients = []
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0)
self.init_model = fft.rfft(concated)
if self.accumulation:
if self.model.accumulated_changes is None:
self.model.accumulated_changes = torch.zeros_like(self.init_model)
self.prev = self.init_model
else:
self.model.accumulated_changes += self.init_model - self.prev
self.prev = self.init_model
super().train(dataset)
with torch.no_grad():
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0)
end_model = fft.rfft(concated)
change = end_model - self.init_model
if self.accumulation:
change += self.model.accumulated_changes
self.model.accumulated_gradients.append(change)
import logging
import numpy as np
import pywt
import torch
from decentralizepy.training.Training import Training
class FrequencyWaveletAccumulator(Training):
"""
This class implements the training module which also accumulates the wavelet frequency at the beginning of steps a communication round.
"""
def __init__(
self,
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds="",
full_epochs="",
batch_size="",
shuffle="",
wavelet="haar",
level=4,
accumulation=True,
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
model : torch.nn.Module
Neural Network for training
optimizer : torch.optim
Optimizer to learn parameters
loss : function
Loss function
log_dir : str
Directory to log the model change.
rounds : int, optional
Number of steps/epochs per training call
full_epochs: bool, optional
True if 1 round = 1 epoch. False if 1 round = 1 minibatch
batch_size : int, optional
Number of items to learn over, in one batch
shuffle : bool
True if the dataset should be shuffled before training.
accumulation : bool
True if the model change should be accumulated across communication steps
"""
super().__init__(
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds,
full_epochs,
batch_size,
shuffle,
)
self.wavelet = wavelet
self.level = level
self.accumulation = accumulation
def train(self, dataset):
"""
Does one training iteration.
If self.accumulation is True then it accumulates model wavelet frequency changes in model.accumulated_frequency.
Otherwise it stores the current wavelet frequency representation of the model in model.accumulated_frequency.
Parameters
----------
dataset : decentralizepy.datasets.Dataset
The training dataset. Should implement get_trainset(batch_size, shuffle)
"""
# this looks at the change from the last round averaging of the frequencies
with torch.no_grad():
self.model.accumulated_gradients = []
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0)
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
self.init_model = torch.from_numpy(data.ravel())
if self.accumulation:
if self.model.accumulated_changes is None:
self.model.accumulated_changes = torch.zeros_like(self.init_model)
self.prev = self.init_model
else:
self.model.accumulated_changes += self.init_model - self.prev
self.prev = self.init_model
super().train(dataset)
with torch.no_grad():
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0)
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
end_model = torch.from_numpy(data.ravel())
change = end_model - self.init_model
if self.accumulation:
change += self.model.accumulated_changes
self.model.accumulated_gradients.append(change)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment