diff --git a/eval/run_all.sh b/eval/run_all.sh
new file mode 100755
index 0000000000000000000000000000000000000000..1afdf0291386cbcf17343e83014f33fb4680adfd
--- /dev/null
+++ b/eval/run_all.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+nfs_home=$1
+python_bin=$2
+decpy_path=$nfs_home/decentralizepy/eval
+cd $decpy_path
+
+env_python=$python_bin/python3
+graph=96_regular.edges #4_node_fullyConnected.edges
+config_file=~/tmp/config.ini
+procs_per_machine=16
+machines=6
+iterations=5
+test_after=21 # we do not test
+eval_file=testing.py
+log_level=INFO
+
+ip_machines=$nfs_home/configs/ip_addr_6Machines.json
+
+m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
+export PYTHONFAULTHANDLER=1
+tests=("step_configs/config_celeba.ini" "step_configs/config_celeba_100.ini" "step_configs/config_celeba_fft.ini" "step_configs/config_celeba_wavelet.ini"
+"step_configs/config_celeba_grow.ini" "step_configs/config_celeba_manualadapt.ini" "step_configs/config_celeba_randomalpha.ini"
+"step_configs/config_celeba_randomalphainc.ini" "step_configs/config_celeba_roundrobin.ini" "step_configs/config_celeba_subsampling.ini"
+"step_configs/config_celeba_topkrandom.ini" "step_configs/config_celeba_topkacc.ini" "step_configs/config_celeba_topkparam.ini")
+
+for i in "${tests[@]}"
+do
+  echo $i
+  IFS='_' read -ra NAMES <<< $i
+  IFS='.' read -ra NAME <<< ${NAMES[-1]}
+  log_dir=$nfs_home/logs/testing/${NAME[0]}$(date '+%Y-%m-%dT%H:%M')/machine$m
+  mkdir -p $log_dir
+  cp $i $config_file
+  $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
+  $env_python $eval_file -ro 0 -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+  echo $i is done
+  sleep 3
+  echo end of sleep
+done
diff --git a/eval/step_configs/config_celeba.ini b/eval/step_configs/config_celeba.ini
index 5cadf017749dbb602188323cbca82a1dfd40dcac..6c9a4b5b16c0086b8d950a90ea24451dabe7e527 100644
--- a/eval/step_configs/config_celeba.ini
+++ b/eval/step_configs/config_celeba.ini
@@ -2,9 +2,9 @@
 dataset_package = decentralizepy.datasets.Celeba
 dataset_class = Celeba
 model_class = CNN
-images_dir = /home/risharma/leaf/data/celeba/data/raw/img_align_celeba
-train_dir = /home/risharma/leaf/data/celeba/per_user_data/train
-test_dir = /home/risharma/leaf/data/celeba/data/test
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
 ; python list of fractions below
 sizes = 
 
@@ -14,11 +14,11 @@ optimizer_class = Adam
 lr = 0.001
 
 [TRAIN_PARAMS]
-training_package = decentralizepy.training.GradientAccumulator
-training_class = GradientAccumulator
-rounds = 20
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
 full_epochs = False
-batch_size = 64
+batch_size = 16
 shuffle = True
 loss_package = torch.nn
 loss_class = CrossEntropyLoss
diff --git a/eval/step_configs/config_celeba_100.ini b/eval/step_configs/config_celeba_100.ini
index 70e14bbaf2fe87717351567aa4420ce439e3fa98..caf05fa846ca9e73d3f5df58aa1a19f01e16de1a 100644
--- a/eval/step_configs/config_celeba_100.ini
+++ b/eval/step_configs/config_celeba_100.ini
@@ -2,9 +2,9 @@
 dataset_package = decentralizepy.datasets.Celeba
 dataset_class = Celeba
 model_class = CNN
-images_dir = /home/risharma/leaf/data/celeba/data/raw/img_align_celeba
-train_dir = /home/risharma/leaf/data/celeba/per_user_data/train
-test_dir = /home/risharma/leaf/data/celeba/data/test
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
 ; python list of fractions below
 sizes = 
 
@@ -16,9 +16,9 @@ lr = 0.001
 [TRAIN_PARAMS]
 training_package = decentralizepy.training.Training
 training_class = Training
-rounds = 20
+rounds = 4
 full_epochs = False
-batch_size = 64
+batch_size = 16
 shuffle = True
 loss_package = torch.nn
 loss_class = CrossEntropyLoss
diff --git a/eval/step_configs/config_celeba_fft.ini b/eval/step_configs/config_celeba_fft.ini
new file mode 100644
index 0000000000000000000000000000000000000000..e8d6a70804ec7acd2c3b38248aa4e756b9a23cd2
--- /dev/null
+++ b/eval/step_configs/config_celeba_fft.ini
@@ -0,0 +1,36 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.FFT
+sharing_class = FFT
+alpha = 0.1
+change_based_selection = True
+accumulation = True
diff --git a/eval/step_configs/config_celeba_grow.ini b/eval/step_configs/config_celeba_grow.ini
index be0812e8c3bc1b91d14d803c3efe89e6d8e9b3e8..37e74ae0eb91a8847189e391da8a88b806a0b85b 100644
--- a/eval/step_configs/config_celeba_grow.ini
+++ b/eval/step_configs/config_celeba_grow.ini
@@ -2,9 +2,9 @@
 dataset_package = decentralizepy.datasets.Celeba
 dataset_class = Celeba
 model_class = CNN
-images_dir = /home/risharma/leaf/data/celeba/data/raw/img_align_celeba
-train_dir = /home/risharma/leaf/data/celeba/per_user_data/train
-test_dir = /home/risharma/leaf/data/celeba/data/test
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
 ; python list of fractions below
 sizes = 
 
@@ -14,11 +14,11 @@ optimizer_class = Adam
 lr = 0.001
 
 [TRAIN_PARAMS]
-training_package = decentralizepy.training.GradientAccumulator
-training_class = GradientAccumulator
-rounds = 20
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
 full_epochs = False
-batch_size = 64
+batch_size = 16
 shuffle = True
 loss_package = torch.nn
 loss_class = CrossEntropyLoss
diff --git a/eval/step_configs/config_celeba_manualadapt.ini b/eval/step_configs/config_celeba_manualadapt.ini
new file mode 100644
index 0000000000000000000000000000000000000000..1c117e2a40cabf297d305b2dd0cd20d88ca46299
--- /dev/null
+++ b/eval/step_configs/config_celeba_manualadapt.ini
@@ -0,0 +1,35 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.ManualAdapt
+sharing_class = ManualAdapt
+change_alpha = [0.1, 0.5]
+change_rounds = [10,30]
diff --git a/eval/step_configs/config_celeba_randomalpha.ini b/eval/step_configs/config_celeba_randomalpha.ini
new file mode 100644
index 0000000000000000000000000000000000000000..1c4b9893b67d0236cc6642a88f11fdf49d318ad5
--- /dev/null
+++ b/eval/step_configs/config_celeba_randomalpha.ini
@@ -0,0 +1,33 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.RandomAlpha
+sharing_class = RandomAlpha
diff --git a/eval/step_configs/config_celeba_randomalphainc.ini b/eval/step_configs/config_celeba_randomalphainc.ini
new file mode 100644
index 0000000000000000000000000000000000000000..5171b64e5d01cdef69db572231c23d2d99f3cd5e
--- /dev/null
+++ b/eval/step_configs/config_celeba_randomalphainc.ini
@@ -0,0 +1,33 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.RandomAlphaIncremental
+sharing_class = RandomAlphaIncremental
diff --git a/eval/step_configs/config_celeba_roundrobin.ini b/eval/step_configs/config_celeba_roundrobin.ini
new file mode 100644
index 0000000000000000000000000000000000000000..3dadf3274607c049aaefc3cf956a9baef239148b
--- /dev/null
+++ b/eval/step_configs/config_celeba_roundrobin.ini
@@ -0,0 +1,34 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.RoundRobinPartial
+sharing_class = RoundRobinPartial
+alpha = 0.1
diff --git a/eval/step_configs/config_celeba_subsampling.ini b/eval/step_configs/config_celeba_subsampling.ini
new file mode 100644
index 0000000000000000000000000000000000000000..b8068984fd28b2d2fab9742687c32f4d56897413
--- /dev/null
+++ b/eval/step_configs/config_celeba_subsampling.ini
@@ -0,0 +1,34 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.SubSampling
+sharing_class = SubSampling
+alpha = 0.1
diff --git a/eval/step_configs/config_celeba_topkacc.ini b/eval/step_configs/config_celeba_topkacc.ini
new file mode 100644
index 0000000000000000000000000000000000000000..89eef29dabffffd710686aeda651f17bbbc4c85c
--- /dev/null
+++ b/eval/step_configs/config_celeba_topkacc.ini
@@ -0,0 +1,35 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.PartialModel
+sharing_class = PartialModel
+alpha = 0.1
+accumulation = True
diff --git a/eval/step_configs/config_celeba_topkparam.ini b/eval/step_configs/config_celeba_topkparam.ini
new file mode 100644
index 0000000000000000000000000000000000000000..babc3e96f76fc88b2276759c7accba551d8a587d
--- /dev/null
+++ b/eval/step_configs/config_celeba_topkparam.ini
@@ -0,0 +1,34 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.TopKParams
+sharing_class = TopKParams
+alpha = 0.1
diff --git a/eval/step_configs/config_celeba_topkrandom.ini b/eval/step_configs/config_celeba_topkrandom.ini
new file mode 100644
index 0000000000000000000000000000000000000000..76749557b865379e92cee7cece1fe4978b707337
--- /dev/null
+++ b/eval/step_configs/config_celeba_topkrandom.ini
@@ -0,0 +1,34 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.TopKPlusRandom
+sharing_class = TopKPlusRandom
+alpha = 0.1
diff --git a/eval/step_configs/config_celeba_wavelet.ini b/eval/step_configs/config_celeba_wavelet.ini
new file mode 100644
index 0000000000000000000000000000000000000000..70e9f155d303d397d42e4a48dff75ab5477a912e
--- /dev/null
+++ b/eval/step_configs/config_celeba_wavelet.ini
@@ -0,0 +1,38 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.Wavelet
+sharing_class = Wavelet
+change_based_selection = True
+alpha = 0.1
+wavelet=sym2
+level= None
+accumulation = True
diff --git a/eval/step_configs/config_femnist.ini b/eval/step_configs/config_femnist.ini
index 4814b8a3e077b06ad1597cefa0a66683f3d9d496..8063181b132ba8a862041f38aa66e8dd99b33fbb 100644
--- a/eval/step_configs/config_femnist.ini
+++ b/eval/step_configs/config_femnist.ini
@@ -14,8 +14,8 @@ optimizer_class = Adam
 lr = 0.001
 
 [TRAIN_PARAMS]
-training_package = decentralizepy.training.GradientAccumulator
-training_class = GradientAccumulator
+training_package = decentralizepy.training.Training
+training_class = Training
 rounds = 47
 full_epochs = False
 batch_size = 16
diff --git a/eval/step_configs/config_femnist_grow.ini b/eval/step_configs/config_femnist_grow.ini
index 9f18ad972e2c5b27e4ffca2242e08488290c083e..2a779c479a20331b965383b6e263373ef95ceedb 100644
--- a/eval/step_configs/config_femnist_grow.ini
+++ b/eval/step_configs/config_femnist_grow.ini
@@ -13,8 +13,8 @@ optimizer_class = Adam
 lr = 0.001
 
 [TRAIN_PARAMS]
-training_package = decentralizepy.training.GradientAccumulator
-training_class = GradientAccumulator
+training_package = decentralizepy.training.Training
+training_class = Training
 rounds = 20
 full_epochs = False
 batch_size = 64
diff --git a/eval/step_configs/config_femnist_topkacc.ini b/eval/step_configs/config_femnist_topkacc.ini
index 805004b6d56c95b3e79cd34debc0bc579e645ca5..2705fe7c8a86cdbb081e567bc14f1de525b41eeb 100644
--- a/eval/step_configs/config_femnist_topkacc.ini
+++ b/eval/step_configs/config_femnist_topkacc.ini
@@ -23,7 +23,6 @@ batch_size = 16
 shuffle = True
 loss_package = torch.nn
 loss_class = CrossEntropyLoss
-accumulation = True
 
 [COMMUNICATION]
 comm_package = decentralizepy.communication.TCP
@@ -33,4 +32,5 @@ addresses_filepath = ip_addr_6Machines.json
 [SHARING]
 sharing_package = decentralizepy.sharing.PartialModel
 sharing_class = PartialModel
-alpha = 0.1
\ No newline at end of file
+alpha = 0.1
+accumulation = True
\ No newline at end of file
diff --git a/eval/testing.py b/eval/testing.py
index abd6333e2710f05e66a72afc617fffaafaa9ba5a..b9c40814c47bb4fc3aaded502dc6c38a2b7bea15 100644
--- a/eval/testing.py
+++ b/eval/testing.py
@@ -65,3 +65,4 @@ if __name__ == "__main__":
             args.reset_optimizer,
         ],
     )
+    print("after spawn")
diff --git a/setup.cfg b/setup.cfg
index 0b85f720b6002ff5292413a3df5697926d622f0a..e174dd4425d7e9dcfb50da276c2f9bc8cfac1e5b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -44,6 +44,7 @@ install_requires =
         localconfig
         PyWavelets
         pandas
+        crudini
 include_package_data = True
 python_requires = >=3.6
 [options.packages.find]
diff --git a/src/decentralizepy/models/Model.py b/src/decentralizepy/models/Model.py
index 196560812b9e5cfa324f360aab3adec8853408b8..f3635a947eafbcfd5f00b6ec1094db5888a5fb30 100644
--- a/src/decentralizepy/models/Model.py
+++ b/src/decentralizepy/models/Model.py
@@ -14,7 +14,7 @@ class Model(nn.Module):
 
         """
         super().__init__()
-        self.accumulated_gradients = []
+        self.model_change = None
         self._param_count_ot = None
         self._param_count_total = None
         self.accumulated_changes = None
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 7854c38d2bd32832bf641c30c87a7d5d0855d2e4..463f57f7da696dfcea78624f3764283a2b52e806 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -481,3 +481,4 @@ class Node:
         )
 
         self.run()
+        logging.info("Node finished running")
diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
index 80b5a5d30503bc6ddbf0a31b3973bb5a8ca8e802..1bc7e0ef1b628e9f7df9603d4b41e8e8bc9f4817 100644
--- a/src/decentralizepy/sharing/FFT.py
+++ b/src/decentralizepy/sharing/FFT.py
@@ -8,10 +8,26 @@ import numpy as np
 import torch
 import torch.fft as fft
 
-from decentralizepy.sharing.Sharing import Sharing
+from decentralizepy.sharing.PartialModel import PartialModel
 
 
-class FFT(Sharing):
+def change_transformer_fft(x):
+    """
+    Transforms the model changes into frequency domain
+
+    Parameters
+    ----------
+    x : torch.Tensor
+        Model change in the space domain
+
+    Returns
+    -------
+    x : torch.Tensor
+        Representation of the change int the frequency domain
+    """
+    return fft.rfft(x)
+
+class FFT(PartialModel):
     """
     This class implements the fft version of model sharing
     It is based on PartialModel.py
@@ -32,8 +48,8 @@ class FFT(Sharing):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
-        pickle=True,
         change_based_selection=True,
+        save_accumulated="",
         accumulation=True,
     ):
         """
@@ -65,56 +81,19 @@ class FFT(Sharing):
             Specifies if the indices of shared parameters should be logged
         metadata_cap : float
             Share full model when self.alpha > metadata_cap
-        pickle : bool
-            use pickle to serialize the model parameters
         change_based_selection : bool
             use frequency change to select topk frequencies
+        save_accumulated : bool
+            True if accumulated weight change in the frequency domain should be written to file. In case of accumulation
+            the accumulated change is stored.
         accumulation : bool
             True if the the indices to share should be selected based on accumulated frequency change
         """
         super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+            rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared,
+            metadata_cap, accumulation, save_accumulated, change_transformer_fft
         )
-        self.alpha = alpha
-        self.dict_ordered = dict_ordered
-        self.save_shared = save_shared
-        self.metadata_cap = metadata_cap
-        self.total_meta = 0
-
-        self.pickle = pickle
-
-        logging.info("subsampling pickling=" + str(pickle))
-
-        if self.save_shared:
-            # Only save for 2 procs: Save space
-            if rank != 0 or rank != 1:
-                self.save_shared = False
-
-        if self.save_shared:
-            self.folder_path = os.path.join(
-                self.log_dir, "shared_params/{}".format(self.rank)
-            )
-            Path(self.folder_path).mkdir(parents=True, exist_ok=True)
-
         self.change_based_selection = change_based_selection
-        self.accumulation = accumulation
-
-        # getting the initial model
-        with torch.no_grad():
-            self.model.accumulated_gradients = []
-            tensors_to_cat = [
-                v.data.flatten() for _, v in self.model.state_dict().items()
-            ]
-            concated = torch.cat(tensors_to_cat, dim=0)
-            self.init_model = fft.rfft(concated)
-            self.prev = None
-            if self.accumulation:
-                if self.model.accumulated_changes is None:
-                    self.model.accumulated_changes = torch.zeros_like(self.init_model)
-                    self.prev = self.init_model
-                else:
-                    self.model.accumulated_changes += self.init_model - self.prev
-                    self.prev = self.init_model
 
     def apply_fft(self):
         """
@@ -129,20 +108,19 @@ class FFT(Sharing):
         """
 
         logging.info("Returning fft compressed model weights")
-        tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
-        concated = torch.cat(tensors_to_cat, dim=0)
-        flat_fft = fft.rfft(concated)
-        if self.change_based_selection:
-
-            assert len(self.model.accumulated_gradients) == 1
-            diff = self.model.accumulated_gradients[0]
-            _, index = torch.topk(
-                diff.abs(), round(self.alpha * len(flat_fft)), dim=0, sorted=False
-            )
-        else:
-            _, index = torch.topk(
-                flat_fft.abs(), round(self.alpha * len(flat_fft)), dim=0, sorted=False
-            )
+        with torch.no_grad():
+            tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
+            concated = torch.cat(tensors_to_cat, dim=0)
+            flat_fft = self.change_transformer(concated)
+            if self.change_based_selection:
+                diff = self.model.model_change
+                _, index = torch.topk(
+                    diff.abs(), round(self.alpha * len(diff)), dim=0, sorted=False
+                )
+            else:
+                _, index = torch.topk(
+                    flat_fft.abs(), round(self.alpha * len(flat_fft)), dim=0, sorted=False
+                )
 
         return flat_fft[index], index
 
@@ -199,7 +177,7 @@ class FFT(Sharing):
                 self.communication.encrypt(m["alpha"])
             )
 
-            return m
+        return m
 
     def deserialized_model(self, m):
         """
@@ -220,8 +198,6 @@ class FFT(Sharing):
             return super().deserialized_model(m)
 
         with torch.no_grad():
-            state_dict = self.model.state_dict()
-
             if not self.dict_ordered:
                 raise NotImplementedError
 
@@ -234,119 +210,47 @@ class FFT(Sharing):
             ret = dict()
             ret["indices"] = indices_tensor
             ret["params"] = params_tensor
-            return ret
+        return ret
 
-    def step(self):
+    def _averaging(self):
         """
-        Perform a sharing step. Implements D-PSGD.
+        Averages the received model with the local model
 
         """
-        t_start = time()
-        shapes = []
-        lens = []
-        end_model = None
-        change = 0
-        self.model.accumulated_gradients = []
         with torch.no_grad():
-            # FFT of this model
-            tensors_to_cat = []
-            for _, v in self.model.state_dict().items():
-                shapes.append(v.shape)
-                t = v.flatten()
-                lens.append(t.shape[0])
-                tensors_to_cat.append(t)
-            concated = torch.cat(tensors_to_cat, dim=0)
-            end_model = fft.rfft(concated)
-            change = end_model - self.init_model
-            if self.accumulation:
-                change += self.model.accumulated_changes
-            self.model.accumulated_gradients.append(change)
-        data = self.serialized_model()
-        t_post_serialize = time()
-        my_uid = self.mapping.get_uid(self.rank, self.machine_id)
-        all_neighbors = self.graph.neighbors(my_uid)
-        iter_neighbors = self.get_neighbors(all_neighbors)
-        data["degree"] = len(all_neighbors)
-        data["iteration"] = self.communication_round
-        for neighbor in iter_neighbors:
-            self.communication.send(neighbor, data)
-        t_post_send = time()
-        logging.info("Waiting for messages from neighbors")
-        while not self.received_from_all():
-            sender, data = self.communication.receive()
-            logging.debug("Received model from {}".format(sender))
-            degree = data["degree"]
-            iteration = data["iteration"]
-            del data["degree"]
-            del data["iteration"]
-            self.peer_deques[sender].append((degree, iteration, data))
-            logging.info(
-                "Deserialized received model from {} of iteration {}".format(
-                    sender, iteration
-                )
-            )
-        t_post_recv = time()
+            total = None
+            weight_total = 0
 
-        logging.info("Starting model averaging after receiving from all neighbors")
-        total = None
-        weight_total = 0
+            flat_fft = self.change_transformer(self.init_model)
 
-        flat_fft = end_model
-
-        for i, n in enumerate(self.peer_deques):
-            degree, iteration, data = self.peer_deques[n].popleft()
-            logging.debug(
-                "Averaging model from neighbor {} of iteration {}".format(n, iteration)
-            )
-            data = self.deserialized_model(data)
-            params = data["params"]
-            indices = data["indices"]
-            # use local data to complement
-            topkf = flat_fft.clone().detach()
-            topkf[indices] = params
-
-            weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
-            weight_total += weight
-            if total is None:
-                total = weight * topkf
-            else:
-                total += weight * topkf
+            for i, n in enumerate(self.peer_deques):
+                degree, iteration, data = self.peer_deques[n].popleft()
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(n, iteration)
+                )
+                data = self.deserialized_model(data)
+                params = data["params"]
+                indices = data["indices"]
+                # use local data to complement
+                topkf = flat_fft.clone().detach()
+                topkf[indices] = params
+
+                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                weight_total += weight
+                if total is None:
+                    total = weight * topkf
+                else:
+                    total += weight * topkf
 
-        # Metro-Hastings
-        total += (1 - weight_total) * flat_fft
-        reverse_total = fft.irfft(total)
+            # Metro-Hastings
+            total += (1 - weight_total) * flat_fft
+            reverse_total = fft.irfft(total)
 
-        start_index = 0
-        std_dict = {}
-        for i, key in enumerate(self.model.state_dict()):
-            end_index = start_index + lens[i]
-            std_dict[key] = reverse_total[start_index:end_index].reshape(shapes[i])
-            start_index = end_index
+            start_index = 0
+            std_dict = {}
+            for i, key in enumerate(self.model.state_dict()):
+                end_index = start_index + self.lens[i]
+                std_dict[key] = reverse_total[start_index:end_index].reshape(self.shapes[i])
+                start_index = end_index
 
         self.model.load_state_dict(std_dict)
-
-        logging.info("Model averaging complete")
-
-        self.communication_round += 1
-
-        with torch.no_grad():
-            self.model.accumulated_gradients = []
-            tensors_to_cat = [
-                v.data.flatten() for _, v in self.model.state_dict().items()
-            ]
-            concated = torch.cat(tensors_to_cat, dim=0)
-            self.init_model = fft.rfft(concated)
-            if self.accumulation:
-                self.model.accumulated_changes += self.init_model - self.prev
-                self.prev = self.init_model
-
-        t_end = time()
-
-        logging.info(
-            "Sharing::step | Serialize: %f; Send: %f; Recv: %f; Averaging: %f; Total: %f",
-            t_post_serialize - t_start,
-            t_post_send - t_post_serialize,
-            t_post_recv - t_post_send,
-            t_end - t_post_recv,
-            t_end - t_start,
-        )
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 204ee2334cb34686e34959ee3eb4b41dba3112c7..c961c43507281bb3daf98b870b25c0fd762bb167 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -7,6 +7,7 @@ import numpy as np
 import torch
 
 from decentralizepy.sharing.Sharing import Sharing
+from decentralizepy.utils import conditional_value, identity
 
 
 class PartialModel(Sharing):
@@ -29,6 +30,9 @@ class PartialModel(Sharing):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
+        accumulation = False,
+        save_accumulated="",
+        change_transformer = identity
     ):
         """
         Constructor
@@ -59,6 +63,13 @@ class PartialModel(Sharing):
             Specifies if the indices of shared parameters should be logged
         metadata_cap : float
             Share full model when self.alpha > metadata_cap
+        accumulation : bool
+            True if the the indices to share should be selected based on accumulated frequency change
+        save_accumulated : bool
+            True if accumulated weight change should be written to file. In case of accumulation the accumulated change
+            is stored. If a change_transformer is used then the transformed change is stored.
+        change_transformer : (x: Tensor) -> Tensor
+            A function that transforms the model change into other domains. Default: identity function
 
         """
         super().__init__(
@@ -69,6 +80,35 @@ class PartialModel(Sharing):
         self.save_shared = save_shared
         self.metadata_cap = metadata_cap
         self.total_meta = 0
+        self.accumulation = accumulation
+        self.save_accumulated = conditional_value(save_accumulated, "", False)
+        self.change_transformer = change_transformer
+
+        # getting the initial model
+        self.shapes = []
+        self.lens = []
+        with torch.no_grad():
+            tensors_to_cat = []
+            for _, v in self.model.state_dict().items():
+                self.shapes.append(v.shape)
+                t = v.flatten()
+                self.lens.append(t.shape[0])
+                tensors_to_cat.append(t)
+            self.init_model = torch.cat(tensors_to_cat, dim=0)
+            if self.accumulation:
+                self.model.accumulated_changes = torch.zeros_like(self.change_transformer(self.init_model))
+                self.prev = self.init_model
+                
+        if self.save_accumulated:
+            self.model_change_path = os.path.join(
+                self.log_dir, "model_change/{}".format(self.rank)
+            )
+            Path(self.model_change_path).mkdir(parents=True, exist_ok=True)
+
+            self.model_val_path = os.path.join(
+                self.log_dir, "model_val/{}".format(self.rank)
+            )
+            Path(self.model_val_path).mkdir(parents=True, exist_ok=True)
 
         # Only save for 2 procs: Save space
         if self.save_shared and not (rank == 0 or rank == 1):
@@ -91,16 +131,9 @@ class PartialModel(Sharing):
             (a,b). a: The magnitudes of the topK gradients, b: Their indices.
 
         """
-        logging.info("Summing up gradients")
-        assert len(self.model.accumulated_gradients) > 0
-        gradient_sum = self.model.accumulated_gradients[0]
-        for i in range(1, len(self.model.accumulated_gradients)):
-            for key in self.model.accumulated_gradients[i]:
-                gradient_sum[key] += self.model.accumulated_gradients[i][key]
 
         logging.info("Returning topk gradients")
-        tensors_to_cat = [v.data.flatten() for _, v in gradient_sum.items()]
-        G_topk = torch.abs(torch.cat(tensors_to_cat, dim=0))
+        G_topk = torch.abs(self.model.model_change)
         std, mean = torch.std_mean(G_topk, unbiased=False)
         self.std = std.item()
         self.mean = mean.item()
@@ -123,8 +156,8 @@ class PartialModel(Sharing):
 
         with torch.no_grad():
             _, G_topk = self.extract_top_gradients()
-
-            self.model.rewind_accumulation(G_topk)
+            if self.accumulation:
+                self.model.rewind_accumulation(G_topk)
             if self.save_shared:
                 shared_params = dict()
                 shared_params["order"] = list(self.model.state_dict().keys())
@@ -219,3 +252,77 @@ class PartialModel(Sharing):
                 start_index = end_index
 
             return state_dict
+
+    def _pre_step(self):
+        """
+        Called at the beginning of step.
+
+        """
+        logging.info("PartialModel _pre_step")
+        with torch.no_grad():
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            pre_share_model = torch.cat(tensors_to_cat, dim=0)
+            change = self.change_transformer(pre_share_model - self.init_model)
+            if self.accumulation:
+                change += self.model.accumulated_changes
+            # stores change of the model due to training, change due to averaging is not accounted
+            self.model.model_change = change
+
+    def _post_step(self):
+        """
+        Called at the end of step.
+
+        """
+        logging.info("PartialModel _post_step")
+        with torch.no_grad():
+            self.model.model_change = None
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            post_share_model = torch.cat(tensors_to_cat, dim=0)
+            self.init_model = post_share_model
+            if self.accumulation:
+                self.model.accumulated_changes += self.change_transformer(self.init_model - self.prev)
+                self.prev = self.init_model
+
+        if self.save_accumulated:
+            self.save_change()
+
+    def save_vector(self, v, s):
+        """
+        Saves the given vector to the file.
+
+        Parameters
+        ----------
+        v : torch.tensor
+            The torch tensor to write to file
+        s : str
+            Path to folder to write to
+
+        """
+        output_dict = dict()
+        output_dict["order"] = list(self.model.state_dict().keys())
+        shapes = dict()
+        for k, v1 in self.model.state_dict().items():
+            shapes[k] = list(v1.shape)
+        output_dict["shapes"] = shapes
+
+        output_dict["tensor"] = v.tolist()
+
+        with open(
+            os.path.join(
+                s,
+                "{}.json".format(self.communication_round + 1),
+            ),
+            "w",
+        ) as of:
+            json.dump(output_dict, of)
+
+    def save_change(self):
+        """
+        Saves the change and the gradient values for every iteration
+
+        """
+        self.save_vector(self.model.model_change, self.model_change_path)
\ No newline at end of file
diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py
index 85fc07bea01cef5802d1178fc3dc1ac0ef5281e0..c998f404b55bec287fdf6382018fe98a21584888 100644
--- a/src/decentralizepy/sharing/Sharing.py
+++ b/src/decentralizepy/sharing/Sharing.py
@@ -31,7 +31,7 @@ class Sharing:
         model : decentralizepy.models.Model
             Model to train
         dataset : decentralizepy.datasets.Dataset
-            Dataset for sharing data. Not implemented yer! TODO
+            Dataset for sharing data. Not implemented yet! TODO
         log_dir : str
             Location to write shared_params (only writing for 2 procs per machine)
 
@@ -122,11 +122,53 @@ class Sharing:
             state_dict[key] = torch.from_numpy(value)
         return state_dict
 
+    def _pre_step(self):
+        """
+        Called at the beginning of step.
+
+        """
+        pass
+
+    def _post_step(self):
+        """
+        Called at the end of step.
+
+        """
+        pass
+
+    def _averaging(self):
+        """
+        Averages the received model with the local model
+
+        """
+        with torch.no_grad():
+            total = dict()
+            weight_total = 0
+            for i, n in enumerate(self.peer_deques):
+                degree, iteration, data = self.peer_deques[n].popleft()
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(n, iteration)
+                )
+                data = self.deserialized_model(data)
+                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                weight_total += weight
+                for key, value in data.items():
+                    if key in total:
+                        total[key] += value * weight
+                    else:
+                        total[key] = value * weight
+
+            for key, value in self.model.state_dict().items():
+                total[key] += (1 - weight_total) * value  # Metro-Hastings
+
+        self.model.load_state_dict(total)
+
     def step(self):
         """
         Perform a sharing step. Implements D-PSGD.
 
         """
+        self._pre_step()
         data = self.serialized_model()
         my_uid = self.mapping.get_uid(self.rank, self.machine_id)
         all_neighbors = self.graph.neighbors(my_uid)
@@ -152,27 +194,8 @@ class Sharing:
             )
 
         logging.info("Starting model averaging after receiving from all neighbors")
-        total = dict()
-        weight_total = 0
-        for i, n in enumerate(self.peer_deques):
-            degree, iteration, data = self.peer_deques[n].popleft()
-            logging.debug(
-                "Averaging model from neighbor {} of iteration {}".format(n, iteration)
-            )
-            data = self.deserialized_model(data)
-            weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
-            weight_total += weight
-            for key, value in data.items():
-                if key in total:
-                    total[key] += value * weight
-                else:
-                    total[key] = value * weight
-
-        for key, value in self.model.state_dict().items():
-            total[key] += (1 - weight_total) * value  # Metro-Hastings
-
-        self.model.load_state_dict(total)
-
+        self._averaging()
         logging.info("Model averaging complete")
 
         self.communication_round += 1
+        self._post_step()
diff --git a/src/decentralizepy/sharing/TopKPlusRandom.py b/src/decentralizepy/sharing/TopKPlusRandom.py
index 1a31e433318b6c813dea2b25f30cdadbfd465840..728d5bfa48d71037a6d165d6396343f46dfa4e3e 100644
--- a/src/decentralizepy/sharing/TopKPlusRandom.py
+++ b/src/decentralizepy/sharing/TopKPlusRandom.py
@@ -84,16 +84,8 @@ class TopKPlusRandom(PartialModel):
             (a,b). a: The magnitudes of the topK gradients, b: Their indices.
 
         """
-        logging.info("Summing up gradients")
-        assert len(self.model.accumulated_gradients) > 0
-        gradient_sum = self.model.accumulated_gradients[0]
-        for i in range(1, len(self.model.accumulated_gradients)):
-            for key in self.model.accumulated_gradients[i]:
-                gradient_sum[key] += self.model.accumulated_gradients[i][key]
-
         logging.info("Returning topk gradients")
-        tensors_to_cat = [v.data.flatten() for _, v in gradient_sum.items()]
-        G = torch.abs(torch.cat(tensors_to_cat, dim=0))
+        G = torch.abs(self.model.model_change)
         std, mean = torch.std_mean(G, unbiased=False)
         self.std = std.item()
         self.mean = mean.item()
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index 2d651b001c8bf18d8a06b1220a10d8398b8a9522..1b73b298409fc25cfd54f18fa75b94d19dd9d4b2 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -8,10 +8,31 @@ import numpy as np
 import pywt
 import torch
 
-from decentralizepy.sharing.Sharing import Sharing
+from decentralizepy.sharing.PartialModel import PartialModel
 
+def change_transformer_wavelet(x, wavelet, level):
+        """
+        Transforms the model changes into wavelet frequency domain
+
+        Parameters
+        ----------
+        x : torch.Tensor
+            Model change in the space domain
+        wavelet : str
+            name of the wavelet to be used in gradient compression
+        level: int
+            name of the wavelet to be used in gradient compression
 
-class Wavelet(Sharing):
+        Returns
+        -------
+        x : torch.Tensor
+            Representation of the change int the wavelet domain
+        """
+        coeff = pywt.wavedec(x, wavelet, level=level)
+        data, coeff_slices = pywt.coeffs_to_array(coeff)
+        return torch.from_numpy(data.ravel())
+
+class Wavelet(PartialModel):
     """
     This class implements the wavelet version of model sharing
     It is based on PartialModel.py
@@ -32,10 +53,10 @@ class Wavelet(Sharing):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
-        pickle=True,
         wavelet="haar",
         level=4,
         change_based_selection=True,
+        save_accumulated="",
         accumulation=False,
     ):
         """
@@ -67,65 +88,33 @@ class Wavelet(Sharing):
             Specifies if the indices of shared parameters should be logged
         metadata_cap : float
             Share full model when self.alpha > metadata_cap
-        pickle : bool
-            use pickle to serialize the model parameters
         wavelet: str
             name of the wavelet to be used in gradient compression
         level: int
             name of the wavelet to be used in gradient compression
         change_based_selection : bool
             use frequency change to select topk frequencies
+        save_accumulated : bool
+            True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation
+            the accumulated change is stored.
         accumulation : bool
             True if the the indices to share should be selected based on accumulated frequency change
         """
-        super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
-        )
-        self.alpha = alpha
-        self.dict_ordered = dict_ordered
-        self.save_shared = save_shared
-        self.metadata_cap = metadata_cap
-        self.total_meta = 0
-
-        self.pickle = pickle
         self.wavelet = wavelet
         self.level = level
-        self.accumulation = accumulation
-
-        logging.info("subsampling pickling=" + str(pickle))
 
-        if self.save_shared:
-            # Only save for 2 procs: Save space
-            if rank != 0 or rank != 1:
-                self.save_shared = False
-
-        if self.save_shared:
-            self.folder_path = os.path.join(
-                self.log_dir, "shared_params/{}".format(self.rank)
-            )
-            Path(self.folder_path).mkdir(parents=True, exist_ok=True)
+        super().__init__(
+            rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared,
+            metadata_cap, accumulation, save_accumulated, lambda x : change_transformer_wavelet(x, wavelet, level)
+        )
 
         self.change_based_selection = change_based_selection
-        self.accumulation = accumulation
 
-        # getting the initial model
-        with torch.no_grad():
-            self.model.accumulated_gradients = []
-            tensors_to_cat = [
-                v.data.flatten() for _, v in self.model.state_dict().items()
-            ]
-            concated = torch.cat(tensors_to_cat, dim=0)
-            coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
-            data, coeff_slices = pywt.coeffs_to_array(coeff)
-            self.init_model = torch.from_numpy(data.ravel())
-            self.prev = None
-            if self.accumulation:
-                if self.model.accumulated_changes is None:
-                    self.model.accumulated_changes = torch.zeros_like(self.init_model)
-                    self.prev = self.init_model
-                else:
-                    self.model.accumulated_changes += self.init_model - self.prev
-                    self.prev = self.init_model
+        # Do a dummy transform to get the shape and coefficents slices
+        coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level)
+        data, coeff_slices = pywt.coeffs_to_array(coeff)
+        self.wt_shape = data.shape
+        self.coeff_slices = coeff_slices
 
     def apply_wavelet(self):
         """
@@ -142,31 +131,27 @@ class Wavelet(Sharing):
         logging.info("Returning dwt compressed model weights")
         tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
         concated = torch.cat(tensors_to_cat, dim=0)
-
-        coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
-        data, coeff_slices = pywt.coeffs_to_array(
-            coeff
-        )  # coeff_slices will be reproduced on the receiver
-        data = data.ravel()
-
+        data = self.change_transformer(concated)
+        logging.info("produced wavelet representation of current model")
         if self.change_based_selection:
-            assert len(self.model.accumulated_gradients) == 1
-            diff = self.model.accumulated_gradients[0]
+            logging.info("changed based selection")
+            diff = self.model.model_change
             _, index = torch.topk(
                 diff.abs(),
-                round(self.alpha * len(data)),
+                round(self.alpha * len(diff)),
                 dim=0,
                 sorted=False,
             )
+            logging.info("finished change based selection")
         else:
             _, index = torch.topk(
-                torch.from_numpy(data).abs(),
+                data.abs(),
                 round(self.alpha * len(data)),
                 dim=0,
                 sorted=False,
             )
 
-        return torch.from_numpy(data[index]), index
+        return data[index], index
 
     def serialized_model(self):
         """
@@ -178,6 +163,7 @@ class Wavelet(Sharing):
             Model converted to json dict
 
         """
+        logging.info("serializing wavelet model")
         if self.alpha > self.metadata_cap:  # Share fully
             return super().serialized_model()
 
@@ -185,7 +171,7 @@ class Wavelet(Sharing):
             topk, indices = self.apply_wavelet()
 
             self.model.rewind_accumulation(indices)
-
+            logging.info("finished rewind")
             if self.save_shared:
                 shared_params = dict()
                 shared_params["order"] = list(self.model.state_dict().keys())
@@ -227,12 +213,12 @@ class Wavelet(Sharing):
 
     def deserialized_model(self, m):
         """
-        Convert received json dict to state_dict.
+        Convert received dict to state_dict.
 
         Parameters
         ----------
         m : dict
-            json dict received
+            received dict
 
         Returns
         -------
@@ -240,26 +226,14 @@ class Wavelet(Sharing):
             state_dict of received
 
         """
+        logging.info("deserializing wavelet model")
         if self.alpha > self.metadata_cap:  # Share fully
             return super().deserialized_model(m)
 
         with torch.no_grad():
-            state_dict = self.model.state_dict()
-
             if not self.dict_ordered:
                 raise NotImplementedError
 
-            shapes = []
-            lens = []
-            tensors_to_cat = []
-            for _, v in state_dict.items():
-                shapes.append(v.shape)
-                t = v.flatten()
-                lens.append(t.shape[0])
-                tensors_to_cat.append(t)
-
-            T = torch.cat(tensors_to_cat, dim=0)
-
             indices = m["indices"]
             alpha = m["alpha"]
             params = m["params"]
@@ -271,128 +245,51 @@ class Wavelet(Sharing):
             ret["params"] = params_tensor
             return ret
 
-    def step(self):
+    def _averaging(self):
         """
-        Perform a sharing step. Implements D-PSGD.
+        Averages the received model with the local model
 
         """
-        t_start = time()
-        shapes = []
-        lens = []
-        end_model = None
-        change = 0
-        self.model.accumulated_gradients = []
         with torch.no_grad():
-            # FFT of this model
-            tensors_to_cat = []
-            for _, v in self.model.state_dict().items():
-                shapes.append(v.shape)
-                t = v.flatten()
-                lens.append(t.shape[0])
-                tensors_to_cat.append(t)
-            concated = torch.cat(tensors_to_cat, dim=0)
-            coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
-            data, coeff_slices = pywt.coeffs_to_array(coeff)
-            shape = data.shape
-            wt_params = data.ravel()
-            end_model = torch.from_numpy(wt_params)
-            change = end_model - self.init_model
-            if self.accumulation:
-                change += self.model.accumulated_changes
-            self.model.accumulated_gradients.append(change)
-        data = self.serialized_model()
-        t_post_serialize = time()
-        my_uid = self.mapping.get_uid(self.rank, self.machine_id)
-        all_neighbors = self.graph.neighbors(my_uid)
-        iter_neighbors = self.get_neighbors(all_neighbors)
-        data["degree"] = len(all_neighbors)
-        data["iteration"] = self.communication_round
-        for neighbor in iter_neighbors:
-            self.communication.send(neighbor, data)
-        t_post_send = time()
-        logging.info("Waiting for messages from neighbors")
-        while not self.received_from_all():
-            sender, data = self.communication.receive()
-            logging.debug("Received model from {}".format(sender))
-            degree = data["degree"]
-            iteration = data["iteration"]
-            del data["degree"]
-            del data["iteration"]
-            self.peer_deques[sender].append((degree, iteration, data))
-            logging.info(
-                "Deserialized received model from {} of iteration {}".format(
-                    sender, iteration
+            total = None
+            weight_total = 0
+            wt_params = self.change_transformer(self.init_model)
+            for i, n in enumerate(self.peer_deques):
+                degree, iteration, data = self.peer_deques[n].popleft()
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(n, iteration)
                 )
-            )
-        t_post_recv = time()
+                data = self.deserialized_model(data)
+                params = data["params"]
+                indices = data["indices"]
+                # use local data to complement
+                topkwf = wt_params.clone().detach()
+                topkwf[indices] = params
+                topkwf = topkwf.reshape(self.wt_shape)
+
+                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                weight_total += weight
+                if total is None:
+                    total = weight * topkwf
+                else:
+                    total += weight * topkwf
 
-        logging.info("Starting model averaging after receiving from all neighbors")
-        total = None
-        weight_total = 0
+            # Metro-Hastings
+            total += (1 - weight_total) * wt_params
 
-        for i, n in enumerate(self.peer_deques):
-            degree, iteration, data = self.peer_deques[n].popleft()
-            logging.debug(
-                "Averaging model from neighbor {} of iteration {}".format(n, iteration)
+            avg_wf_params = pywt.array_to_coeffs(
+                total.numpy(), self.coeff_slices, output_format="wavedec"
+            )
+            reverse_total = torch.from_numpy(
+                pywt.waverec(avg_wf_params, wavelet=self.wavelet)
             )
-            data = self.deserialized_model(data)
-            params = data["params"]
-            indices = data["indices"]
-            # use local data to complement
-            topkwf = wt_params.copy()  # .clone().detach()
-            topkwf[indices] = params
-            topkwf = torch.from_numpy(topkwf.reshape(shape))
-
-            weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
-            weight_total += weight
-            if total is None:
-                total = weight * topkwf
-            else:
-                total += weight * topkwf
-
-        # Metro-Hastings
-        total += (1 - weight_total) * wt_params
-
-        avg_wf_params = pywt.array_to_coeffs(
-            total, coeff_slices, output_format="wavedec"
-        )
-        reverse_total = torch.from_numpy(
-            pywt.waverec(avg_wf_params, wavelet=self.wavelet)
-        )
 
-        start_index = 0
-        std_dict = {}
-        for i, key in enumerate(self.model.state_dict()):
-            end_index = start_index + lens[i]
-            std_dict[key] = reverse_total[start_index:end_index].reshape(shapes[i])
-            start_index = end_index
+            start_index = 0
+            std_dict = {}
+            for i, key in enumerate(self.model.state_dict()):
+                end_index = start_index + self.lens[i]
+                std_dict[key] = reverse_total[start_index:end_index].reshape(self.shapes[i])
+                start_index = end_index
 
         self.model.load_state_dict(std_dict)
 
-        logging.info("Model averaging complete")
-
-        self.communication_round += 1
-
-        with torch.no_grad():
-            self.model.accumulated_gradients = []
-            tensors_to_cat = [
-                v.data.flatten() for _, v in self.model.state_dict().items()
-            ]
-            concated = torch.cat(tensors_to_cat, dim=0)
-            coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
-            data, coeff_slices = pywt.coeffs_to_array(coeff)
-            self.init_model = torch.from_numpy(data.ravel())
-            if self.accumulation:
-                self.model.accumulated_changes += self.init_model - self.prev
-                self.prev = self.init_model
-
-        t_end = time()
-
-        logging.info(
-            "Sharing::step | Serialize: %f; Send: %f; Recv: %f; Averaging: %f; Total: %f",
-            t_post_serialize - t_start,
-            t_post_send - t_post_serialize,
-            t_post_recv - t_post_send,
-            t_end - t_post_recv,
-            t_end - t_start,
-        )
diff --git a/src/decentralizepy/training/GradientAccumulator.py b/src/decentralizepy/training/GradientAccumulator.py
deleted file mode 100644
index fcff8e6ec56e673c10098d32d973370251740913..0000000000000000000000000000000000000000
--- a/src/decentralizepy/training/GradientAccumulator.py
+++ /dev/null
@@ -1,112 +0,0 @@
-import logging
-
-from decentralizepy.training.Training import Training
-
-
-class GradientAccumulator(Training):
-    """
-    This class implements the training module which also accumulates gradients of steps in a list.
-
-    """
-
-    def __init__(
-        self,
-        rank,
-        machine_id,
-        mapping,
-        model,
-        optimizer,
-        loss,
-        log_dir,
-        rounds="",
-        full_epochs="",
-        batch_size="",
-        shuffle="",
-    ):
-        """
-        Constructor
-
-        Parameters
-        ----------
-        rank : int
-            Rank of process local to the machine
-        machine_id : int
-            Machine ID on which the process in running
-        mapping : decentralizepy.mappings
-            The object containing the mapping rank <--> uid
-        model : torch.nn.Module
-            Neural Network for training
-        optimizer : torch.optim
-            Optimizer to learn parameters
-        loss : function
-            Loss function
-        log_dir : str
-            Directory to log the model change.
-        rounds : int, optional
-            Number of steps/epochs per training call
-        full_epochs: bool, optional
-            True if 1 round = 1 epoch. False if 1 round = 1 minibatch
-        batch_size : int, optional
-            Number of items to learn over, in one batch
-        shuffle : bool
-            True if the dataset should be shuffled before training.
-
-        """
-        super().__init__(
-            rank,
-            machine_id,
-            mapping,
-            model,
-            optimizer,
-            loss,
-            log_dir,
-            rounds,
-            full_epochs,
-            batch_size,
-            shuffle,
-        )
-
-    def trainstep(self, data, target):
-        """
-        One training step on a minibatch.
-
-        Parameters
-        ----------
-        data : any
-            Data item
-        target : any
-            Label
-
-        Returns
-        -------
-        int
-            Loss Value for the step
-
-        """
-        self.model.zero_grad()
-        output = self.model(data)
-        loss_val = self.loss(output, target)
-        loss_val.backward()
-        logging.debug("Accumulating Gradients")
-        self.model.accumulated_gradients.append(
-            {
-                k: v.grad.clone().detach()
-                for k, v in zip(self.model.state_dict(), self.model.parameters())
-            }
-        )
-        self.optimizer.step()
-        return loss_val.item()
-
-    def train(self, dataset):
-        """
-        One training iteration with accumulation of gradients in model.accumulated_gradients.
-        Goes through the entire dataset.
-
-        Parameters
-        ----------
-        dataset : decentralizepy.datasets.Dataset
-            The training dataset. Should implement get_trainset(batch_size, shuffle)
-
-        """
-        self.model.accumulated_gradients = []
-        super().train(dataset)
diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py
index 996e4bc4316f3590da1dcfcbc287df81d52fa9b7..f91946844e0519d4e8a3a70b6b4fd9f744348c16 100644
--- a/src/decentralizepy/utils.py
+++ b/src/decentralizepy/utils.py
@@ -108,3 +108,17 @@ def write_args(args, path):
     }
     with open(os.path.join(path, "args.json"), "w") as of:
         json.dump(data, of)
+
+def identity(obj):
+    """
+    Identity function
+    Parameters
+    ----------
+    obj
+        Some object
+    Returns
+    -------
+     obj
+        The same object
+    """
+    return obj
\ No newline at end of file