From 76f1fbb08d290a13893778491b25f2185bd1553d Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Wed, 9 Mar 2022 12:20:31 +0100
Subject: [PATCH] integrating new sharing methods

---
 eval/96_regular.edges                         | 381 ++++++++++++++++++
 eval/plot.py                                  |   6 +
 eval/run.sh                                   |  22 +-
 eval/step_configs/config_femnist_fft.ini      |  37 ++
 eval/step_configs/config_femnist_sharing.ini  |  32 ++
 .../config_femnist_subsampling.ini            |  34 ++
 eval/step_configs/config_femnist_topk.ini     |  36 ++
 .../step_configs/config_femnist_topkparam.ini |  34 ++
 eval/step_configs/config_femnist_wavelet.ini  |  41 ++
 eval/testing.py                               |   3 +-
 setup.cfg                                     |   1 +
 src/decentralizepy/models/Model.py            |   3 +
 src/decentralizepy/node/Node.py               |   6 +
 src/decentralizepy/sharing/FFT.py             | 333 +++++++++++++++
 src/decentralizepy/sharing/SubSampling.py     | 287 +++++++++++++
 src/decentralizepy/sharing/TopK.py            | 227 +++++++++++
 src/decentralizepy/sharing/TopKParams.py      | 225 +++++++++++
 src/decentralizepy/sharing/Wavelet.py         | 370 +++++++++++++++++
 .../training/FrequencyAccumulator.py          | 105 +++++
 .../training/FrequencyWaveletAccumulator.py   | 113 ++++++
 .../training/ModelChangeAccumulator.py        | 103 +++++
 src/decentralizepy/training/Training.py       |   2 +-
 22 files changed, 2384 insertions(+), 17 deletions(-)
 create mode 100644 eval/96_regular.edges
 create mode 100644 eval/step_configs/config_femnist_fft.ini
 create mode 100644 eval/step_configs/config_femnist_sharing.ini
 create mode 100644 eval/step_configs/config_femnist_subsampling.ini
 create mode 100644 eval/step_configs/config_femnist_topk.ini
 create mode 100644 eval/step_configs/config_femnist_topkparam.ini
 create mode 100644 eval/step_configs/config_femnist_wavelet.ini
 create mode 100644 src/decentralizepy/sharing/FFT.py
 create mode 100644 src/decentralizepy/sharing/SubSampling.py
 create mode 100644 src/decentralizepy/sharing/TopK.py
 create mode 100644 src/decentralizepy/sharing/TopKParams.py
 create mode 100644 src/decentralizepy/sharing/Wavelet.py
 create mode 100644 src/decentralizepy/training/FrequencyAccumulator.py
 create mode 100644 src/decentralizepy/training/FrequencyWaveletAccumulator.py
 create mode 100644 src/decentralizepy/training/ModelChangeAccumulator.py

diff --git a/eval/96_regular.edges b/eval/96_regular.edges
new file mode 100644
index 0000000..0db09a2
--- /dev/null
+++ b/eval/96_regular.edges
@@ -0,0 +1,381 @@
+96
+0 24
+0 1
+0 26
+0 95
+1 2
+1 0
+1 82
+1 83
+2 33
+2 90
+2 3
+2 1
+3 2
+3 4
+3 14
+3 79
+4 3
+4 12
+4 5
+4 86
+5 64
+5 42
+5 4
+5 6
+6 9
+6 5
+6 62
+6 7
+7 24
+7 8
+7 45
+7 6
+8 81
+8 17
+8 9
+8 7
+9 8
+9 10
+9 53
+9 6
+10 9
+10 11
+10 29
+10 31
+11 80
+11 10
+11 36
+11 12
+12 11
+12 4
+12 13
+12 70
+13 12
+13 53
+13 30
+13 14
+14 3
+14 15
+14 13
+14 47
+15 16
+15 26
+15 14
+16 41
+16 17
+16 15
+17 8
+17 16
+17 18
+17 83
+18 17
+18 19
+18 95
+18 63
+19 82
+19 18
+19 20
+19 22
+20 19
+20 59
+20 21
+20 22
+21 72
+21 58
+21 20
+21 22
+22 19
+22 20
+22 21
+22 23
+23 24
+23 65
+23 85
+23 22
+24 0
+24 25
+24 23
+24 7
+25 32
+25 24
+25 26
+25 38
+26 0
+26 25
+26 27
+26 15
+27 32
+27 26
+27 28
+27 63
+28 27
+28 92
+28 29
+28 39
+29 10
+29 52
+29 28
+29 30
+30 66
+30 29
+30 13
+30 31
+31 32
+31 10
+31 36
+31 30
+32 25
+32 27
+32 31
+32 33
+33 32
+33 2
+33 84
+33 34
+34 33
+34 50
+34 35
+34 93
+35 57
+35 34
+35 43
+35 36
+36 35
+36 11
+36 37
+36 31
+37 88
+37 36
+37 38
+37 79
+38 25
+38 37
+38 39
+38 49
+39 40
+39 28
+39 77
+39 38
+40 41
+40 91
+40 39
+40 87
+41 16
+41 40
+41 42
+41 51
+42 41
+42 43
+42 5
+43 42
+43 35
+43 44
+44 72
+44 43
+44 75
+44 45
+45 67
+45 44
+45 46
+45 7
+46 76
+46 45
+46 54
+46 47
+47 48
+47 65
+47 14
+47 46
+48 56
+48 49
+48 61
+48 47
+49 48
+49 50
+49 38
+49 71
+50 49
+50 34
+50 51
+50 93
+51 41
+51 50
+51 52
+51 95
+52 51
+52 74
+52 53
+52 29
+53 9
+53 52
+53 13
+53 54
+54 75
+54 53
+54 46
+54 55
+55 56
+55 69
+55 85
+55 54
+56 48
+56 57
+56 69
+56 55
+57 56
+57 89
+57 58
+57 35
+58 57
+58 59
+58 21
+58 86
+59 73
+59 58
+59 20
+59 60
+60 62
+60 59
+60 61
+60 78
+61 48
+61 62
+61 60
+61 94
+62 60
+62 61
+62 6
+62 63
+63 64
+63 18
+63 27
+63 62
+64 65
+64 84
+64 5
+64 63
+65 64
+65 66
+65 23
+65 47
+66 65
+66 89
+66 67
+66 30
+67 80
+67 66
+67 68
+67 45
+68 67
+68 92
+68 69
+68 94
+69 56
+69 68
+69 70
+69 55
+70 90
+70 12
+70 69
+70 71
+71 72
+71 49
+71 70
+71 87
+72 73
+72 44
+72 21
+72 71
+73 72
+73 91
+73 59
+73 74
+74 73
+74 75
+74 52
+74 76
+75 74
+75 44
+75 54
+75 76
+76 74
+76 75
+76 77
+76 46
+77 81
+77 76
+77 78
+77 39
+78 88
+78 60
+78 77
+78 79
+79 80
+79 3
+79 37
+79 78
+80 81
+80 67
+80 11
+80 79
+81 8
+81 82
+81 80
+81 77
+82 81
+82 1
+82 83
+82 19
+83 1
+83 82
+83 84
+83 17
+84 64
+84 33
+84 83
+84 85
+85 84
+85 55
+85 86
+85 23
+86 58
+86 4
+86 85
+86 87
+87 40
+87 88
+87 86
+87 71
+88 89
+88 37
+88 78
+88 87
+89 88
+89 57
+89 66
+89 90
+90 89
+90 2
+90 91
+90 70
+91 40
+91 73
+91 90
+91 92
+92 93
+92 91
+92 68
+92 28
+93 50
+93 34
+93 94
+93 92
+94 93
+94 68
+94 61
+94 95
+95 0
+95 18
+95 51
+95 94
diff --git a/eval/plot.py b/eval/plot.py
index d3c3a39..f354937 100644
--- a/eval/plot.py
+++ b/eval/plot.py
@@ -61,14 +61,20 @@ def plot_results(path):
         plt.figure(1)
         means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
         plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
+        with open(os.path.join(path, "train_loss_" + folder + ".json"), "w") as f:
+            json.dump({"mean": means, "std": stdevs}, f)
         # Plot Testing loss
         plt.figure(2)
         means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
         plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
+        with open(os.path.join(path, "test_loss_" + folder + ".json"), "w") as f:
+            json.dump({"mean": means, "std": stdevs}, f)
         # Plot Testing Accuracy
         plt.figure(3)
         means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
         plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
+        with open(os.path.join(path, "test_acc_" + folder + ".json"), "w") as f:
+            json.dump({"mean": means, "std": stdevs}, f)
         plt.figure(6)
         means, stdevs, mins, maxs = get_stats([x["grad_std"] for x in results])
         plot(
diff --git a/eval/run.sh b/eval/run.sh
index 9869a17..0198413 100755
--- a/eval/run.sh
+++ b/eval/run.sh
@@ -4,29 +4,21 @@ decpy_path=~/Gitlab/decentralizepy/eval
 cd $decpy_path
 
 env_python=~/miniconda3/envs/decpy/bin/python3
-graph=96_nodes_random1.edges
+graph=96_regular.edges
 original_config=epoch_configs/config_celeba.ini
 config_file=/tmp/config.ini
 procs_per_machine=16
 machines=6
-iterations=76
-test_after=2
+iterations=200
+test_after=10
 eval_file=testing.py
 log_level=INFO
+log_dir_base=/mnt/nfs/some_user/logs/test
 
 m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
 
-cp $original_config $config_file
-echo "alpha = 0.75" >> $config_file
-$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
-
-cp $original_config $config_file
-echo "alpha = 0.50" >> $config_file
-$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+log_dir=$log_dir_base$m
 
 cp $original_config $config_file
-echo "alpha = 0.10" >> $config_file
-$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
-
-config_file=epoch_configs/config_celeba_100.ini
-$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $original_config -ll $log_level
+# echo "alpha = 0.10" >> $config_file
+$env_python $eval_file -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
\ No newline at end of file
diff --git a/eval/step_configs/config_femnist_fft.ini b/eval/step_configs/config_femnist_fft.ini
new file mode 100644
index 0000000..32c5e17
--- /dev/null
+++ b/eval/step_configs/config_femnist_fft.ini
@@ -0,0 +1,37 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Femnist
+dataset_class = Femnist
+model_class = CNN
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+# There are 734463 femnist samples
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.FrequencyAccumulator
+training_class = FrequencyAccumulator
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+accumulation = True
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.FFT
+sharing_class = FFT
+alpha = 0.1
+change_based_selection = True
+accumulation = True
\ No newline at end of file
diff --git a/eval/step_configs/config_femnist_sharing.ini b/eval/step_configs/config_femnist_sharing.ini
new file mode 100644
index 0000000..42ab50c
--- /dev/null
+++ b/eval/step_configs/config_femnist_sharing.ini
@@ -0,0 +1,32 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Femnist
+dataset_class = Femnist
+model_class = CNN
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 10
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.Sharing
+sharing_class = Sharing
diff --git a/eval/step_configs/config_femnist_subsampling.ini b/eval/step_configs/config_femnist_subsampling.ini
new file mode 100644
index 0000000..53121d8
--- /dev/null
+++ b/eval/step_configs/config_femnist_subsampling.ini
@@ -0,0 +1,34 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Femnist
+dataset_class = Femnist
+model_class = CNN
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+# There are 734463 femnist samples
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.SubSampling
+sharing_class = SubSampling
+alpha = 0.1
diff --git a/eval/step_configs/config_femnist_topk.ini b/eval/step_configs/config_femnist_topk.ini
new file mode 100644
index 0000000..57ba8f0
--- /dev/null
+++ b/eval/step_configs/config_femnist_topk.ini
@@ -0,0 +1,36 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Femnist
+dataset_class = Femnist
+model_class = CNN
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
+; python list of fractions below
+sizes =
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+# There are 734463 femnist samples
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.ModelChangeAccumulator
+training_class = ModelChangeAccumulator
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+accumulation = True
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.TopK
+sharing_class = TopK
+alpha = 0.1
+accumulation = True
\ No newline at end of file
diff --git a/eval/step_configs/config_femnist_topkparam.ini b/eval/step_configs/config_femnist_topkparam.ini
new file mode 100644
index 0000000..41c50c0
--- /dev/null
+++ b/eval/step_configs/config_femnist_topkparam.ini
@@ -0,0 +1,34 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Femnist
+dataset_class = Femnist
+model_class = CNN
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+# There are 734463 femnist samples
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.TopKParams
+sharing_class = TopKParams
+alpha = 0.1
diff --git a/eval/step_configs/config_femnist_wavelet.ini b/eval/step_configs/config_femnist_wavelet.ini
new file mode 100644
index 0000000..e53e3ea
--- /dev/null
+++ b/eval/step_configs/config_femnist_wavelet.ini
@@ -0,0 +1,41 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Femnist
+dataset_class = Femnist
+model_class = CNN
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+# There are 734463 femnist samples
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.FrequencyWaveletAccumulator
+training_class = FrequencyWaveletAccumulator
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+wavelet=sym2
+level= None
+accumulation = True
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.Wavelet
+sharing_class = Wavelet
+change_based_selection = True
+alpha = 0.1
+wavelet=sym2
+level= None
+accumulation = True
diff --git a/eval/testing.py b/eval/testing.py
index abd6333..0ae70de 100644
--- a/eval/testing.py
+++ b/eval/testing.py
@@ -24,7 +24,8 @@ def read_ini(file_path):
 if __name__ == "__main__":
     args = utils.get_args()
 
-    Path(args.log_dir).mkdir(parents=True, exist_ok=True)
+    # prevents accidental log overwrites
+    Path(args.log_dir).mkdir(parents=True, exist_ok=False)
 
     log_level = {
         "INFO": logging.INFO,
diff --git a/setup.cfg b/setup.cfg
index 3faa1f3..2ffd572 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -42,6 +42,7 @@ install_requires =
         pillow
         smallworld
         localconfig
+        PyWavelets
 include_package_data = True
 python_requires = >=3.6
 [options.packages.find]
diff --git a/src/decentralizepy/models/Model.py b/src/decentralizepy/models/Model.py
index f757500..e9e556b 100644
--- a/src/decentralizepy/models/Model.py
+++ b/src/decentralizepy/models/Model.py
@@ -17,6 +17,9 @@ class Model(nn.Module):
         self.accumulated_gradients = []
         self._param_count_ot = None
         self._param_count_total = None
+        self.accumulated_frequency = None
+        self.prev_model_params = None
+        self.prev = None
 
     def count_params(self, only_trainable=False):
         """
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index e5764ae..fd2c75f 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -92,6 +92,8 @@ class Node:
             The object containing the mapping rank <--> uid
         graph : decentralizepy.graphs
             The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) ) for which the model should be trained
         log_dir : str
             Logging directory
         reset_optimizer : int
@@ -278,6 +280,8 @@ class Node:
             The object containing the global graph
         config : dict
             A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) ) for which the model should be trained
         log_dir : str
             Logging directory
         log_level : logging.Level
@@ -443,6 +447,8 @@ class Node:
                 training_class = Training
                 epochs_per_round = 25
                 batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) ) for which the model should be trained
         log_dir : str
             Logging directory
         log_level : logging.Level
diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
new file mode 100644
index 0000000..4a3ee36
--- /dev/null
+++ b/src/decentralizepy/sharing/FFT.py
@@ -0,0 +1,333 @@
+import base64
+import json
+import logging
+import os
+import pickle
+from pathlib import Path
+from time import time
+
+import torch
+import torch.fft as fft
+
+from decentralizepy.sharing.Sharing import Sharing
+
+
+class FFT(Sharing):
+    """
+    This class implements the fft version of model sharing
+    It is based on PartialModel.py
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha=1.0,
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+        pickle=True,
+        change_based_selection=True,
+        accumulation=True,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        alpha : float
+            Percentage of model to share
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+        pickle : bool
+            use pickle to serialize the model parameters
+        change_based_selection : bool
+            use frequency change to select topk frequencies
+        accumulation : bool
+            True if the the indices to share should be selected based on accumulated frequency change
+        """
+        super().__init__(
+            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+        )
+        self.alpha = alpha
+        self.dict_ordered = dict_ordered
+        self.save_shared = save_shared
+        self.metadata_cap = metadata_cap
+        self.total_meta = 0
+
+        self.pickle = pickle
+
+        logging.info("subsampling pickling=" + str(pickle))
+
+        if self.save_shared:
+            # Only save for 2 procs: Save space
+            if rank != 0 or rank != 1:
+                self.save_shared = False
+
+        if self.save_shared:
+            self.folder_path = os.path.join(
+                self.log_dir, "shared_params/{}".format(self.rank)
+            )
+            Path(self.folder_path).mkdir(parents=True, exist_ok=True)
+
+        self.change_based_selection = change_based_selection
+        self.accumulation = accumulation
+
+    def apply_fft(self):
+        """
+        Does fft transformation of the model parameters and selects topK (alpha) of them in the frequency domain
+        based on the undergone change during the current training step
+
+        Returns
+        -------
+        tuple
+            (a,b). a: selected fft frequencies (complex numbers), b: Their indices.
+
+        """
+
+        logging.info("Returning fft compressed model weights")
+        tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
+        concated = torch.cat(tensors_to_cat, dim=0)
+
+        if self.change_based_selection:
+            flat_fft = fft.rfft(concated)
+            if self.accumulation:
+                logging.info(
+                    "fft topk extract frequencies based on accumulated model frequency change"
+                )
+                diff = self.model.accumulated_frequency + (flat_fft - self.model.prev)
+            else:
+                diff = flat_fft - self.model.accumulated_frequency
+            _, index = torch.topk(
+                diff.abs(), round(self.alpha * len(flat_fft)), dim=0, sorted=False
+            )
+        else:
+            flat_fft = fft.rfft(concated)
+            _, index = torch.topk(
+                flat_fft.abs(), round(self.alpha * len(flat_fft)), dim=0, sorted=False
+            )
+
+        if self.accumulation:
+            self.model.accumulated_frequency[index] = 0.0
+        return flat_fft[index], index
+
+    def serialized_model(self):
+        """
+        Convert model to json dict. self.alpha specifies the fraction of model to send.
+
+        Returns
+        -------
+        dict
+            Model converted to json dict
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().serialized_model()
+
+        with torch.no_grad():
+            topk, indices = self.apply_fft()
+
+            if self.save_shared:
+                shared_params = dict()
+                shared_params["order"] = list(self.model.state_dict().keys())
+                shapes = dict()
+                for k, v in self.model.state_dict().items():
+                    shapes[k] = list(v.shape)
+                shared_params["shapes"] = shapes
+
+                shared_params[self.communication_round] = indices.tolist()  # is slow
+
+                shared_params["alpha"] = self.alpha
+
+                with open(
+                    os.path.join(
+                        self.folder_path,
+                        "{}_shared_params.json".format(self.communication_round + 1),
+                    ),
+                    "w",
+                ) as of:
+                    json.dump(shared_params, of)
+
+            m = dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            m["alpha"] = self.alpha
+            m["params"] = topk.numpy()
+            m["indices"] = indices.numpy()
+
+            self.total_data += len(self.communication.encrypt(m["params"]))
+            self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
+                self.communication.encrypt(m["alpha"])
+            )
+
+            return m
+
+    def deserialized_model(self, m):
+        """
+        Convert received json dict to state_dict.
+
+        Parameters
+        ----------
+        m : dict
+            json dict received
+
+        Returns
+        -------
+        state_dict
+            state_dict of received
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().deserialized_model(m)
+
+        with torch.no_grad():
+            state_dict = self.model.state_dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            shapes = []
+            lens = []
+            tensors_to_cat = []
+            for _, v in state_dict.items():
+                shapes.append(v.shape)
+                t = v.flatten()
+                lens.append(t.shape[0])
+                tensors_to_cat.append(t)
+
+            T = torch.cat(tensors_to_cat, dim=0)
+
+            indices = m["indices"]
+            alpha = m["alpha"]
+            params = m["params"]
+
+            params_tensor = torch.tensor(params)
+            indices_tensor = torch.tensor(indices)
+            ret = dict()
+            ret["indices"] = indices_tensor
+            ret["params"] = params_tensor
+            return ret
+
+    def step(self):
+        """
+        Perform a sharing step. Implements D-PSGD.
+
+        """
+        t_start = time()
+        data = self.serialized_model()
+        t_post_serialize = time()
+        my_uid = self.mapping.get_uid(self.rank, self.machine_id)
+        all_neighbors = self.graph.neighbors(my_uid)
+        iter_neighbors = self.get_neighbors(all_neighbors)
+        data["degree"] = len(all_neighbors)
+        data["iteration"] = self.communication_round
+        for neighbor in iter_neighbors:
+            self.communication.send(neighbor, data)
+        t_post_send = time()
+        logging.info("Waiting for messages from neighbors")
+        while not self.received_from_all():
+            sender, data = self.communication.receive()
+            logging.debug("Received model from {}".format(sender))
+            degree = data["degree"]
+            iteration = data["iteration"]
+            del data["degree"]
+            del data["iteration"]
+            self.peer_deques[sender].append((degree, iteration, data))
+            logging.info(
+                "Deserialized received model from {} of iteration {}".format(
+                    sender, iteration
+                )
+            )
+        t_post_recv = time()
+
+        logging.info("Starting model averaging after receiving from all neighbors")
+        total = None
+        weight_total = 0
+
+        # FFT of this model
+        shapes = []
+        lens = []
+        tensors_to_cat = []
+        for _, v in self.model.state_dict().items():
+            shapes.append(v.shape)
+            t = v.flatten()
+            lens.append(t.shape[0])
+            tensors_to_cat.append(t)
+        concated = torch.cat(tensors_to_cat, dim=0)
+        flat_fft = fft.rfft(concated)
+
+        for i, n in enumerate(self.peer_deques):
+            degree, iteration, data = self.peer_deques[n].popleft()
+            logging.debug(
+                "Averaging model from neighbor {} of iteration {}".format(n, iteration)
+            )
+            data = self.deserialized_model(data)
+            params = data["params"]
+            indices = data["indices"]
+            # use local data to complement
+            topkf = flat_fft.clone().detach()
+            topkf[indices] = params
+
+            weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+            weight_total += weight
+            if total is None:
+                total = weight * topkf
+            else:
+                total += weight * topkf
+
+        # Metro-Hastings
+        total += (1 - weight_total) * flat_fft
+        reverse_total = fft.irfft(total)
+
+        start_index = 0
+        std_dict = {}
+        for i, key in enumerate(self.model.state_dict()):
+            end_index = start_index + lens[i]
+            std_dict[key] = reverse_total[start_index:end_index].reshape(shapes[i])
+            start_index = end_index
+
+        self.model.load_state_dict(std_dict)
+
+        logging.info("Model averaging complete")
+
+        self.communication_round += 1
+
+        t_end = time()
+
+        logging.info(
+            "Sharing::step | Serialize: %f; Send: %f; Recv: %f; Averaging: %f; Total: %f",
+            t_post_serialize - t_start,
+            t_post_send - t_post_serialize,
+            t_post_recv - t_post_send,
+            t_end - t_post_recv,
+            t_end - t_start,
+        )
diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py
new file mode 100644
index 0000000..6fe3f93
--- /dev/null
+++ b/src/decentralizepy/sharing/SubSampling.py
@@ -0,0 +1,287 @@
+import base64
+import json
+import logging
+import os
+import pickle
+from pathlib import Path
+
+import torch
+
+from decentralizepy.sharing.Sharing import Sharing
+
+
+class SubSampling(Sharing):
+    """
+    This class implements the subsampling version of model sharing
+    It is based on PartialModel.py
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha=1.0,
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+        pickle=True,
+        layerwise=False,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        alpha : float
+            Percentage of model to share
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+        pickle : bool
+            use pickle to serialize the model parameters
+
+        """
+        super().__init__(
+            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+        )
+        self.alpha = alpha
+        self.dict_ordered = dict_ordered
+        self.save_shared = save_shared
+        self.metadata_cap = metadata_cap
+        self.total_meta = 0
+
+        # self.random_seed_generator = torch.Generator()
+        # # Will use the random device if supported by CPU, else uses the system time
+        # # In the latter case we could get duplicate seeds on some of the machines
+        # self.random_seed_generator.seed()
+
+        self.random_generator = torch.Generator()
+        # Will use the random device if supported by CPU, else uses the system time
+        # In the latter case we could get duplicate seeds on some of the machines
+        self.random_generator.seed()
+        self.seed = self.random_generator.initial_seed()
+
+        self.pickle = pickle
+        self.layerwise = layerwise
+
+        logging.info("subsampling pickling=" + str(pickle))
+
+        if self.save_shared:
+            # Only save for 2 procs: Save space
+            if rank != 0 or rank != 1:
+                self.save_shared = False
+
+        if self.save_shared:
+            self.folder_path = os.path.join(
+                self.log_dir, "shared_params/{}".format(self.rank)
+            )
+            Path(self.folder_path).mkdir(parents=True, exist_ok=True)
+
+    def apply_subsampling(self):
+        """
+        Creates a random binary mask that is used to subsample the parameters that will be shared
+
+        Returns
+        -------
+        tuple
+            (a,b,c). a: the selected parameters as flat vector, b: the random seed used to crate the binary mask
+                     c: the alpha
+
+        """
+
+        logging.info("Returning subsampling gradients")
+        if not self.layerwise:
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            concated = torch.cat(tensors_to_cat, dim=0)
+
+            curr_seed = self.seed + self.communication_round  # is increased in step
+            self.random_generator.manual_seed(curr_seed)
+            # logging.debug("Subsampling seed for uid = " + str(self.uid) + " is: " + str(curr_seed))
+            # Or we could use torch.bernoulli
+            binary_mask = (
+                torch.rand(
+                    size=(concated.size(dim=0),), generator=self.random_generator
+                )
+                <= self.alpha
+            )
+            subsample = concated[binary_mask]
+            # logging.debug("Subsampling vector is of size: " + str(subsample.size(dim = 0)))
+            return (subsample, curr_seed, self.alpha)
+        else:
+            values_list = []
+            offsets = [0]
+            off = 0
+            curr_seed = self.seed + self.communication_round  # is increased in step
+            self.random_generator.manual_seed(curr_seed)
+            for _, v in self.model.state_dict().items():
+                flat = v.flatten()
+                binary_mask = (
+                    torch.rand(
+                        size=(flat.size(dim=0),), generator=self.random_generator
+                    )
+                    <= self.alpha
+                )
+                selected = flat[binary_mask]
+                values_list.append(selected)
+                off += selected.size(dim=0)
+                offsets.append(off)
+            subsample = torch.cat(values_list, dim=0)
+            return (subsample, curr_seed, self.alpha)
+
+    def serialized_model(self):
+        """
+        Convert model to json dict. self.alpha specifies the fraction of model to send.
+
+        Returns
+        -------
+        dict
+            Model converted to json dict
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().serialized_model()
+
+        with torch.no_grad():
+            subsample, seed, alpha = self.apply_subsampling()
+
+            if self.save_shared:
+                shared_params = dict()
+                shared_params["order"] = list(self.model.state_dict().keys())
+                shapes = dict()
+                for k, v in self.model.state_dict().items():
+                    shapes[k] = list(v.shape)
+                shared_params["shapes"] = shapes
+
+                # TODO: should store the shared indices and not the value
+                # shared_params[self.communication_round] = subsample.tolist() # is slow
+
+                shared_params["seed"] = seed
+
+                shared_params["alpha"] = alpha
+
+                with open(
+                    os.path.join(
+                        self.folder_path,
+                        "{}_shared_params.json".format(self.communication_round + 1),
+                    ),
+                    "w",
+                ) as of:
+                    json.dump(shared_params, of)
+
+            m = dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            m["seed"] = seed
+            m["alpha"] = alpha
+            m["params"] = subsample.numpy()
+
+            # logging.info("Converted dictionary to json")
+            self.total_data += len(self.communication.encrypt(m["params"]))
+            self.total_meta += len(self.communication.encrypt(m["seed"])) + len(
+                self.communication.encrypt(m["alpha"])
+            )
+
+            return m
+
+    def deserialized_model(self, m):
+        """
+        Convert received json dict to state_dict.
+
+        Parameters
+        ----------
+        m : dict
+            json dict received
+
+        Returns
+        -------
+        state_dict
+            state_dict of received
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().deserialized_model(m)
+
+        with torch.no_grad():
+            state_dict = self.model.state_dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            seed = m["seed"]
+            alpha = m["alpha"]
+            params = m["params"]
+
+            random_generator = (
+                torch.Generator()
+            )  # new generator, such that we do not overwrite the other one
+            random_generator.manual_seed(seed)
+
+            shapes = []
+            lens = []
+            tensors_to_cat = []
+            binary_submasks = []
+            for _, v in state_dict.items():
+                shapes.append(v.shape)
+                t = v.flatten()
+                lens.append(t.shape[0])
+                tensors_to_cat.append(t)
+                if self.layerwise:
+                    binary_mask = (
+                        torch.rand(size=(t.size(dim=0),), generator=random_generator)
+                        <= alpha
+                    )
+                    binary_submasks.append(binary_mask)
+
+            T = torch.cat(tensors_to_cat, dim=0)
+
+            params_tensor = torch.from_numpy(params)
+
+            if not self.layerwise:
+                binary_mask = (
+                    torch.rand(size=(T.size(dim=0),), generator=random_generator)
+                    <= alpha
+                )
+            else:
+                binary_mask = torch.cat(binary_submasks, dim=0)
+
+            logging.debug("Original tensor: {}".format(T[binary_mask]))
+            T[binary_mask] = params_tensor
+            logging.debug("Final tensor: {}".format(T[binary_mask]))
+
+            start_index = 0
+            for i, key in enumerate(state_dict):
+                end_index = start_index + lens[i]
+                state_dict[key] = T[start_index:end_index].reshape(shapes[i])
+                start_index = end_index
+
+            return state_dict
diff --git a/src/decentralizepy/sharing/TopK.py b/src/decentralizepy/sharing/TopK.py
new file mode 100644
index 0000000..47b4151
--- /dev/null
+++ b/src/decentralizepy/sharing/TopK.py
@@ -0,0 +1,227 @@
+import json
+import logging
+import os
+from pathlib import Path
+
+import torch
+
+from decentralizepy.sharing.Sharing import Sharing
+
+
+class TopK(Sharing):
+    """
+    This class implements topk selection of model parameters based on the model change since the beginning of the
+    communication step: --> Use ModelChangeAccumulator
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha=1.0,
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+        accumulation=False,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        alpha : float
+            Percentage of model to share
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+
+        """
+        super().__init__(
+            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+        )
+        self.alpha = alpha
+        self.dict_ordered = dict_ordered
+        self.save_shared = save_shared
+        self.metadata_cap = metadata_cap
+        self.total_meta = 0
+        self.accumulation = accumulation
+
+        if self.save_shared:
+            # Only save for 2 procs: Save space
+            if rank != 0 or rank != 1:
+                self.save_shared = False
+
+        if self.save_shared:
+            self.folder_path = os.path.join(
+                self.log_dir, "shared_params/{}".format(self.rank)
+            )
+            Path(self.folder_path).mkdir(parents=True, exist_ok=True)
+
+    def extract_top_gradients(self):
+        """
+        Extract the indices and values of the topK gradients.
+        The gradients must have been accumulationd.
+
+        Returns
+        -------
+        tuple
+            (a,b). a: The magnitudes of the topK gradients, b: Their indices.
+
+        """
+        tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
+        concated = torch.cat(tensors_to_cat, dim=0)
+        if self.accumulation:
+            logging.info(
+                "TopK extract gradients based on accumulated model parameter change"
+            )
+            diff = self.model.prev_model_params + (concated - self.model.prev)
+        else:
+            diff = concated - self.model.prev_model_params
+        G_topk = torch.abs(diff)
+
+        std, mean = torch.std_mean(G_topk, unbiased=False)
+        self.std = std.item()
+        self.mean = mean.item()
+        value, ind = torch.topk(
+            G_topk, round(self.alpha * G_topk.shape[0]), dim=0, sorted=False
+        )
+
+        # only needed when ModelChangeAccumulator.accumulation = True
+        # does not cause problems otherwise
+        if self.accumulation:
+            self.model.prev_model_params[ind] = 0.0  # torch.zeros((len(G_topk),))
+        return value, ind
+
+    def serialized_model(self):
+        """
+        Convert model to a dict. self.alpha specifies the fraction of model to send.
+
+        Returns
+        -------
+        dict
+            Model converted to a dict
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().serialized_model()
+
+        with torch.no_grad():
+            _, G_topk = self.extract_top_gradients()
+
+            if self.save_shared:
+                shared_params = dict()
+                shared_params["order"] = list(self.model.state_dict().keys())
+                shapes = dict()
+                for k, v in self.model.state_dict().items():
+                    shapes[k] = list(v.shape)
+                shared_params["shapes"] = shapes
+
+                shared_params[self.communication_round] = G_topk.tolist()
+
+                with open(
+                    os.path.join(
+                        self.folder_path,
+                        "{}_shared_params.json".format(self.communication_round + 1),
+                    ),
+                    "w",
+                ) as of:
+                    json.dump(shared_params, of)
+
+            logging.info("Extracting topk params")
+
+            tensors_to_cat = [v.data.flatten() for v in self.model.parameters()]
+            T = torch.cat(tensors_to_cat, dim=0)
+            T_topk = T[G_topk]
+
+            logging.info("Generating dictionary to send")
+
+            m = dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            m["indices"] = G_topk.numpy()
+            m["params"] = T_topk.numpy()
+
+            assert len(m["indices"]) == len(m["params"])
+            logging.info("Elements sending: {}".format(len(m["indices"])))
+
+            logging.info("Generated dictionary to send")
+
+            logging.info("Converted dictionary to pickle")
+            self.total_data += len(self.communication.encrypt(m["params"]))
+            self.total_meta += len(self.communication.encrypt(m["indices"]))
+
+            return m
+
+    def deserialized_model(self, m):
+        """
+        Convert received dict to state_dict.
+
+        Parameters
+        ----------
+        m : dict
+            dict received
+
+        Returns
+        -------
+        state_dict
+            state_dict of received
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().deserialized_model(m)
+
+        with torch.no_grad():
+            state_dict = self.model.state_dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            shapes = []
+            lens = []
+            tensors_to_cat = []
+            for _, v in state_dict.items():
+                shapes.append(v.shape)
+                t = v.flatten()
+                lens.append(t.shape[0])
+                tensors_to_cat.append(t)
+
+            T = torch.cat(tensors_to_cat, dim=0)
+            index_tensor = torch.tensor(m["indices"])
+            logging.debug("Original tensor: {}".format(T[index_tensor]))
+            T[index_tensor] = torch.tensor(m["params"])
+            logging.debug("Final tensor: {}".format(T[index_tensor]))
+            start_index = 0
+            for i, key in enumerate(state_dict):
+                end_index = start_index + lens[i]
+                state_dict[key] = T[start_index:end_index].reshape(shapes[i])
+                start_index = end_index
+
+            return state_dict
diff --git a/src/decentralizepy/sharing/TopKParams.py b/src/decentralizepy/sharing/TopKParams.py
new file mode 100644
index 0000000..3beb10f
--- /dev/null
+++ b/src/decentralizepy/sharing/TopKParams.py
@@ -0,0 +1,225 @@
+import json
+import logging
+import os
+from pathlib import Path
+
+import torch
+
+from decentralizepy.sharing.Sharing import Sharing
+
+
+class TopKParams(Sharing):
+    """
+    This class implements the vanilla version of partial model sharing.
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha=1.0,
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        alpha : float
+            Percentage of model to share
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+
+        """
+        super().__init__(
+            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+        )
+        self.alpha = alpha
+        self.dict_ordered = dict_ordered
+        self.save_shared = save_shared
+        self.metadata_cap = metadata_cap
+        self.total_meta = 0
+
+        if self.save_shared:
+            # Only save for 2 procs: Save space
+            if rank != 0 or rank != 1:
+                self.save_shared = False
+
+        if self.save_shared:
+            self.folder_path = os.path.join(
+                self.log_dir, "shared_params/{}".format(self.rank)
+            )
+            Path(self.folder_path).mkdir(parents=True, exist_ok=True)
+
+    def extract_top_params(self):
+        """
+        Extract the indices and values of the topK params layerwise.
+        The gradients must have been accumulated.
+
+        Returns
+        -------
+        tuple
+            (a,b,c). a: The topK params, b: Their indices, c: The offsets
+
+        """
+
+        logging.info("Returning TopKParams gradients")
+        values_list = []
+        index_list = []
+        offsets = [0]
+        off = 0
+        for _, v in self.model.state_dict().items():
+            flat = v.flatten()
+            values, index = torch.topk(
+                flat.abs(), round(self.alpha * flat.size(dim=0)), dim=0, sorted=False
+            )
+            values_list.append(flat[index])
+            index_list.append(index)
+            off += values.size(dim=0)
+            offsets.append(off)
+        cat_values = torch.cat(values_list, dim=0)
+        cat_index = torch.cat(index_list, dim=0)
+
+        # logging.debug("Subsampling vector is of size: " + str(subsample.size(dim = 0)))
+        return (cat_values, cat_index, offsets)
+
+    def serialized_model(self):
+        """
+        Convert model to json dict. self.alpha specifies the fraction of model to send.
+
+        Returns
+        -------
+        dict
+            Model converted to json dict
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().serialized_model()
+
+        with torch.no_grad():
+            values, index, offsets = self.extract_top_params()
+
+            if self.save_shared:
+                shared_params = dict()
+                shared_params["order"] = list(self.model.state_dict().keys())
+                shapes = dict()
+                for k, v in self.model.state_dict().items():
+                    shapes[k] = list(v.shape)
+                shared_params["shapes"] = shapes
+
+                shared_params[self.communication_round] = index.tolist()
+                # TODO: store offsets
+
+                with open(
+                    os.path.join(
+                        self.folder_path,
+                        "{}_shared_params.json".format(self.communication_round + 1),
+                    ),
+                    "w",
+                ) as of:
+                    json.dump(shared_params, of)
+
+            logging.info("Extracting topk params")
+
+            logging.info("Generating dictionary to send")
+
+            m = dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            m["indices"] = index.numpy()
+            m["params"] = values.numpy()
+            m["offsets"] = offsets
+
+            assert len(m["indices"]) == len(m["params"])
+            logging.info("Elements sending: {}".format(len(m["indices"])))
+
+            logging.info("Generated dictionary to send")
+
+            # for key in m:
+            #    m[key] = json.dumps(m[key])
+
+            logging.info("Converted dictionary to json")
+            self.total_data += len(self.communication.encrypt(m["params"]))
+            self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
+                self.communication.encrypt(m["offsets"])
+            )
+
+            return m
+
+    def deserialized_model(self, m):
+        """
+        Convert received json dict to state_dict.
+
+        Parameters
+        ----------
+        m : dict
+            json dict received
+
+        Returns
+        -------
+        state_dict
+            state_dict of received
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().deserialized_model(m)
+
+        with torch.no_grad():
+            state_dict = self.model.state_dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            shapes = []
+            lens = []
+            tensors_to_cat = []
+            offsets = m["offsets"]
+            params = torch.tensor(m["params"])
+            indices = torch.tensor(m["indices"])
+
+            for i, (_, v) in enumerate(state_dict.items()):
+                shapes.append(v.shape)
+                t = v.flatten().clone().detach()  # it is not always copied
+                lens.append(t.shape[0])
+                index = indices[offsets[i] : offsets[i + 1]]
+                t[index] = params[offsets[i] : offsets[i + 1]]
+                tensors_to_cat.append(t)
+
+            start_index = 0
+            for i, key in enumerate(state_dict):
+                end_index = start_index + lens[i]
+                state_dict[key] = tensors_to_cat[i].reshape(shapes[i])
+                start_index = end_index
+
+            return state_dict
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
new file mode 100644
index 0000000..a6cccaf
--- /dev/null
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -0,0 +1,370 @@
+import base64
+import json
+import logging
+import os
+import pickle
+from pathlib import Path
+from time import time
+
+import pywt
+import torch
+
+from decentralizepy.sharing.Sharing import Sharing
+
+
+class Wavelet(Sharing):
+    """
+    This class implements the wavelet version of model sharing
+    It is based on PartialModel.py
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha=1.0,
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+        pickle=True,
+        wavelet="haar",
+        level=4,
+        change_based_selection=True,
+        accumulation=False,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        alpha : float
+            Percentage of model to share
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+        pickle : bool
+            use pickle to serialize the model parameters
+        wavelet: str
+            name of the wavelet to be used in gradient compression
+        level: int
+            name of the wavelet to be used in gradient compression
+        change_based_selection : bool
+            use frequency change to select topk frequencies
+        accumulation : bool
+            True if the the indices to share should be selected based on accumulated frequency change
+        """
+        super().__init__(
+            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+        )
+        self.alpha = alpha
+        self.dict_ordered = dict_ordered
+        self.save_shared = save_shared
+        self.metadata_cap = metadata_cap
+        self.total_meta = 0
+
+        self.pickle = pickle
+        self.wavelet = wavelet
+        self.level = level
+        self.accumulation = accumulation
+
+        logging.info("subsampling pickling=" + str(pickle))
+
+        if self.save_shared:
+            # Only save for 2 procs: Save space
+            if rank != 0 or rank != 1:
+                self.save_shared = False
+
+        if self.save_shared:
+            self.folder_path = os.path.join(
+                self.log_dir, "shared_params/{}".format(self.rank)
+            )
+            Path(self.folder_path).mkdir(parents=True, exist_ok=True)
+
+        self.change_based_selection = change_based_selection
+
+    def apply_wavelet(self):
+        """
+        Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain
+        based on the undergone change during the current training step
+
+        Returns
+        -------
+        tuple
+            (a,b). a: selected wavelet coefficients, b: Their indices.
+
+        """
+
+        logging.info("Returning dwt compressed model weights")
+        tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
+        concated = torch.cat(tensors_to_cat, dim=0)
+        if self.change_based_selection:
+            coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
+            data, coeff_slices = pywt.coeffs_to_array(
+                coeff
+            )  # coeff_slices will be reproduced on the receiver
+            data = data.ravel()
+
+            if self.accumulation:
+                logging.info(
+                    "wavelet topk extract frequencies based on accumulated model frequency change"
+                )
+                diff = self.model.accumulated_frequency + (data - self.model.prev)
+            else:
+                diff = data - self.model.accumulated_frequency
+            _, index = torch.topk(
+                torch.from_numpy(diff).abs(),
+                round(self.alpha * len(data)),
+                dim=0,
+                sorted=False,
+            )
+        else:
+            coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
+            data, coeff_slices = pywt.coeffs_to_array(
+                coeff
+            )  # coeff_slices will be reproduced on the receiver
+            data = data.ravel()
+            _, index = torch.topk(
+                torch.from_numpy(data).abs(),
+                round(self.alpha * len(data)),
+                dim=0,
+                sorted=False,
+            )
+
+        if self.accumulation:
+            self.model.accumulated_frequency[index] = 0.0
+        return torch.from_numpy(data[index]), index
+
+    def serialized_model(self):
+        """
+        Convert model to json dict. self.alpha specifies the fraction of model to send.
+
+        Returns
+        -------
+        dict
+            Model converted to json dict
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().serialized_model()
+
+        with torch.no_grad():
+            topk, indices = self.apply_wavelet()
+
+            if self.save_shared:
+                shared_params = dict()
+                shared_params["order"] = list(self.model.state_dict().keys())
+                shapes = dict()
+                for k, v in self.model.state_dict().items():
+                    shapes[k] = list(v.shape)
+                shared_params["shapes"] = shapes
+
+                shared_params[self.communication_round] = indices.tolist()  # is slow
+
+                shared_params["alpha"] = self.alpha
+
+                with open(
+                    os.path.join(
+                        self.folder_path,
+                        "{}_shared_params.json".format(self.communication_round + 1),
+                    ),
+                    "w",
+                ) as of:
+                    json.dump(shared_params, of)
+
+            m = dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            m["alpha"] = self.alpha
+
+            m["params"] = topk.numpy()
+
+            m["indices"] = indices.numpy()
+
+            self.total_data += len(self.communication.encrypt(m["params"]))
+            self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
+                self.communication.encrypt(m["alpha"])
+            )
+
+            return m
+
+    def deserialized_model(self, m):
+        """
+        Convert received json dict to state_dict.
+
+        Parameters
+        ----------
+        m : dict
+            json dict received
+
+        Returns
+        -------
+        state_dict
+            state_dict of received
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().deserialized_model(m)
+
+        with torch.no_grad():
+            state_dict = self.model.state_dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            shapes = []
+            lens = []
+            tensors_to_cat = []
+            for _, v in state_dict.items():
+                shapes.append(v.shape)
+                t = v.flatten()
+                lens.append(t.shape[0])
+                tensors_to_cat.append(t)
+
+            T = torch.cat(tensors_to_cat, dim=0)
+
+            indices = m["indices"]
+            alpha = m["alpha"]
+            params = m["params"]
+
+            params_tensor = torch.tensor(params)
+            indices_tensor = torch.tensor(indices)
+            ret = dict()
+            ret["indices"] = indices_tensor
+            ret["params"] = params_tensor
+            return ret
+
+    def step(self):
+        """
+        Perform a sharing step. Implements D-PSGD.
+
+        """
+        t_start = time()
+        data = self.serialized_model()
+        t_post_serialize = time()
+        my_uid = self.mapping.get_uid(self.rank, self.machine_id)
+        all_neighbors = self.graph.neighbors(my_uid)
+        iter_neighbors = self.get_neighbors(all_neighbors)
+        data["degree"] = len(all_neighbors)
+        data["iteration"] = self.communication_round
+        for neighbor in iter_neighbors:
+            self.communication.send(neighbor, data)
+        t_post_send = time()
+        logging.info("Waiting for messages from neighbors")
+        while not self.received_from_all():
+            sender, data = self.communication.receive()
+            logging.debug("Received model from {}".format(sender))
+            degree = data["degree"]
+            iteration = data["iteration"]
+            del data["degree"]
+            del data["iteration"]
+            self.peer_deques[sender].append((degree, iteration, data))
+            logging.info(
+                "Deserialized received model from {} of iteration {}".format(
+                    sender, iteration
+                )
+            )
+        t_post_recv = time()
+
+        logging.info("Starting model averaging after receiving from all neighbors")
+        total = None
+        weight_total = 0
+
+        # FFT of this model
+        shapes = []
+        lens = []
+        tensors_to_cat = []
+        # TODO: should we detach
+        for _, v in self.model.state_dict().items():
+            shapes.append(v.shape)
+            t = v.flatten()
+            lens.append(t.shape[0])
+            tensors_to_cat.append(t)
+        concated = torch.cat(tensors_to_cat, dim=0)
+        coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
+        wt_params, coeff_slices = pywt.coeffs_to_array(
+            coeff
+        )  # coeff_slices will be reproduced on the receiver
+        shape = wt_params.shape
+        wt_params = wt_params.ravel()
+
+        for i, n in enumerate(self.peer_deques):
+            degree, iteration, data = self.peer_deques[n].popleft()
+            logging.debug(
+                "Averaging model from neighbor {} of iteration {}".format(n, iteration)
+            )
+            data = self.deserialized_model(data)
+            params = data["params"]
+            indices = data["indices"]
+            # use local data to complement
+            topkwf = wt_params.copy()  # .clone().detach()
+            topkwf[indices] = params
+            topkwf = torch.from_numpy(topkwf.reshape(shape))
+
+            weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+            weight_total += weight
+            if total is None:
+                total = weight * topkwf
+            else:
+                total += weight * topkwf
+
+        # Metro-Hastings
+        total += (1 - weight_total) * wt_params
+
+        avg_wf_params = pywt.array_to_coeffs(
+            total, coeff_slices, output_format="wavedec"
+        )
+        reverse_total = torch.from_numpy(
+            pywt.waverec(avg_wf_params, wavelet=self.wavelet)
+        )
+
+        start_index = 0
+        std_dict = {}
+        for i, key in enumerate(self.model.state_dict()):
+            end_index = start_index + lens[i]
+            std_dict[key] = reverse_total[start_index:end_index].reshape(shapes[i])
+            start_index = end_index
+
+        self.model.load_state_dict(std_dict)
+
+        logging.info("Model averaging complete")
+
+        self.communication_round += 1
+
+        t_end = time()
+
+        logging.info(
+            "Sharing::step | Serialize: %f; Send: %f; Recv: %f; Averaging: %f; Total: %f",
+            t_post_serialize - t_start,
+            t_post_send - t_post_serialize,
+            t_post_recv - t_post_send,
+            t_end - t_post_recv,
+            t_end - t_start,
+        )
diff --git a/src/decentralizepy/training/FrequencyAccumulator.py b/src/decentralizepy/training/FrequencyAccumulator.py
new file mode 100644
index 0000000..9c264cc
--- /dev/null
+++ b/src/decentralizepy/training/FrequencyAccumulator.py
@@ -0,0 +1,105 @@
+import logging
+
+import torch
+from torch import fft
+
+from decentralizepy.training.Training import Training
+
+
+class FrequencyAccumulator(Training):
+    """
+    This class implements the training module which also accumulates the fft frequency at the beginning of steps a communication round.
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        model,
+        optimizer,
+        loss,
+        log_dir,
+        rounds="",
+        full_epochs="",
+        batch_size="",
+        shuffle="",
+        accumulation=True,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        model : torch.nn.Module
+            Neural Network for training
+        optimizer : torch.optim
+            Optimizer to learn parameters
+        loss : function
+            Loss function
+        log_dir : str
+            Directory to log the model change.
+        rounds : int, optional
+            Number of steps/epochs per training call
+        full_epochs: bool, optional
+            True if 1 round = 1 epoch. False if 1 round = 1 minibatch
+        batch_size : int, optional
+            Number of items to learn over, in one batch
+        shuffle : bool
+            True if the dataset should be shuffled before training.
+        accumulation : bool
+            True if the model change should be accumulated across communication steps
+        """
+        super().__init__(
+            rank,
+            machine_id,
+            mapping,
+            model,
+            optimizer,
+            loss,
+            log_dir,
+            rounds,
+            full_epochs,
+            batch_size,
+            shuffle,
+        )
+        self.accumulation = accumulation
+
+    def train(self, dataset):
+        """
+        Does one training iteration.
+        If self.accumulation is True then it accumulates model fft frequency changes in model.accumulated_frequency.
+        Otherwise it stores the current fft frequency representation of the model in model.accumulated_frequency.
+
+        Parameters
+        ----------
+        dataset : decentralizepy.datasets.Dataset
+            The training dataset. Should implement get_trainset(batch_size, shuffle)
+
+        """
+
+        # this looks at the change from the last round averaging of the frequencies
+        tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
+        concated = torch.cat(tensors_to_cat, dim=0)
+        flat_fft = fft.rfft(concated)
+        if self.accumulation:
+            if self.model.accumulated_frequency is None:
+                logging.info("Initialize fft frequency accumulation")
+                self.model.accumulated_frequency = torch.zeros_like(flat_fft)
+                self.model.prev = flat_fft
+            else:
+                logging.info("fft frequency accumulation step")
+                self.model.accumulated_frequency += flat_fft - self.model.prev
+                self.model.prev = flat_fft
+        else:
+            logging.info("fft frequency accumulation reset")
+            self.model.accumulated_frequency = flat_fft
+
+        super().train(dataset)
diff --git a/src/decentralizepy/training/FrequencyWaveletAccumulator.py b/src/decentralizepy/training/FrequencyWaveletAccumulator.py
new file mode 100644
index 0000000..cf65724
--- /dev/null
+++ b/src/decentralizepy/training/FrequencyWaveletAccumulator.py
@@ -0,0 +1,113 @@
+import logging
+
+import numpy as np
+import pywt
+import torch
+
+from decentralizepy.training.Training import Training
+
+
+class FrequencyWaveletAccumulator(Training):
+    """
+    This class implements the training module which also accumulates the wavelet frequency at the beginning of steps a communication round.
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        model,
+        optimizer,
+        loss,
+        log_dir,
+        rounds="",
+        full_epochs="",
+        batch_size="",
+        shuffle="",
+        wavelet="haar",
+        level=4,
+        accumulation=True,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        model : torch.nn.Module
+            Neural Network for training
+        optimizer : torch.optim
+            Optimizer to learn parameters
+        loss : function
+            Loss function
+        log_dir : str
+            Directory to log the model change.
+        rounds : int, optional
+            Number of steps/epochs per training call
+        full_epochs: bool, optional
+            True if 1 round = 1 epoch. False if 1 round = 1 minibatch
+        batch_size : int, optional
+            Number of items to learn over, in one batch
+        shuffle : bool
+            True if the dataset should be shuffled before training.
+        accumulation : bool
+            True if the model change should be accumulated across communication steps
+        """
+        super().__init__(
+            rank,
+            machine_id,
+            mapping,
+            model,
+            optimizer,
+            loss,
+            log_dir,
+            rounds,
+            full_epochs,
+            batch_size,
+            shuffle,
+        )
+        self.wavelet = wavelet
+        self.level = level
+        self.accumulation = accumulation
+
+    def train(self, dataset):
+        """
+        Does one training iteration.
+        If self.accumulation is True then it accumulates model wavelet frequency changes in model.accumulated_frequency.
+        Otherwise it stores the current wavelet frequency representation of the model in model.accumulated_frequency.
+
+        Parameters
+        ----------
+        dataset : decentralizepy.datasets.Dataset
+            The training dataset. Should implement get_trainset(batch_size, shuffle)
+
+        """
+
+        # this looks at the change from the last round averaging of the frequencies
+        tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
+        concated = torch.cat(tensors_to_cat, dim=0)
+        coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
+        data, coeff_slices = pywt.coeffs_to_array(coeff)
+        data = data.ravel()
+        if self.accumulation:
+            if self.model.accumulated_frequency is None:
+                logging.info("Initialize wavelet frequency accumulation")
+                self.model.accumulated_frequency = np.zeros_like(
+                    data
+                )  # torch.zeros_like(data)
+                self.model.prev = data
+            else:
+                logging.info("wavelet frequency accumulation step")
+                self.model.accumulated_frequency += data - self.model.prev
+                self.model.prev = data
+        else:
+            logging.info("wavelet frequency accumulation reset")
+            self.model.accumulated_frequency = data
+        super().train(dataset)
diff --git a/src/decentralizepy/training/ModelChangeAccumulator.py b/src/decentralizepy/training/ModelChangeAccumulator.py
new file mode 100644
index 0000000..1c70283
--- /dev/null
+++ b/src/decentralizepy/training/ModelChangeAccumulator.py
@@ -0,0 +1,103 @@
+import logging
+
+import torch
+from torch import fft
+
+from decentralizepy.training.Training import Training
+
+
+class ModelChangeAccumulator(Training):
+    """
+    This class implements the training module which also accumulates the model change at the beginning of a communication round.
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        model,
+        optimizer,
+        loss,
+        log_dir,
+        rounds="",
+        full_epochs="",
+        batch_size="",
+        shuffle="",
+        accumulation=True,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        model : torch.nn.Module
+            Neural Network for training
+        optimizer : torch.optim
+            Optimizer to learn parameters
+        loss : function
+            Loss function
+        log_dir : str
+            Directory to log the model change.
+        rounds : int, optional
+            Number of steps/epochs per training call
+        full_epochs: bool, optional
+            True if 1 round = 1 epoch. False if 1 round = 1 minibatch
+        batch_size : int, optional
+            Number of items to learn over, in one batch
+        shuffle : bool
+            True if the dataset should be shuffled before training.
+        accumulation : bool
+            True if the model change should be accumulated across communication steps
+
+        """
+        super().__init__(
+            rank,
+            machine_id,
+            mapping,
+            model,
+            optimizer,
+            loss,
+            log_dir,
+            rounds,
+            full_epochs,
+            batch_size,
+            shuffle,
+        )
+        self.accumulation = accumulation
+
+    def train(self, dataset):
+        """
+        Does one training iteration.
+        If self.accumulation is True then it accumulates model parameter changes in model.prev_model_params.
+        Otherwise it stores the current model parameters in model.prev_model_params.
+
+        Parameters
+        ----------
+        dataset : decentralizepy.datasets.Dataset
+            The training dataset. Should implement get_trainset(batch_size, shuffle)
+
+        """
+
+        tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
+        concated = torch.cat(tensors_to_cat, dim=0)
+        if self.accumulation:
+            if self.model.prev_model_params is None:
+                logging.info("Initialize model parameter accumulation.")
+                self.model.prev_model_params = torch.zeros_like(concated)
+                self.model.prev = concated
+            else:
+                logging.info("model parameter accumulation step")
+                self.model.prev_model_params += concated - self.model.prev
+                self.model.prev = concated
+        else:
+            logging.info("model parameter reset")
+            self.model.prev_model_params = concated
+        super().train(dataset)
diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py
index 3b99bef..5adc4a9 100644
--- a/src/decentralizepy/training/Training.py
+++ b/src/decentralizepy/training/Training.py
@@ -46,7 +46,7 @@ class Training:
             Directory to log the model change.
         rounds : int, optional
             Number of steps/epochs per training call
-        full_epochs: bool, optional
+        full_epochs : bool, optional
             True if 1 round = 1 epoch. False if 1 round = 1 minibatch
         batch_size : int, optional
             Number of items to learn over, in one batch
-- 
GitLab