diff --git a/eval/main.ipynb b/eval/main.ipynb
index 80daae672a600f6358cfa9f559b50ecb47cf2261..0873005aa5e53dd1cf5abc0ffe0d38eba3127093 100644
--- a/eval/main.ipynb
+++ b/eval/main.ipynb
@@ -5709,6 +5709,41 @@
     "print(i)"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from torch import multiprocessing as mp\n",
+    "from decentralizepy.node.PeerSampler import PeerSampler\n",
+    "from decentralizepy.node.Node import Node\n",
+    "from decentralizepy.mappings.Linear import Linear\n",
+    "from decentralizepy.graphs.Regular import Regular\n",
+    "\n",
+    "l = Linear(1, 6)\n",
+    "g = Regular(6, 2)\n",
+    "processes = [mp.Process(target = PeerSampler, args=[-1, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[1, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[2, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[3, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[4, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[5, 0, l, g, None]),\n",
+    "             mp.Process(target = Node, args=[6, 0, l, g, None]),\n",
+    "            ]\n",
+    "\n",
+    "for p in processes:\n",
+    "    p.start()\n",
+    "\n",
+    "for p in processes:\n",
+    "    p.join()\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -5718,11 +5753,9 @@
   }
  ],
  "metadata": {
-  "interpreter": {
-   "hash": "996934296aa9d79be6c3d800a38d8fdb7dfa8fe7bb07df178f1397cde2cb8742"
-  },
   "kernelspec": {
-   "display_name": "Python 3.9.7 64-bit ('tff': conda)",
+   "display_name": "Python 3.9.7 ('decpy')",
+   "language": "python",
    "name": "python3"
   },
   "language_info": {
@@ -5735,9 +5768,14 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.11"
+   "version": "3.9.12"
   },
-  "orig_nbformat": 4
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "70be49349d3cda3718db277e01495433e35b5db6f514174958763e3b43682235"
+   }
+  }
  },
  "nbformat": 4,
  "nbformat_minor": 2
diff --git a/eval/plot.py b/eval/plot.py
index 601d8e86bfda57e69a0f88fa305a6a7fa674e23a..62ac2302ff53dd2f2f7093e3fdc5e22fc9ec2cab 100644
--- a/eval/plot.py
+++ b/eval/plot.py
@@ -29,16 +29,17 @@ def get_stats(l):
 def plot(means, stdevs, mins, maxs, title, label, loc):
     plt.title(title)
     plt.xlabel("communication rounds")
-    x_axis = list(means.keys())
-    y_axis = list(means.values())
-    err = list(stdevs.values())
-    plt.errorbar(x_axis, y_axis, yerr=err, label=label)
+    x_axis = np.array(list(means.keys()))
+    y_axis = np.array(list(means.values()))
+    err = np.array(list(stdevs.values()))
+    plt.plot(x_axis, y_axis, label=label)
+    plt.fill_between(x_axis, y_axis - err, y_axis + err, alpha=0.4)
     plt.legend(loc=loc)
 
 
 def plot_results(path, centralized, data_machine="machine0", data_node=0):
     folders = os.listdir(path)
-    if centralized.lower() in ['true', '1', 't', 'y', 'yes']:
+    if centralized.lower() in ["true", "1", "t", "y", "yes"]:
         centralized = True
         print("Centralized")
     else:
@@ -66,7 +67,10 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
                 filepath = os.path.join(mf_path, f)
                 with open(filepath, "r") as inf:
                     results.append(json.load(inf))
-
+        if folder.startswith("FL") or folder.startswith("Parameter Server"):
+            data_node = -1
+        else:
+            data_node = 0
         with open(folder_path / data_machine / f"{data_node}_results.json", "r") as f:
             main_data = json.load(f)
         main_data = [main_data]
@@ -124,23 +128,7 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
         df.to_csv(
             os.path.join(path, "test_acc_" + folder + ".csv"), index_label="rounds"
         )
-        plt.figure(6)
-        means, stdevs, mins, maxs = get_stats([x["grad_std"] for x in results])
-        plot(
-            means,
-            stdevs,
-            mins,
-            maxs,
-            "Gradient Variation over Nodes",
-            folder,
-            "upper right",
-        )
-        # Plot Testing loss
-        plt.figure(7)
-        means, stdevs, mins, maxs = get_stats([x["grad_mean"] for x in results])
-        plot(
-            means, stdevs, mins, maxs, "Gradient Magnitude Mean", folder, "upper right"
-        )
+
         # Collect total_bytes shared
         bytes_list = []
         for x in results:
@@ -149,6 +137,7 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
         means, stdevs, mins, maxs = get_stats(bytes_list)
         bytes_means[folder] = list(means.values())[0]
         bytes_stdevs[folder] = list(stdevs.values())[0]
+        print(bytes_list)
 
         meta_list = []
         for x in results:
@@ -175,10 +164,6 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
     plt.savefig(os.path.join(path, "test_loss.png"), dpi=300)
     plt.figure(3)
     plt.savefig(os.path.join(path, "test_acc.png"), dpi=300)
-    plt.figure(6)
-    plt.savefig(os.path.join(path, "grad_std.png"), dpi=300)
-    plt.figure(7)
-    plt.savefig(os.path.join(path, "grad_mean.png"), dpi=300)
     # Plot total_bytes
     plt.figure(4)
     plt.title("Data Shared")
@@ -257,5 +242,6 @@ if __name__ == "__main__":
     # The args are:
     # 1: the folder with the data
     # 2: True/False: If True then the evaluation on the test set was centralized
+    # for federated learning folder name must start with "FL"!
     plot_results(sys.argv[1], sys.argv[2])
     # plot_parameters(sys.argv[1])
diff --git a/eval/plotting_from_csv.py b/eval/plotting_from_csv.py
index b8d4320c507d1b7532f3fef1fba5f2d02c8d926a..dbd1c1a49c7aeeb6fcc6f2378b75e636cd8c9628 100644
--- a/eval/plotting_from_csv.py
+++ b/eval/plotting_from_csv.py
@@ -23,7 +23,7 @@ def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel):
 
 
 def plot_results(path, epochs, global_epochs="True"):
-    if global_epochs.lower() in ['true', '1', 't', 'y', 'yes']:
+    if global_epochs.lower() in ["true", "1", "t", "y", "yes"]:
         global_epochs = True
     else:
         global_epochs = False
@@ -52,10 +52,12 @@ def plot_results(path, epochs, global_epochs="True"):
         if global_epochs:
             rounds = results_csv["rounds"].iloc[0]
             print("Rounds: ", rounds)
-            results_cr = results_csv[results_csv.rounds <= epochs*rounds]
+            results_cr = results_csv[results_csv.rounds <= epochs * rounds]
             means = results_cr["mean"].to_numpy()
             stdevs = results_cr["std"].to_numpy()
-            x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
+            x_axis = (
+                results_cr["rounds"].to_numpy() / rounds
+            )  # list(np.arange(0, len(means), 1))
             x_label = "global epochs"
         else:
             results_cr = results_csv[results_csv.rounds <= epochs]
@@ -85,10 +87,12 @@ def plot_results(path, epochs, global_epochs="True"):
         if global_epochs:
             rounds = results_csv["rounds"].iloc[0]
             print("Rounds: ", rounds)
-            results_cr = results_csv[results_csv.rounds <= epochs*rounds]
+            results_cr = results_csv[results_csv.rounds <= epochs * rounds]
             means = results_cr["mean"].to_numpy()
             stdevs = results_cr["std"].to_numpy()
-            x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
+            x_axis = (
+                results_cr["rounds"].to_numpy() / rounds
+            )  # list(np.arange(0, len(means), 1))
             x_label = "global epochs"
         else:
             results_cr = results_csv[results_csv.rounds <= epochs]
@@ -120,10 +124,12 @@ def plot_results(path, epochs, global_epochs="True"):
         if global_epochs:
             rounds = results_csv["rounds"].iloc[0]
             print("Rounds: ", rounds)
-            results_cr = results_csv[results_csv.rounds <= epochs*rounds]
+            results_cr = results_csv[results_csv.rounds <= epochs * rounds]
             means = results_cr["mean"].to_numpy()
             stdevs = results_cr["std"].to_numpy()
-            x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
+            x_axis = (
+                results_cr["rounds"].to_numpy() / rounds
+            )  # list(np.arange(0, len(means), 1))
             x_label = "global epochs"
         else:
             results_cr = results_csv[results_csv.rounds <= epochs]
diff --git a/eval/run.sh b/eval/run.sh
index ff6392f81283aa075f7db62ba6656c30d714455e..9d71d96167a2c303a0ccd4957e65508c68f4681b 100755
--- a/eval/run.sh
+++ b/eval/run.sh
@@ -1,17 +1,17 @@
 #!/bin/bash
 
-decpy_path=/mnt/nfs/risharma/Gitlab/decentralizepy/eval
+decpy_path=/mnt/nfs/kirsten/Gitlab/jac_decentralizepy/decentralizepy/eval
 cd $decpy_path
 
 env_python=~/miniconda3/envs/decpy/bin/python3
-graph=/mnt/nfs/risharma/Gitlab/tutorial/regular_16.txt
-original_config=/mnt/nfs/risharma/Gitlab/tutorial/config_celeba_sharing.ini
-config_file=~/tmp/config.ini
+graph=/mnt/nfs/kirsten/Gitlab/tutorial/regular_16.txt
+original_config=/mnt/nfs/kirsten/Gitlab/tutorial/config_celeba_sharing.ini
+config_file=~/tmp/config_celeba_sharing.ini
 procs_per_machine=8
 machines=2
 iterations=5
 test_after=2
-eval_file=testing.py
+eval_file=testingPeerSampler.py
 log_level=INFO
 
 m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
diff --git a/eval/run_xtimes.sh b/eval/run_xtimes.sh
index 3176f5357cc53ae6d251cb8136821e09caa4cbf4..4ca373a942b989d046933610e385d3ab95d40c81 100755
--- a/eval/run_xtimes.sh
+++ b/eval/run_xtimes.sh
@@ -1,5 +1,6 @@
 #!/bin/bash
 # Documentation
+# Note: documentation was not written for this run file, so actual behaviour may differ
 # This bash file takes three inputs. The first argument (nfs_home) is the path to the nfs home directory.
 # The second one (python_bin) is the path to the python bin folder.
 # The last argument (logs_subfolder) is the path to the logs folder with respect to the nfs home directory.
@@ -18,8 +19,10 @@
 # Each node needs a folder called 'tmp' in the user's home directory
 #
 # Note:
-# - The script does not change the optimizer. All configs are writen to use SGD.
-# - The script will set '--test_after' and '--train_evaluate_after' such that it happens at the end of a global epoch.
+# - The script does not change the optimizer. All configs are writen to use Adam.
+#   For SGD these need to be changed manually
+# - The script will set '--test_after' and '--train_evaluate_after' to comm_rounds_per_global_epoch, i.e., the eavaluation
+#   on the train set and on the test set is carried out every global epoch.
 # - The '--reset_optimizer' option is set to 0, i.e., the optimizer is not reset after a communication round (only
 #   relevant for Adams and other optimizers with internal state)
 #
@@ -37,41 +40,40 @@ decpy_path=$nfs_home/decentralizepy/eval
 cd $decpy_path
 
 env_python=$python_bin/python3
-graph=96_regular.edges
 config_file=~/tmp/config.ini
 procs_per_machine=16
 machines=6
 global_epochs=150
-eval_file=testing.py
+eval_file=testingPeerSampler.py
 log_level=INFO
-
-ip_machines=$nfs_home/configs/ip_addr_6Machines.json
-
+ip_machines=$nfs_home/$logs_subfolder/ip_addr_6Machines.json
 m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
 export PYTHONFAULTHANDLER=1
-
 # Base configs for which the gird search is done
-tests=("step_configs/config_femnist_partialmodel.ini" "step_configs/config_femnist_topkacc.ini" "step_configs/config_femnist_wavelet.ini")
+tests="$nfs_home/$logs_subfolder/config.ini"
+#tests=("$nfs_home/$logs_subfolder/config_cifar_sharing.ini" "$nfs_home/$logs_subfolder/config_cifar_partialmodel.ini" "$nfs_home/$logs_subfolder/config_cifar_topkacc.ini" "$nfs_home/$logs_subfolder/config_cifar_topkaccRandomAlpha.ini" "$nfs_home/$logs_subfolder/config_cifar_subsampling.ini" "$nfs_home/$logs_subfolder/config_cifar_wavelet.ini" "$nfs_home/$logs_subfolder/config_cifar_waveletRandomAlpha.ini")
+#tests=("$nfs_home/$logs_subfolder/config_cifar_partialmodel.ini" "$nfs_home/$logs_subfolder/config_cifar_topkacc.ini" "$nfs_home/$logs_subfolder/config_cifar_topkaccRandomAlpha.ini" "$nfs_home/$logs_subfolder/config_cifar_subsampling.ini" "$nfs_home/$logs_subfolder/config_cifar_wavelet.ini" "$nfs_home/$logs_subfolder/config_cifar_waveletRandomAlpha.ini")
+#tests=("$nfs_home/$logs_subfolder/config_cifar_subsampling.ini" "$nfs_home/$logs_subfolder/config_cifar_sharing.ini" "$nfs_home/$logs_subfolder/config_cifar_waveletRandomAlpha.ini")
+#tests=("$nfs_home/$logs_subfolder/config_cifar_waveletRandomAlpha.ini")
 # Learning rates
-lr="0.001"
+lr="0.01"
 # Batch size
-batchsize="16"
+batchsize="8"
 # The number of communication rounds per global epoch
-comm_rounds_per_global_epoch="1"
+comm_rounds_per_global_epoch="20"
 procs=`expr $procs_per_machine \* $machines`
 echo procs: $procs
 # Celeba has 63741 samples
 # Reddit has 70642
 # Femnist 734463
-# Shakespeares 3678451, subsampled 678696
-# cifar 50000
-dataset_size=734463
+# Shakespeares 3678451
+dataset_size=50000
 # Calculating the number of samples that each user/proc will have on average
 samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
-
 # random_seeds for which to rerun the experiments
-random_seeds=("97")
+# random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("94")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -85,10 +87,10 @@ echo iterations: $iterations
 batches_per_comm_round=$($env_python -c "from math import floor; x = floor($batches_per_epoch / $comm_rounds_per_global_epoch); print(1 if x==0 else x)")
 # since the batches per communication round were rounded down we need to change the number of iterations to reflect that
 new_iterations=$($env_python -c "from math import floor; tmp = floor($batches_per_epoch / $comm_rounds_per_global_epoch); x = 1 if tmp == 0 else tmp; y = floor((($batches_per_epoch / $comm_rounds_per_global_epoch)/x)*$iterations); print($iterations if y<$iterations else y)")
-echo batches per communication round: $batches_per_comm_round
-echo corrected iterations: $new_iterations
 test_after=$(($new_iterations / $global_epochs))
 echo test after: $test_after
+echo batches per communication round: $batches_per_comm_round
+echo corrected iterations: $new_iterations
 for i in "${tests[@]}"
 do
   for seed in "${random_seeds[@]}"
@@ -96,9 +98,14 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')/machine$m
-    echo results are stored in: $log_dir
+    #log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    echo results are stored in: $log_dir_base
+    log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
+    weight_store_dir=$log_dir_base/weights
+    mkdir -p $weight_store_dir
+    graph=$nfs_home/decentralizepy/eval/96_regular.edges
     cp $i $config_file
     # changing the config files to reflect the values of the current grid search state
     $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
@@ -106,10 +113,14 @@ do
     $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round
     $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize
     $python_bin/crudini --set $config_file DATASET random_seed $seed
+    $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
+    $python_bin/crudini --set $config_file COMMUNICATION offset 10720
+    # $env_python $eval_file -cte 0 -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
     $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+
     echo $i is done
     sleep 200
     echo end of sleep
     done
 done
-#
\ No newline at end of file
+#
diff --git a/eval/run_xtimes_celeba.sh b/eval/run_xtimes_celeba.sh
index c6d3e060c4fd60115e40e8b99bb39560b4cd1fe8..7eab580ffeda45780772b2f287b4e0efd6e3f588 100755
--- a/eval/run_xtimes_celeba.sh
+++ b/eval/run_xtimes_celeba.sh
@@ -38,11 +38,11 @@ cd $decpy_path
 
 env_python=$python_bin/python3
 graph=96_regular.edges
-config_file=~/tmp/config.ini
+config_file=~/tmp/config_celeba_sharing.ini
 procs_per_machine=16
 machines=6
 global_epochs=150
-eval_file=testing.py
+eval_file=testingFederated.py
 log_level=INFO
 
 ip_machines=$nfs_home/configs/ip_addr_6Machines.json
@@ -51,7 +51,8 @@ m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print
 export PYTHONFAULTHANDLER=1
 
 # Base configs for which the gird search is done
-tests=("step_configs/config_celeba_sharing.ini" "step_configs/config_celeba_partialmodel.ini" "step_configs/config_celeba_topkacc.ini" "step_configs/config_celeba_subsampling.ini" "step_configs/config_celeba_wavelet.ini")
+#tests=("step_configs/config_celeba_sharing.ini" "step_configs/config_celeba_partialmodel.ini" "step_configs/config_celeba_topkacc.ini" "step_configs/config_celeba_subsampling.ini" "step_configs/config_celeba_wavelet.ini")
+tests=("step_configs/config_celeba_sharing.ini")
 # Learning rates
 lr="0.001"
 # Batch size
@@ -66,7 +67,8 @@ samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
 
 # random_seeds for which to rerun the experiments
-random_seeds=("90" "91" "92" "93" "94")
+#random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("90")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -91,7 +93,7 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
     echo results are stored in: $log_dir_base
     log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
diff --git a/eval/run_xtimes_cifar.sh b/eval/run_xtimes_cifar.sh
index 0c545c651e5150f42332ebf81cc2baae4c0f5ef6..a5b9eca69f34bd4d621e965aa03fcc801a6f368a 100755
--- a/eval/run_xtimes_cifar.sh
+++ b/eval/run_xtimes_cifar.sh
@@ -41,9 +41,10 @@ graph=96_regular.edges
 config_file=~/tmp/config.ini
 procs_per_machine=16
 machines=6
-global_epochs=300
-eval_file=testing.py
+global_epochs=100
+eval_file=testingFederated.py
 log_level=INFO
+working_rate=0.1
 
 ip_machines=$nfs_home/configs/ip_addr_6Machines.json
 
@@ -51,7 +52,7 @@ m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print
 export PYTHONFAULTHANDLER=1
 
 # Base configs for which the gird search is done
-tests=("step_configs/config_cifar_sharing.ini" "step_configs/config_cifar_partialmodel.ini" "step_configs/config_cifar_topkacc.ini" "step_configs/config_cifar_subsampling.ini" "step_configs/config_cifar_wavelet.ini")
+tests=("step_configs/config_cifar_sharing.ini")
 # Learning rates
 lr="0.01"
 # Batch size
@@ -66,7 +67,7 @@ samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
 
 # random_seeds for which to rerun the experiments
-random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("90")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -91,7 +92,7 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
     echo results are stored in: $log_dir_base
     log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
@@ -104,7 +105,7 @@ do
     $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round
     $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize
     $python_bin/crudini --set $config_file DATASET random_seed $seed
-    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -wr $working_rate
     echo $i is done
     sleep 200
     echo end of sleep
diff --git a/eval/run_xtimes_femnist.sh b/eval/run_xtimes_femnist.sh
index c5e18a2ab958997e08b96fc4682c1b64f021cb25..cd65a1297dfc5173cc22e5cad3987e5bedad0551 100755
--- a/eval/run_xtimes_femnist.sh
+++ b/eval/run_xtimes_femnist.sh
@@ -38,11 +38,11 @@ cd $decpy_path
 
 env_python=$python_bin/python3
 graph=96_regular.edges
-config_file=~/tmp/config.ini
+config_file=~/tmp/config_femnist_sharing.ini
 procs_per_machine=16
 machines=6
 global_epochs=80
-eval_file=testing.py
+eval_file=testingFederated.py
 log_level=INFO
 
 ip_machines=$nfs_home/configs/ip_addr_6Machines.json
@@ -51,7 +51,8 @@ m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print
 export PYTHONFAULTHANDLER=1
 
 # Base configs for which the gird search is done
-tests=("step_configs/config_femnist_sharing.ini" "step_configs/config_femnist_partialmodel.ini" "step_configs/config_femnist_topkacc.ini" "step_configs/config_femnist_subsampling.ini" "step_configs/config_femnist_wavelet.ini")
+#tests=("step_configs/config_femnist_sharing.ini" "step_configs/config_femnist_partialmodel.ini" "step_configs/config_femnist_topkacc.ini" "step_configs/config_femnist_subsampling.ini" "step_configs/config_femnist_wavelet.ini")
+tests=("step_configs/config_femnist_sharing.ini")
 # Learning rates
 lr="0.01"
 # Batch size
@@ -66,7 +67,8 @@ samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
 
 # random_seeds for which to rerun the experiments
-random_seeds=("90" "91" "92" "93" "94")
+#random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("90")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -91,7 +93,7 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
     echo results are stored in: $log_dir_base
     log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
diff --git a/eval/run_xtimes_reddit.sh b/eval/run_xtimes_reddit.sh
index 589f52a2978107a3efac300b592048adb09f1717..2f738c0c60ad8d949a1242481b4e918c2286d099 100755
--- a/eval/run_xtimes_reddit.sh
+++ b/eval/run_xtimes_reddit.sh
@@ -38,11 +38,11 @@ cd $decpy_path
 
 env_python=$python_bin/python3
 graph=96_regular.edges
-config_file=~/tmp/config.ini
+config_file=~/tmp/config_reddit_sharing.ini
 procs_per_machine=16
 machines=6
 global_epochs=50
-eval_file=testing.py
+eval_file=testingFederated.py
 log_level=INFO
 
 ip_machines=$nfs_home/configs/ip_addr_6Machines.json
@@ -51,7 +51,8 @@ m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print
 export PYTHONFAULTHANDLER=1
 
 # Base configs for which the gird search is done
-tests=("step_configs/config_reddit_sharing.ini" "step_configs/config_reddit_partialmodel.ini" "step_configs/config_reddit_topkacc.ini" "step_configs/config_reddit_subsampling.ini" "step_configs/config_reddit_wavelet.ini")
+#tests=("step_configs/config_reddit_sharing.ini" "step_configs/config_reddit_partialmodel.ini" "step_configs/config_reddit_topkacc.ini" "step_configs/config_reddit_subsampling.ini" "step_configs/config_reddit_wavelet.ini")
+tests=("step_configs/config_reddit_sharing.ini")
 # Learning rates
 lr="1"
 # Batch size
@@ -66,7 +67,8 @@ samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
 
 # random_seeds for which to rerun the experiments
-random_seeds=("90" "91" "92" "93" "94")
+#random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("90")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -91,7 +93,7 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
     echo results are stored in: $log_dir_base
     log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
diff --git a/eval/run_xtimes_shakespeare.sh b/eval/run_xtimes_shakespeare.sh
index 8e268a18248eb789bc91a6fb68d2183d6a1b96b3..b6c6d6c0174f0ec76809adb42797df612991338c 100755
--- a/eval/run_xtimes_shakespeare.sh
+++ b/eval/run_xtimes_shakespeare.sh
@@ -38,11 +38,11 @@ cd $decpy_path
 
 env_python=$python_bin/python3
 graph=96_regular.edges
-config_file=~/tmp/config.ini
+config_file=~/tmp/config_shakespeare_sharing.ini
 procs_per_machine=16
 machines=6
-global_epochs=200
-eval_file=testing.py
+global_epochs=100
+eval_file=testingFederated.py
 log_level=INFO
 
 ip_machines=$nfs_home/configs/ip_addr_6Machines.json
@@ -51,7 +51,8 @@ m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print
 export PYTHONFAULTHANDLER=1
 
 # Base configs for which the gird search is done
-tests=("step_configs/config_shakespeare_sharing.ini" "step_configs/config_shakespeare_partialmodel.ini" "step_configs/config_shakespeare_topkacc.ini" "step_configs/config_shakespeare_subsampling.ini" "step_configs/config_shakespeare_wavelet.ini")
+#tests=("step_configs/config_shakespeare_sharing.ini" "step_configs/config_shakespeare_partialmodel.ini" "step_configs/config_shakespeare_topkacc.ini" "step_configs/config_shakespeare_subsampling.ini" "step_configs/config_shakespeare_wavelet.ini")
+tests=("step_configs/config_shakespeare_sharing.ini")
 # Learning rates
 lr="0.5"
 # Batch size
@@ -66,7 +67,8 @@ samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
 
 # random_seeds for which to rerun the experiments
-random_seeds=("90" "91" "92" "93" "94")
+#random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("90")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -91,7 +93,7 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
     echo results are stored in: $log_dir_base
     log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
diff --git a/eval/testing.py b/eval/testing.py
index 9125828cc2ffbf5e9ccf7178355c136c77cebdb8..556d1796580e268b29691fdef81b33017473d2be 100644
--- a/eval/testing.py
+++ b/eval/testing.py
@@ -8,7 +8,7 @@ from torch import multiprocessing as mp
 from decentralizepy import utils
 from decentralizepy.graphs.Graph import Graph
 from decentralizepy.mappings.Linear import Linear
-from decentralizepy.node.Node import Node
+from decentralizepy.node.DPSGDNode import DPSGDNode
 
 
 def read_ini(file_path):
@@ -51,7 +51,7 @@ if __name__ == "__main__":
     m_id = args.machine_id
 
     mp.spawn(
-        fn=Node,
+        fn=DPSGDNode,
         nprocs=procs_per_machine,
         args=[
             m_id,
@@ -65,7 +65,33 @@ if __name__ == "__main__":
             args.test_after,
             args.train_evaluate_after,
             args.reset_optimizer,
-            args.centralized_train_eval,
-            args.centralized_test_eval,
         ],
     )
+
+    processes = []
+    for r in range(procs_per_machine):
+        processes.append(
+            mp.Process(
+                target=DPSGDNode,
+                args=[
+                    r,
+                    m_id,
+                    l,
+                    g,
+                    my_config,
+                    args.iterations,
+                    args.log_dir,
+                    args.weights_store_dir,
+                    log_level[args.log_level],
+                    args.test_after,
+                    args.train_evaluate_after,
+                    args.reset_optimizer,
+                ],
+            )
+        )
+
+    for p in processes:
+        p.start()
+
+    for p in processes:
+        p.join()
diff --git a/eval/testingFederated.py b/eval/testingFederated.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0666d791cff5fd778b8ffe4d499ed344ed1e5a5
--- /dev/null
+++ b/eval/testingFederated.py
@@ -0,0 +1,104 @@
+import logging
+from pathlib import Path
+from shutil import copy
+
+from localconfig import LocalConfig
+from torch import multiprocessing as mp
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Linear import Linear
+from decentralizepy.node.DPSGDNodeFederated import DPSGDNodeFederated
+from decentralizepy.node.FederatedParameterServer import FederatedParameterServer
+
+
+def read_ini(file_path):
+    config = LocalConfig(file_path)
+    for section in config:
+        print("Section: ", section)
+        for key, value in config.items(section):
+            print((key, value))
+    print(dict(config.items("DATASET")))
+    return config
+
+
+if __name__ == "__main__":
+    args = utils.get_args()
+
+    Path(args.log_dir).mkdir(parents=True, exist_ok=True)
+
+    log_level = {
+        "INFO": logging.INFO,
+        "DEBUG": logging.DEBUG,
+        "WARNING": logging.WARNING,
+        "ERROR": logging.ERROR,
+        "CRITICAL": logging.CRITICAL,
+    }
+
+    config = read_ini(args.config_file)
+    my_config = dict()
+    for section in config:
+        my_config[section] = dict(config.items(section))
+
+    copy(args.config_file, args.log_dir)
+    copy(args.graph_file, args.log_dir)
+    utils.write_args(args, args.log_dir)
+
+    g = Graph()
+    g.read_graph_from_file(args.graph_file, args.graph_type)
+    n_machines = args.machines
+    procs_per_machine = args.procs_per_machine
+    l = Linear(n_machines, procs_per_machine)
+    m_id = args.machine_id
+
+    sm = args.server_machine
+    sr = args.server_rank
+
+    processes = []
+    if sm == m_id:
+        processes.append(
+            mp.Process(
+                target=FederatedParameterServer,
+                args=[
+                    sr,
+                    m_id,
+                    l,
+                    g,
+                    my_config,
+                    args.iterations,
+                    args.log_dir,
+                    args.weights_store_dir,
+                    log_level[args.log_level],
+                    args.test_after,
+                    args.train_evaluate_after,
+                    args.working_rate,
+                ],
+            )
+        )
+
+    for r in range(0, procs_per_machine):
+        processes.append(
+            mp.Process(
+                target=DPSGDNodeFederated,
+                args=[
+                    r,
+                    m_id,
+                    l,
+                    g,
+                    my_config,
+                    args.iterations,
+                    args.log_dir,
+                    args.weights_store_dir,
+                    log_level[args.log_level],
+                    args.test_after,
+                    args.train_evaluate_after,
+                    args.reset_optimizer,
+                ],
+            )
+        )
+
+    for p in processes:
+        p.start()
+
+    for p in processes:
+        p.join()
diff --git a/eval/testingPeerSampler.py b/eval/testingPeerSampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e0b39a838254bd078f0ae709f68b967f97d3caa
--- /dev/null
+++ b/eval/testingPeerSampler.py
@@ -0,0 +1,102 @@
+import logging
+from pathlib import Path
+from shutil import copy
+
+from localconfig import LocalConfig
+from torch import multiprocessing as mp
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Linear import Linear
+from decentralizepy.node.DPSGDWithPeerSampler import DPSGDWithPeerSampler
+from decentralizepy.node.PeerSampler import PeerSampler
+# from decentralizepy.node.PeerSamplerDynamic import PeerSamplerDynamic
+
+
+def read_ini(file_path):
+    config = LocalConfig(file_path)
+    for section in config:
+        print("Section: ", section)
+        for key, value in config.items(section):
+            print((key, value))
+    print(dict(config.items("DATASET")))
+    return config
+
+
+if __name__ == "__main__":
+    args = utils.get_args()
+
+    Path(args.log_dir).mkdir(parents=True, exist_ok=True)
+
+    log_level = {
+        "INFO": logging.INFO,
+        "DEBUG": logging.DEBUG,
+        "WARNING": logging.WARNING,
+        "ERROR": logging.ERROR,
+        "CRITICAL": logging.CRITICAL,
+    }
+
+    config = read_ini(args.config_file)
+    my_config = dict()
+    for section in config:
+        my_config[section] = dict(config.items(section))
+
+    copy(args.config_file, args.log_dir)
+    copy(args.graph_file, args.log_dir)
+    utils.write_args(args, args.log_dir)
+
+    g = Graph()
+    g.read_graph_from_file(args.graph_file, args.graph_type)
+    n_machines = args.machines
+    procs_per_machine = args.procs_per_machine
+    l = Linear(n_machines, procs_per_machine)
+    m_id = args.machine_id
+
+    sm = args.server_machine
+    sr = args.server_rank
+
+    processes = []
+    if sm == m_id:
+        processes.append(
+            mp.Process(
+                # target=PeerSamplerDynamic,
+                target=PeerSampler,
+                args=[
+                    sr,
+                    m_id,
+                    l,
+                    g,
+                    my_config,
+                    args.iterations,
+                    args.log_dir,
+                    log_level[args.log_level],
+                ],
+            )
+        )
+
+    for r in range(0, procs_per_machine):
+        processes.append(
+            mp.Process(
+                target=DPSGDWithPeerSampler,
+                args=[
+                    r,
+                    m_id,
+                    l,
+                    g,
+                    my_config,
+                    args.iterations,
+                    args.log_dir,
+                    args.weights_store_dir,
+                    log_level[args.log_level],
+                    args.test_after,
+                    args.train_evaluate_after,
+                    args.reset_optimizer,
+                ],
+            )
+        )
+
+    for p in processes:
+        p.start()
+
+    for p in processes:
+        p.join()
diff --git a/setup.cfg b/setup.cfg
index 1b3f6c715474db0d582c94ea94baa8b6193ff5bf..2df457a2da8208ebc42ad6b41012b99e6704d14c 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -54,7 +54,7 @@ python_requires = >=3.6
 where = src
 [options.extras_require]
 dev =
-        black
+        black>22.3.0
         coverage
         isort
         pytest
diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py
index 6d72e1cfc61e6f9888cac3d6cdb52c81c84d24e4..16de517dc2c1c88dc951ed544bece2799fd426a9 100644
--- a/src/decentralizepy/communication/TCP.py
+++ b/src/decentralizepy/communication/TCP.py
@@ -1,4 +1,3 @@
-import importlib
 import json
 import logging
 import pickle
@@ -36,7 +35,8 @@ class TCP(Communication):
 
         """
         machine_addr = self.ip_addrs[str(machine_id)]
-        port = rank + self.offset
+        port = (2 * rank + 1) + self.offset
+        assert port > 0
         return "tcp://{}:{}".format(machine_addr, port)
 
     def __init__(
@@ -46,10 +46,7 @@ class TCP(Communication):
         mapping,
         total_procs,
         addresses_filepath,
-        compress=False,
-        offset=20000,
-        compression_package=None,
-        compression_class=None,
+        offset=9000,
     ):
         """
         Constructor
@@ -81,30 +78,19 @@ class TCP(Communication):
         self.rank = rank
         self.machine_id = machine_id
         self.mapping = mapping
-        self.offset = 20000 + offset
+        self.offset = offset
         self.uid = mapping.get_uid(rank, machine_id)
         self.identity = str(self.uid).encode()
         self.context = zmq.Context()
         self.router = self.context.socket(zmq.ROUTER)
         self.router.setsockopt(zmq.IDENTITY, self.identity)
         self.router.bind(self.addr(rank, machine_id))
-        self.sent_disconnections = False
-        self.compress = compress
-
-        if compression_package and compression_class:
-            compressor_module = importlib.import_module(compression_package)
-            compressor_class = getattr(compressor_module, compression_class)
-            self.compressor = compressor_class()
-            logging.info(f"Using the {compressor_class} to compress the data")
-        else:
-            assert not self.compress
 
         self.total_data = 0
         self.total_meta = 0
 
         self.peer_deque = deque()
         self.peer_sockets = dict()
-        self.barrier = set()
 
     def __del__(self):
         """
@@ -128,26 +114,12 @@ class TCP(Communication):
             Encoded data
 
         """
-        if self.compress:
-            if "indices" in data:
-                data["indices"] = self.compressor.compress(data["indices"])
-
-            assert "params" in data
-            data["params"] = self.compressor.compress_float(data["params"])
+        data_len = 0
+        if "params" in data:
             data_len = len(pickle.dumps(data["params"]))
-            output = pickle.dumps(data)
-
-            # the compressed meta data gets only a few bytes smaller after pickling
-            self.total_meta += len(output) - data_len
-            self.total_data += data_len
-        else:
-            output = pickle.dumps(data)
-            # centralized testing uses its own instance
-            if type(data) == dict:
-                assert "params" in data
-                data_len = len(pickle.dumps(data["params"]))
-                self.total_meta += len(output) - data_len
-                self.total_data += data_len
+        output = pickle.dumps(data)
+        self.total_meta += len(output) - data_len
+        self.total_data += data_len
         return output
 
     def decrypt(self, sender, data):
@@ -168,63 +140,35 @@ class TCP(Communication):
 
         """
         sender = int(sender.decode())
-        if self.compress:
-            data = pickle.loads(data)
-            if "indices" in data:
-                data["indices"] = self.compressor.decompress(data["indices"])
-            if "params" in data:
-                data["params"] = self.compressor.decompress_float(data["params"])
-        else:
-            data = pickle.loads(data)
+        data = pickle.loads(data)
         return sender, data
 
-    def connect_neighbors(self, neighbors):
+    def init_connection(self, neighbor):
         """
-        Connects all neighbors. Sends HELLO. Waits for HELLO.
-        Caches any data received while waiting for HELLOs.
+        Initiates a socket to a given node.
 
         Parameters
         ----------
-        neighbors : list(int)
-            List of neighbors
-
-        Raises
-        ------
-        RuntimeError
-            If received BYE while waiting for HELLO
+        neighbor : int
+            neighbor to connect to
 
         """
-        logging.info("Sending connection request to neighbors")
-        for uid in neighbors:
-            logging.debug("Connecting to my neighbour: {}".format(uid))
-            id = str(uid).encode()
-            req = self.context.socket(zmq.DEALER)
-            req.setsockopt(zmq.IDENTITY, self.identity)
-            req.connect(self.addr(*self.mapping.get_machine_and_rank(uid)))
-            self.peer_sockets[id] = req
-            req.send(HELLO)
-
-        num_neighbors = len(neighbors)
-        while len(self.barrier) < num_neighbors:
-            sender, recv = self.router.recv_multipart()
-
-            if recv == HELLO:
-                logging.debug("Received {} from {}".format(HELLO, sender))
-                self.barrier.add(sender)
-            elif recv == BYE:
-                logging.debug("Received {} from {}".format(BYE, sender))
-                raise RuntimeError(
-                    "A neighbour wants to disconnect before training started!"
-                )
-            else:
-                logging.debug(
-                    "Received message from {} @ connect_neighbors".format(sender)
-                )
-
-                self.peer_deque.append(self.decrypt(sender, recv))
-
-        logging.info("Connected to all neighbors")
-        self.initialized = True
+        logging.debug("Connecting to my neighbour: {}".format(neighbor))
+        id = str(neighbor).encode()
+        req = self.context.socket(zmq.DEALER)
+        req.setsockopt(zmq.IDENTITY, self.identity)
+        req.connect(self.addr(*self.mapping.get_machine_and_rank(neighbor)))
+        self.peer_sockets[id] = req
+
+    def destroy_connection(self, neighbor, linger=None):
+        id = str(neighbor).encode()
+        if self.already_connected(neighbor):
+            self.peer_sockets[id].close(linger=linger)
+            del self.peer_sockets[id]
+
+    def already_connected(self, neighbor):
+        id = str(neighbor).encode()
+        return id in self.peer_sockets
 
     def receive(self):
         """
@@ -241,25 +185,10 @@ class TCP(Communication):
             If received HELLO
 
         """
-        assert self.initialized == True
-        if len(self.peer_deque) != 0:
-            resp = self.peer_deque.popleft()
-            return resp
 
         sender, recv = self.router.recv_multipart()
-
-        if recv == HELLO:
-            logging.debug("Received {} from {}".format(HELLO, sender))
-            raise RuntimeError(
-                "A neighbour wants to connect when everyone is connected!"
-            )
-        elif recv == BYE:
-            logging.debug("Received {} from {}".format(BYE, sender))
-            self.barrier.remove(sender)
-            return self.receive()
-        else:
-            logging.debug("Received message from {}".format(sender))
-            return self.decrypt(sender, recv)
+        s, r = self.decrypt(sender, recv)
+        return s, r
 
     def send(self, uid, data, encrypt=True):
         """
@@ -273,7 +202,6 @@ class TCP(Communication):
             Message as a Python dictionary
 
         """
-        assert self.initialized == True
         if encrypt:
             to_send = self.encrypt(data)
         else:
@@ -283,28 +211,4 @@ class TCP(Communication):
         id = str(uid).encode()
         self.peer_sockets[id].send(to_send)
         logging.debug("{} sent the message to {}.".format(self.uid, uid))
-        logging.info("Sent this round: {}".format(data_size))
-
-    def disconnect_neighbors(self):
-        """
-        Disconnects all neighbors.
-
-        """
-        assert self.initialized == True
-        if not self.sent_disconnections:
-            logging.info("Disconnecting neighbors")
-            for sock in self.peer_sockets.values():
-                sock.send(BYE)
-            self.sent_disconnections = True
-            while len(self.barrier):
-                sender, recv = self.router.recv_multipart()
-                if recv == BYE:
-                    logging.debug("Received {} from {}".format(BYE, sender))
-                    self.barrier.remove(sender)
-                else:
-                    logging.critical(
-                        "Received unexpected {} from {}".format(recv, sender)
-                    )
-                    raise RuntimeError(
-                        "Received a message when expecting BYE from {}".format(sender)
-                    )
+        logging.info("Sent message size: {}".format(data_size))
diff --git a/src/decentralizepy/compression/Compression.py b/src/decentralizepy/compression/Compression.py
index 0924cafe370b774763874592d9f49601f77168cf..b45e6415d1cf2e068c623ef784156e0748cb23b7 100644
--- a/src/decentralizepy/compression/Compression.py
+++ b/src/decentralizepy/compression/Compression.py
@@ -1,6 +1,3 @@
-import numpy as np
-
-
 class Compression:
     """
     Compression API
diff --git a/src/decentralizepy/compression/Elias.py b/src/decentralizepy/compression/Elias.py
index 235cf002e6f5e9e36b2d388ddcaf81af0f0adab5..0d408d88a7a3c1f89c3d9f0127f13806ac821ee9 100644
--- a/src/decentralizepy/compression/Elias.py
+++ b/src/decentralizepy/compression/Elias.py
@@ -1,6 +1,5 @@
 # elias implementation: taken from this stack overflow post:
 # https://stackoverflow.com/questions/62843156/python-fast-compression-of-large-amount-of-numbers-with-elias-gamma
-import fpzip
 import numpy as np
 
 from decentralizepy.compression.Compression import Compression
diff --git a/src/decentralizepy/compression/EliasFpzip.py b/src/decentralizepy/compression/EliasFpzip.py
index 0c82560aae28efb64bdba2a9cf0abf81d7fcdda7..0142dd95efdf9e51fdc29fdf016267bae93376da 100644
--- a/src/decentralizepy/compression/EliasFpzip.py
+++ b/src/decentralizepy/compression/EliasFpzip.py
@@ -1,7 +1,6 @@
 # elias implementation: taken from this stack overflow post:
 # https://stackoverflow.com/questions/62843156/python-fast-compression-of-large-amount-of-numbers-with-elias-gamma
 import fpzip
-import numpy as np
 
 from decentralizepy.compression.Elias import Elias
 
diff --git a/src/decentralizepy/compression/EliasFpzipLossy.py b/src/decentralizepy/compression/EliasFpzipLossy.py
index 617a78b2b27ff88bd57db29a3a65d71ad1e0a843..0b60307a289b7605e02aeffbf039888c4e0062dd 100644
--- a/src/decentralizepy/compression/EliasFpzipLossy.py
+++ b/src/decentralizepy/compression/EliasFpzipLossy.py
@@ -1,7 +1,6 @@
 # elias implementation: taken from this stack overflow post:
 # https://stackoverflow.com/questions/62843156/python-fast-compression-of-large-amount-of-numbers-with-elias-gamma
 import fpzip
-import numpy as np
 
 from decentralizepy.compression.Elias import Elias
 
diff --git a/src/decentralizepy/datasets/MovieLens.py b/src/decentralizepy/datasets/MovieLens.py
index dafb4cee9883c942859af036960ccda5da42c2f5..95e55cc6dd61f0d4139f28ec9f2681116a48c366 100644
--- a/src/decentralizepy/datasets/MovieLens.py
+++ b/src/decentralizepy/datasets/MovieLens.py
@@ -3,7 +3,6 @@ import math
 import os
 import zipfile
 
-import numpy as np
 import pandas as pd
 import requests
 import torch
diff --git a/src/decentralizepy/datasets/Shakespeare.py b/src/decentralizepy/datasets/Shakespeare.py
index 0c0293208717385fe69ee41bd8355cc15b6997e1..c7ede740a574a82a6779f9434cfd2751867d2b64 100644
--- a/src/decentralizepy/datasets/Shakespeare.py
+++ b/src/decentralizepy/datasets/Shakespeare.py
@@ -1,7 +1,6 @@
 import json
 import logging
 import os
-import re
 from collections import defaultdict
 
 import numpy as np
diff --git a/src/decentralizepy/graphs/Graph.py b/src/decentralizepy/graphs/Graph.py
index 689d2dc62544d603ea492807bc6790b9e90c5f95..dc66eef2c3b0f5cefbec3b5d5b05400c965eaf90 100644
--- a/src/decentralizepy/graphs/Graph.py
+++ b/src/decentralizepy/graphs/Graph.py
@@ -22,6 +22,9 @@ class Graph:
             self.n_procs = n_procs
             self.adj_list = [set() for i in range(self.n_procs)]
 
+    def get_all_nodes(self):
+        return [i for i in range(self.n_procs)]
+
     def __insert_adj__(self, node, neighbours):
         """
         Inserts `neighbours` into the adjacency list of `node`
diff --git a/src/decentralizepy/mappings/Linear.py b/src/decentralizepy/mappings/Linear.py
index 9419fbd40a18d2c9ca1a4992854a6971ce937dde..f166dc9fb145a8e0c0387fffef378135f87e75d9 100644
--- a/src/decentralizepy/mappings/Linear.py
+++ b/src/decentralizepy/mappings/Linear.py
@@ -8,7 +8,7 @@ class Linear(Mapping):
 
     """
 
-    def __init__(self, n_machines, procs_per_machine):
+    def __init__(self, n_machines, procs_per_machine, global_service_machine=0):
         """
         Constructor
 
@@ -23,6 +23,7 @@ class Linear(Mapping):
         super().__init__(n_machines * procs_per_machine)
         self.n_machines = n_machines
         self.procs_per_machine = procs_per_machine
+        self.global_service_machine = global_service_machine
 
     def get_uid(self, rank: int, machine_id: int):
         """
@@ -41,6 +42,8 @@ class Linear(Mapping):
             the unique identifier
 
         """
+        if rank < 0:
+            return rank
         return machine_id * self.procs_per_machine + rank
 
     def get_machine_and_rank(self, uid: int):
@@ -58,6 +61,8 @@ class Linear(Mapping):
             a tuple of rank and machine_id
 
         """
+        if uid < 0:
+            return uid, self.global_service_machine
         return (uid % self.procs_per_machine), (uid // self.procs_per_machine)
 
     def get_local_procs_count(self):
diff --git a/src/decentralizepy/node/DPSGDNode.py b/src/decentralizepy/node/DPSGDNode.py
new file mode 100644
index 0000000000000000000000000000000000000000..3951ad297f99137c19435386ecba05c7d8db5730
--- /dev/null
+++ b/src/decentralizepy/node/DPSGDNode.py
@@ -0,0 +1,449 @@
+import importlib
+import json
+import logging
+import math
+import os
+from collections import deque
+
+import torch
+from matplotlib import pyplot as plt
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+
+
+class DPSGDNode(Node):
+    """
+    This class defines the node for DPSGD
+
+    """
+
+    def save_plot(self, l, label, title, xlabel, filename):
+        """
+        Save Matplotlib plot. Clears previous plots.
+
+        Parameters
+        ----------
+        l : dict
+            dict of x -> y. `x` must be castable to int.
+        label : str
+            label of the plot. Used for legend.
+        title : str
+            Header
+        xlabel : str
+            x-axis label
+        filename : str
+            Name of file to save the plot as.
+
+        """
+        plt.clf()
+        y_axis = [l[key] for key in l.keys()]
+        x_axis = list(map(int, l.keys()))
+        plt.plot(x_axis, y_axis, label=label)
+        plt.xlabel(xlabel)
+        plt.title(title)
+        plt.savefig(filename)
+
+    def get_neighbors(self, node=None):
+        return self.my_neighbors
+
+    # def instantiate_peer_deques(self):
+    #     for neighbor in self.my_neighbors:
+    #         if neighbor not in self.peer_deques:
+    #             self.peer_deques[neighbor] = deque()
+
+    def receive_DPSGD(self):
+        return self.receive_channel("DPSGD")
+
+    def run(self):
+        """
+        Start the decentralized learning
+
+        """
+        self.testset = self.dataset.get_testset()
+        rounds_to_test = self.test_after
+        rounds_to_train_evaluate = self.train_evaluate_after
+        global_epoch = 1
+        change = 1
+
+        for iteration in range(self.iterations):
+            logging.info("Starting training iteration: %d", iteration)
+            rounds_to_train_evaluate -= 1
+            rounds_to_test -= 1
+
+            self.iteration = iteration
+            self.trainer.train(self.dataset)
+
+            new_neighbors = self.get_neighbors()
+
+            # The following code does not work because TCP sockets are supposed to be long lived.
+            # for neighbor in self.my_neighbors:
+            #     if neighbor not in new_neighbors:
+            #         logging.info("Removing neighbor {}".format(neighbor))
+            #         if neighbor in self.peer_deques:
+            #             assert len(self.peer_deques[neighbor]) == 0
+            #             del self.peer_deques[neighbor]
+            #         self.communication.destroy_connection(neighbor, linger = 10000)
+            #         self.barrier.remove(neighbor)
+
+            self.my_neighbors = new_neighbors
+            self.connect_neighbors()
+            logging.info("Connected to all neighbors")
+            # self.instantiate_peer_deques()
+
+            to_send = self.sharing.get_data_to_send()
+            to_send["CHANNEL"] = "DPSGD"
+
+            for neighbor in self.my_neighbors:
+                self.communication.send(neighbor, to_send)
+
+            while not self.received_from_all():
+                sender, data = self.receive_DPSGD()
+                logging.info(
+                    "Received Model from {} of iteration {}".format(
+                        sender, data["iteration"]
+                    )
+                )
+                if sender not in self.peer_deques:
+                    self.peer_deques[sender] = deque()
+                self.peer_deques[sender].append(data)
+
+            averaging_deque = dict()
+            for neighbor in self.my_neighbors:
+                averaging_deque[neighbor] = self.peer_deques[neighbor]
+
+            self.sharing._averaging(averaging_deque)
+
+            if self.reset_optimizer:
+                self.optimizer = self.optimizer_class(
+                    self.model.parameters(), **self.optimizer_params
+                )  # Reset optimizer state
+                self.trainer.reset_optimizer(self.optimizer)
+
+            if iteration:
+                with open(
+                    os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
+                    "r",
+                ) as inf:
+                    results_dict = json.load(inf)
+            else:
+                results_dict = {
+                    "train_loss": {},
+                    "test_loss": {},
+                    "test_acc": {},
+                    "total_bytes": {},
+                    "total_meta": {},
+                    "total_data_per_n": {},
+                }
+
+            results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
+
+            if hasattr(self.communication, "total_meta"):
+                results_dict["total_meta"][
+                    iteration + 1
+                ] = self.communication.total_meta
+            if hasattr(self.communication, "total_data"):
+                results_dict["total_data_per_n"][
+                    iteration + 1
+                ] = self.communication.total_data
+
+            if rounds_to_train_evaluate == 0:
+                logging.info("Evaluating on train set.")
+                rounds_to_train_evaluate = self.train_evaluate_after * change
+                loss_after_sharing = self.trainer.eval_loss(self.dataset)
+                results_dict["train_loss"][iteration + 1] = loss_after_sharing
+                self.save_plot(
+                    results_dict["train_loss"],
+                    "train_loss",
+                    "Training Loss",
+                    "Communication Rounds",
+                    os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
+                )
+
+            if self.dataset.__testing__ and rounds_to_test == 0:
+                rounds_to_test = self.test_after * change
+                logging.info("Evaluating on test set.")
+                ta, tl = self.dataset.test(self.model, self.loss)
+                results_dict["test_acc"][iteration + 1] = ta
+                results_dict["test_loss"][iteration + 1] = tl
+
+                if global_epoch == 49:
+                    change *= 2
+
+                global_epoch += change
+
+            with open(
+                os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
+            ) as of:
+                json.dump(results_dict, of)
+        if self.model.shared_parameters_counter is not None:
+            logging.info("Saving the shared parameter counts")
+            with open(
+                os.path.join(
+                    self.log_dir, "{}_shared_parameters.json".format(self.rank)
+                ),
+                "w",
+            ) as of:
+                json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
+        self.disconnect_neighbors()
+        logging.info("Storing final weight")
+        self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
+        logging.info("All neighbors disconnected. Process complete!")
+
+    def cache_fields(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+        weights_store_dir,
+        test_after,
+        train_evaluate_after,
+        reset_optimizer,
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        """
+        self.rank = rank
+        self.machine_id = machine_id
+        self.graph = graph
+        self.mapping = mapping
+        self.uid = self.mapping.get_uid(rank, machine_id)
+        self.log_dir = log_dir
+        self.weights_store_dir = weights_store_dir
+        self.iterations = iterations
+        self.test_after = test_after
+        self.train_evaluate_after = train_evaluate_after
+        self.reset_optimizer = reset_optimizer
+        self.sent_disconnections = False
+
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
+
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
+        comm_module = importlib.import_module(comm_configs["comm_package"])
+        comm_class = getattr(comm_module, comm_configs["comm_class"])
+        comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
+        self.communication = comm_class(
+            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
+        )
+
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.init_log(log_dir, rank, log_level)
+
+        self.cache_fields(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+        )
+        self.init_dataset_model(config["DATASET"])
+        self.init_optimizer(config["OPTIMIZER_PARAMS"])
+        self.init_trainer(config["TRAIN_PARAMS"])
+        self.init_comm(config["COMMUNICATION"])
+
+        self.message_queue = dict()
+
+        self.barrier = set()
+        self.my_neighbors = self.graph.neighbors(self.uid)
+
+        self.init_sharing(config["SHARING"])
+        self.peer_deques = dict()
+        self.connect_neighbors()
+
+    def received_from_all(self):
+        """
+        Check if all neighbors have sent the current iteration
+
+        Returns
+        -------
+        bool
+            True if required data has been received, False otherwise
+
+        """
+        for k in self.my_neighbors:
+            if (k not in self.peer_deques) or len(self.peer_deques[k]) == 0:
+                return False
+        return True
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        args : optional
+            Other arguments
+
+        """
+
+        total_threads = os.cpu_count()
+        self.threads_per_proc = max(
+            math.floor(total_threads / mapping.procs_per_machine), 1
+        )
+        torch.set_num_threads(self.threads_per_proc)
+        torch.set_num_interop_threads(1)
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            log_level,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+            *args
+        )
+        logging.info(
+            "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
+        )
+        self.run()
diff --git a/src/decentralizepy/node/DPSGDNodeFederated.py b/src/decentralizepy/node/DPSGDNodeFederated.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1e36df1b936b33a4f415b469f1d803e936f1eec
--- /dev/null
+++ b/src/decentralizepy/node/DPSGDNodeFederated.py
@@ -0,0 +1,353 @@
+import importlib
+import json
+import logging
+import math
+import os
+from collections import deque
+
+import torch
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+
+
+class DPSGDNodeFederated(Node):
+    """
+    This class defines the node for federated DPSGD
+
+    """
+
+    def run(self):
+        """
+        Start the decentralized learning
+
+        """
+        while len(self.barrier):
+            sender, data = self.receive_channel("WORKER_REQUEST")
+
+            if "BYE" in data:
+                logging.info("Received {} from {}".format("BYE", sender))
+                self.barrier.remove(sender)
+                break
+
+            iteration = data["iteration"]
+            del data["iteration"]
+            del data["CHANNEL"]
+
+            self.model.load_state_dict(data["params"])
+            self.sharing._post_step()
+            self.sharing.communication_round += 1
+
+            logging.info(
+                "Received worker request at node {}, global iteration {}, local round {}".format(
+                    self.uid, iteration, self.participated
+                )
+            )
+
+            if self.reset_optimizer:
+                self.optimizer = self.optimizer_class(
+                    self.model.parameters(), **self.optimizer_params
+                )  # Reset optimizer state
+                self.trainer.reset_optimizer(self.optimizer)
+
+            # Perform iteration
+            logging.info("Starting training iteration")
+            self.trainer.train(self.dataset)
+
+            # Send update to server
+            to_send = self.sharing.get_data_to_send()
+            to_send["CHANNEL"] = "DPSGD"
+            self.communication.send(self.parameter_server_uid, to_send)
+
+            if self.participated > 0:
+                with open(
+                    os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
+                    "r",
+                ) as inf:
+                    results_dict = json.load(inf)
+            else:
+                results_dict = {
+                    "train_loss": {},
+                    "test_loss": {},
+                    "test_acc": {},
+                    "total_bytes": {},
+                    "total_meta": {},
+                    "total_data_per_n": {},
+                }
+
+            results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
+
+            if hasattr(self.communication, "total_meta"):
+                results_dict["total_meta"][
+                    iteration + 1
+                ] = self.communication.total_meta
+            if hasattr(self.communication, "total_data"):
+                results_dict["total_data_per_n"][
+                    iteration + 1
+                ] = self.communication.total_data
+
+            with open(
+                os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
+            ) as of:
+                json.dump(results_dict, of)
+
+            self.participated += 1
+
+        # only if has participated in learning
+        if self.participated > 0:
+            logging.info("Storing final weight")
+            self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
+
+        logging.info("Server disconnected. Process complete!")
+
+    def cache_fields(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+        weights_store_dir,
+        test_after,
+        train_evaluate_after,
+        reset_optimizer,
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        """
+        self.rank = rank
+        self.machine_id = machine_id
+        self.graph = graph
+        self.mapping = mapping
+        self.uid = self.mapping.get_uid(rank, machine_id)
+        self.log_dir = log_dir
+        self.weights_store_dir = weights_store_dir
+        self.iterations = iterations
+        self.test_after = test_after
+        self.train_evaluate_after = train_evaluate_after
+        self.reset_optimizer = reset_optimizer
+        self.sent_disconnections = False
+
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
+
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
+        comm_module = importlib.import_module(comm_configs["comm_package"])
+        comm_class = getattr(comm_module, comm_configs["comm_class"])
+        comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
+        self.communication = comm_class(
+            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
+        )
+
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.init_log(log_dir, rank, log_level)
+
+        self.cache_fields(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+        )
+        self.init_dataset_model(config["DATASET"])
+        self.init_optimizer(config["OPTIMIZER_PARAMS"])
+        self.init_trainer(config["TRAIN_PARAMS"])
+        self.init_comm(config["COMMUNICATION"])
+
+        self.message_queue = dict()
+
+        self.barrier = set()
+
+        self.participated = 0
+
+        self.init_sharing(config["SHARING"])
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        parameter_server_uid=-1,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        parameter_server_uid: int
+            The parameter server's uid
+        args : optional
+            Other arguments
+
+        """
+
+        total_threads = os.cpu_count()
+        self.threads_per_proc = max(
+            math.floor(total_threads / mapping.procs_per_machine), 1
+        )
+        torch.set_num_threads(self.threads_per_proc)
+        torch.set_num_interop_threads(1)
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            log_level,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+            *args
+        )
+        logging.info(
+            "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
+        )
+
+        self.message_queue["PEERS"] = deque()
+
+        self.parameter_server_uid = parameter_server_uid
+        self.connect_neighbor(self.parameter_server_uid)
+        self.wait_for_hello(self.parameter_server_uid)
+
+        self.run()
diff --git a/src/decentralizepy/node/DPSGDWithPeerSampler.py b/src/decentralizepy/node/DPSGDWithPeerSampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4f90e744e90d51916b3bf8f39667b74a49f6259
--- /dev/null
+++ b/src/decentralizepy/node/DPSGDWithPeerSampler.py
@@ -0,0 +1,155 @@
+import logging
+import math
+import os
+from collections import deque
+
+import torch
+
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.DPSGDNode import DPSGDNode
+
+
+class DPSGDWithPeerSampler(DPSGDNode):
+    """
+    This class defines the node for DPSGD
+
+    """
+
+    def receive_neighbors(self):
+        return self.receive_channel("PEERS")[1]["NEIGHBORS"]
+
+    def get_neighbors(self, node=None):
+        logging.info("Requesting neighbors from the peer sampler.")
+        self.communication.send(
+            self.peer_sampler_uid,
+            {
+                "REQUEST_NEIGHBORS": self.uid,
+                "iteration": self.iteration,
+                "CHANNEL": "SERVER_REQUEST",
+            },
+        )
+        my_neighbors = self.receive_neighbors()
+        logging.info("Neighbors this round: {}".format(my_neighbors))
+        return my_neighbors
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        peer_sampler_uid=-1,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        args : optional
+            Other arguments
+
+        """
+
+        total_threads = os.cpu_count()
+        self.threads_per_proc = max(
+            math.floor(total_threads / mapping.procs_per_machine), 1
+        )
+        torch.set_num_threads(self.threads_per_proc)
+        torch.set_num_interop_threads(1)
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            log_level,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+            *args
+        )
+        logging.info(
+            "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
+        )
+
+        self.message_queue["PEERS"] = deque()
+
+        self.peer_sampler_uid = peer_sampler_uid
+        self.connect_neighbor(self.peer_sampler_uid)
+        self.wait_for_hello(self.peer_sampler_uid)
+
+        self.run()
+
+    def disconnect_neighbors(self):
+        """
+        Disconnects all neighbors.
+
+        Raises
+        ------
+        RuntimeError
+            If received another message while waiting for BYEs
+
+        """
+        if not self.sent_disconnections:
+            logging.info("Disconnecting neighbors")
+            for uid in self.barrier:
+                self.communication.send(uid, {"BYE": self.uid, "CHANNEL": "DISCONNECT"})
+            self.communication.send(
+                self.peer_sampler_uid, {"BYE": self.uid, "CHANNEL": "SERVER_REQUEST"}
+            )
+            self.sent_disconnections = True
+
+            self.barrier.remove(self.peer_sampler_uid)
+
+            while len(self.barrier):
+                sender, _ = self.receive_disconnect()
+                self.barrier.remove(sender)
diff --git a/src/decentralizepy/node/FederatedParameterServer.py b/src/decentralizepy/node/FederatedParameterServer.py
new file mode 100644
index 0000000000000000000000000000000000000000..75fb11134ac2289a8857f07d4a8139ed99a1316b
--- /dev/null
+++ b/src/decentralizepy/node/FederatedParameterServer.py
@@ -0,0 +1,483 @@
+import importlib
+import json
+import logging
+import math
+import os
+import random
+from collections import deque
+
+from matplotlib import pyplot as plt
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+
+
+class FederatedParameterServer(Node):
+    """
+    This class defines the parameter serving service
+
+    """
+
+    def save_plot(self, l, label, title, xlabel, filename):
+        """
+        Save Matplotlib plot. Clears previous plots.
+
+        Parameters
+        ----------
+        l : dict
+            dict of x -> y. `x` must be castable to int.
+        label : str
+            label of the plot. Used for legend.
+        title : str
+            Header
+        xlabel : str
+            x-axis label
+        filename : str
+            Name of file to save the plot as.
+
+        """
+        plt.clf()
+        y_axis = [l[key] for key in l.keys()]
+        x_axis = list(map(int, l.keys()))
+        plt.plot(x_axis, y_axis, label=label)
+        plt.xlabel(xlabel)
+        plt.title(title)
+        plt.savefig(filename)
+
+    def init_log(self, log_dir, log_level, force=True):
+        """
+        Instantiate Logging.
+
+        Parameters
+        ----------
+        log_dir : str
+            Logging directory
+        rank : rank : int
+            Rank of process local to the machine
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        force : bool
+            Argument to logging.basicConfig()
+
+        """
+        log_file = os.path.join(log_dir, "ParameterServer.log")
+        logging.basicConfig(
+            filename=log_file,
+            format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s",
+            level=log_level,
+            force=force,
+        )
+
+    def cache_fields(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+        weights_store_dir,
+        test_after,
+        train_evaluate_after,
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+
+        """
+        self.rank = rank
+        self.machine_id = machine_id
+        self.graph = graph
+        self.mapping = mapping
+        self.uid = self.mapping.get_uid(rank, machine_id)
+        self.log_dir = log_dir
+        self.iterations = iterations
+        self.sent_disconnections = False
+        self.weights_store_dir = weights_store_dir
+        self.test_after = test_after
+        self.train_evaluate_after = train_evaluate_after
+
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
+
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
+        comm_module = importlib.import_module(comm_configs["comm_package"])
+        comm_class = getattr(comm_module, comm_configs["comm_class"])
+        comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
+        self.communication = comm_class(
+            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
+        )
+
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.init_log(log_dir, log_level)
+
+        self.cache_fields(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            test_after,
+            train_evaluate_after,
+        )
+
+        self.message_queue = dict()
+
+        self.barrier = set()
+
+        self.peer_deques = dict()
+
+        self.init_dataset_model(config["DATASET"])
+        self.init_comm(config["COMMUNICATION"])
+        self.init_optimizer(config["OPTIMIZER_PARAMS"])
+        self.init_trainer(config["TRAIN_PARAMS"])
+
+        self.my_neighbors = self.graph.get_all_nodes()
+        self.connect_neighbors()
+
+        self.init_sharing(config["SHARING"])
+
+    def received_from_all(self):
+        """
+        Check if all current workers have sent the current iteration
+
+        Returns
+        -------
+        bool
+            True if required data has been received, False otherwise
+
+        """
+        for k in self.current_workers:
+            if (k not in self.peer_deques) or len(self.peer_deques[k]) == 0:
+                return False
+        return True
+
+    def disconnect_neighbors(self):
+        """
+        Disconnects all neighbors.
+
+        Raises
+        ------
+        RuntimeError
+            If received another message while waiting for BYEs
+
+        """
+        if not self.sent_disconnections:
+            logging.info("Disconnecting neighbors")
+
+            for neighbor in self.my_neighbors:
+                self.communication.send(
+                    neighbor, {"BYE": self.uid, "CHANNEL": "WORKER_REQUEST"}
+                )
+                self.barrier.remove(neighbor)
+
+            self.sent_disconnections = True
+
+    def get_working_nodes(self):
+        """
+        Randomly select set of clients for the current iteration
+
+        """
+        k = int(math.ceil(len(self.my_neighbors) * self.working_fraction))
+        return random.sample(self.my_neighbors, k)
+
+    def run(self):
+        """
+        Start the federated parameter-serving service.
+
+        """
+        self.testset = self.dataset.get_testset()
+        rounds_to_test = self.test_after
+        rounds_to_train_evaluate = self.train_evaluate_after
+        global_epoch = 1
+        change = 1
+
+        to_send = dict()
+
+        for iteration in range(self.iterations):
+            self.iteration = iteration
+            # reset deques after each iteration
+            self.peer_deques = dict()
+
+            # Get workers for this iteration
+            self.current_workers = self.get_working_nodes()
+
+            # Params to send to workers
+            to_send["params"] = self.model.state_dict()
+            to_send["CHANNEL"] = "WORKER_REQUEST"
+            to_send["iteration"] = iteration
+
+            # Notify workers
+            for worker in self.current_workers:
+                self.communication.send(worker, to_send)
+
+            # Receive updates from current workers
+            while not self.received_from_all():
+                sender, data = self.receive_channel("DPSGD")
+                if sender not in self.peer_deques:
+                    self.peer_deques[sender] = deque()
+                self.peer_deques[sender].append(data)
+
+            logging.info("Received from all current workers")
+
+            # Average received updates
+            averaging_deque = dict()
+            for worker in self.current_workers:
+                averaging_deque[worker] = self.peer_deques[worker]
+
+            self.sharing._pre_step()
+            self.sharing._averaging_server(averaging_deque)
+
+            if iteration:
+                with open(
+                    os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
+                    "r",
+                ) as inf:
+                    results_dict = json.load(inf)
+            else:
+                results_dict = {
+                    "train_loss": {},
+                    "test_loss": {},
+                    "test_acc": {},
+                    "total_bytes": {},
+                    "total_meta": {},
+                    "total_data_per_n": {},
+                }
+
+            results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
+
+            if hasattr(self.communication, "total_meta"):
+                results_dict["total_meta"][
+                    iteration + 1
+                ] = self.communication.total_meta
+            if hasattr(self.communication, "total_data"):
+                results_dict["total_data_per_n"][
+                    iteration + 1
+                ] = self.communication.total_data
+
+            rounds_to_train_evaluate -= 1
+
+            if rounds_to_train_evaluate == 0:
+                logging.info("Evaluating on train set.")
+                rounds_to_train_evaluate = self.train_evaluate_after * change
+                loss_after_sharing = self.trainer.eval_loss(self.dataset)
+                results_dict["train_loss"][iteration + 1] = loss_after_sharing
+                self.save_plot(
+                    results_dict["train_loss"],
+                    "train_loss",
+                    "Training Loss",
+                    "Communication Rounds",
+                    os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
+                )
+
+            rounds_to_test -= 1
+
+            if self.dataset.__testing__ and rounds_to_test == 0:
+                rounds_to_test = self.test_after * change
+                logging.info("Evaluating on test set.")
+                ta, tl = self.dataset.test(self.model, self.loss)
+                results_dict["test_acc"][iteration + 1] = ta
+                results_dict["test_loss"][iteration + 1] = tl
+
+                if global_epoch == 49:
+                    change *= 2
+
+                global_epoch += change
+
+            with open(
+                os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
+            ) as of:
+                json.dump(results_dict, of)
+
+        if self.model.shared_parameters_counter is not None:
+            logging.info("Saving the shared parameter counts")
+            with open(
+                os.path.join(
+                    self.log_dir, "{}_shared_parameters.json".format(self.rank)
+                ),
+                "w",
+            ) as of:
+                json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
+
+        self.disconnect_neighbors()
+        logging.info("Storing final weight")
+        self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
+        logging.info("All neighbors disconnected. Process complete!")
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        working_fraction=1.0,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        working_fraction : float
+            Percentage of nodes participating in one global iteration
+        args : optional
+            Other arguments
+
+        """
+        super().__init__(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            log_level,
+            *args
+        )
+
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            log_level,
+            test_after,
+            train_evaluate_after,
+            *args
+        )
+
+        self.working_fraction = working_fraction
+
+        random.seed(self.mapping.get_uid(self.rank, self.machine_id))
+
+        self.run()
+
+        logging.info("Parameter Server exiting")
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 91f34e5ff4cb8540315c23e22fa951f2df7eac76..145b36297909314c663cc3402e9cd4c0205c1dc5 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -1,18 +1,14 @@
 import importlib
-import json
 import logging
 import math
 import os
+from collections import deque
 
 import torch
-from matplotlib import pyplot as plt
 
 from decentralizepy import utils
-from decentralizepy.communication.TCP import TCP
 from decentralizepy.graphs.Graph import Graph
-from decentralizepy.graphs.Star import Star
 from decentralizepy.mappings.Mapping import Mapping
-from decentralizepy.train_test_evaluation import TrainTestHelper
 
 
 class Node:
@@ -21,31 +17,103 @@ class Node:
 
     """
 
-    def save_plot(self, l, label, title, xlabel, filename):
+    def connect_neighbor(self, neighbor):
         """
-        Save Matplotlib plot. Clears previous plots.
+        Connects given neighbor. Sends HELLO.
 
-        Parameters
-        ----------
-        l : dict
-            dict of x -> y. `x` must be castable to int.
-        label : str
-            label of the plot. Used for legend.
-        title : str
-            Header
-        xlabel : str
-            x-axis label
-        filename : str
-            Name of file to save the plot as.
+        """
+        logging.info("Sending connection request to {}".format(neighbor))
+        self.communication.init_connection(neighbor)
+        self.communication.send(neighbor, {"HELLO": self.uid, "CHANNEL": "CONNECT"})
+
+    def receive_channel(self, channel):
+        if channel not in self.message_queue:
+            self.message_queue[channel] = deque()
+
+        if len(self.message_queue[channel]) > 0:
+            return self.message_queue[channel].popleft()
+        else:
+            sender, recv = self.communication.receive()
+            logging.info(
+                "Received some message from {} with CHANNEL: {}".format(
+                    sender, recv["CHANNEL"]
+                )
+            )
+            assert "CHANNEL" in recv
+            while recv["CHANNEL"] != channel:
+                if recv["CHANNEL"] not in self.message_queue:
+                    self.message_queue[recv["CHANNEL"]] = deque()
+                self.message_queue[recv["CHANNEL"]].append((sender, recv))
+                sender, recv = self.communication.receive()
+                logging.info(
+                    "Received some message from {} with CHANNEL: {}".format(
+                        sender, recv["CHANNEL"]
+                    )
+                )
+            return (sender, recv)
+
+    def receive_hello(self):
+        return self.receive_channel("CONNECT")
+
+    def wait_for_hello(self, neighbor):
+        """
+        Waits for HELLO.
+        Caches any data received while waiting for HELLOs.
+
+        Raises
+        ------
+        RuntimeError
+            If received BYE while waiting for HELLO
+
+        """
+        while neighbor not in self.barrier:
+            logging.info("Waiting HELLO from {}".format(neighbor))
+            sender, _ = self.receive_hello()
+            logging.info("Received HELLO from {}".format(sender))
+            self.barrier.add(sender)
+
+    def connect_neighbors(self):
+        """
+        Connects all neighbors. Sends HELLO. Waits for HELLO.
+        Caches any data received while waiting for HELLOs.
+
+        Raises
+        ------
+        RuntimeError
+            If received BYE while waiting for HELLO
+
+        """
+        logging.info("Sending connection request to all neighbors")
+        wait_acknowledgements = []
+        for neighbor in self.my_neighbors:
+            if not self.communication.already_connected(neighbor):
+                self.connect_neighbor(neighbor)
+                wait_acknowledgements.append(neighbor)
+
+        for neighbor in wait_acknowledgements:
+            self.wait_for_hello(neighbor)
+
+    def receive_disconnect(self):
+        return self.receive_channel("DISCONNECT")
+
+    def disconnect_neighbors(self):
+        """
+        Disconnects all neighbors.
+
+        Raises
+        ------
+        RuntimeError
+            If received another message while waiting for BYEs
 
         """
-        plt.clf()
-        y_axis = [l[key] for key in l.keys()]
-        x_axis = list(map(int, l.keys()))
-        plt.plot(x_axis, y_axis, label=label)
-        plt.xlabel(xlabel)
-        plt.title(title)
-        plt.savefig(filename)
+        if not self.sent_disconnections:
+            logging.info("Disconnecting neighbors")
+            for uid in self.barrier:
+                self.communication.send(uid, {"BYE": self.uid, "CHANNEL": "DISCONNECT"})
+            self.sent_disconnections = True
+            while len(self.barrier):
+                sender, _ = self.receive_disconnect()
+                self.barrier.remove(sender)
 
     def init_log(self, log_dir, rank, log_level, force=True):
         """
@@ -68,7 +136,7 @@ class Node:
             filename=log_file,
             format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s",
             level=log_level,
-            force=True,
+            force=force,
         )
 
     def cache_fields(
@@ -79,12 +147,6 @@ class Node:
         graph,
         iterations,
         log_dir,
-        weights_store_dir,
-        test_after,
-        train_evaluate_after,
-        reset_optimizer,
-        centralized_train_eval,
-        centralized_test_eval,
     ):
         """
         Instantiate object field with arguments.
@@ -103,18 +165,6 @@ class Node:
             Number of iterations (communication steps) for which the model should be trained
         log_dir : str
             Logging directory
-        weights_store_dir : str
-            Directory in which to store model weights
-        test_after : int
-            Number of iterations after which the test loss and accuracy arecalculated
-        train_evaluate_after : int
-            Number of iterations after which the train loss is calculated
-        reset_optimizer : int
-            1 if optimizer should be reset every communication round, else 0
-        centralized_train_eval : bool
-            If set the train set evaluation happens at the node with uid 0
-        centralized_test_eval : bool
-            If set the train set evaluation happens at the node with uid 0
         """
         self.rank = rank
         self.machine_id = machine_id
@@ -122,19 +172,12 @@ class Node:
         self.mapping = mapping
         self.uid = self.mapping.get_uid(rank, machine_id)
         self.log_dir = log_dir
-        self.weights_store_dir = weights_store_dir
         self.iterations = iterations
-        self.test_after = test_after
-        self.train_evaluate_after = train_evaluate_after
-        self.reset_optimizer = reset_optimizer
-        self.centralized_train_eval = centralized_train_eval
-        self.centralized_test_eval = centralized_test_eval
-
-        logging.debug("Rank: %d", self.rank)
-        logging.debug("type(graph): %s", str(type(self.rank)))
-        logging.debug("type(mapping): %s", str(type(self.mapping)))
+        self.sent_disconnections = False
 
-        self.star = Star(self.mapping.get_n_procs())
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
 
     def init_dataset_model(self, dataset_configs):
         """
@@ -243,17 +286,6 @@ class Node:
         comm_class = getattr(comm_module, comm_configs["comm_class"])
         comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
         self.addresses_filepath = comm_params.get("addresses_filepath", None)
-        if self.centralized_test_eval:
-            self.testing_comm = TCP(
-                self.rank,
-                self.machine_id,
-                self.mapping,
-                self.star.n_procs,
-                self.addresses_filepath,
-                offset=self.star.n_procs,
-            )
-            self.testing_comm.connect_neighbors(self.star.neighbors(self.uid))
-
         self.communication = comm_class(
             self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
         )
@@ -294,13 +326,7 @@ class Node:
         config,
         iterations=1,
         log_dir=".",
-        weights_store_dir=".",
         log_level=logging.INFO,
-        test_after=5,
-        train_evaluate_after=1,
-        reset_optimizer=1,
-        centralized_train_eval=False,
-        centralized_test_eval=True,
         *args
     ):
         """
@@ -322,26 +348,16 @@ class Node:
             Number of iterations (communication steps) for which the model should be trained
         log_dir : str
             Logging directory
-        weights_store_dir : str
-            Directory in which to store model weights
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
-        test_after : int
-            Number of iterations after which the test loss and accuracy arecalculated
-        train_evaluate_after : int
-            Number of iterations after which the train loss is calculated
-        reset_optimizer : int
-            1 if optimizer should be reset every communication round, else 0
-        centralized_train_eval : bool
-            If set the train set evaluation happens at the node with uid 0
-        centralized_test_eval : bool
-            If set the train set evaluation happens at the node with uid 0
         args : optional
             Other arguments
 
         """
         logging.info("Started process.")
 
+        self.init_log(log_dir, rank, log_level)
+
         self.cache_fields(
             rank,
             machine_id,
@@ -349,18 +365,17 @@ class Node:
             graph,
             iterations,
             log_dir,
-            weights_store_dir,
-            test_after,
-            train_evaluate_after,
-            reset_optimizer,
-            centralized_train_eval,
-            centralized_test_eval,
         )
-        self.init_log(log_dir, rank, log_level)
         self.init_dataset_model(config["DATASET"])
         self.init_optimizer(config["OPTIMIZER_PARAMS"])
         self.init_trainer(config["TRAIN_PARAMS"])
         self.init_comm(config["COMMUNICATION"])
+
+        self.message_queue = dict()
+
+        self.barrier = set()
+        self.my_neighbors = self.graph.neighbors(self.uid)
+
         self.init_sharing(config["SHARING"])
 
     def run(self):
@@ -368,146 +383,7 @@ class Node:
         Start the decentralized learning
 
         """
-        self.testset = self.dataset.get_testset()
-        self.communication.connect_neighbors(self.graph.neighbors(self.uid))
-        rounds_to_test = self.test_after
-        rounds_to_train_evaluate = self.train_evaluate_after
-        global_epoch = 1
-        change = 1
-        if self.uid == 0:
-            dataset = self.dataset
-            if self.centralized_train_eval:
-                dataset_params_copy = self.dataset_params.copy()
-                if "sizes" in dataset_params_copy:
-                    del dataset_params_copy["sizes"]
-                self.whole_dataset = self.dataset_class(
-                    self.rank,
-                    self.machine_id,
-                    self.mapping,
-                    sizes=[1.0],
-                    **dataset_params_copy
-                )
-                dataset = self.whole_dataset
-            if self.centralized_test_eval:
-                tthelper = TrainTestHelper(
-                    dataset,  # self.whole_dataset,
-                    # self.model_test, # todo: this only works if eval_train is set to false
-                    self.model,
-                    self.loss,
-                    self.weights_store_dir,
-                    self.mapping.get_n_procs(),
-                    self.trainer,
-                    self.testing_comm,
-                    self.star,
-                    self.threads_per_proc,
-                    eval_train=self.centralized_train_eval,
-                )
-
-        for iteration in range(self.iterations):
-            logging.info("Starting training iteration: %d", iteration)
-            self.trainer.train(self.dataset)
-
-            self.sharing.step()
-
-            if self.reset_optimizer:
-                self.optimizer = self.optimizer_class(
-                    self.model.parameters(), **self.optimizer_params
-                )  # Reset optimizer state
-                self.trainer.reset_optimizer(self.optimizer)
-
-            if iteration:
-                with open(
-                    os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
-                    "r",
-                ) as inf:
-                    results_dict = json.load(inf)
-            else:
-                results_dict = {
-                    "train_loss": {},
-                    "test_loss": {},
-                    "test_acc": {},
-                    "total_bytes": {},
-                    "total_meta": {},
-                    "total_data_per_n": {},
-                    "grad_mean": {},
-                    "grad_std": {},
-                }
-
-            results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
-
-            if hasattr(self.communication, "total_meta"):
-                results_dict["total_meta"][
-                    iteration + 1
-                ] = self.communication.total_meta
-            if hasattr(self.communication, "total_data"):
-                results_dict["total_data_per_n"][
-                    iteration + 1
-                ] = self.communication.total_data
-            if hasattr(self.sharing, "mean"):
-                results_dict["grad_mean"][iteration + 1] = self.sharing.mean
-            if hasattr(self.sharing, "std"):
-                results_dict["grad_std"][iteration + 1] = self.sharing.std
-
-            rounds_to_train_evaluate -= 1
-
-            if rounds_to_train_evaluate == 0 and not self.centralized_train_eval:
-                logging.info("Evaluating on train set.")
-                rounds_to_train_evaluate = self.train_evaluate_after * change
-                loss_after_sharing = self.trainer.eval_loss(self.dataset)
-                results_dict["train_loss"][iteration + 1] = loss_after_sharing
-                self.save_plot(
-                    results_dict["train_loss"],
-                    "train_loss",
-                    "Training Loss",
-                    "Communication Rounds",
-                    os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
-                )
-
-            rounds_to_test -= 1
-
-            if self.dataset.__testing__ and rounds_to_test == 0:
-                rounds_to_test = self.test_after * change
-                # ta, tl = self.dataset.test(self.model, self.loss)
-                # self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
-                if self.centralized_test_eval:
-                    if self.uid == 0:
-                        ta, tl, trl = tthelper.train_test_evaluation(iteration)
-                        results_dict["test_acc"][iteration + 1] = ta
-                        results_dict["test_loss"][iteration + 1] = tl
-                        if trl is not None:
-                            results_dict["train_loss"][iteration + 1] = trl
-                    else:
-                        self.testing_comm.send(0, self.model.get_weights())
-                        sender, data = self.testing_comm.receive()
-                        assert sender == 0 and data == "finished"
-                else:
-                    logging.info("Evaluating on test set.")
-                    ta, tl = self.dataset.test(self.model, self.loss)
-                    results_dict["test_acc"][iteration + 1] = ta
-                    results_dict["test_loss"][iteration + 1] = tl
-
-                if global_epoch == 49:
-                    change *= 2
-
-                global_epoch += change
-
-            with open(
-                os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
-            ) as of:
-                json.dump(results_dict, of)
-        if self.model.shared_parameters_counter is not None:
-            logging.info("Saving the shared parameter counts")
-            with open(
-                os.path.join(
-                    self.log_dir, "{}_shared_parameters.json".format(self.rank)
-                ),
-                "w",
-            ) as of:
-                json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
-        self.communication.disconnect_neighbors()
-        logging.info("Storing final weight")
-        self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
-        logging.info("All neighbors disconnected. Process complete!")
+        raise NotImplementedError
 
     def __init__(
         self,
@@ -518,13 +394,7 @@ class Node:
         config,
         iterations=1,
         log_dir=".",
-        weights_store_dir=".",
         log_level=logging.INFO,
-        test_after=5,
-        train_evaluate_after=1,
-        reset_optimizer=1,
-        centralized_train_eval=0,
-        centralized_test_eval=1,
         *args
     ):
         """
@@ -559,28 +429,12 @@ class Node:
         log_dir : str
             Logging directory
         weights_store_dir : str
-            Directory in which to store model weights
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
-        test_after : int
-            Number of iterations after which the test loss and accuracy arecalculated
-        train_evaluate_after : int
-            Number of iterations after which the train loss is calculated
-        reset_optimizer : int
-            1 if optimizer should be reset every communication round, else 0
-        centralized_train_eval : int
-            If set then the train set evaluation happens at the node with uid 0.
-            Note: If it is True then centralized_test_eval needs to be true as well!
-        centralized_test_eval : int
-            If set then the trainset evaluation happens at the node with uid 0
         args : optional
             Other arguments
 
         """
-        centralized_train_eval = centralized_train_eval == 1
-        centralized_test_eval = centralized_test_eval == 1
-        # If centralized_train_eval is True then centralized_test_eval needs to be true as well!
-        assert not centralized_train_eval or centralized_test_eval
 
         total_threads = os.cpu_count()
         self.threads_per_proc = max(
@@ -588,25 +442,17 @@ class Node:
         )
         torch.set_num_threads(self.threads_per_proc)
         torch.set_num_interop_threads(1)
-        self.instantiate(
-            rank,
-            machine_id,
-            mapping,
-            graph,
-            config,
-            iterations,
-            log_dir,
-            weights_store_dir,
-            log_level,
-            test_after,
-            train_evaluate_after,
-            reset_optimizer,
-            centralized_train_eval == 1,
-            centralized_test_eval == 1,
-            *args
-        )
+        # self.instantiate(
+        #     rank,
+        #     machine_id,
+        #     mapping,
+        #     graph,
+        #     config,
+        #     iterations,
+        #     log_dir,
+        #     log_level,
+        #     *args
+        # )
         logging.info(
             "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
         )
-
-        self.run()
diff --git a/src/decentralizepy/node/PeerSampler.py b/src/decentralizepy/node/PeerSampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..1311a26706b2f44303ab3a291850f2641b870b68
--- /dev/null
+++ b/src/decentralizepy/node/PeerSampler.py
@@ -0,0 +1,262 @@
+import importlib
+import logging
+import os
+from collections import deque
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+
+
+class PeerSampler(Node):
+    """
+    This class defines the peer sampling service
+
+    """
+
+    def init_log(self, log_dir, log_level, force=True):
+        """
+        Instantiate Logging.
+
+        Parameters
+        ----------
+        log_dir : str
+            Logging directory
+        rank : rank : int
+            Rank of process local to the machine
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        force : bool
+            Argument to logging.basicConfig()
+
+        """
+        log_file = os.path.join(log_dir, "PeerSampler.log")
+        logging.basicConfig(
+            filename=log_file,
+            format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s",
+            level=log_level,
+            force=force,
+        )
+
+    def cache_fields(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+
+        """
+        self.rank = rank
+        self.machine_id = machine_id
+        self.graph = graph
+        self.mapping = mapping
+        self.uid = self.mapping.get_uid(rank, machine_id)
+        self.log_dir = log_dir
+        self.iterations = iterations
+        self.sent_disconnections = False
+
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
+
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
+        comm_module = importlib.import_module(comm_configs["comm_package"])
+        comm_class = getattr(comm_module, comm_configs["comm_class"])
+        comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
+        self.communication = comm_class(
+            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
+        )
+
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        log_level=logging.INFO,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.init_log(log_dir, log_level)
+
+        self.cache_fields(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+        )
+
+        self.message_queue = dict()
+
+        self.barrier = set()
+
+        self.init_comm(config["COMMUNICATION"])
+        self.my_neighbors = self.graph.get_all_nodes()
+        self.connect_neighbors()
+
+    def get_neighbors(self, node, iteration=None):
+        return self.graph.neighbors(node)
+
+    def receive_server_request(self):
+        return self.receive_channel("SERVER_REQUEST")
+
+    def run(self):
+        """
+        Start the peer-sampling service.
+
+        """
+        while len(self.barrier) > 0:
+            sender, data = self.receive_server_request()
+            if "BYE" in data:
+                logging.debug("Received {} from {}".format("BYE", sender))
+                self.barrier.remove(sender)
+
+            elif "REQUEST_NEIGHBORS" in data:
+                logging.debug("Received {} from {}".format("Request", sender))
+                if "iteration" in data:
+                    resp = {
+                        "NEIGHBORS": self.get_neighbors(sender, data["iteration"]),
+                        "CHANNEL": "PEERS",
+                    }
+                else:
+                    resp = {"NEIGHBORS": self.get_neighbors(sender), "CHANNEL": "PEERS"}
+                self.communication.send(sender, resp)
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        log_level=logging.INFO,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        args : optional
+            Other arguments
+
+        """
+        super().__init__(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            log_level,
+            *args
+        )
+
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            log_level,
+            *args
+        )
+
+        self.run()
+
+        logging.info("Peer Sampler exiting")
diff --git a/src/decentralizepy/node/PeerSamplerDynamic.py b/src/decentralizepy/node/PeerSamplerDynamic.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffee60a526a97d77d9d7d633b86bbccad7568cc7
--- /dev/null
+++ b/src/decentralizepy/node/PeerSamplerDynamic.py
@@ -0,0 +1,100 @@
+import logging
+from collections import deque
+
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.graphs.Regular import Regular
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.PeerSampler import PeerSampler
+
+
+class PeerSamplerDynamic(PeerSampler):
+    """
+    This class defines the peer sampling service
+
+    """
+
+    def get_neighbors(self, node, iteration=None):
+        if iteration != None:
+            if iteration > self.iteration:
+                logging.info(
+                    "iteration, self.iteration: {}, {}".format(
+                        iteration, self.iteration
+                    )
+                )
+                assert iteration == self.iteration + 1
+                self.iteration = iteration
+                self.graphs.append(Regular(self.graph.n_procs, self.graph_degree))
+            return self.graphs[iteration].neighbors(node)
+        else:
+            return self.graph.neighbors(node)
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        log_level=logging.INFO,
+        graph_degree=4,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        args : optional
+            Other arguments
+
+        """
+
+        self.iteration = -1
+        self.graphs = []
+        self.graph_degree = graph_degree
+
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            log_level,
+            *args
+        )
+
+        self.run()
+
+        logging.info("Peer Sampler exiting")
diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
index 17650c1dd49aa220eeeeb9ff2526555b1797cf14..4babfc285f771fb681daeafe3a4c11dbdddbde3e 100644
--- a/src/decentralizepy/sharing/FFT.py
+++ b/src/decentralizepy/sharing/FFT.py
@@ -1,8 +1,6 @@
 import json
 import logging
 import os
-from pathlib import Path
-from time import time
 
 import numpy as np
 import torch
@@ -53,6 +51,9 @@ class FFT(PartialModel):
         save_accumulated="",
         accumulation=True,
         accumulate_averaging_changes=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -111,6 +112,9 @@ class FFT(PartialModel):
             save_accumulated,
             change_transformer_fft,
             accumulate_averaging_changes,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.change_based_selection = change_based_selection
 
@@ -163,7 +167,7 @@ class FFT(PartialModel):
                 self.model.accumulated_changes = torch.zeros_like(
                     self.model.accumulated_changes
                 )
-            return m
+            return self.compress_data(m)
 
         with torch.no_grad():
             topk, indices = self.apply_fft()
@@ -199,7 +203,7 @@ class FFT(PartialModel):
             m["indices"] = indices.numpy().astype(np.int32)
             m["send_partial"] = True
 
-        return m
+        return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -216,6 +220,8 @@ class FFT(PartialModel):
             state_dict of received
 
         """
+        m = self.decompress_data(m)
+
         ret = dict()
         if "send_partial" not in m:
             params = m["params"]
@@ -237,7 +243,7 @@ class FFT(PartialModel):
             ret["send_partial"] = True
         return ret
 
-    def _averaging(self):
+    def _averaging(self, peer_deques):
         """
         Averages the received model with the local model
 
@@ -251,8 +257,12 @@ class FFT(PartialModel):
             pre_share_model = torch.cat(tensors_to_cat, dim=0)
             flat_fft = self.change_transformer(pre_share_model)
 
-            for i, n in enumerate(self.peer_deques):
-                degree, iteration, data = self.peer_deques[n].popleft()
+            for i, n in enumerate(peer_deques):
+                data = peer_deques[n].popleft()
+                degree, iteration = data["degree"], data["iteration"]
+                del data["degree"]
+                del data["iteration"]
+                del data["CHANNEL"]
                 logging.debug(
                     "Averaging model from neighbor {} of iteration {}".format(
                         n, iteration
@@ -268,7 +278,7 @@ class FFT(PartialModel):
                 else:
                     topkf = params
 
-                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                weight = 1 / (max(len(peer_deques), degree) + 1)  # Metro-Hastings
                 weight_total += weight
                 if total is None:
                     total = weight * topkf
@@ -289,3 +299,5 @@ class FFT(PartialModel):
                 start_index = end_index
 
         self.model.load_state_dict(std_dict)
+        self._post_step()
+        self.communication_round += 1
diff --git a/src/decentralizepy/sharing/GrowingAlpha.py b/src/decentralizepy/sharing/GrowingAlpha.py
index 7fe7bf5d86992804b290db1bb0ac55cda1f47299..a13a869b8909a819098ce9371625ac0e5d1da2f9 100644
--- a/src/decentralizepy/sharing/GrowingAlpha.py
+++ b/src/decentralizepy/sharing/GrowingAlpha.py
@@ -1,3 +1,4 @@
+# Deprecated
 import logging
 
 from decentralizepy.sharing.PartialModel import PartialModel
@@ -25,6 +26,9 @@ class GrowingAlpha(PartialModel):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -74,12 +78,15 @@ class GrowingAlpha(PartialModel):
             dict_ordered,
             save_shared,
             metadata_cap,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.init_alpha = init_alpha
         self.max_alpha = max_alpha
         self.k = k
 
-    def step(self):
+    def get_data_to_send(self):
         """
         Perform a sharing step. Implements D-PSGD with alpha increasing as a linear function.
 
@@ -93,4 +100,4 @@ class GrowingAlpha(PartialModel):
             self.communication_round += 1
             return
 
-        super().step()
+        return super().get_data_to_send()
diff --git a/src/decentralizepy/sharing/LowerBoundTopK.py b/src/decentralizepy/sharing/LowerBoundTopK.py
index 6ac532960c2386e45c090a6b5d168f8c77a0365f..9d227b22ba106bf1d3d08b2ed40492f083fcf9c7 100644
--- a/src/decentralizepy/sharing/LowerBoundTopK.py
+++ b/src/decentralizepy/sharing/LowerBoundTopK.py
@@ -24,6 +24,9 @@ class LowerBoundTopK(PartialModel):
         log_dir,
         lower_bound=0.1,
         metro_hastings=True,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
         **kwargs,
     ):
         """
@@ -81,7 +84,9 @@ class LowerBoundTopK(PartialModel):
             model,
             dataset,
             log_dir,
-            **kwargs,
+            compress,
+            compression_package,
+            compression_class**kwargs,
         )
         self.lower_bound = lower_bound
         self.metro_hastings = metro_hastings
@@ -154,6 +159,8 @@ class LowerBoundTopK(PartialModel):
         if "send_partial" not in m:
             return super().deserialized_model(m)
 
+        m = self.decompress_data(m)
+
         with torch.no_grad():
             state_dict = self.model.state_dict()
 
@@ -169,7 +176,7 @@ class LowerBoundTopK(PartialModel):
 
             return T, index_tensor
 
-    def _averaging(self):
+    def _averaging(self, peer_deques):
         """
         Averages the received model with the local model
 
@@ -187,8 +194,12 @@ class LowerBoundTopK(PartialModel):
                 weight_total = 0
                 weight_vector = torch.ones_like(self.init_model)
                 datas = []
-                for i, n in enumerate(self.peer_deques):
-                    degree, iteration, data = self.peer_deques[n].popleft()
+                for i, n in enumerate(peer_deques):
+                    data = peer_deques[n].popleft()
+                    degree, iteration = data["degree"], data["iteration"]
+                    del data["degree"]
+                    del data["iteration"]
+                    del data["CHANNEL"]
                     logging.debug(
                         "Averaging model from neighbor {} of iteration {}".format(
                             n, iteration
@@ -215,3 +226,5 @@ class LowerBoundTopK(PartialModel):
 
             logging.info("new averaging")
             self.model.load_state_dict(total)
+            self._post_step()
+            self.communication_round += 1
diff --git a/src/decentralizepy/sharing/ManualAdapt.py b/src/decentralizepy/sharing/ManualAdapt.py
index dcb94cf1dcc626e1b01b325539e41fd0891da030..9a54eb7ea04e9cebd4109a71be77229e150b2bba 100644
--- a/src/decentralizepy/sharing/ManualAdapt.py
+++ b/src/decentralizepy/sharing/ManualAdapt.py
@@ -1,3 +1,4 @@
+# Deprecated
 import logging
 
 from decentralizepy.sharing.PartialModel import PartialModel
@@ -24,6 +25,9 @@ class ManualAdapt(PartialModel):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -81,11 +85,14 @@ class ManualAdapt(PartialModel):
             dict_ordered,
             save_shared,
             metadata_cap,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.change_alpha = change_alpha[1:]
         self.change_rounds = change_rounds
 
-    def step(self):
+    def get_data_to_send(self):
         """
         Perform a sharing step. Implements D-PSGD with alpha manually given.
 
@@ -101,6 +108,6 @@ class ManualAdapt(PartialModel):
         if self.alpha == 0.0:
             logging.info("Not sending/receiving data (alpha=0.0)")
             self.communication_round += 1
-            return
+            return dict()
 
-        super().step()
+        return super().get_data_to_send()
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 3111e82ce9af9ad8c6aba27311219c823df6e135..b302f5df01d006a5386eb08fe488580ef81e3916 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -34,6 +34,9 @@ class PartialModel(Sharing):
         save_accumulated="",
         change_transformer=identity,
         accumulate_averaging_changes=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -76,7 +79,17 @@ class PartialModel(Sharing):
 
         """
         super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+            rank,
+            machine_id,
+            communication,
+            mapping,
+            graph,
+            model,
+            dataset,
+            log_dir,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha = alpha
         self.dict_ordered = dict_ordered
@@ -129,6 +142,23 @@ class PartialModel(Sharing):
             self.change_transformer(self.init_model).shape[0], dtype=torch.int32
         )
 
+    def compress_data(self, data):
+        result = dict(data)
+        if self.compress:
+            if "indices" in result:
+                result["indices"] = self.compressor.compress(result["indices"])
+            if "params" in result:
+                result["params"] = self.compressor.compress_float(result["params"])
+        return result
+
+    def decompress_data(self, data):
+        if self.compress:
+            if "indices" in data:
+                data["indices"] = self.compressor.decompress(data["indices"])
+            if "params" in data:
+                data["params"] = self.compressor.decompress_float(data["params"])
+        return data
+
     def extract_top_gradients(self):
         """
         Extract the indices and values of the topK gradients.
@@ -220,7 +250,7 @@ class PartialModel(Sharing):
 
             logging.info("Converted dictionary to pickle")
 
-            return m
+            return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -241,6 +271,8 @@ class PartialModel(Sharing):
             return super().deserialized_model(m)
 
         with torch.no_grad():
+            m = self.decompress_data(m)
+
             state_dict = self.model.state_dict()
 
             if not self.dict_ordered:
diff --git a/src/decentralizepy/sharing/RandomAlpha.py b/src/decentralizepy/sharing/RandomAlpha.py
index 1956c29a2b8c81c191bcb46ab2bfcd4e61913262..3bac634263940e986d0fb3160db277a0b94747be 100644
--- a/src/decentralizepy/sharing/RandomAlpha.py
+++ b/src/decentralizepy/sharing/RandomAlpha.py
@@ -28,6 +28,9 @@ class RandomAlpha(PartialModel):
         save_accumulated="",
         change_transformer=identity,
         accumulate_averaging_changes=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -75,14 +78,17 @@ class RandomAlpha(PartialModel):
             save_accumulated,
             change_transformer,
             accumulate_averaging_changes,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha_list = eval(alpha_list)
         random.seed(self.mapping.get_uid(self.rank, self.machine_id))
 
-    def step(self):
+    def get_data_to_send(self):
         """
         Perform a sharing step. Implements D-PSGD with alpha randomly chosen.
 
         """
         self.alpha = random.choice(self.alpha_list)
-        super().step()
+        return super().get_data_to_send()
diff --git a/src/decentralizepy/sharing/RandomAlphaIncremental.py b/src/decentralizepy/sharing/RandomAlphaIncremental.py
index c3b7c0d6fc5339fa74493df15d71e43285ec73f9..96ead3d947336e311c018e9f06d0874cf5418a76 100644
--- a/src/decentralizepy/sharing/RandomAlphaIncremental.py
+++ b/src/decentralizepy/sharing/RandomAlphaIncremental.py
@@ -1,3 +1,4 @@
+# Deprecated
 import random
 
 from decentralizepy.sharing.PartialModel import PartialModel
@@ -24,6 +25,9 @@ class RandomAlphaIncremental(PartialModel):
         metadata_cap=1.0,
         range_start=0.1,
         range_end=0.2,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -67,16 +71,19 @@ class RandomAlphaIncremental(PartialModel):
             dict_ordered,
             save_shared,
             metadata_cap,
+            compress,
+            compression_package,
+            compression_class,
         )
         random.seed(self.mapping.get_uid(self.rank, self.machine_id))
         self.range_start = range_start
         self.range_end = range_end
 
-    def step(self):
+    def get_data_to_send(self):
         """
         Perform a sharing step. Implements D-PSGD with alpha randomly chosen from an increasing range.
 
         """
         self.alpha = round(random.uniform(self.range_start, self.range_end), 2)
         self.range_end = min(1.0, self.range_end + round(random.uniform(0.0, 0.1), 2))
-        super().step()
+        return super().get_data_to_send()
diff --git a/src/decentralizepy/sharing/RandomAlphaWavelet.py b/src/decentralizepy/sharing/RandomAlphaWavelet.py
index 44ea3364bc913042931583d22bbcba78f0fef5be..de2a5e61fff3d0190678c883102fd897bb6f9b21 100644
--- a/src/decentralizepy/sharing/RandomAlphaWavelet.py
+++ b/src/decentralizepy/sharing/RandomAlphaWavelet.py
@@ -29,6 +29,9 @@ class RandomAlpha(Wavelet):
         save_accumulated="",
         accumulation=False,
         accumulate_averaging_changes=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -78,14 +81,17 @@ class RandomAlpha(Wavelet):
             save_accumulated,
             accumulation,
             accumulate_averaging_changes,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha_list = eval(alpha_list)
         random.seed(self.mapping.get_uid(self.rank, self.machine_id))
 
-    def step(self):
+    def get_data_to_send(self):
         """
         Perform a sharing step. Implements D-PSGD with alpha randomly chosen.
 
         """
         self.alpha = random.choice(self.alpha_list)
-        super().step()
+        return super().get_data_to_send()
diff --git a/src/decentralizepy/sharing/RoundRobinPartial.py b/src/decentralizepy/sharing/RoundRobinPartial.py
index c5288a563740df27dab25028e4958201a3fd8875..fbe0179c82f45243febacd516a63c460b78d2734 100644
--- a/src/decentralizepy/sharing/RoundRobinPartial.py
+++ b/src/decentralizepy/sharing/RoundRobinPartial.py
@@ -25,6 +25,9 @@ class RoundRobinPartial(Sharing):
         dataset,
         log_dir,
         alpha=1.0,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -52,7 +55,17 @@ class RoundRobinPartial(Sharing):
 
         """
         super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+            rank,
+            machine_id,
+            communication,
+            mapping,
+            graph,
+            model,
+            dataset,
+            log_dir,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha = alpha
         random.seed(self.mapping.get_uid(rank, machine_id))
@@ -104,7 +117,7 @@ class RoundRobinPartial(Sharing):
 
             logging.info("Converted dictionary to json")
             self.total_data += len(self.communication.encrypt(m["params"]))
-            return m
+            return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -121,9 +134,9 @@ class RoundRobinPartial(Sharing):
             state_dict of received
 
         """
+        m = self.decompress_data(m)
         with torch.no_grad():
             state_dict = self.model.state_dict()
-
             shapes = []
             lens = []
             tensors_to_cat = []
diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py
index 0ad3927a6d7fb80acfa01e120fde1ef8db4a8d77..1becb58cf339f11d03833f846c354ba780de677a 100644
--- a/src/decentralizepy/sharing/Sharing.py
+++ b/src/decentralizepy/sharing/Sharing.py
@@ -1,5 +1,5 @@
+import importlib
 import logging
-from collections import deque
 
 import torch
 
@@ -11,7 +11,18 @@ class Sharing:
     """
 
     def __init__(
-        self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -47,11 +58,6 @@ class Sharing:
         self.communication_round = 0
         self.log_dir = log_dir
 
-        self.peer_deques = dict()
-        self.my_neighbors = self.graph.neighbors(self.uid)
-        for n in self.my_neighbors:
-            self.peer_deques[n] = deque()
-
         self.shapes = []
         self.lens = []
         with torch.no_grad():
@@ -60,38 +66,28 @@ class Sharing:
                 t = v.flatten().numpy()
                 self.lens.append(t.shape[0])
 
-    def received_from_all(self):
-        """
-        Check if all neighbors have sent the current iteration
-
-        Returns
-        -------
-        bool
-            True if required data has been received, False otherwise
-
-        """
-        for _, i in self.peer_deques.items():
-            if len(i) == 0:
-                return False
-        return True
-
-    def get_neighbors(self, neighbors):
-        """
-        Choose which neighbors to share with
-
-        Parameters
-        ----------
-        neighbors : list(int)
-            List of all neighbors
-
-        Returns
-        -------
-        list(int)
-            Neighbors to share with
-
-        """
-        # modify neighbors here
-        return neighbors
+        self.compress = compress
+
+        if compression_package and compression_class:
+            compressor_module = importlib.import_module(compression_package)
+            compressor_class = getattr(compressor_module, compression_class)
+            self.compressor = compressor_class()
+            logging.info(f"Using the {compressor_class} to compress the data")
+        else:
+            assert not self.compress
+
+    def compress_data(self, data):
+        result = dict(data)
+        if self.compress:
+            if "params" in result:
+                result["params"] = self.compressor.compress_float(result["params"])
+        return result
+
+    def decompress_data(self, data):
+        if self.compress:
+            if "params" in data:
+                data["params"] = self.compressor.decompress_float(data["params"])
+        return data
 
     def serialized_model(self):
         """
@@ -111,7 +107,7 @@ class Sharing:
         flat = torch.cat(to_cat)
         data = dict()
         data["params"] = flat.numpy()
-        return data
+        return self.compress_data(data)
 
     def deserialized_model(self, m):
         """
@@ -129,11 +125,14 @@ class Sharing:
 
         """
         state_dict = dict()
+        m = self.decompress_data(m)
         T = m["params"]
         start_index = 0
         for i, key in enumerate(self.model.state_dict()):
             end_index = start_index + self.lens[i]
-            state_dict[key] = torch.from_numpy(T[start_index:end_index].reshape(self.shapes[i]))
+            state_dict[key] = torch.from_numpy(
+                T[start_index:end_index].reshape(self.shapes[i])
+            )
             start_index = end_index
 
         return state_dict
@@ -152,7 +151,7 @@ class Sharing:
         """
         pass
 
-    def _averaging(self):
+    def _averaging(self, peer_deques):
         """
         Averages the received model with the local model
 
@@ -160,15 +159,20 @@ class Sharing:
         with torch.no_grad():
             total = dict()
             weight_total = 0
-            for i, n in enumerate(self.peer_deques):
-                degree, iteration, data = self.peer_deques[n].popleft()
+            for i, n in enumerate(peer_deques):
+                data = peer_deques[n].popleft()
+                degree, iteration = data["degree"], data["iteration"]
+                del data["degree"]
+                del data["iteration"]
+                del data["CHANNEL"]
                 logging.debug(
                     "Averaging model from neighbor {} of iteration {}".format(
                         n, iteration
                     )
                 )
                 data = self.deserialized_model(data)
-                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                # Metro-Hastings
+                weight = 1 / (max(len(peer_deques), degree) + 1)
                 weight_total += weight
                 for key, value in data.items():
                     if key in total:
@@ -180,41 +184,45 @@ class Sharing:
                 total[key] += (1 - weight_total) * value  # Metro-Hastings
 
         self.model.load_state_dict(total)
+        self._post_step()
+        self.communication_round += 1
 
-    def step(self):
-        """
-        Perform a sharing step. Implements D-PSGD.
-
-        """
+    def get_data_to_send(self):
         self._pre_step()
         data = self.serialized_model()
         my_uid = self.mapping.get_uid(self.rank, self.machine_id)
         all_neighbors = self.graph.neighbors(my_uid)
-        iter_neighbors = self.get_neighbors(all_neighbors)
         data["degree"] = len(all_neighbors)
         data["iteration"] = self.communication_round
-        encrypted = self.communication.encrypt(data)
-        for neighbor in iter_neighbors:
-            self.communication.send(neighbor, encrypted, encrypt=False)
-
-        logging.info("Waiting for messages from neighbors")
-        while not self.received_from_all():
-            sender, data = self.communication.receive()
-            logging.debug("Received model from {}".format(sender))
-            degree = data["degree"]
-            iteration = data["iteration"]
-            del data["degree"]
-            del data["iteration"]
-            self.peer_deques[sender].append((degree, iteration, data))
-            logging.info(
-                "Deserialized received model from {} of iteration {}".format(
-                    sender, iteration
-                )
-            )
+        return data
 
-        logging.info("Starting model averaging after receiving from all neighbors")
-        self._averaging()
-        logging.info("Model averaging complete")
+    def _averaging_server(self, peer_deques):
+        """
+        Averages the received models of all working nodes
 
-        self.communication_round += 1
+        """
+        with torch.no_grad():
+            total = dict()
+            for i, n in enumerate(peer_deques):
+                data = peer_deques[n].popleft()
+                degree, iteration = data["degree"], data["iteration"]
+                del data["degree"]
+                del data["iteration"]
+                del data["CHANNEL"]
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(
+                        n, iteration
+                    )
+                )
+                data = self.deserialized_model(data)
+                weight = 1 / len(peer_deques)
+                for key, value in data.items():
+                    if key in total:
+                        total[key] += weight * value
+                    else:
+                        total[key] = weight * value
+
+        self.model.load_state_dict(total)
         self._post_step()
+        self.communication_round += 1
+        return total
diff --git a/src/decentralizepy/sharing/SharingCentrality.py b/src/decentralizepy/sharing/SharingCentrality.py
index f933a0e6e002b7064eccbaa88f92280bdff3f488..8b10f3cc231af4fea49ca2b8a3dbfc0704783ccf 100644
--- a/src/decentralizepy/sharing/SharingCentrality.py
+++ b/src/decentralizepy/sharing/SharingCentrality.py
@@ -1,3 +1,4 @@
+# Deprecated
 import logging
 from collections import deque
 
@@ -188,6 +189,7 @@ class Sharing:
             iteration = data["iteration"]
             del data["degree"]
             del data["iteration"]
+            del data["CHANNEL"]
             self.peer_deques[sender].append((degree, iteration, data))
             logging.info(
                 "Deserialized received model from {} of iteration {}".format(
diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py
index b51cb07ce0345ee339be0fe2e338ffd9ab61b63e..7201d338e2cca8eeab5076e03eec9fb39e425780 100644
--- a/src/decentralizepy/sharing/SubSampling.py
+++ b/src/decentralizepy/sharing/SubSampling.py
@@ -31,6 +31,9 @@ class SubSampling(Sharing):
         metadata_cap=1.0,
         pickle=True,
         layerwise=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -66,7 +69,17 @@ class SubSampling(Sharing):
 
         """
         super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+            rank,
+            machine_id,
+            communication,
+            mapping,
+            graph,
+            model,
+            dataset,
+            log_dir,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha = alpha
         self.dict_ordered = dict_ordered
@@ -215,7 +228,7 @@ class SubSampling(Sharing):
             m["alpha"] = alpha
             m["params"] = subsample.numpy()
 
-            return m
+            return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -235,6 +248,8 @@ class SubSampling(Sharing):
         if self.alpha > self.metadata_cap:  # Share fully
             return super().deserialized_model(m)
 
+        m = self.decompress_data(m)
+
         with torch.no_grad():
             state_dict = self.model.state_dict()
 
diff --git a/src/decentralizepy/sharing/Synchronous.py b/src/decentralizepy/sharing/Synchronous.py
index 2c2d5e76cfa328260b14fcb9cbf2614e7101c751..7fc1c353b577f7d39a8321968c8f930be24d30c6 100644
--- a/src/decentralizepy/sharing/Synchronous.py
+++ b/src/decentralizepy/sharing/Synchronous.py
@@ -1,3 +1,4 @@
+# Deprecated
 import logging
 from collections import deque
 
diff --git a/src/decentralizepy/sharing/TopKNormalized.py b/src/decentralizepy/sharing/TopKNormalized.py
index 15a3caff239e46e48ca5408658354ec3d3d127a5..b281294079e56137d16070acc17904ebc4d24f22 100644
--- a/src/decentralizepy/sharing/TopKNormalized.py
+++ b/src/decentralizepy/sharing/TopKNormalized.py
@@ -31,6 +31,9 @@ class TopKNormalized(PartialModel):
         change_transformer=identity,
         accumulate_averaging_changes=False,
         epsilon=0.01,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -91,6 +94,9 @@ class TopKNormalized(PartialModel):
             save_accumulated,
             change_transformer,
             accumulate_averaging_changes,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.epsilon = epsilon
 
diff --git a/src/decentralizepy/sharing/TopKParams.py b/src/decentralizepy/sharing/TopKParams.py
index f1881798e91ff7cacb114f4071acc97ba81530e7..c2b0e3fcd24856aa63bc9c09241f8e728a997dfe 100644
--- a/src/decentralizepy/sharing/TopKParams.py
+++ b/src/decentralizepy/sharing/TopKParams.py
@@ -29,6 +29,9 @@ class TopKParams(Sharing):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -62,7 +65,17 @@ class TopKParams(Sharing):
 
         """
         super().__init__(
-            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
+            rank,
+            machine_id,
+            communication,
+            mapping,
+            graph,
+            model,
+            dataset,
+            log_dir,
+            compress,
+            compression_package,
+            compression_class,
         )
         self.alpha = alpha
         self.dict_ordered = dict_ordered
@@ -171,7 +184,7 @@ class TopKParams(Sharing):
 
             logging.info("Converted dictionary to json")
 
-            return m
+            return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -191,6 +204,8 @@ class TopKParams(Sharing):
         if self.alpha > self.metadata_cap:  # Share fully
             return super().deserialized_model(m)
 
+        m = self.decompress_data(m)
+
         with torch.no_grad():
             state_dict = self.model.state_dict()
 
diff --git a/src/decentralizepy/sharing/TopKPlusRandom.py b/src/decentralizepy/sharing/TopKPlusRandom.py
index 728d5bfa48d71037a6d165d6396343f46dfa4e3e..8962933b4d94505ce43d00e53b19857129cf1f7f 100644
--- a/src/decentralizepy/sharing/TopKPlusRandom.py
+++ b/src/decentralizepy/sharing/TopKPlusRandom.py
@@ -26,6 +26,9 @@ class TopKPlusRandom(PartialModel):
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -71,6 +74,9 @@ class TopKPlusRandom(PartialModel):
             dict_ordered,
             save_shared,
             metadata_cap,
+            compress,
+            compression_package,
+            compression_class,
         )
 
     def extract_top_gradients(self):
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index 91c97d0a5d71f20c9eac79405b066f97f087f0bc..7fff1641650025dd8c0480957056e130bc7cab99 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -1,8 +1,6 @@
 import json
 import logging
 import os
-from pathlib import Path
-from time import time
 
 import numpy as np
 import pywt
@@ -61,6 +59,9 @@ class Wavelet(PartialModel):
         save_accumulated="",
         accumulation=False,
         accumulate_averaging_changes=False,
+        compress=False,
+        compression_package=None,
+        compression_class=None,
     ):
         """
         Constructor
@@ -125,6 +126,9 @@ class Wavelet(PartialModel):
             save_accumulated,
             lambda x: change_transformer_wavelet(x, wavelet, level),
             accumulate_averaging_changes,
+            compress,
+            compression_package,
+            compression_class,
         )
 
         self.change_based_selection = change_based_selection
@@ -185,7 +189,7 @@ class Wavelet(PartialModel):
                 self.model.accumulated_changes = torch.zeros_like(
                     self.model.accumulated_changes
                 )
-            return m
+            return self.compress_data(m)
 
         with torch.no_grad():
             topk, indices = self.apply_wavelet()
@@ -199,7 +203,8 @@ class Wavelet(PartialModel):
                     shapes[k] = list(v.shape)
                 shared_params["shapes"] = shapes
 
-                shared_params[self.communication_round] = indices.tolist()  # is slow
+                # is slow
+                shared_params[self.communication_round] = indices.tolist()
 
                 shared_params["alpha"] = self.alpha
 
@@ -223,7 +228,7 @@ class Wavelet(PartialModel):
 
             m["send_partial"] = True
 
-            return m
+            return self.compress_data(m)
 
     def deserialized_model(self, m):
         """
@@ -240,6 +245,7 @@ class Wavelet(PartialModel):
             state_dict of received
 
         """
+        m = self.decompress_data(m)
         ret = dict()
         if "send_partial" not in m:
             params = m["params"]
@@ -260,7 +266,7 @@ class Wavelet(PartialModel):
             ret["send_partial"] = True
         return ret
 
-    def _averaging(self):
+    def _averaging(self, peer_deques):
         """
         Averages the received model with the local model
 
@@ -269,8 +275,12 @@ class Wavelet(PartialModel):
             total = None
             weight_total = 0
             wt_params = self.pre_share_model_transformed
-            for i, n in enumerate(self.peer_deques):
-                degree, iteration, data = self.peer_deques[n].popleft()
+            for i, n in enumerate(peer_deques):
+                data = peer_deques[n].popleft()
+                degree, iteration = data["degree"], data["iteration"]
+                del data["degree"]
+                del data["iteration"]
+                del data["CHANNEL"]
                 logging.debug(
                     "Averaging model from neighbor {} of iteration {}".format(
                         n, iteration
@@ -287,7 +297,8 @@ class Wavelet(PartialModel):
                 else:
                     topkwf = params.reshape(self.wt_shape)
 
-                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                # Metro-Hastings
+                weight = 1 / (max(len(peer_deques), degree) + 1)
                 weight_total += weight
                 if total is None:
                     total = weight * topkwf
@@ -314,3 +325,61 @@ class Wavelet(PartialModel):
                 start_index = end_index
 
         self.model.load_state_dict(std_dict)
+        self._post_step()
+        self.communication_round += 1
+
+    def _averaging_server(self, peer_deques):
+        """
+        Averages the received models of all working nodes
+
+        """
+        with torch.no_grad():
+            total = None
+            wt_params = self.pre_share_model_transformed
+            for i, n in enumerate(peer_deques):
+                data = peer_deques[n].popleft()
+                degree, iteration = data["degree"], data["iteration"]
+                del data["degree"]
+                del data["iteration"]
+                del data["CHANNEL"]
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(
+                        n, iteration
+                    )
+                )
+                data = self.deserialized_model(data)
+                params = data["params"]
+                if "indices" in data:
+                    indices = data["indices"]
+                    # use local data to complement
+                    topkwf = wt_params.clone().detach()
+                    topkwf[indices] = params
+                    topkwf = topkwf.reshape(self.wt_shape)
+                else:
+                    topkwf = params.reshape(self.wt_shape)
+
+                weight = 1 / len(peer_deques)
+                if total is None:
+                    total = weight * topkwf
+                else:
+                    total += weight * topkwf
+
+            avg_wf_params = pywt.array_to_coeffs(
+                total.numpy(), self.coeff_slices, output_format="wavedec"
+            )
+            reverse_total = torch.from_numpy(
+                pywt.waverec(avg_wf_params, wavelet=self.wavelet)
+            )
+
+            start_index = 0
+            std_dict = {}
+            for i, key in enumerate(self.model.state_dict()):
+                end_index = start_index + self.lens[i]
+                std_dict[key] = reverse_total[start_index:end_index].reshape(
+                    self.shapes[i]
+                )
+                start_index = end_index
+
+        self.model.load_state_dict(std_dict)
+        self._post_step()
+        self.communication_round += 1
diff --git a/src/decentralizepy/train_test_evaluation.py b/src/decentralizepy/train_test_evaluation.py
deleted file mode 100644
index 319d308f49dc95c44c4a46478348d45077bd62f6..0000000000000000000000000000000000000000
--- a/src/decentralizepy/train_test_evaluation.py
+++ /dev/null
@@ -1,95 +0,0 @@
-import logging
-import os
-import pickle
-from pathlib import Path
-
-import numpy as np
-import torch
-
-from decentralizepy.graphs import Graph
-
-
-class TrainTestHelper:
-    def __init__(
-        self,
-        dataset,
-        model,
-        loss,
-        dir,
-        n_procs,
-        trainer,
-        comm,
-        graph: Graph,
-        threads_per_proc,
-        eval_train=False,
-    ):
-        self.dataset = dataset
-        self.model = model
-        self.loss = loss
-        self.dir = Path(dir)
-        self.n_procs = n_procs
-        self.trainer = trainer
-        self.comm = comm
-        self.star = graph
-        self.threads_per_proc = threads_per_proc
-        self.eval_train = eval_train
-
-    def train_test_evaluation(self, iteration):
-        with torch.no_grad():
-            self.model.eval()
-            total_threads = os.cpu_count()
-            torch.set_num_threads(total_threads)
-
-            neighbors = self.star.neighbors(0)
-            state_dict_copy = {}
-            shapes = []
-            lens = []
-            to_cat = []
-            for key, val in self.model.state_dict().items():
-                shapes.append(val.shape)
-                clone_val = val.clone().detach()
-                state_dict_copy[key] = clone_val
-                flat = clone_val.flatten()
-                to_cat.append(flat)
-                lens.append(flat.shape[0])
-
-            my_weight = torch.cat(to_cat)
-            weights = [my_weight]
-            # TODO: add weight of node 0
-            for i in neighbors:
-                sender, data = self.comm.receive()
-                logging.info(f"Received weight from {sender}")
-                weights.append(data)
-
-            # averaging
-            average_weight = np.mean([w.numpy() for w in weights], axis=0)
-
-            start_index = 0
-            average_weight_dict = {}
-            for i, key in enumerate(state_dict_copy):
-                end_index = start_index + lens[i]
-                average_weight_dict[key] = torch.from_numpy(
-                    average_weight[start_index:end_index].reshape(shapes[i])
-                )
-                start_index = end_index
-            self.model.load_state_dict(average_weight_dict)
-            if self.eval_train:
-                logging.info("Evaluating on train set.")
-                trl = self.trainer.eval_loss(self.dataset)
-            else:
-                trl = None
-            logging.info("Evaluating on test set.")
-            ta, tl = self.dataset.test(self.model, self.loss)
-            # reload old weight
-            self.model.load_state_dict(state_dict_copy)
-
-            if trl is not None:
-                print(iteration, ta, tl, trl)
-            else:
-                print(iteration, ta, tl)
-
-            torch.set_num_threads(self.threads_per_proc)
-            for neighbor in neighbors:
-                self.comm.send(neighbor, "finished")
-            self.model.train()
-        return ta, tl, trl
diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py
index 3c8c8581a4d14144d00280f7c0369030b7688026..0d5bda39874ad348d63bc592c702b17a5e5f2290 100644
--- a/src/decentralizepy/utils.py
+++ b/src/decentralizepy/utils.py
@@ -83,8 +83,9 @@ def get_args():
     parser.add_argument("-ta", "--test_after", type=int, default=5)
     parser.add_argument("-tea", "--train_evaluate_after", type=int, default=1)
     parser.add_argument("-ro", "--reset_optimizer", type=int, default=1)
-    parser.add_argument("-ctr", "--centralized_train_eval", type=int, default=0)
-    parser.add_argument("-cte", "--centralized_test_eval", type=int, default=1)
+    parser.add_argument("-sm", "--server_machine", type=int, default=0)
+    parser.add_argument("-sr", "--server_rank", type=int, default=-1)
+    parser.add_argument("-wr", "--working_rate", type=float, default=1.0)
 
     args = parser.parse_args()
     return args
@@ -116,8 +117,7 @@ def write_args(args, path):
         "test_after": args.test_after,
         "train_evaluate_after": args.train_evaluate_after,
         "reset_optimizer": args.reset_optimizer,
-        "centralized_train_eval": args.centralized_train_eval,
-        "centralized_test_eval": args.centralized_test_eval,
+        "working_rate": args.working_rate,
     }
     with open(os.path.join(path, "args.json"), "w") as of:
         json.dump(data, of)
diff --git a/tutorial/config_celeba_sharing.ini b/tutorial/config_celeba_sharing.ini
new file mode 100644
index 0000000000000000000000000000000000000000..020d006bef446c5a642f4d29972664bc17a141c7
--- /dev/null
+++ b/tutorial/config_celeba_sharing.ini
@@ -0,0 +1,35 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.CIFAR10
+dataset_class = CIFAR10
+model_class = LeNet
+train_dir = /mnt/nfs/shared/CIFAR
+test_dir = /mnt/nfs/shared/CIFAR
+; python list of fractions below
+sizes = 
+random_seed = 90
+partition_niid = True
+shards = 4
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = SGD
+lr = 0.01
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 3
+full_epochs = False
+batch_size = 8
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = /mnt/nfs/risharma/Gitlab/tutorial/ip.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.Sharing
+sharing_class = Sharing
diff --git a/tutorial/ip.json b/tutorial/ip.json
new file mode 100644
index 0000000000000000000000000000000000000000..15d6591df53574707ac03627fa19c9ecd749b1e3
--- /dev/null
+++ b/tutorial/ip.json
@@ -0,0 +1,3 @@
+{
+    "0": "127.0.0.1"
+}
\ No newline at end of file
diff --git a/tutorial/regular_16.txt b/tutorial/regular_16.txt
new file mode 100644
index 0000000000000000000000000000000000000000..32800c8d53f7102ffd494d4ab2971743b7697703
--- /dev/null
+++ b/tutorial/regular_16.txt
@@ -0,0 +1,49 @@
+16
+0 12
+0 14
+0 15
+1 8
+1 3
+1 6
+2 9
+2 10
+2 5
+3 1
+3 11
+3 9
+4 9
+4 12
+4 13
+5 2
+5 6
+5 7
+6 1
+6 5
+6 7
+7 5
+7 6
+7 14
+8 1
+8 13
+8 14
+9 2
+9 3
+9 4
+10 2
+10 11
+10 13
+11 10
+11 3
+11 15
+12 0
+12 4
+12 15
+13 8
+13 10
+13 4
+14 0
+14 8
+14 7
+15 0
+15 11
+15 12
diff --git a/tutorial/run_decentralized.sh b/tutorial/run_decentralized.sh
new file mode 100755
index 0000000000000000000000000000000000000000..692bb430b2032bbd7e1109a27a97aa3225165f1e
--- /dev/null
+++ b/tutorial/run_decentralized.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+decpy_path=/mnt/nfs/risharma/Gitlab/decentralizepy/eval
+cd $decpy_path
+
+env_python=~/miniconda3/envs/decpy/bin/python3
+graph=/mnt/nfs/risharma/Gitlab/tutorial/96_regular.edges
+original_config=/mnt/nfs/risharma/Gitlab/tutorial/config_celeba_sharing.ini
+config_file=~/tmp/config.ini
+procs_per_machine=16
+machines=1
+iterations=80
+test_after=20
+eval_file=testingPeerSampler.py
+log_level=INFO
+
+m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
+echo M is $m
+log_dir=$(date '+%Y-%m-%dT%H:%M')/machine$m
+mkdir -p $log_dir
+
+cp $original_config $config_file
+# echo "alpha = 0.10" >> $config_file
+$env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -wsd $log_dir
\ No newline at end of file
diff --git a/tutorial/run_federated.sh b/tutorial/run_federated.sh
new file mode 100755
index 0000000000000000000000000000000000000000..5113ed79073cf94fb99a2b66597bd600007fd45c
--- /dev/null
+++ b/tutorial/run_federated.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+decpy_path=/mnt/nfs/risharma/Gitlab/decentralizepy/eval
+cd $decpy_path
+
+env_python=~/miniconda3/envs/decpy/bin/python3
+graph=/mnt/nfs/risharma/Gitlab/tutorial/96_regular.edges
+original_config=/mnt/nfs/risharma/Gitlab/tutorial/config_celeba_sharing.ini
+config_file=~/tmp/config.ini
+procs_per_machine=16
+machines=1
+iterations=80
+test_after=20
+eval_file=testingFederated.py
+log_level=INFO
+server_rank=-1
+server_machine=0
+working_rate=0.5
+
+m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
+echo M is $m
+log_dir=$(date '+%Y-%m-%dT%H:%M')/machine$m
+mkdir -p $log_dir
+
+cp $original_config $config_file
+# echo "alpha = 0.10" >> $config_file
+$env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -ctr 0 -cte 0 -wsd $log_dir -sm $server_machine -sr $server_rank -wr $working_rate
\ No newline at end of file