diff --git a/src/decentralizepy/training/GradientAccumulator.py b/src/decentralizepy/training/GradientAccumulator.py index 23d4fdc7ad68b10c4254c37e2166e9a33d452c05..5022308fd78800b6029ce6a7a9c4a9604ef5dce1 100644 --- a/src/decentralizepy/training/GradientAccumulator.py +++ b/src/decentralizepy/training/GradientAccumulator.py @@ -5,7 +5,7 @@ from decentralizepy.training.Training import Training class GradientAccumulator(Training): def __init__( - self, model, optimizer, loss, epochs_per_round="", batch_size="", shuffle="" + self, model, optimizer, loss, rounds="", full_epochs="", batch_size="", shuffle="" ): """ Constructor @@ -22,9 +22,9 @@ class GradientAccumulator(Training): batch_size : int, optional Number of items to learn over, in one batch shuffle : bool - True if the dataset should be shuffled before training. Not implemented yet! TODO + True if the dataset should be shuffled before training. """ - super().__init__(model, optimizer, loss, epochs_per_round, batch_size, shuffle) + super().__init__(model, optimizer, loss, rounds, full_epochs, batch_size, shuffle) def trainstep(self, data, target): """ diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py index d5b3e9c14594557a94f0ab5a1c4e0124d43cf628..7d594351fd28dc1be3207ca8b74700f13e45b9c6 100644 --- a/src/decentralizepy/training/Training.py +++ b/src/decentralizepy/training/Training.py @@ -37,7 +37,7 @@ class Training: batch_size : int, optional Number of items to learn over, in one batch shuffle : bool - True if the dataset should be shuffled before training. Not implemented yet! TODO + True if the dataset should be shuffled before training. """ self.model = model self.optimizer = optimizer