diff --git a/README.rst b/README.rst index 38cba02c7e0a8ce478cd67641a3ce0f34fdc55ec..5acb1c43daadf40a8d6e0e06a61d0bcfa354b0b5 100644 --- a/README.rst +++ b/README.rst @@ -15,6 +15,7 @@ Setting up decentralizepy pip3 install --upgrade pip pip install --upgrade pip +* On Mac M1, installing ``pyzmq`` fails with `pip`. Use ``conda``. * Install decentralizepy for development. :: pip3 install --editable .\[dev\] diff --git a/eval/main.ipynb b/eval/main.ipynb index db66e893b9325a7fffbbcaa421197b7cbc33899b..80daae672a600f6358cfa9f559b50ecb47cf2261 100644 --- a/eval/main.ipynb +++ b/eval/main.ipynb @@ -268,7 +268,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -300,6 +300,26 @@ "\n" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<generator object Module.parameters at 0x111ec1ba0>" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m1.parameters()" + ] + }, { "cell_type": "code", "execution_count": 25, @@ -5715,7 +5735,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.8.11" }, "orig_nbformat": 4 }, diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index ef66d1752dae14a341b8c04042ee42961544f91a..24c316818feb5698531e90eb498ced15253e03f5 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -177,7 +177,7 @@ class Node: ], ) self.trainer = train_class( - self.model, self.optimizer, self.loss, **train_params + self.model, self.optimizer, self.loss, self.log_dir, **train_params ) def init_comm(self, comm_configs): diff --git a/src/decentralizepy/training/ChangeAccumulator.py b/src/decentralizepy/training/ChangeAccumulator.py new file mode 100644 index 0000000000000000000000000000000000000000..cf9641fc8da1b117671f384650615ec7b4e8c77e --- /dev/null +++ b/src/decentralizepy/training/ChangeAccumulator.py @@ -0,0 +1,148 @@ +import json +import os +from pathlib import Path + +import torch + +from decentralizepy.training.Training import Training +from decentralizepy.utils import conditional_value + + +class ChangeAccumulator(Training): + """ + This class implements the training module which also accumulates model change in a list. + + """ + + def __init__( + self, + model, + optimizer, + loss, + log_dir, + rounds="", + full_epochs="", + batch_size="", + shuffle="", + save_accumulated="", + ): + """ + Constructor + + Parameters + ---------- + model : torch.nn.Module + Neural Network for training + optimizer : torch.optim + Optimizer to learn parameters + loss : function + Loss function + log_dir : str + Directory to log the model change. + rounds : int, optional + Number of steps/epochs per training call + full_epochs: bool, optional + True if 1 round = 1 epoch. False if 1 round = 1 minibatch + batch_size : int, optional + Number of items to learn over, in one batch + shuffle : bool + True if the dataset should be shuffled before training. + save_accumulated : bool + True if accumulated weight change should be written to file + + """ + super().__init__( + model, optimizer, loss, log_dir, rounds, full_epochs, batch_size, shuffle + ) + self.save_accumulated = conditional_value(save_accumulated, "", True) + self.communication_round = 0 + if self.save_accumulated: + self.model_change_path = os.path.join( + self.log_dir, "model_change/{}".format(self.rank) + ) + Path(self.model_change_path).mkdir(parents=True, exist_ok=True) + + self.model_val_path = os.path.join( + self.log_dir, "model_val/{}".format(self.rank) + ) + Path(self.model_val_path).mkdir(parents=True, exist_ok=True) + + def save_vector(self, v, s): + """ + Saves the given vector to the file. + + Parameters + ---------- + v : torch.tensor + The torch tensor to write to file + s : str + Path to folder to write to + + """ + output_dict = dict() + output_dict["order"] = list(self.model.state_dict().keys()) + shapes = dict() + for k, v in self.model.state_dict().items(): + shapes[k] = list(v.shape) + output_dict["shapes"] = shapes + + output_dict[self.communication_round] = v.tolist() + + with open( + os.path.join( + s, + "{}.json".format(self.communication_round + 1), + ), + "w", + ) as of: + json.dump(output_dict, of) + + def save_change(self): + """ + Saves the change and the gradient values for every iteration + + """ + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.accumulated_gradients[0].items() + ] + change = torch.abs(torch.cat(tensors_to_cat, dim=0)) + self.save_vector(change, self.model_change_path) + + def save_model_params(self): + """ + Saves the change and the gradient values for every iteration + + """ + tensors_to_cat = [v.data.flatten() for _, v in self.model.items()] + params = torch.abs(torch.cat(tensors_to_cat, dim=0)) + self.save_vector(params, self.model_val_path) + + def train(self, dataset): + """ + One training iteration with accumulation of model change in model.accumulated_gradients. + Goes through the entire dataset. + + Parameters + ---------- + dataset : decentralizepy.datasets.Dataset + The training dataset. Should implement get_trainset(batch_size, shuffle) + + """ + self.model.accumulated_gradients = [] + self.init_model = { + k: v.data.clone().detach() + for k, v in zip(self.model.state_dict(), self.model.parameters()) + } + super().train(dataset) + with torch.no_grad(): + change = { + k: v.data.clone().detach() - self.init_model[k] + for k, v in zip(self.model.state_dict(), self.model.parameters()) + } + self.model.accumulated_gradients.append(change) + + if self.save_accumulated: + self.save_change() + self.save_model_params() + + self.communication_round += 1 diff --git a/src/decentralizepy/training/GradientAccumulator.py b/src/decentralizepy/training/GradientAccumulator.py index 3c160594d74e6f50ae107cc0f486077e5428c6bc..3171019f7397cb49b757e3a41315ec967fa5d27f 100644 --- a/src/decentralizepy/training/GradientAccumulator.py +++ b/src/decentralizepy/training/GradientAccumulator.py @@ -14,6 +14,7 @@ class GradientAccumulator(Training): model, optimizer, loss, + log_dir, rounds="", full_epochs="", batch_size="", @@ -30,6 +31,8 @@ class GradientAccumulator(Training): Optimizer to learn parameters loss : function Loss function + log_dir : str + Directory to log the model change. rounds : int, optional Number of steps/epochs per training call full_epochs: bool, optional @@ -41,7 +44,7 @@ class GradientAccumulator(Training): """ super().__init__( - model, optimizer, loss, rounds, full_epochs, batch_size, shuffle + model, optimizer, loss, log_dir, rounds, full_epochs, batch_size, shuffle ) def trainstep(self, data, target): diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py index 6696f357bbd631151be5a5f6a9eaa747a8b01493..47c8f778fb38761f766504d7461b72c9b2005618 100644 --- a/src/decentralizepy/training/Training.py +++ b/src/decentralizepy/training/Training.py @@ -16,6 +16,7 @@ class Training: model, optimizer, loss, + log_dir, rounds="", full_epochs="", batch_size="", @@ -32,6 +33,8 @@ class Training: Optimizer to learn parameters loss : function Loss function + log_dir : str + Directory to log the model change. rounds : int, optional Number of steps/epochs per training call full_epochs: bool, optional @@ -45,6 +48,7 @@ class Training: self.model = model self.optimizer = optimizer self.loss = loss + self.log_dir = log_dir self.rounds = utils.conditional_value(rounds, "", int(1)) self.full_epochs = utils.conditional_value(full_epochs, "", False) self.batch_size = utils.conditional_value(batch_size, "", int(1))