diff --git a/eval/96_regular.edges b/eval/96_regular.edges
new file mode 100644
index 0000000000000000000000000000000000000000..0db09a2763b09045c8218ce25e640f052b6089e8
--- /dev/null
+++ b/eval/96_regular.edges
@@ -0,0 +1,381 @@
+96
+0 24
+0 1
+0 26
+0 95
+1 2
+1 0
+1 82
+1 83
+2 33
+2 90
+2 3
+2 1
+3 2
+3 4
+3 14
+3 79
+4 3
+4 12
+4 5
+4 86
+5 64
+5 42
+5 4
+5 6
+6 9
+6 5
+6 62
+6 7
+7 24
+7 8
+7 45
+7 6
+8 81
+8 17
+8 9
+8 7
+9 8
+9 10
+9 53
+9 6
+10 9
+10 11
+10 29
+10 31
+11 80
+11 10
+11 36
+11 12
+12 11
+12 4
+12 13
+12 70
+13 12
+13 53
+13 30
+13 14
+14 3
+14 15
+14 13
+14 47
+15 16
+15 26
+15 14
+16 41
+16 17
+16 15
+17 8
+17 16
+17 18
+17 83
+18 17
+18 19
+18 95
+18 63
+19 82
+19 18
+19 20
+19 22
+20 19
+20 59
+20 21
+20 22
+21 72
+21 58
+21 20
+21 22
+22 19
+22 20
+22 21
+22 23
+23 24
+23 65
+23 85
+23 22
+24 0
+24 25
+24 23
+24 7
+25 32
+25 24
+25 26
+25 38
+26 0
+26 25
+26 27
+26 15
+27 32
+27 26
+27 28
+27 63
+28 27
+28 92
+28 29
+28 39
+29 10
+29 52
+29 28
+29 30
+30 66
+30 29
+30 13
+30 31
+31 32
+31 10
+31 36
+31 30
+32 25
+32 27
+32 31
+32 33
+33 32
+33 2
+33 84
+33 34
+34 33
+34 50
+34 35
+34 93
+35 57
+35 34
+35 43
+35 36
+36 35
+36 11
+36 37
+36 31
+37 88
+37 36
+37 38
+37 79
+38 25
+38 37
+38 39
+38 49
+39 40
+39 28
+39 77
+39 38
+40 41
+40 91
+40 39
+40 87
+41 16
+41 40
+41 42
+41 51
+42 41
+42 43
+42 5
+43 42
+43 35
+43 44
+44 72
+44 43
+44 75
+44 45
+45 67
+45 44
+45 46
+45 7
+46 76
+46 45
+46 54
+46 47
+47 48
+47 65
+47 14
+47 46
+48 56
+48 49
+48 61
+48 47
+49 48
+49 50
+49 38
+49 71
+50 49
+50 34
+50 51
+50 93
+51 41
+51 50
+51 52
+51 95
+52 51
+52 74
+52 53
+52 29
+53 9
+53 52
+53 13
+53 54
+54 75
+54 53
+54 46
+54 55
+55 56
+55 69
+55 85
+55 54
+56 48
+56 57
+56 69
+56 55
+57 56
+57 89
+57 58
+57 35
+58 57
+58 59
+58 21
+58 86
+59 73
+59 58
+59 20
+59 60
+60 62
+60 59
+60 61
+60 78
+61 48
+61 62
+61 60
+61 94
+62 60
+62 61
+62 6
+62 63
+63 64
+63 18
+63 27
+63 62
+64 65
+64 84
+64 5
+64 63
+65 64
+65 66
+65 23
+65 47
+66 65
+66 89
+66 67
+66 30
+67 80
+67 66
+67 68
+67 45
+68 67
+68 92
+68 69
+68 94
+69 56
+69 68
+69 70
+69 55
+70 90
+70 12
+70 69
+70 71
+71 72
+71 49
+71 70
+71 87
+72 73
+72 44
+72 21
+72 71
+73 72
+73 91
+73 59
+73 74
+74 73
+74 75
+74 52
+74 76
+75 74
+75 44
+75 54
+75 76
+76 74
+76 75
+76 77
+76 46
+77 81
+77 76
+77 78
+77 39
+78 88
+78 60
+78 77
+78 79
+79 80
+79 3
+79 37
+79 78
+80 81
+80 67
+80 11
+80 79
+81 8
+81 82
+81 80
+81 77
+82 81
+82 1
+82 83
+82 19
+83 1
+83 82
+83 84
+83 17
+84 64
+84 33
+84 83
+84 85
+85 84
+85 55
+85 86
+85 23
+86 58
+86 4
+86 85
+86 87
+87 40
+87 88
+87 86
+87 71
+88 89
+88 37
+88 78
+88 87
+89 88
+89 57
+89 66
+89 90
+90 89
+90 2
+90 91
+90 70
+91 40
+91 73
+91 90
+91 92
+92 93
+92 91
+92 68
+92 28
+93 50
+93 34
+93 94
+93 92
+94 93
+94 68
+94 61
+94 95
+95 0
+95 18
+95 51
+95 94
diff --git a/eval/plot.py b/eval/plot.py
index d3c3a393fcd495ba2b5088b3df9fb8e5fb02253f..f3f82c77ad55872c562a06ffb81b04f4b7269c65 100644
--- a/eval/plot.py
+++ b/eval/plot.py
@@ -3,6 +3,7 @@ import os
 import sys
 
 import numpy as np
+import pandas as pd
 from matplotlib import pyplot as plt
 
 
@@ -61,14 +62,50 @@ def plot_results(path):
         plt.figure(1)
         means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
         plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
+        df = pd.DataFrame(
+            {
+                "mean": list(means.values()),
+                "std": list(stdevs.values()),
+                "nr_nodes": [len(results)] * len(means),
+            },
+            list(means.keys()),
+            columns=["mean", "std", "nr_nodes"],
+        )
+        df.to_csv(
+            os.path.join(path, "train_loss_" + folder + ".csv"), index_label="rounds"
+        )
         # Plot Testing loss
         plt.figure(2)
         means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
         plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
+        df = pd.DataFrame(
+            {
+                "mean": list(means.values()),
+                "std": list(stdevs.values()),
+                "nr_nodes": [len(results)] * len(means),
+            },
+            list(means.keys()),
+            columns=["mean", "std", "nr_nodes"],
+        )
+        df.to_csv(
+            os.path.join(path, "test_loss_" + folder + ".csv"), index_label="rounds"
+        )
         # Plot Testing Accuracy
         plt.figure(3)
         means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
         plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
+        df = pd.DataFrame(
+            {
+                "mean": list(means.values()),
+                "std": list(stdevs.values()),
+                "nr_nodes": [len(results)] * len(means),
+            },
+            list(means.keys()),
+            columns=["mean", "std", "nr_nodes"],
+        )
+        df.to_csv(
+            os.path.join(path, "test_acc_" + folder + ".csv"), index_label="rounds"
+        )
         plt.figure(6)
         means, stdevs, mins, maxs = get_stats([x["grad_std"] for x in results])
         plot(
diff --git a/eval/run.sh b/eval/run.sh
index 9869a176adb4bb06d22f300da2d70b767f1114cc..528bdc97ab23b77be71c5d8d1413a740eacd946f 100755
--- a/eval/run.sh
+++ b/eval/run.sh
@@ -4,29 +4,20 @@ decpy_path=~/Gitlab/decentralizepy/eval
 cd $decpy_path
 
 env_python=~/miniconda3/envs/decpy/bin/python3
-graph=96_nodes_random1.edges
+graph=96_regular.edges
 original_config=epoch_configs/config_celeba.ini
 config_file=/tmp/config.ini
 procs_per_machine=16
 machines=6
-iterations=76
-test_after=2
+iterations=200
+test_after=10
 eval_file=testing.py
 log_level=INFO
 
 m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
+log_dir=$(date '+%Y-%m-%dT%H:%M')/machine$m
+mkdir -p $log_dir
 
 cp $original_config $config_file
-echo "alpha = 0.75" >> $config_file
-$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
-
-cp $original_config $config_file
-echo "alpha = 0.50" >> $config_file
-$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
-
-cp $original_config $config_file
-echo "alpha = 0.10" >> $config_file
-$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
-
-config_file=epoch_configs/config_celeba_100.ini
-$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $original_config -ll $log_level
+# echo "alpha = 0.10" >> $config_file
+$env_python $eval_file -ro 0 -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
\ No newline at end of file
diff --git a/eval/run_all.sh b/eval/run_all.sh
new file mode 100755
index 0000000000000000000000000000000000000000..d5d0c043479f20a605c823b077f9f384c90f5808
--- /dev/null
+++ b/eval/run_all.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+nfs_home=$1
+python_bin=$2
+decpy_path=$nfs_home/decentralizepy/eval
+cd $decpy_path
+
+env_python=$python_bin/python3
+graph=96_regular.edges #4_node_fullyConnected.edges
+config_file=~/tmp/config.ini
+procs_per_machine=16
+machines=6
+iterations=5
+train_evaluate_after=5
+test_after=21 # we do not test
+eval_file=testing.py
+log_level=INFO
+
+ip_machines=$nfs_home/configs/ip_addr_6Machines.json
+
+m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
+export PYTHONFAULTHANDLER=1
+tests=("step_configs/config_celeba_partialmodel.ini" "step_configs/config_celeba_sharing.ini" "step_configs/config_celeba_fft.ini" "step_configs/config_celeba_wavelet.ini"
+"step_configs/config_celeba_grow.ini" "step_configs/config_celeba_manualadapt.ini" "step_configs/config_celeba_randomalpha.ini"
+"step_configs/config_celeba_randomalphainc.ini" "step_configs/config_celeba_roundrobin.ini" "step_configs/config_celeba_subsampling.ini"
+"step_configs/config_celeba_topkrandom.ini" "step_configs/config_celeba_topkacc.ini" "step_configs/config_celeba_topkparam.ini")
+
+for i in "${tests[@]}"
+do
+  echo $i
+  IFS='_' read -ra NAMES <<< $i
+  IFS='.' read -ra NAME <<< ${NAMES[-1]}
+  log_dir=$nfs_home/logs/testing/${NAME[0]}$(date '+%Y-%m-%dT%H:%M')/machine$m
+  mkdir -p $log_dir
+  cp $i $config_file
+  $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
+  $env_python $eval_file -ro 0 -tea $train_evaluate_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+  echo $i is done
+  sleep 3
+  echo end of sleep
+done
diff --git a/eval/step_configs/config_celeba_fft.ini b/eval/step_configs/config_celeba_fft.ini
new file mode 100644
index 0000000000000000000000000000000000000000..e8d6a70804ec7acd2c3b38248aa4e756b9a23cd2
--- /dev/null
+++ b/eval/step_configs/config_celeba_fft.ini
@@ -0,0 +1,36 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.FFT
+sharing_class = FFT
+alpha = 0.1
+change_based_selection = True
+accumulation = True
diff --git a/eval/step_configs/config_celeba_grow.ini b/eval/step_configs/config_celeba_grow.ini
index be0812e8c3bc1b91d14d803c3efe89e6d8e9b3e8..37e74ae0eb91a8847189e391da8a88b806a0b85b 100644
--- a/eval/step_configs/config_celeba_grow.ini
+++ b/eval/step_configs/config_celeba_grow.ini
@@ -2,9 +2,9 @@
 dataset_package = decentralizepy.datasets.Celeba
 dataset_class = Celeba
 model_class = CNN
-images_dir = /home/risharma/leaf/data/celeba/data/raw/img_align_celeba
-train_dir = /home/risharma/leaf/data/celeba/per_user_data/train
-test_dir = /home/risharma/leaf/data/celeba/data/test
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
 ; python list of fractions below
 sizes = 
 
@@ -14,11 +14,11 @@ optimizer_class = Adam
 lr = 0.001
 
 [TRAIN_PARAMS]
-training_package = decentralizepy.training.GradientAccumulator
-training_class = GradientAccumulator
-rounds = 20
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
 full_epochs = False
-batch_size = 64
+batch_size = 16
 shuffle = True
 loss_package = torch.nn
 loss_class = CrossEntropyLoss
diff --git a/eval/step_configs/config_celeba_manualadapt.ini b/eval/step_configs/config_celeba_manualadapt.ini
new file mode 100644
index 0000000000000000000000000000000000000000..1c117e2a40cabf297d305b2dd0cd20d88ca46299
--- /dev/null
+++ b/eval/step_configs/config_celeba_manualadapt.ini
@@ -0,0 +1,35 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.ManualAdapt
+sharing_class = ManualAdapt
+change_alpha = [0.1, 0.5]
+change_rounds = [10,30]
diff --git a/eval/step_configs/config_celeba.ini b/eval/step_configs/config_celeba_partialmodel.ini
similarity index 63%
rename from eval/step_configs/config_celeba.ini
rename to eval/step_configs/config_celeba_partialmodel.ini
index 5cadf017749dbb602188323cbca82a1dfd40dcac..6c9a4b5b16c0086b8d950a90ea24451dabe7e527 100644
--- a/eval/step_configs/config_celeba.ini
+++ b/eval/step_configs/config_celeba_partialmodel.ini
@@ -2,9 +2,9 @@
 dataset_package = decentralizepy.datasets.Celeba
 dataset_class = Celeba
 model_class = CNN
-images_dir = /home/risharma/leaf/data/celeba/data/raw/img_align_celeba
-train_dir = /home/risharma/leaf/data/celeba/per_user_data/train
-test_dir = /home/risharma/leaf/data/celeba/data/test
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
 ; python list of fractions below
 sizes = 
 
@@ -14,11 +14,11 @@ optimizer_class = Adam
 lr = 0.001
 
 [TRAIN_PARAMS]
-training_package = decentralizepy.training.GradientAccumulator
-training_class = GradientAccumulator
-rounds = 20
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
 full_epochs = False
-batch_size = 64
+batch_size = 16
 shuffle = True
 loss_package = torch.nn
 loss_class = CrossEntropyLoss
diff --git a/eval/step_configs/config_celeba_randomalpha.ini b/eval/step_configs/config_celeba_randomalpha.ini
new file mode 100644
index 0000000000000000000000000000000000000000..1c4b9893b67d0236cc6642a88f11fdf49d318ad5
--- /dev/null
+++ b/eval/step_configs/config_celeba_randomalpha.ini
@@ -0,0 +1,33 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.RandomAlpha
+sharing_class = RandomAlpha
diff --git a/eval/step_configs/config_celeba_randomalphainc.ini b/eval/step_configs/config_celeba_randomalphainc.ini
new file mode 100644
index 0000000000000000000000000000000000000000..5171b64e5d01cdef69db572231c23d2d99f3cd5e
--- /dev/null
+++ b/eval/step_configs/config_celeba_randomalphainc.ini
@@ -0,0 +1,33 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.RandomAlphaIncremental
+sharing_class = RandomAlphaIncremental
diff --git a/eval/step_configs/config_celeba_roundrobin.ini b/eval/step_configs/config_celeba_roundrobin.ini
new file mode 100644
index 0000000000000000000000000000000000000000..3dadf3274607c049aaefc3cf956a9baef239148b
--- /dev/null
+++ b/eval/step_configs/config_celeba_roundrobin.ini
@@ -0,0 +1,34 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.RoundRobinPartial
+sharing_class = RoundRobinPartial
+alpha = 0.1
diff --git a/eval/step_configs/config_celeba_100.ini b/eval/step_configs/config_celeba_sharing.ini
similarity index 74%
rename from eval/step_configs/config_celeba_100.ini
rename to eval/step_configs/config_celeba_sharing.ini
index 70e14bbaf2fe87717351567aa4420ce439e3fa98..caf05fa846ca9e73d3f5df58aa1a19f01e16de1a 100644
--- a/eval/step_configs/config_celeba_100.ini
+++ b/eval/step_configs/config_celeba_sharing.ini
@@ -2,9 +2,9 @@
 dataset_package = decentralizepy.datasets.Celeba
 dataset_class = Celeba
 model_class = CNN
-images_dir = /home/risharma/leaf/data/celeba/data/raw/img_align_celeba
-train_dir = /home/risharma/leaf/data/celeba/per_user_data/train
-test_dir = /home/risharma/leaf/data/celeba/data/test
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
 ; python list of fractions below
 sizes = 
 
@@ -16,9 +16,9 @@ lr = 0.001
 [TRAIN_PARAMS]
 training_package = decentralizepy.training.Training
 training_class = Training
-rounds = 20
+rounds = 4
 full_epochs = False
-batch_size = 64
+batch_size = 16
 shuffle = True
 loss_package = torch.nn
 loss_class = CrossEntropyLoss
diff --git a/eval/step_configs/config_celeba_subsampling.ini b/eval/step_configs/config_celeba_subsampling.ini
new file mode 100644
index 0000000000000000000000000000000000000000..b8068984fd28b2d2fab9742687c32f4d56897413
--- /dev/null
+++ b/eval/step_configs/config_celeba_subsampling.ini
@@ -0,0 +1,34 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.SubSampling
+sharing_class = SubSampling
+alpha = 0.1
diff --git a/eval/step_configs/config_celeba_topkacc.ini b/eval/step_configs/config_celeba_topkacc.ini
new file mode 100644
index 0000000000000000000000000000000000000000..89eef29dabffffd710686aeda651f17bbbc4c85c
--- /dev/null
+++ b/eval/step_configs/config_celeba_topkacc.ini
@@ -0,0 +1,35 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.PartialModel
+sharing_class = PartialModel
+alpha = 0.1
+accumulation = True
diff --git a/eval/step_configs/config_celeba_topkparam.ini b/eval/step_configs/config_celeba_topkparam.ini
new file mode 100644
index 0000000000000000000000000000000000000000..babc3e96f76fc88b2276759c7accba551d8a587d
--- /dev/null
+++ b/eval/step_configs/config_celeba_topkparam.ini
@@ -0,0 +1,34 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.TopKParams
+sharing_class = TopKParams
+alpha = 0.1
diff --git a/eval/step_configs/config_celeba_topkrandom.ini b/eval/step_configs/config_celeba_topkrandom.ini
new file mode 100644
index 0000000000000000000000000000000000000000..76749557b865379e92cee7cece1fe4978b707337
--- /dev/null
+++ b/eval/step_configs/config_celeba_topkrandom.ini
@@ -0,0 +1,34 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.TopKPlusRandom
+sharing_class = TopKPlusRandom
+alpha = 0.1
diff --git a/eval/step_configs/config_celeba_wavelet.ini b/eval/step_configs/config_celeba_wavelet.ini
new file mode 100644
index 0000000000000000000000000000000000000000..1c97eb9f7e4c5220fd65f231e4423b9b49edc723
--- /dev/null
+++ b/eval/step_configs/config_celeba_wavelet.ini
@@ -0,0 +1,38 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.Wavelet
+sharing_class = Wavelet
+change_based_selection = True
+alpha = 0.1
+wavelet=sym2
+level= 4
+accumulation = True
diff --git a/eval/step_configs/config_femnist_fft.ini b/eval/step_configs/config_femnist_fft.ini
new file mode 100644
index 0000000000000000000000000000000000000000..afac1f43678d6b505dd9fceb8e27f91ce890b0ab
--- /dev/null
+++ b/eval/step_configs/config_femnist_fft.ini
@@ -0,0 +1,37 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Femnist
+dataset_class = Femnist
+random_seed = 97
+model_class = CNN
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+# There are 734463 femnist samples
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.FFT
+sharing_class = FFT
+alpha = 0.1
+change_based_selection = True
+accumulation = True
diff --git a/eval/step_configs/config_femnist_grow.ini b/eval/step_configs/config_femnist_grow.ini
index 9f18ad972e2c5b27e4ffca2242e08488290c083e..2a779c479a20331b965383b6e263373ef95ceedb 100644
--- a/eval/step_configs/config_femnist_grow.ini
+++ b/eval/step_configs/config_femnist_grow.ini
@@ -13,8 +13,8 @@ optimizer_class = Adam
 lr = 0.001
 
 [TRAIN_PARAMS]
-training_package = decentralizepy.training.GradientAccumulator
-training_class = GradientAccumulator
+training_package = decentralizepy.training.Training
+training_class = Training
 rounds = 20
 full_epochs = False
 batch_size = 64
diff --git a/eval/step_configs/config_femnist.ini b/eval/step_configs/config_femnist_partialmodel.ini
similarity index 68%
rename from eval/step_configs/config_femnist.ini
rename to eval/step_configs/config_femnist_partialmodel.ini
index 43bb07dc32e862ed73155ba8a879df569d8266c3..de4f1cef82d2a90665228ef403b02a32fff3188d 100644
--- a/eval/step_configs/config_femnist.ini
+++ b/eval/step_configs/config_femnist_partialmodel.ini
@@ -1,9 +1,10 @@
 [DATASET]
 dataset_package = decentralizepy.datasets.Femnist
 dataset_class = Femnist
+random_seed = 97
 model_class = CNN
-train_dir = /home/risharma/leaf/data/femnist/per_user_data/train
-test_dir = /home/risharma/leaf/data/femnist/data/test
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
 ; python list of fractions below
 sizes = 
 
@@ -13,11 +14,11 @@ optimizer_class = Adam
 lr = 0.001
 
 [TRAIN_PARAMS]
-training_package = decentralizepy.training.GradientAccumulator
-training_class = GradientAccumulator
-rounds = 20
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
 full_epochs = False
-batch_size = 64
+batch_size = 16
 shuffle = True
 loss_package = torch.nn
 loss_class = CrossEntropyLoss
@@ -30,3 +31,4 @@ addresses_filepath = ip_addr_6Machines.json
 [SHARING]
 sharing_package = decentralizepy.sharing.PartialModel
 sharing_class = PartialModel
+alpha=0.1
diff --git a/eval/step_configs/config_femnist_100.ini b/eval/step_configs/config_femnist_sharing.ini
similarity index 77%
rename from eval/step_configs/config_femnist_100.ini
rename to eval/step_configs/config_femnist_sharing.ini
index 4e3e9ba57519f265240130dc8c054355e3bd4d18..e1af10b3822402030cd23e80b67318e370df80aa 100644
--- a/eval/step_configs/config_femnist_100.ini
+++ b/eval/step_configs/config_femnist_sharing.ini
@@ -1,11 +1,12 @@
 [DATASET]
 dataset_package = decentralizepy.datasets.Femnist
 dataset_class = Femnist
+random_seed = 97
 model_class = CNN
-train_dir = /home/risharma/leaf/data/femnist/per_user_data/train
-test_dir = /home/risharma/leaf/data/femnist/data/test
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
 ; python list of fractions below
-sizes = 
+sizes =
 
 [OPTIMIZER_PARAMS]
 optimizer_package = torch.optim
@@ -15,9 +16,9 @@ lr = 0.001
 [TRAIN_PARAMS]
 training_package = decentralizepy.training.Training
 training_class = Training
-rounds = 20
+rounds = 47
 full_epochs = False
-batch_size = 64
+batch_size = 16
 shuffle = True
 loss_package = torch.nn
 loss_class = CrossEntropyLoss
diff --git a/eval/step_configs/config_femnist_subsampling.ini b/eval/step_configs/config_femnist_subsampling.ini
new file mode 100644
index 0000000000000000000000000000000000000000..61a1e9a1358861eb89053c9396bd91ae906ed0ea
--- /dev/null
+++ b/eval/step_configs/config_femnist_subsampling.ini
@@ -0,0 +1,35 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Femnist
+dataset_class = Femnist
+random_seed = 97
+model_class = CNN
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+# There are 734463 femnist samples
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.SubSampling
+sharing_class = SubSampling
+alpha = 0.1
diff --git a/eval/step_configs/config_femnist_topkacc.ini b/eval/step_configs/config_femnist_topkacc.ini
new file mode 100644
index 0000000000000000000000000000000000000000..c9155d148fc62f743ea26de12d6ee18b88357a45
--- /dev/null
+++ b/eval/step_configs/config_femnist_topkacc.ini
@@ -0,0 +1,36 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Femnist
+dataset_class = Femnist
+random_seed = 97
+model_class = CNN
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
+; python list of fractions below
+sizes =
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+# There are 734463 femnist samples
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.PartialModel
+sharing_class = PartialModel
+alpha = 0.1
+accumulation = True
\ No newline at end of file
diff --git a/eval/step_configs/config_femnist_topkparam.ini b/eval/step_configs/config_femnist_topkparam.ini
new file mode 100644
index 0000000000000000000000000000000000000000..ada3c3fb402a9a8ac1cb9e1526347033a3325947
--- /dev/null
+++ b/eval/step_configs/config_femnist_topkparam.ini
@@ -0,0 +1,35 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Femnist
+dataset_class = Femnist
+random_seed = 97
+model_class = CNN
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+# There are 734463 femnist samples
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.TopKParams
+sharing_class = TopKParams
+alpha = 0.1
diff --git a/eval/step_configs/config_femnist_wavelet.ini b/eval/step_configs/config_femnist_wavelet.ini
new file mode 100644
index 0000000000000000000000000000000000000000..b6ff27856b3b7f43fd9ef1bd2c321003bda80a38
--- /dev/null
+++ b/eval/step_configs/config_femnist_wavelet.ini
@@ -0,0 +1,39 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Femnist
+dataset_class = Femnist
+random_seed = 97
+model_class = CNN
+train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+# There are 734463 femnist samples
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.Wavelet
+sharing_class = Wavelet
+change_based_selection = True
+alpha = 0.1
+wavelet=sym2
+level= 4
+accumulation = True
diff --git a/eval/testing.py b/eval/testing.py
index abd6333e2710f05e66a72afc617fffaafaa9ba5a..efb80dfa107be747ef083888741d3d413c30361c 100644
--- a/eval/testing.py
+++ b/eval/testing.py
@@ -62,6 +62,7 @@ if __name__ == "__main__":
             args.log_dir,
             log_level[args.log_level],
             args.test_after,
+            args.train_evaluate_after,
             args.reset_optimizer,
         ],
     )
diff --git a/setup.cfg b/setup.cfg
index 3faa1f36fc490a44ab218e6ce2c38aa78c9b9016..e174dd4425d7e9dcfb50da276c2f9bc8cfac1e5b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -42,6 +42,9 @@ install_requires =
         pillow
         smallworld
         localconfig
+        PyWavelets
+        pandas
+        crudini
 include_package_data = True
 python_requires = >=3.6
 [options.packages.find]
diff --git a/src/decentralizepy/models/Model.py b/src/decentralizepy/models/Model.py
index f7575007730e1376aba8b4bc3726216e738b905d..f3635a947eafbcfd5f00b6ec1094db5888a5fb30 100644
--- a/src/decentralizepy/models/Model.py
+++ b/src/decentralizepy/models/Model.py
@@ -14,9 +14,10 @@ class Model(nn.Module):
 
         """
         super().__init__()
-        self.accumulated_gradients = []
+        self.model_change = None
         self._param_count_ot = None
         self._param_count_total = None
+        self.accumulated_changes = None
 
     def count_params(self, only_trainable=False):
         """
@@ -43,3 +44,16 @@ class Model(nn.Module):
             if not self._param_count_total:
                 self._param_count_total = sum(p.numel() for p in self.parameters())
             return self._param_count_total
+
+    def rewind_accumulation(self, indices):
+        """
+        resets accumulated_changes at the given indices
+
+        Parameters
+        ----------
+        indices : torch.Tensor
+            Tensor that contains indices corresponding to the flatten model
+
+        """
+        if self.accumulated_changes is not None:
+            self.accumulated_changes[indices] = 0.0
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index e5764aea2430e0c9f8de57a34825e77400d4765d..74de4e1376bb1bae0b713e7c1b7d439316e7facb 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -77,6 +77,7 @@ class Node:
         iterations,
         log_dir,
         test_after,
+        train_evaluate_after,
         reset_optimizer,
     ):
         """
@@ -92,8 +93,14 @@ class Node:
             The object containing the mapping rank <--> uid
         graph : decentralizepy.graphs
             The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
         log_dir : str
             Logging directory
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
         reset_optimizer : int
             1 if optimizer should be reset every communication round, else 0
 
@@ -106,6 +113,7 @@ class Node:
         self.log_dir = log_dir
         self.iterations = iterations
         self.test_after = test_after
+        self.train_evaluate_after = train_evaluate_after
         self.reset_optimizer = reset_optimizer
 
         logging.debug("Rank: %d", self.rank)
@@ -260,6 +268,7 @@ class Node:
         log_dir=".",
         log_level=logging.INFO,
         test_after=5,
+        train_evaluate_after = 1,
         reset_optimizer=1,
         *args
     ):
@@ -278,10 +287,16 @@ class Node:
             The object containing the global graph
         config : dict
             A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
         log_dir : str
             Logging directory
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
         reset_optimizer : int
             1 if optimizer should be reset every communication round, else 0
         args : optional
@@ -298,6 +313,7 @@ class Node:
             iterations,
             log_dir,
             test_after,
+            train_evaluate_after,
             reset_optimizer,
         )
         self.init_log(log_dir, rank, log_level)
@@ -315,6 +331,7 @@ class Node:
         self.testset = self.dataset.get_testset()
         self.communication.connect_neighbors(self.graph.neighbors(self.uid))
         rounds_to_test = self.test_after
+        rounds_to_train_evaluate = self.train_evaluate_after
 
         for iteration in range(self.iterations):
             logging.info("Starting training iteration: %d", iteration)
@@ -328,7 +345,6 @@ class Node:
                 )  # Reset optimizer state
                 self.trainer.reset_optimizer(self.optimizer)
 
-            loss_after_sharing = self.trainer.eval_loss(self.dataset)
 
             if iteration:
                 with open(
@@ -348,7 +364,6 @@ class Node:
                     "grad_std": {},
                 }
 
-            results_dict["train_loss"][iteration + 1] = loss_after_sharing
             results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
 
             if hasattr(self.sharing, "total_meta"):
@@ -361,14 +376,21 @@ class Node:
                 results_dict["grad_mean"][iteration + 1] = self.sharing.mean
             if hasattr(self.sharing, "std"):
                 results_dict["grad_std"][iteration + 1] = self.sharing.std
-
-            self.save_plot(
-                results_dict["train_loss"],
-                "train_loss",
-                "Training Loss",
-                "Communication Rounds",
-                os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
-            )
+            
+            rounds_to_train_evaluate -= 1
+
+            if rounds_to_train_evaluate == 0:
+                logging.info("Evaluating on train set.")
+                rounds_to_train_evaluate = self.train_evaluate_after
+                loss_after_sharing = self.trainer.eval_loss(self.dataset)
+                results_dict["train_loss"][iteration + 1] = loss_after_sharing
+                self.save_plot(
+                    results_dict["train_loss"],
+                    "train_loss",
+                    "Training Loss",
+                    "Communication Rounds",
+                    os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
+                )
 
             rounds_to_test -= 1
 
@@ -413,6 +435,7 @@ class Node:
         log_dir=".",
         log_level=logging.INFO,
         test_after=5,
+        train_evaluate_after=1,
         reset_optimizer=1,
         *args
     ):
@@ -443,10 +466,16 @@ class Node:
                 training_class = Training
                 epochs_per_round = 25
                 batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
         log_dir : str
             Logging directory
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
         reset_optimizer : int
             1 if optimizer should be reset every communication round, else 0
         args : optional
@@ -467,6 +496,7 @@ class Node:
             log_dir,
             log_level,
             test_after,
+            train_evaluate_after,
             reset_optimizer,
             *args
         )
diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c0172f5ce27739165b6c22f98461b6839d6d515
--- /dev/null
+++ b/src/decentralizepy/sharing/FFT.py
@@ -0,0 +1,287 @@
+import json
+import logging
+import os
+from pathlib import Path
+from time import time
+
+import numpy as np
+import torch
+import torch.fft as fft
+
+from decentralizepy.sharing.PartialModel import PartialModel
+
+
+def change_transformer_fft(x):
+    """
+    Transforms the model changes into frequency domain
+
+    Parameters
+    ----------
+    x : torch.Tensor
+        Model change in the space domain
+
+    Returns
+    -------
+    x : torch.Tensor
+        Representation of the change int the frequency domain
+    """
+    return fft.rfft(x)
+
+
+class FFT(PartialModel):
+    """
+    This class implements the fft version of model sharing
+    It is based on PartialModel.py
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha=1.0,
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+        change_based_selection=True,
+        save_accumulated="",
+        accumulation=True,
+        accumulate_averaging_changes=False,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        alpha : float
+            Percentage of model to share
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+        change_based_selection : bool
+            use frequency change to select topk frequencies
+        save_accumulated : bool
+            True if accumulated weight change in the frequency domain should be written to file. In case of accumulation
+            the accumulated change is stored.
+        accumulation : bool
+            True if the the indices to share should be selected based on accumulated frequency change
+        accumulate_averaging_changes: bool
+            True if the accumulation should account the model change due to averaging
+
+        """
+        super().__init__(
+            rank,
+            machine_id,
+            communication,
+            mapping,
+            graph,
+            model,
+            dataset,
+            log_dir,
+            alpha,
+            dict_ordered,
+            save_shared,
+            metadata_cap,
+            accumulation,
+            save_accumulated,
+            change_transformer_fft,
+            accumulate_averaging_changes,
+        )
+        self.change_based_selection = change_based_selection
+
+    def apply_fft(self):
+        """
+        Does fft transformation of the model parameters and selects topK (alpha) of them in the frequency domain
+        based on the undergone change during the current training step
+
+        Returns
+        -------
+        tuple
+            (a,b). a: selected fft frequencies (complex numbers), b: Their indices.
+
+        """
+
+        logging.info("Returning fft compressed model weights")
+        with torch.no_grad():
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            concated = torch.cat(tensors_to_cat, dim=0)
+            flat_fft = self.change_transformer(concated)
+            if self.change_based_selection:
+                diff = self.model.model_change
+                _, index = torch.topk(
+                    diff.abs(), round(self.alpha * len(diff)), dim=0, sorted=False
+                )
+            else:
+                _, index = torch.topk(
+                    flat_fft.abs(),
+                    round(self.alpha * len(flat_fft)),
+                    dim=0,
+                    sorted=False,
+                )
+
+        return flat_fft[index], index
+
+    def serialized_model(self):
+        """
+        Convert model to json dict. self.alpha specifies the fraction of model to send.
+
+        Returns
+        -------
+        dict
+            Model converted to json dict
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().serialized_model()
+
+        with torch.no_grad():
+            topk, indices = self.apply_fft()
+
+            self.model.rewind_accumulation(indices)
+
+            if self.save_shared:
+                shared_params = dict()
+                shared_params["order"] = list(self.model.state_dict().keys())
+                shapes = dict()
+                for k, v in self.model.state_dict().items():
+                    shapes[k] = list(v.shape)
+                shared_params["shapes"] = shapes
+
+                shared_params[self.communication_round] = indices.tolist()  # is slow
+
+                shared_params["alpha"] = self.alpha
+
+                with open(
+                    os.path.join(
+                        self.folder_path,
+                        "{}_shared_params.json".format(self.communication_round + 1),
+                    ),
+                    "w",
+                ) as of:
+                    json.dump(shared_params, of)
+
+            m = dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            m["alpha"] = self.alpha
+            m["params"] = topk.numpy()
+            m["indices"] = indices.numpy().astype(np.int32)
+
+            self.total_data += len(self.communication.encrypt(m["params"]))
+            self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
+                self.communication.encrypt(m["alpha"])
+            )
+
+        return m
+
+    def deserialized_model(self, m):
+        """
+        Convert received json dict to state_dict.
+
+        Parameters
+        ----------
+        m : dict
+            json dict received
+
+        Returns
+        -------
+        state_dict
+            state_dict of received
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().deserialized_model(m)
+
+        with torch.no_grad():
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            indices = m["indices"]
+            alpha = m["alpha"]
+            params = m["params"]
+
+            params_tensor = torch.tensor(params)
+            indices_tensor = torch.tensor(indices, dtype=torch.long)
+            ret = dict()
+            ret["indices"] = indices_tensor
+            ret["params"] = params_tensor
+        return ret
+
+    def _averaging(self):
+        """
+        Averages the received model with the local model
+
+        """
+        with torch.no_grad():
+            total = None
+            weight_total = 0
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            pre_share_model = torch.cat(tensors_to_cat, dim=0)
+            flat_fft = self.change_transformer(pre_share_model)
+
+            for i, n in enumerate(self.peer_deques):
+                degree, iteration, data = self.peer_deques[n].popleft()
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(
+                        n, iteration
+                    )
+                )
+                data = self.deserialized_model(data)
+                params = data["params"]
+                indices = data["indices"]
+                # use local data to complement
+                topkf = flat_fft.clone().detach()
+                topkf[indices] = params
+
+                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                weight_total += weight
+                if total is None:
+                    total = weight * topkf
+                else:
+                    total += weight * topkf
+
+            # Metro-Hastings
+            total += (1 - weight_total) * flat_fft
+            reverse_total = fft.irfft(total)
+
+            start_index = 0
+            std_dict = {}
+            for i, key in enumerate(self.model.state_dict()):
+                end_index = start_index + self.lens[i]
+                std_dict[key] = reverse_total[start_index:end_index].reshape(
+                    self.shapes[i]
+                )
+                start_index = end_index
+
+        self.model.load_state_dict(std_dict)
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 6a8f0cb71a3468f7aa33d0bbd91172c4be4ffefd..97c702b9b334f05273063ba25aaa7df4472ab27e 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -7,6 +7,7 @@ import numpy as np
 import torch
 
 from decentralizepy.sharing.Sharing import Sharing
+from decentralizepy.utils import conditional_value, identity
 
 
 class PartialModel(Sharing):
@@ -29,6 +30,10 @@ class PartialModel(Sharing):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
+        accumulation=False,
+        save_accumulated="",
+        change_transformer=identity,
+        accumulate_averaging_changes=False,
     ):
         """
         Constructor
@@ -59,6 +64,15 @@ class PartialModel(Sharing):
             Specifies if the indices of shared parameters should be logged
         metadata_cap : float
             Share full model when self.alpha > metadata_cap
+        accumulation : bool
+            True if the the indices to share should be selected based on accumulated frequency change
+        save_accumulated : bool
+            True if accumulated weight change should be written to file. In case of accumulation the accumulated change
+            is stored. If a change_transformer is used then the transformed change is stored.
+        change_transformer : (x: Tensor) -> Tensor
+            A function that transforms the model change into other domains. Default: identity function
+        accumulate_averaging_changes: bool
+            True if the accumulation should account the model change due to averaging
 
         """
         super().__init__(
@@ -69,6 +83,38 @@ class PartialModel(Sharing):
         self.save_shared = save_shared
         self.metadata_cap = metadata_cap
         self.total_meta = 0
+        self.accumulation = accumulation
+        self.save_accumulated = conditional_value(save_accumulated, "", False)
+        self.change_transformer = change_transformer
+        self.accumulate_averaging_changes = accumulate_averaging_changes
+
+        # getting the initial model
+        self.shapes = []
+        self.lens = []
+        with torch.no_grad():
+            tensors_to_cat = []
+            for _, v in self.model.state_dict().items():
+                self.shapes.append(v.shape)
+                t = v.flatten()
+                self.lens.append(t.shape[0])
+                tensors_to_cat.append(t)
+            self.init_model = torch.cat(tensors_to_cat, dim=0)
+            if self.accumulation:
+                self.model.accumulated_changes = torch.zeros_like(
+                    self.change_transformer(self.init_model)
+                )
+                self.prev = self.init_model
+
+        if self.save_accumulated:
+            self.model_change_path = os.path.join(
+                self.log_dir, "model_change/{}".format(self.rank)
+            )
+            Path(self.model_change_path).mkdir(parents=True, exist_ok=True)
+
+            self.model_val_path = os.path.join(
+                self.log_dir, "model_val/{}".format(self.rank)
+            )
+            Path(self.model_val_path).mkdir(parents=True, exist_ok=True)
 
         # Only save for 2 procs: Save space
         if self.save_shared and not (rank == 0 or rank == 1):
@@ -91,16 +137,9 @@ class PartialModel(Sharing):
             (a,b). a: The magnitudes of the topK gradients, b: Their indices.
 
         """
-        logging.info("Summing up gradients")
-        assert len(self.model.accumulated_gradients) > 0
-        gradient_sum = self.model.accumulated_gradients[0]
-        for i in range(1, len(self.model.accumulated_gradients)):
-            for key in self.model.accumulated_gradients[i]:
-                gradient_sum[key] += self.model.accumulated_gradients[i][key]
 
         logging.info("Returning topk gradients")
-        tensors_to_cat = [v.data.flatten() for _, v in gradient_sum.items()]
-        G_topk = torch.abs(torch.cat(tensors_to_cat, dim=0))
+        G_topk = torch.abs(self.model.model_change)
         std, mean = torch.std_mean(G_topk, unbiased=False)
         self.std = std.item()
         self.mean = mean.item()
@@ -118,12 +157,13 @@ class PartialModel(Sharing):
             Model converted to a dict
 
         """
-        if self.alpha > self.metadata_cap:  # Share fully
+        if self.alpha >= self.metadata_cap:  # Share fully
             return super().serialized_model()
 
         with torch.no_grad():
             _, G_topk = self.extract_top_gradients()
-
+            if self.accumulation:
+                self.model.rewind_accumulation(G_topk)
             if self.save_shared:
                 shared_params = dict()
                 shared_params["order"] = list(self.model.state_dict().keys())
@@ -218,3 +258,86 @@ class PartialModel(Sharing):
                 start_index = end_index
 
             return state_dict
+
+    def _pre_step(self):
+        """
+        Called at the beginning of step.
+
+        """
+        logging.info("PartialModel _pre_step")
+        with torch.no_grad():
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            pre_share_model = torch.cat(tensors_to_cat, dim=0)
+            change = self.change_transformer(pre_share_model - self.init_model)
+            if self.accumulation:
+                if not self.accumulate_averaging_changes:
+                    # Need to accumulate in _pre_step as the accumulation gets rewind during the step
+                    self.model.accumulated_changes += change
+                    change = self.model.accumulated_changes.clone().detach()
+                else:
+                    # For the legacy implementation, we will only rewind currently accumulated values
+                    # and add the model change due to averaging in the end
+                    change += self.model.accumulated_changes
+            # stores change of the model due to training, change due to averaging is not accounted
+            self.model.model_change = change
+
+    def _post_step(self):
+        """
+        Called at the end of step.
+
+        """
+        logging.info("PartialModel _post_step")
+        with torch.no_grad():
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            post_share_model = torch.cat(tensors_to_cat, dim=0)
+            self.init_model = post_share_model
+            if self.accumulation:
+                if self.accumulate_averaging_changes:
+                    self.model.accumulated_changes += self.change_transformer(
+                        self.init_model - self.prev
+                    )
+                self.prev = self.init_model
+            self.model.model_change = None
+        if self.save_accumulated:
+            self.save_change()
+
+    def save_vector(self, v, s):
+        """
+        Saves the given vector to the file.
+
+        Parameters
+        ----------
+        v : torch.tensor
+            The torch tensor to write to file
+        s : str
+            Path to folder to write to
+
+        """
+        output_dict = dict()
+        output_dict["order"] = list(self.model.state_dict().keys())
+        shapes = dict()
+        for k, v1 in self.model.state_dict().items():
+            shapes[k] = list(v1.shape)
+        output_dict["shapes"] = shapes
+
+        output_dict["tensor"] = v.tolist()
+
+        with open(
+            os.path.join(
+                s,
+                "{}.json".format(self.communication_round + 1),
+            ),
+            "w",
+        ) as of:
+            json.dump(output_dict, of)
+
+    def save_change(self):
+        """
+        Saves the change and the gradient values for every iteration
+
+        """
+        self.save_vector(self.model.model_change, self.model_change_path)
diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py
index 85fc07bea01cef5802d1178fc3dc1ac0ef5281e0..3fe189c14c94fdcc740fb0046bce5c78ebc36e25 100644
--- a/src/decentralizepy/sharing/Sharing.py
+++ b/src/decentralizepy/sharing/Sharing.py
@@ -31,7 +31,7 @@ class Sharing:
         model : decentralizepy.models.Model
             Model to train
         dataset : decentralizepy.datasets.Dataset
-            Dataset for sharing data. Not implemented yer! TODO
+            Dataset for sharing data. Not implemented yet! TODO
         log_dir : str
             Location to write shared_params (only writing for 2 procs per machine)
 
@@ -122,11 +122,55 @@ class Sharing:
             state_dict[key] = torch.from_numpy(value)
         return state_dict
 
+    def _pre_step(self):
+        """
+        Called at the beginning of step.
+
+        """
+        pass
+
+    def _post_step(self):
+        """
+        Called at the end of step.
+
+        """
+        pass
+
+    def _averaging(self):
+        """
+        Averages the received model with the local model
+
+        """
+        with torch.no_grad():
+            total = dict()
+            weight_total = 0
+            for i, n in enumerate(self.peer_deques):
+                degree, iteration, data = self.peer_deques[n].popleft()
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(
+                        n, iteration
+                    )
+                )
+                data = self.deserialized_model(data)
+                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                weight_total += weight
+                for key, value in data.items():
+                    if key in total:
+                        total[key] += value * weight
+                    else:
+                        total[key] = value * weight
+
+            for key, value in self.model.state_dict().items():
+                total[key] += (1 - weight_total) * value  # Metro-Hastings
+
+        self.model.load_state_dict(total)
+
     def step(self):
         """
         Perform a sharing step. Implements D-PSGD.
 
         """
+        self._pre_step()
         data = self.serialized_model()
         my_uid = self.mapping.get_uid(self.rank, self.machine_id)
         all_neighbors = self.graph.neighbors(my_uid)
@@ -152,27 +196,8 @@ class Sharing:
             )
 
         logging.info("Starting model averaging after receiving from all neighbors")
-        total = dict()
-        weight_total = 0
-        for i, n in enumerate(self.peer_deques):
-            degree, iteration, data = self.peer_deques[n].popleft()
-            logging.debug(
-                "Averaging model from neighbor {} of iteration {}".format(n, iteration)
-            )
-            data = self.deserialized_model(data)
-            weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
-            weight_total += weight
-            for key, value in data.items():
-                if key in total:
-                    total[key] += value * weight
-                else:
-                    total[key] = value * weight
-
-        for key, value in self.model.state_dict().items():
-            total[key] += (1 - weight_total) * value  # Metro-Hastings
-
-        self.model.load_state_dict(total)
-
+        self._averaging()
         logging.info("Model averaging complete")
 
         self.communication_round += 1
+        self._post_step()
diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ec0c44f054eb44debf056189513ab814f620a9f
--- /dev/null
+++ b/src/decentralizepy/sharing/SubSampling.py
@@ -0,0 +1,285 @@
+import json
+import logging
+import os
+from pathlib import Path
+
+import torch
+
+from decentralizepy.sharing.Sharing import Sharing
+
+
+class SubSampling(Sharing):
+    """
+    This class implements the subsampling version of model sharing
+    It is based on PartialModel.py
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha=1.0,
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+        pickle=True,
+        layerwise=False,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        alpha : float
+            Percentage of model to share
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+        pickle : bool
+            use pickle to serialize the model parameters
+
+        """
+        super().__init__(
+            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+        )
+        self.alpha = alpha
+        self.dict_ordered = dict_ordered
+        self.save_shared = save_shared
+        self.metadata_cap = metadata_cap
+        self.total_meta = 0
+
+        # self.random_seed_generator = torch.Generator()
+        # # Will use the random device if supported by CPU, else uses the system time
+        # # In the latter case we could get duplicate seeds on some of the machines
+        # self.random_seed_generator.seed()
+
+        self.random_generator = torch.Generator()
+        # Will use the random device if supported by CPU, else uses the system time
+        # In the latter case we could get duplicate seeds on some of the machines
+        self.random_generator.seed()
+        self.seed = self.random_generator.initial_seed()
+
+        self.pickle = pickle
+        self.layerwise = layerwise
+
+        logging.info("subsampling pickling=" + str(pickle))
+
+        if self.save_shared:
+            # Only save for 2 procs: Save space
+            if rank != 0 or rank != 1:
+                self.save_shared = False
+
+        if self.save_shared:
+            self.folder_path = os.path.join(
+                self.log_dir, "shared_params/{}".format(self.rank)
+            )
+            Path(self.folder_path).mkdir(parents=True, exist_ok=True)
+
+    def apply_subsampling(self):
+        """
+        Creates a random binary mask that is used to subsample the parameters that will be shared
+
+        Returns
+        -------
+        tuple
+            (a,b,c). a: the selected parameters as flat vector, b: the random seed used to crate the binary mask
+                     c: the alpha
+
+        """
+
+        logging.info("Returning subsampling gradients")
+        if not self.layerwise:
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            concated = torch.cat(tensors_to_cat, dim=0)
+
+            curr_seed = self.seed + self.communication_round  # is increased in step
+            self.random_generator.manual_seed(curr_seed)
+            # logging.debug("Subsampling seed for uid = " + str(self.uid) + " is: " + str(curr_seed))
+            # Or we could use torch.bernoulli
+            binary_mask = (
+                torch.rand(
+                    size=(concated.size(dim=0),), generator=self.random_generator
+                )
+                <= self.alpha
+            )
+            subsample = concated[binary_mask]
+            # logging.debug("Subsampling vector is of size: " + str(subsample.size(dim = 0)))
+            return (subsample, curr_seed, self.alpha)
+        else:
+            values_list = []
+            offsets = [0]
+            off = 0
+            curr_seed = self.seed + self.communication_round  # is increased in step
+            self.random_generator.manual_seed(curr_seed)
+            for _, v in self.model.state_dict().items():
+                flat = v.flatten()
+                binary_mask = (
+                    torch.rand(
+                        size=(flat.size(dim=0),), generator=self.random_generator
+                    )
+                    <= self.alpha
+                )
+                selected = flat[binary_mask]
+                values_list.append(selected)
+                off += selected.size(dim=0)
+                offsets.append(off)
+            subsample = torch.cat(values_list, dim=0)
+            return (subsample, curr_seed, self.alpha)
+
+    def serialized_model(self):
+        """
+        Convert model to json dict. self.alpha specifies the fraction of model to send.
+
+        Returns
+        -------
+        dict
+            Model converted to json dict
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().serialized_model()
+
+        with torch.no_grad():
+            subsample, seed, alpha = self.apply_subsampling()
+
+            if self.save_shared:
+                shared_params = dict()
+                shared_params["order"] = list(self.model.state_dict().keys())
+                shapes = dict()
+                for k, v in self.model.state_dict().items():
+                    shapes[k] = list(v.shape)
+                shared_params["shapes"] = shapes
+
+                # TODO: should store the shared indices and not the value
+                # shared_params[self.communication_round] = subsample.tolist() # is slow
+
+                shared_params["seed"] = seed
+
+                shared_params["alpha"] = alpha
+
+                with open(
+                    os.path.join(
+                        self.folder_path,
+                        "{}_shared_params.json".format(self.communication_round + 1),
+                    ),
+                    "w",
+                ) as of:
+                    json.dump(shared_params, of)
+
+            m = dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            m["seed"] = seed
+            m["alpha"] = alpha
+            m["params"] = subsample.numpy()
+
+            # logging.info("Converted dictionary to json")
+            self.total_data += len(self.communication.encrypt(m["params"]))
+            self.total_meta += len(self.communication.encrypt(m["seed"])) + len(
+                self.communication.encrypt(m["alpha"])
+            )
+
+            return m
+
+    def deserialized_model(self, m):
+        """
+        Convert received json dict to state_dict.
+
+        Parameters
+        ----------
+        m : dict
+            json dict received
+
+        Returns
+        -------
+        state_dict
+            state_dict of received
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().deserialized_model(m)
+
+        with torch.no_grad():
+            state_dict = self.model.state_dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            seed = m["seed"]
+            alpha = m["alpha"]
+            params = m["params"]
+
+            random_generator = (
+                torch.Generator()
+            )  # new generator, such that we do not overwrite the other one
+            random_generator.manual_seed(seed)
+
+            shapes = []
+            lens = []
+            tensors_to_cat = []
+            binary_submasks = []
+            for _, v in state_dict.items():
+                shapes.append(v.shape)
+                t = v.flatten()
+                lens.append(t.shape[0])
+                tensors_to_cat.append(t)
+                if self.layerwise:
+                    binary_mask = (
+                        torch.rand(size=(t.size(dim=0),), generator=random_generator)
+                        <= alpha
+                    )
+                    binary_submasks.append(binary_mask)
+
+            T = torch.cat(tensors_to_cat, dim=0)
+
+            params_tensor = torch.from_numpy(params)
+
+            if not self.layerwise:
+                binary_mask = (
+                    torch.rand(size=(T.size(dim=0),), generator=random_generator)
+                    <= alpha
+                )
+            else:
+                binary_mask = torch.cat(binary_submasks, dim=0)
+
+            logging.debug("Original tensor: {}".format(T[binary_mask]))
+            T[binary_mask] = params_tensor
+            logging.debug("Final tensor: {}".format(T[binary_mask]))
+
+            start_index = 0
+            for i, key in enumerate(state_dict):
+                end_index = start_index + lens[i]
+                state_dict[key] = T[start_index:end_index].reshape(shapes[i])
+                start_index = end_index
+
+            return state_dict
diff --git a/src/decentralizepy/sharing/TopKParams.py b/src/decentralizepy/sharing/TopKParams.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6535ce5052b25b0f7665010504cffc5e8f26be0
--- /dev/null
+++ b/src/decentralizepy/sharing/TopKParams.py
@@ -0,0 +1,226 @@
+import json
+import logging
+import os
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from decentralizepy.sharing.Sharing import Sharing
+
+
+class TopKParams(Sharing):
+    """
+    This class implements the vanilla version of partial model sharing.
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha=1.0,
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        alpha : float
+            Percentage of model to share
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+
+        """
+        super().__init__(
+            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+        )
+        self.alpha = alpha
+        self.dict_ordered = dict_ordered
+        self.save_shared = save_shared
+        self.metadata_cap = metadata_cap
+        self.total_meta = 0
+
+        if self.save_shared:
+            # Only save for 2 procs: Save space
+            if rank != 0 or rank != 1:
+                self.save_shared = False
+
+        if self.save_shared:
+            self.folder_path = os.path.join(
+                self.log_dir, "shared_params/{}".format(self.rank)
+            )
+            Path(self.folder_path).mkdir(parents=True, exist_ok=True)
+
+    def extract_top_params(self):
+        """
+        Extract the indices and values of the topK params layerwise.
+        The gradients must have been accumulated.
+
+        Returns
+        -------
+        tuple
+            (a,b,c). a: The topK params, b: Their indices, c: The offsets
+
+        """
+
+        logging.info("Returning TopKParams gradients")
+        values_list = []
+        index_list = []
+        offsets = [0]
+        off = 0
+        for _, v in self.model.state_dict().items():
+            flat = v.flatten()
+            values, index = torch.topk(
+                flat.abs(), round(self.alpha * flat.size(dim=0)), dim=0, sorted=False
+            )
+            values_list.append(flat[index])
+            index_list.append(index)
+            off += values.size(dim=0)
+            offsets.append(off)
+        cat_values = torch.cat(values_list, dim=0)
+        cat_index = torch.cat(index_list, dim=0)
+
+        # logging.debug("Subsampling vector is of size: " + str(subsample.size(dim = 0)))
+        return (cat_values, cat_index, offsets)
+
+    def serialized_model(self):
+        """
+        Convert model to json dict. self.alpha specifies the fraction of model to send.
+
+        Returns
+        -------
+        dict
+            Model converted to json dict
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().serialized_model()
+
+        with torch.no_grad():
+            values, index, offsets = self.extract_top_params()
+
+            if self.save_shared:
+                shared_params = dict()
+                shared_params["order"] = list(self.model.state_dict().keys())
+                shapes = dict()
+                for k, v in self.model.state_dict().items():
+                    shapes[k] = list(v.shape)
+                shared_params["shapes"] = shapes
+
+                shared_params[self.communication_round] = index.tolist()
+                # TODO: store offsets
+
+                with open(
+                    os.path.join(
+                        self.folder_path,
+                        "{}_shared_params.json".format(self.communication_round + 1),
+                    ),
+                    "w",
+                ) as of:
+                    json.dump(shared_params, of)
+
+            logging.info("Extracting topk params")
+
+            logging.info("Generating dictionary to send")
+
+            m = dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            m["indices"] = index.numpy().astype(np.int32)
+            m["params"] = values.numpy()
+            m["offsets"] = offsets
+
+            assert len(m["indices"]) == len(m["params"])
+            logging.info("Elements sending: {}".format(len(m["indices"])))
+
+            logging.info("Generated dictionary to send")
+
+            # for key in m:
+            #    m[key] = json.dumps(m[key])
+
+            logging.info("Converted dictionary to json")
+            self.total_data += len(self.communication.encrypt(m["params"]))
+            self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
+                self.communication.encrypt(m["offsets"])
+            )
+
+            return m
+
+    def deserialized_model(self, m):
+        """
+        Convert received json dict to state_dict.
+
+        Parameters
+        ----------
+        m : dict
+            json dict received
+
+        Returns
+        -------
+        state_dict
+            state_dict of received
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().deserialized_model(m)
+
+        with torch.no_grad():
+            state_dict = self.model.state_dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            shapes = []
+            lens = []
+            tensors_to_cat = []
+            offsets = m["offsets"]
+            params = torch.tensor(m["params"])
+            indices = torch.tensor(m["indices"], dtype=torch.long)
+
+            for i, (_, v) in enumerate(state_dict.items()):
+                shapes.append(v.shape)
+                t = v.flatten().clone().detach()  # it is not always copied
+                lens.append(t.shape[0])
+                index = indices[offsets[i] : offsets[i + 1]]
+                t[index] = params[offsets[i] : offsets[i + 1]]
+                tensors_to_cat.append(t)
+
+            start_index = 0
+            for i, key in enumerate(state_dict):
+                end_index = start_index + lens[i]
+                state_dict[key] = tensors_to_cat[i].reshape(shapes[i])
+                start_index = end_index
+
+            return state_dict
diff --git a/src/decentralizepy/sharing/TopKPlusRandom.py b/src/decentralizepy/sharing/TopKPlusRandom.py
index 1a31e433318b6c813dea2b25f30cdadbfd465840..728d5bfa48d71037a6d165d6396343f46dfa4e3e 100644
--- a/src/decentralizepy/sharing/TopKPlusRandom.py
+++ b/src/decentralizepy/sharing/TopKPlusRandom.py
@@ -84,16 +84,8 @@ class TopKPlusRandom(PartialModel):
             (a,b). a: The magnitudes of the topK gradients, b: Their indices.
 
         """
-        logging.info("Summing up gradients")
-        assert len(self.model.accumulated_gradients) > 0
-        gradient_sum = self.model.accumulated_gradients[0]
-        for i in range(1, len(self.model.accumulated_gradients)):
-            for key in self.model.accumulated_gradients[i]:
-                gradient_sum[key] += self.model.accumulated_gradients[i][key]
-
         logging.info("Returning topk gradients")
-        tensors_to_cat = [v.data.flatten() for _, v in gradient_sum.items()]
-        G = torch.abs(torch.cat(tensors_to_cat, dim=0))
+        G = torch.abs(self.model.model_change)
         std, mean = torch.std_mean(G, unbiased=False)
         self.std = std.item()
         self.mean = mean.item()
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
new file mode 100644
index 0000000000000000000000000000000000000000..363a487d098c76b36a388c20f09f5c618a2868a5
--- /dev/null
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -0,0 +1,315 @@
+import json
+import logging
+import os
+from pathlib import Path
+from time import time
+
+import numpy as np
+import pywt
+import torch
+
+from decentralizepy.sharing.PartialModel import PartialModel
+
+
+def change_transformer_wavelet(x, wavelet, level):
+    """
+    Transforms the model changes into wavelet frequency domain
+
+    Parameters
+    ----------
+    x : torch.Tensor
+        Model change in the space domain
+    wavelet : str
+        name of the wavelet to be used in gradient compression
+    level: int
+        name of the wavelet to be used in gradient compression
+
+    Returns
+    -------
+    x : torch.Tensor
+        Representation of the change int the wavelet domain
+    """
+    coeff = pywt.wavedec(x, wavelet, level=level)
+    data, coeff_slices = pywt.coeffs_to_array(coeff)
+    return torch.from_numpy(data.ravel())
+
+
+class Wavelet(PartialModel):
+    """
+    This class implements the wavelet version of model sharing
+    It is based on PartialModel.py
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha=1.0,
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+        wavelet="haar",
+        level=4,
+        change_based_selection=True,
+        save_accumulated="",
+        accumulation=False,
+        accumulate_averaging_changes=False,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        alpha : float
+            Percentage of model to share
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+        wavelet: str
+            name of the wavelet to be used in gradient compression
+        level: int
+            name of the wavelet to be used in gradient compression
+        change_based_selection : bool
+            use frequency change to select topk frequencies
+        save_accumulated : bool
+            True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation
+            the accumulated change is stored.
+        accumulation : bool
+            True if the the indices to share should be selected based on accumulated frequency change
+        accumulate_averaging_changes: bool
+            True if the accumulation should account the model change due to averaging
+        """
+        self.wavelet = wavelet
+        self.level = level
+
+        super().__init__(
+            rank,
+            machine_id,
+            communication,
+            mapping,
+            graph,
+            model,
+            dataset,
+            log_dir,
+            alpha,
+            dict_ordered,
+            save_shared,
+            metadata_cap,
+            accumulation,
+            save_accumulated,
+            lambda x: change_transformer_wavelet(x, wavelet, level),
+            accumulate_averaging_changes,
+        )
+
+        self.change_based_selection = change_based_selection
+
+        # Do a dummy transform to get the shape and coefficents slices
+        coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level)
+        data, coeff_slices = pywt.coeffs_to_array(coeff)
+        self.wt_shape = data.shape
+        self.coeff_slices = coeff_slices
+
+    def apply_wavelet(self):
+        """
+        Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain
+        based on the undergone change during the current training step
+
+        Returns
+        -------
+        tuple
+            (a,b). a: selected wavelet coefficients, b: Their indices.
+
+        """
+
+        logging.info("Returning wavelet compressed model weights")
+        tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
+        concated = torch.cat(tensors_to_cat, dim=0)
+        data = self.change_transformer(concated)
+        if self.change_based_selection:
+            diff = self.model.model_change
+            _, index = torch.topk(
+                diff.abs(),
+                round(self.alpha * len(diff)),
+                dim=0,
+                sorted=False,
+            )
+        else:
+            _, index = torch.topk(
+                data.abs(),
+                round(self.alpha * len(data)),
+                dim=0,
+                sorted=False,
+            )
+
+        return data[index], index
+
+    def serialized_model(self):
+        """
+        Convert model to json dict. self.alpha specifies the fraction of model to send.
+
+        Returns
+        -------
+        dict
+            Model converted to json dict
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().serialized_model()
+
+        with torch.no_grad():
+            topk, indices = self.apply_wavelet()
+
+            self.model.rewind_accumulation(indices)
+            if self.save_shared:
+                shared_params = dict()
+                shared_params["order"] = list(self.model.state_dict().keys())
+                shapes = dict()
+                for k, v in self.model.state_dict().items():
+                    shapes[k] = list(v.shape)
+                shared_params["shapes"] = shapes
+
+                shared_params[self.communication_round] = indices.tolist()  # is slow
+
+                shared_params["alpha"] = self.alpha
+
+                with open(
+                    os.path.join(
+                        self.folder_path,
+                        "{}_shared_params.json".format(self.communication_round + 1),
+                    ),
+                    "w",
+                ) as of:
+                    json.dump(shared_params, of)
+
+            m = dict()
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            m["alpha"] = self.alpha
+
+            m["params"] = topk.numpy()
+
+            m["indices"] = indices.numpy().astype(np.int32)
+
+            self.total_data += len(self.communication.encrypt(m["params"]))
+            self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
+                self.communication.encrypt(m["alpha"])
+            )
+
+            return m
+
+    def deserialized_model(self, m):
+        """
+        Convert received dict to state_dict.
+
+        Parameters
+        ----------
+        m : dict
+            received dict
+
+        Returns
+        -------
+        state_dict
+            state_dict of received
+
+        """
+        if self.alpha > self.metadata_cap:  # Share fully
+            return super().deserialized_model(m)
+
+        with torch.no_grad():
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            indices = m["indices"]
+            alpha = m["alpha"]
+            params = m["params"]
+
+            params_tensor = torch.tensor(params)
+            indices_tensor = torch.tensor(indices, dtype=torch.long)
+            ret = dict()
+            ret["indices"] = indices_tensor
+            ret["params"] = params_tensor
+            return ret
+
+    def _averaging(self):
+        """
+        Averages the received model with the local model
+
+        """
+        with torch.no_grad():
+            total = None
+            weight_total = 0
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            pre_share_model = torch.cat(tensors_to_cat, dim=0)
+            wt_params = self.change_transformer(pre_share_model)
+            for i, n in enumerate(self.peer_deques):
+                degree, iteration, data = self.peer_deques[n].popleft()
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(
+                        n, iteration
+                    )
+                )
+                data = self.deserialized_model(data)
+                params = data["params"]
+                indices = data["indices"]
+                # use local data to complement
+                topkwf = wt_params.clone().detach()
+                topkwf[indices] = params
+                topkwf = topkwf.reshape(self.wt_shape)
+
+                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                weight_total += weight
+                if total is None:
+                    total = weight * topkwf
+                else:
+                    total += weight * topkwf
+
+            # Metro-Hastings
+            total += (1 - weight_total) * wt_params
+
+            avg_wf_params = pywt.array_to_coeffs(
+                total.numpy(), self.coeff_slices, output_format="wavedec"
+            )
+            reverse_total = torch.from_numpy(
+                pywt.waverec(avg_wf_params, wavelet=self.wavelet)
+            )
+
+            start_index = 0
+            std_dict = {}
+            for i, key in enumerate(self.model.state_dict()):
+                end_index = start_index + self.lens[i]
+                std_dict[key] = reverse_total[start_index:end_index].reshape(
+                    self.shapes[i]
+                )
+                start_index = end_index
+
+        self.model.load_state_dict(std_dict)
diff --git a/src/decentralizepy/training/ChangeAccumulator.py b/src/decentralizepy/training/ChangeAccumulator.py
deleted file mode 100644
index 6ee5dc79f1e32aa36cc09ab06f1309bc28fd2f32..0000000000000000000000000000000000000000
--- a/src/decentralizepy/training/ChangeAccumulator.py
+++ /dev/null
@@ -1,167 +0,0 @@
-import json
-import os
-from pathlib import Path
-
-import torch
-
-from decentralizepy.training.Training import Training
-from decentralizepy.utils import conditional_value
-
-
-class ChangeAccumulator(Training):
-    """
-    This class implements the training module which also accumulates model change in a list.
-
-    """
-
-    def __init__(
-        self,
-        rank,
-        machine_id,
-        mapping,
-        model,
-        optimizer,
-        loss,
-        log_dir,
-        rounds="",
-        full_epochs="",
-        batch_size="",
-        shuffle="",
-        save_accumulated="",
-    ):
-        """
-        Constructor
-
-        Parameters
-        ----------
-        rank : int
-            Rank of process local to the machine
-        machine_id : int
-            Machine ID on which the process in running
-        mapping : decentralizepy.mappings
-            The object containing the mapping rank <--> uid
-        model : torch.nn.Module
-            Neural Network for training
-        optimizer : torch.optim
-            Optimizer to learn parameters
-        loss : function
-            Loss function
-        log_dir : str
-            Directory to log the model change.
-        rounds : int, optional
-            Number of steps/epochs per training call
-        full_epochs: bool, optional
-            True if 1 round = 1 epoch. False if 1 round = 1 minibatch
-        batch_size : int, optional
-            Number of items to learn over, in one batch
-        shuffle : bool
-            True if the dataset should be shuffled before training.
-        save_accumulated : bool
-            True if accumulated weight change should be written to file
-
-        """
-        super().__init__(
-            rank,
-            machine_id,
-            mapping,
-            model,
-            optimizer,
-            loss,
-            log_dir,
-            rounds,
-            full_epochs,
-            batch_size,
-            shuffle,
-        )
-        self.save_accumulated = conditional_value(save_accumulated, "", False)
-        self.communication_round = 0
-        if self.save_accumulated:
-            self.model_change_path = os.path.join(
-                self.log_dir, "model_change/{}".format(self.rank)
-            )
-            Path(self.model_change_path).mkdir(parents=True, exist_ok=True)
-
-            self.model_val_path = os.path.join(
-                self.log_dir, "model_val/{}".format(self.rank)
-            )
-            Path(self.model_val_path).mkdir(parents=True, exist_ok=True)
-
-    def save_vector(self, v, s):
-        """
-        Saves the given vector to the file.
-
-        Parameters
-        ----------
-        v : torch.tensor
-            The torch tensor to write to file
-        s : str
-            Path to folder to write to
-
-        """
-        output_dict = dict()
-        output_dict["order"] = list(self.model.state_dict().keys())
-        shapes = dict()
-        for k, v1 in self.model.state_dict().items():
-            shapes[k] = list(v1.shape)
-        output_dict["shapes"] = shapes
-
-        output_dict["tensor"] = v.tolist()
-
-        with open(
-            os.path.join(
-                s,
-                "{}.json".format(self.communication_round + 1),
-            ),
-            "w",
-        ) as of:
-            json.dump(output_dict, of)
-
-    def save_change(self):
-        """
-        Saves the change and the gradient values for every iteration
-
-        """
-        tensors_to_cat = [
-            v.data.flatten() for _, v in self.model.accumulated_gradients[0].items()
-        ]
-        change = torch.abs(torch.cat(tensors_to_cat, dim=0))
-        self.save_vector(change, self.model_change_path)
-
-    def save_model_params(self):
-        """
-        Saves the change and the gradient values for every iteration
-
-        """
-        tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
-        params = torch.abs(torch.cat(tensors_to_cat, dim=0))
-        self.save_vector(params, self.model_val_path)
-
-    def train(self, dataset):
-        """
-        One training iteration with accumulation of model change in model.accumulated_gradients.
-        Goes through the entire dataset.
-
-        Parameters
-        ----------
-        dataset : decentralizepy.datasets.Dataset
-            The training dataset. Should implement get_trainset(batch_size, shuffle)
-
-        """
-        self.model.accumulated_gradients = []
-        self.init_model = {
-            k: v.data.clone().detach()
-            for k, v in zip(self.model.state_dict(), self.model.parameters())
-        }
-        super().train(dataset)
-        with torch.no_grad():
-            change = {
-                k: v.data.clone().detach() - self.init_model[k]
-                for k, v in zip(self.model.state_dict(), self.model.parameters())
-            }
-            self.model.accumulated_gradients.append(change)
-
-            if self.save_accumulated:
-                self.save_change()
-                self.save_model_params()
-
-        self.communication_round += 1
diff --git a/src/decentralizepy/training/GradientAccumulator.py b/src/decentralizepy/training/GradientAccumulator.py
deleted file mode 100644
index fcff8e6ec56e673c10098d32d973370251740913..0000000000000000000000000000000000000000
--- a/src/decentralizepy/training/GradientAccumulator.py
+++ /dev/null
@@ -1,112 +0,0 @@
-import logging
-
-from decentralizepy.training.Training import Training
-
-
-class GradientAccumulator(Training):
-    """
-    This class implements the training module which also accumulates gradients of steps in a list.
-
-    """
-
-    def __init__(
-        self,
-        rank,
-        machine_id,
-        mapping,
-        model,
-        optimizer,
-        loss,
-        log_dir,
-        rounds="",
-        full_epochs="",
-        batch_size="",
-        shuffle="",
-    ):
-        """
-        Constructor
-
-        Parameters
-        ----------
-        rank : int
-            Rank of process local to the machine
-        machine_id : int
-            Machine ID on which the process in running
-        mapping : decentralizepy.mappings
-            The object containing the mapping rank <--> uid
-        model : torch.nn.Module
-            Neural Network for training
-        optimizer : torch.optim
-            Optimizer to learn parameters
-        loss : function
-            Loss function
-        log_dir : str
-            Directory to log the model change.
-        rounds : int, optional
-            Number of steps/epochs per training call
-        full_epochs: bool, optional
-            True if 1 round = 1 epoch. False if 1 round = 1 minibatch
-        batch_size : int, optional
-            Number of items to learn over, in one batch
-        shuffle : bool
-            True if the dataset should be shuffled before training.
-
-        """
-        super().__init__(
-            rank,
-            machine_id,
-            mapping,
-            model,
-            optimizer,
-            loss,
-            log_dir,
-            rounds,
-            full_epochs,
-            batch_size,
-            shuffle,
-        )
-
-    def trainstep(self, data, target):
-        """
-        One training step on a minibatch.
-
-        Parameters
-        ----------
-        data : any
-            Data item
-        target : any
-            Label
-
-        Returns
-        -------
-        int
-            Loss Value for the step
-
-        """
-        self.model.zero_grad()
-        output = self.model(data)
-        loss_val = self.loss(output, target)
-        loss_val.backward()
-        logging.debug("Accumulating Gradients")
-        self.model.accumulated_gradients.append(
-            {
-                k: v.grad.clone().detach()
-                for k, v in zip(self.model.state_dict(), self.model.parameters())
-            }
-        )
-        self.optimizer.step()
-        return loss_val.item()
-
-    def train(self, dataset):
-        """
-        One training iteration with accumulation of gradients in model.accumulated_gradients.
-        Goes through the entire dataset.
-
-        Parameters
-        ----------
-        dataset : decentralizepy.datasets.Dataset
-            The training dataset. Should implement get_trainset(batch_size, shuffle)
-
-        """
-        self.model.accumulated_gradients = []
-        super().train(dataset)
diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py
index 3b99befa457b31580880b3113bee76ba2e4e3feb..5adc4a964a80000e48bf1c3c6991ebbfc3f98499 100644
--- a/src/decentralizepy/training/Training.py
+++ b/src/decentralizepy/training/Training.py
@@ -46,7 +46,7 @@ class Training:
             Directory to log the model change.
         rounds : int, optional
             Number of steps/epochs per training call
-        full_epochs: bool, optional
+        full_epochs : bool, optional
             True if 1 round = 1 epoch. False if 1 round = 1 minibatch
         batch_size : int, optional
             Number of items to learn over, in one batch
diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py
index 996e4bc4316f3590da1dcfcbc287df81d52fa9b7..3ca85f529eff9a531fe8e29695da733e4c92c11c 100644
--- a/src/decentralizepy/utils.py
+++ b/src/decentralizepy/utils.py
@@ -75,6 +75,7 @@ def get_args():
     parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges")
     parser.add_argument("-gt", "--graph_type", type=str, default="edges")
     parser.add_argument("-ta", "--test_after", type=int, default=5)
+    parser.add_argument("-tea", "--train_evaluate_after", type=int, default=1)
     parser.add_argument("-ro", "--reset_optimizer", type=int, default=1)
 
     args = parser.parse_args()
@@ -108,3 +109,18 @@ def write_args(args, path):
     }
     with open(os.path.join(path, "args.json"), "w") as of:
         json.dump(data, of)
+
+
+def identity(obj):
+    """
+    Identity function
+    Parameters
+    ----------
+    obj
+        Some object
+    Returns
+    -------
+     obj
+        The same object
+    """
+    return obj