diff --git a/eval/plot.py b/eval/plot.py
index 35f16fca45143895fb766ae351293849d4697f74..f119053db0f2ce90f3a09791a13c5f83d4fd7504 100644
--- a/eval/plot.py
+++ b/eval/plot.py
@@ -67,7 +67,10 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
                 filepath = os.path.join(mf_path, f)
                 with open(filepath, "r") as inf:
                     results.append(json.load(inf))
-
+        if folder.startswith("FL"):
+            data_node = -1
+        else:
+            data_node = 0
         with open(folder_path / data_machine / f"{data_node}_results.json", "r") as f:
             main_data = json.load(f)
         main_data = [main_data]
@@ -258,5 +261,6 @@ if __name__ == "__main__":
     # The args are:
     # 1: the folder with the data
     # 2: True/False: If True then the evaluation on the test set was centralized
+    # for federated learning folder name must start with "FL"!
     plot_results(sys.argv[1], sys.argv[2])
     # plot_parameters(sys.argv[1])
diff --git a/eval/run.sh b/eval/run.sh
index ff6392f81283aa075f7db62ba6656c30d714455e..b96f8b8079fd3825d25ac6790a4e31035c1e8b94 100755
--- a/eval/run.sh
+++ b/eval/run.sh
@@ -1,17 +1,18 @@
 #!/bin/bash
 
-decpy_path=/mnt/nfs/risharma/Gitlab/decentralizepy/eval
+decpy_path=/mnt/nfs/kirsten/Gitlab/jac_decentralizepy/decentralizepy/eval
 cd $decpy_path
 
 env_python=~/miniconda3/envs/decpy/bin/python3
-graph=/mnt/nfs/risharma/Gitlab/tutorial/regular_16.txt
-original_config=/mnt/nfs/risharma/Gitlab/tutorial/config_celeba_sharing.ini
-config_file=~/tmp/config.ini
+graph=/mnt/nfs/kirsten/Gitlab/tutorial/regular_16.txt
+original_config=/mnt/nfs/kirsten/Gitlab/tutorial/config_celeba_sharing.ini
+config_file=~/tmp/config_celeba_sharing.ini
 procs_per_machine=8
 machines=2
 iterations=5
 test_after=2
-eval_file=testing.py
+eval_file=testingFederated.py
+#eval_file=testingPeerSampler.py
 log_level=INFO
 
 m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
diff --git a/eval/run_xtimes.sh b/eval/run_xtimes.sh
index 3176f5357cc53ae6d251cb8136821e09caa4cbf4..4ca373a942b989d046933610e385d3ab95d40c81 100755
--- a/eval/run_xtimes.sh
+++ b/eval/run_xtimes.sh
@@ -1,5 +1,6 @@
 #!/bin/bash
 # Documentation
+# Note: documentation was not written for this run file, so actual behaviour may differ
 # This bash file takes three inputs. The first argument (nfs_home) is the path to the nfs home directory.
 # The second one (python_bin) is the path to the python bin folder.
 # The last argument (logs_subfolder) is the path to the logs folder with respect to the nfs home directory.
@@ -18,8 +19,10 @@
 # Each node needs a folder called 'tmp' in the user's home directory
 #
 # Note:
-# - The script does not change the optimizer. All configs are writen to use SGD.
-# - The script will set '--test_after' and '--train_evaluate_after' such that it happens at the end of a global epoch.
+# - The script does not change the optimizer. All configs are writen to use Adam.
+#   For SGD these need to be changed manually
+# - The script will set '--test_after' and '--train_evaluate_after' to comm_rounds_per_global_epoch, i.e., the eavaluation
+#   on the train set and on the test set is carried out every global epoch.
 # - The '--reset_optimizer' option is set to 0, i.e., the optimizer is not reset after a communication round (only
 #   relevant for Adams and other optimizers with internal state)
 #
@@ -37,41 +40,40 @@ decpy_path=$nfs_home/decentralizepy/eval
 cd $decpy_path
 
 env_python=$python_bin/python3
-graph=96_regular.edges
 config_file=~/tmp/config.ini
 procs_per_machine=16
 machines=6
 global_epochs=150
-eval_file=testing.py
+eval_file=testingPeerSampler.py
 log_level=INFO
-
-ip_machines=$nfs_home/configs/ip_addr_6Machines.json
-
+ip_machines=$nfs_home/$logs_subfolder/ip_addr_6Machines.json
 m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
 export PYTHONFAULTHANDLER=1
-
 # Base configs for which the gird search is done
-tests=("step_configs/config_femnist_partialmodel.ini" "step_configs/config_femnist_topkacc.ini" "step_configs/config_femnist_wavelet.ini")
+tests="$nfs_home/$logs_subfolder/config.ini"
+#tests=("$nfs_home/$logs_subfolder/config_cifar_sharing.ini" "$nfs_home/$logs_subfolder/config_cifar_partialmodel.ini" "$nfs_home/$logs_subfolder/config_cifar_topkacc.ini" "$nfs_home/$logs_subfolder/config_cifar_topkaccRandomAlpha.ini" "$nfs_home/$logs_subfolder/config_cifar_subsampling.ini" "$nfs_home/$logs_subfolder/config_cifar_wavelet.ini" "$nfs_home/$logs_subfolder/config_cifar_waveletRandomAlpha.ini")
+#tests=("$nfs_home/$logs_subfolder/config_cifar_partialmodel.ini" "$nfs_home/$logs_subfolder/config_cifar_topkacc.ini" "$nfs_home/$logs_subfolder/config_cifar_topkaccRandomAlpha.ini" "$nfs_home/$logs_subfolder/config_cifar_subsampling.ini" "$nfs_home/$logs_subfolder/config_cifar_wavelet.ini" "$nfs_home/$logs_subfolder/config_cifar_waveletRandomAlpha.ini")
+#tests=("$nfs_home/$logs_subfolder/config_cifar_subsampling.ini" "$nfs_home/$logs_subfolder/config_cifar_sharing.ini" "$nfs_home/$logs_subfolder/config_cifar_waveletRandomAlpha.ini")
+#tests=("$nfs_home/$logs_subfolder/config_cifar_waveletRandomAlpha.ini")
 # Learning rates
-lr="0.001"
+lr="0.01"
 # Batch size
-batchsize="16"
+batchsize="8"
 # The number of communication rounds per global epoch
-comm_rounds_per_global_epoch="1"
+comm_rounds_per_global_epoch="20"
 procs=`expr $procs_per_machine \* $machines`
 echo procs: $procs
 # Celeba has 63741 samples
 # Reddit has 70642
 # Femnist 734463
-# Shakespeares 3678451, subsampled 678696
-# cifar 50000
-dataset_size=734463
+# Shakespeares 3678451
+dataset_size=50000
 # Calculating the number of samples that each user/proc will have on average
 samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
-
 # random_seeds for which to rerun the experiments
-random_seeds=("97")
+# random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("94")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -85,10 +87,10 @@ echo iterations: $iterations
 batches_per_comm_round=$($env_python -c "from math import floor; x = floor($batches_per_epoch / $comm_rounds_per_global_epoch); print(1 if x==0 else x)")
 # since the batches per communication round were rounded down we need to change the number of iterations to reflect that
 new_iterations=$($env_python -c "from math import floor; tmp = floor($batches_per_epoch / $comm_rounds_per_global_epoch); x = 1 if tmp == 0 else tmp; y = floor((($batches_per_epoch / $comm_rounds_per_global_epoch)/x)*$iterations); print($iterations if y<$iterations else y)")
-echo batches per communication round: $batches_per_comm_round
-echo corrected iterations: $new_iterations
 test_after=$(($new_iterations / $global_epochs))
 echo test after: $test_after
+echo batches per communication round: $batches_per_comm_round
+echo corrected iterations: $new_iterations
 for i in "${tests[@]}"
 do
   for seed in "${random_seeds[@]}"
@@ -96,9 +98,14 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')/machine$m
-    echo results are stored in: $log_dir
+    #log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    echo results are stored in: $log_dir_base
+    log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
+    weight_store_dir=$log_dir_base/weights
+    mkdir -p $weight_store_dir
+    graph=$nfs_home/decentralizepy/eval/96_regular.edges
     cp $i $config_file
     # changing the config files to reflect the values of the current grid search state
     $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
@@ -106,10 +113,14 @@ do
     $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round
     $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize
     $python_bin/crudini --set $config_file DATASET random_seed $seed
+    $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
+    $python_bin/crudini --set $config_file COMMUNICATION offset 10720
+    # $env_python $eval_file -cte 0 -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
     $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+
     echo $i is done
     sleep 200
     echo end of sleep
     done
 done
-#
\ No newline at end of file
+#
diff --git a/eval/run_xtimes_celeba.sh b/eval/run_xtimes_celeba.sh
index c6d3e060c4fd60115e40e8b99bb39560b4cd1fe8..7eab580ffeda45780772b2f287b4e0efd6e3f588 100755
--- a/eval/run_xtimes_celeba.sh
+++ b/eval/run_xtimes_celeba.sh
@@ -38,11 +38,11 @@ cd $decpy_path
 
 env_python=$python_bin/python3
 graph=96_regular.edges
-config_file=~/tmp/config.ini
+config_file=~/tmp/config_celeba_sharing.ini
 procs_per_machine=16
 machines=6
 global_epochs=150
-eval_file=testing.py
+eval_file=testingFederated.py
 log_level=INFO
 
 ip_machines=$nfs_home/configs/ip_addr_6Machines.json
@@ -51,7 +51,8 @@ m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print
 export PYTHONFAULTHANDLER=1
 
 # Base configs for which the gird search is done
-tests=("step_configs/config_celeba_sharing.ini" "step_configs/config_celeba_partialmodel.ini" "step_configs/config_celeba_topkacc.ini" "step_configs/config_celeba_subsampling.ini" "step_configs/config_celeba_wavelet.ini")
+#tests=("step_configs/config_celeba_sharing.ini" "step_configs/config_celeba_partialmodel.ini" "step_configs/config_celeba_topkacc.ini" "step_configs/config_celeba_subsampling.ini" "step_configs/config_celeba_wavelet.ini")
+tests=("step_configs/config_celeba_sharing.ini")
 # Learning rates
 lr="0.001"
 # Batch size
@@ -66,7 +67,8 @@ samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
 
 # random_seeds for which to rerun the experiments
-random_seeds=("90" "91" "92" "93" "94")
+#random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("90")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -91,7 +93,7 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
     echo results are stored in: $log_dir_base
     log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
diff --git a/eval/run_xtimes_cifar.sh b/eval/run_xtimes_cifar.sh
index 0c545c651e5150f42332ebf81cc2baae4c0f5ef6..038fef102862f53fb3f89cef3612b23a05b46dba 100755
--- a/eval/run_xtimes_cifar.sh
+++ b/eval/run_xtimes_cifar.sh
@@ -41,8 +41,8 @@ graph=96_regular.edges
 config_file=~/tmp/config.ini
 procs_per_machine=16
 machines=6
-global_epochs=300
-eval_file=testing.py
+global_epochs=100
+eval_file=testingFederated.py
 log_level=INFO
 
 ip_machines=$nfs_home/configs/ip_addr_6Machines.json
@@ -51,7 +51,7 @@ m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print
 export PYTHONFAULTHANDLER=1
 
 # Base configs for which the gird search is done
-tests=("step_configs/config_cifar_sharing.ini" "step_configs/config_cifar_partialmodel.ini" "step_configs/config_cifar_topkacc.ini" "step_configs/config_cifar_subsampling.ini" "step_configs/config_cifar_wavelet.ini")
+tests=("step_configs/config_cifar_sharing.ini")
 # Learning rates
 lr="0.01"
 # Batch size
@@ -66,7 +66,7 @@ samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
 
 # random_seeds for which to rerun the experiments
-random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("90")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -91,7 +91,7 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
     echo results are stored in: $log_dir_base
     log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
diff --git a/eval/run_xtimes_femnist.sh b/eval/run_xtimes_femnist.sh
index c5e18a2ab958997e08b96fc4682c1b64f021cb25..cd65a1297dfc5173cc22e5cad3987e5bedad0551 100755
--- a/eval/run_xtimes_femnist.sh
+++ b/eval/run_xtimes_femnist.sh
@@ -38,11 +38,11 @@ cd $decpy_path
 
 env_python=$python_bin/python3
 graph=96_regular.edges
-config_file=~/tmp/config.ini
+config_file=~/tmp/config_femnist_sharing.ini
 procs_per_machine=16
 machines=6
 global_epochs=80
-eval_file=testing.py
+eval_file=testingFederated.py
 log_level=INFO
 
 ip_machines=$nfs_home/configs/ip_addr_6Machines.json
@@ -51,7 +51,8 @@ m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print
 export PYTHONFAULTHANDLER=1
 
 # Base configs for which the gird search is done
-tests=("step_configs/config_femnist_sharing.ini" "step_configs/config_femnist_partialmodel.ini" "step_configs/config_femnist_topkacc.ini" "step_configs/config_femnist_subsampling.ini" "step_configs/config_femnist_wavelet.ini")
+#tests=("step_configs/config_femnist_sharing.ini" "step_configs/config_femnist_partialmodel.ini" "step_configs/config_femnist_topkacc.ini" "step_configs/config_femnist_subsampling.ini" "step_configs/config_femnist_wavelet.ini")
+tests=("step_configs/config_femnist_sharing.ini")
 # Learning rates
 lr="0.01"
 # Batch size
@@ -66,7 +67,8 @@ samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
 
 # random_seeds for which to rerun the experiments
-random_seeds=("90" "91" "92" "93" "94")
+#random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("90")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -91,7 +93,7 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
     echo results are stored in: $log_dir_base
     log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
diff --git a/eval/run_xtimes_reddit.sh b/eval/run_xtimes_reddit.sh
index 589f52a2978107a3efac300b592048adb09f1717..2f738c0c60ad8d949a1242481b4e918c2286d099 100755
--- a/eval/run_xtimes_reddit.sh
+++ b/eval/run_xtimes_reddit.sh
@@ -38,11 +38,11 @@ cd $decpy_path
 
 env_python=$python_bin/python3
 graph=96_regular.edges
-config_file=~/tmp/config.ini
+config_file=~/tmp/config_reddit_sharing.ini
 procs_per_machine=16
 machines=6
 global_epochs=50
-eval_file=testing.py
+eval_file=testingFederated.py
 log_level=INFO
 
 ip_machines=$nfs_home/configs/ip_addr_6Machines.json
@@ -51,7 +51,8 @@ m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print
 export PYTHONFAULTHANDLER=1
 
 # Base configs for which the gird search is done
-tests=("step_configs/config_reddit_sharing.ini" "step_configs/config_reddit_partialmodel.ini" "step_configs/config_reddit_topkacc.ini" "step_configs/config_reddit_subsampling.ini" "step_configs/config_reddit_wavelet.ini")
+#tests=("step_configs/config_reddit_sharing.ini" "step_configs/config_reddit_partialmodel.ini" "step_configs/config_reddit_topkacc.ini" "step_configs/config_reddit_subsampling.ini" "step_configs/config_reddit_wavelet.ini")
+tests=("step_configs/config_reddit_sharing.ini")
 # Learning rates
 lr="1"
 # Batch size
@@ -66,7 +67,8 @@ samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
 
 # random_seeds for which to rerun the experiments
-random_seeds=("90" "91" "92" "93" "94")
+#random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("90")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -91,7 +93,7 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
     echo results are stored in: $log_dir_base
     log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
diff --git a/eval/run_xtimes_shakespeare.sh b/eval/run_xtimes_shakespeare.sh
index 8e268a18248eb789bc91a6fb68d2183d6a1b96b3..b6c6d6c0174f0ec76809adb42797df612991338c 100755
--- a/eval/run_xtimes_shakespeare.sh
+++ b/eval/run_xtimes_shakespeare.sh
@@ -38,11 +38,11 @@ cd $decpy_path
 
 env_python=$python_bin/python3
 graph=96_regular.edges
-config_file=~/tmp/config.ini
+config_file=~/tmp/config_shakespeare_sharing.ini
 procs_per_machine=16
 machines=6
-global_epochs=200
-eval_file=testing.py
+global_epochs=100
+eval_file=testingFederated.py
 log_level=INFO
 
 ip_machines=$nfs_home/configs/ip_addr_6Machines.json
@@ -51,7 +51,8 @@ m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print
 export PYTHONFAULTHANDLER=1
 
 # Base configs for which the gird search is done
-tests=("step_configs/config_shakespeare_sharing.ini" "step_configs/config_shakespeare_partialmodel.ini" "step_configs/config_shakespeare_topkacc.ini" "step_configs/config_shakespeare_subsampling.ini" "step_configs/config_shakespeare_wavelet.ini")
+#tests=("step_configs/config_shakespeare_sharing.ini" "step_configs/config_shakespeare_partialmodel.ini" "step_configs/config_shakespeare_topkacc.ini" "step_configs/config_shakespeare_subsampling.ini" "step_configs/config_shakespeare_wavelet.ini")
+tests=("step_configs/config_shakespeare_sharing.ini")
 # Learning rates
 lr="0.5"
 # Batch size
@@ -66,7 +67,8 @@ samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
 
 # random_seeds for which to rerun the experiments
-random_seeds=("90" "91" "92" "93" "94")
+#random_seeds=("90" "91" "92" "93" "94")
+random_seeds=("90")
 # random_seed = 97
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
@@ -91,7 +93,7 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    log_dir_base=$nfs_home/$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
     echo results are stored in: $log_dir_base
     log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
diff --git a/eval/testingFederated.py b/eval/testingFederated.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9cdeca90cf72982d6566b5c3ddcc9c361b3e569
--- /dev/null
+++ b/eval/testingFederated.py
@@ -0,0 +1,107 @@
+import logging
+from pathlib import Path
+from shutil import copy
+
+from localconfig import LocalConfig
+from torch import multiprocessing as mp
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Linear import Linear
+from decentralizepy.node.FederatedParameterServer import FederatedParameterServer
+from decentralizepy.node.DPSGDNodeFederated import DPSGDNodeFederated
+
+
+def read_ini(file_path):
+    config = LocalConfig(file_path)
+    for section in config:
+        print("Section: ", section)
+        for key, value in config.items(section):
+            print((key, value))
+    print(dict(config.items("DATASET")))
+    return config
+
+
+if __name__ == "__main__":
+    args = utils.get_args()
+
+    Path(args.log_dir).mkdir(parents=True, exist_ok=True)
+
+    log_level = {
+        "INFO": logging.INFO,
+        "DEBUG": logging.DEBUG,
+        "WARNING": logging.WARNING,
+        "ERROR": logging.ERROR,
+        "CRITICAL": logging.CRITICAL,
+    }
+
+    config = read_ini(args.config_file)
+    my_config = dict()
+    for section in config:
+        my_config[section] = dict(config.items(section))
+
+    copy(args.config_file, args.log_dir)
+    copy(args.graph_file, args.log_dir)
+    utils.write_args(args, args.log_dir)
+
+    g = Graph()
+    g.read_graph_from_file(args.graph_file, args.graph_type)
+    n_machines = args.machines
+    procs_per_machine = args.procs_per_machine
+    l = Linear(n_machines, procs_per_machine)
+    m_id = args.machine_id
+
+    sm = args.server_machine
+    sr = args.server_rank
+
+    # TODO
+    working_fraction = 1.0
+
+    processes = []
+    if sm == m_id:
+        processes.append(
+            mp.Process(
+                target=FederatedParameterServer,
+                args=[
+                    sr,
+                    m_id,
+                    l,
+                    g,
+                    my_config,
+                    args.iterations,
+                    args.log_dir,
+                    args.weights_store_dir,
+                    log_level[args.log_level],
+                    args.test_after,
+                    args.train_evaluate_after,
+                    working_fraction,
+                ],
+            )
+        )
+
+    for r in range(0, procs_per_machine):
+        processes.append(
+            mp.Process(
+                target=DPSGDNodeFederated,
+                args=[
+                    r,
+                    m_id,
+                    l,
+                    g,
+                    my_config,
+                    args.iterations,
+                    args.log_dir,
+                    args.weights_store_dir,
+                    log_level[args.log_level],
+                    args.test_after,
+                    args.train_evaluate_after,
+                    args.reset_optimizer,
+                ],
+            )
+        )
+
+    for p in processes:
+        p.start()
+
+    for p in processes:
+        p.join()
diff --git a/eval/testingPeerSampler.py b/eval/testingPeerSampler.py
index ecc26365407227a086f71220d0e7a2661f0ba169..decedf2b68ac4af4f8926f37a7ad21201136cd64 100644
--- a/eval/testingPeerSampler.py
+++ b/eval/testingPeerSampler.py
@@ -10,6 +10,9 @@ from decentralizepy.graphs.Graph import Graph
 from decentralizepy.mappings.Linear import Linear
 from decentralizepy.node.DPSGDWithPeerSampler import DPSGDWithPeerSampler
 from decentralizepy.node.PeerSamplerDynamic import PeerSamplerDynamic
+from decentralizepy.node.PeerSampler import PeerSampler
+from decentralizepy.node.ParameterServer import ParameterServer
+from decentralizepy.node.DPSGDNodeWithParameterServer import DPSGDNodeWithParameterServer
 
 
 def read_ini(file_path):
@@ -58,7 +61,9 @@ if __name__ == "__main__":
     if sm == m_id:
         processes.append(
             mp.Process(
-                target=PeerSamplerDynamic,
+                # target=PeerSamplerDynamic,
+                target=ParameterServer,
+                # target=PeerSampler,
                 args=[
                     sr,
                     m_id,
@@ -75,7 +80,8 @@ if __name__ == "__main__":
     for r in range(0, procs_per_machine):
         processes.append(
             mp.Process(
-                target=DPSGDWithPeerSampler,
+                target=DPSGDNodeWithParameterServer,
+                # target=DPSGDWithPeerSampler,
                 args=[
                     r,
                     m_id,
diff --git a/logs/config_celeba_sharing.ini b/logs/config_celeba_sharing.ini
new file mode 100644
index 0000000000000000000000000000000000000000..c5302ae575b7f96a604d1e04b8e78854718795dd
--- /dev/null
+++ b/logs/config_celeba_sharing.ini
@@ -0,0 +1,39 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = SGD
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 4
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = /mnt/nfs/kirsten/Gitlab/tutorial/ip.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.Sharing
+sharing_class = Sharing
+
+;sharing_package = decentralizepy.sharing.PartialModel
+;sharing_class = PartialModel
+;alpha = 0.1
+;accumulation = True
+;accumulate_averaging_changes = True
diff --git a/src/decentralizepy/node/DPSGDNodeFederated.py b/src/decentralizepy/node/DPSGDNodeFederated.py
new file mode 100644
index 0000000000000000000000000000000000000000..d77e35d409d86fa66ffc210e6e0ff0bd62f9f865
--- /dev/null
+++ b/src/decentralizepy/node/DPSGDNodeFederated.py
@@ -0,0 +1,325 @@
+import importlib
+import logging
+import math
+import os
+from collections import deque
+
+import torch
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+
+
+class DPSGDNodeFederated(Node):
+    """
+    This class defines the node for federated DPSGD
+
+    """
+
+    def run(self):
+        """
+        Start the decentralized learning
+
+        """
+        while len(self.barrier):
+            sender, data = self.receive_channel("WORKER_REQUEST")
+
+            if "BYE" in data:
+                logging.info("Received {} from {}".format("BYE", sender))
+                self.barrier.remove(sender)
+                break
+
+            iteration = data["iteration"]
+            del data["iteration"]
+            del data["CHANNEL"]
+
+            if iteration == 0:
+                del data["degree"]
+                data = self.sharing.deserialized_model(data)
+
+            self.model.load_state_dict(data)
+            self.sharing._post_step()
+            self.sharing.communication_round += 1
+
+            logging.info("Received worker request at node {}, global iteration {}, local round {}".format(
+                self.uid,
+                iteration,
+                self.participated
+            ))
+
+            if self.reset_optimizer:
+                self.optimizer = self.optimizer_class(
+                    self.model.parameters(), **self.optimizer_params
+                )  # Reset optimizer state
+                self.trainer.reset_optimizer(self.optimizer)
+
+            # Perform iteration
+            logging.info("Starting training iteration")
+            self.trainer.train(self.dataset)
+
+            # Send update to server
+            to_send = self.sharing.get_data_to_send()
+            to_send["CHANNEL"] = "DPSGD"
+            self.communication.send(self.parameter_server_uid, to_send)
+
+            self.participated += 1
+
+        # only if has participated in learning
+        if self.participated > 0:
+            logging.info("Storing final weight")
+            self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
+
+        logging.info("Server disconnected. Process complete!")
+
+    def cache_fields(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+        weights_store_dir,
+        test_after,
+        train_evaluate_after,
+        reset_optimizer
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        """
+        self.rank = rank
+        self.machine_id = machine_id
+        self.graph = graph
+        self.mapping = mapping
+        self.uid = self.mapping.get_uid(rank, machine_id)
+        self.log_dir = log_dir
+        self.weights_store_dir = weights_store_dir
+        self.iterations = iterations
+        self.test_after = test_after
+        self.train_evaluate_after = train_evaluate_after
+        self.reset_optimizer = reset_optimizer
+        self.sent_disconnections = False
+
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
+
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
+        comm_module = importlib.import_module(comm_configs["comm_package"])
+        comm_class = getattr(comm_module, comm_configs["comm_class"])
+        comm_params = utils.remove_keys(
+            comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
+        self.communication = comm_class(
+            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
+        )
+
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.init_log(log_dir, rank, log_level)
+
+        self.cache_fields(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+        )
+        self.init_dataset_model(config["DATASET"])
+        self.init_optimizer(config["OPTIMIZER_PARAMS"])
+        self.init_trainer(config["TRAIN_PARAMS"])
+        self.init_comm(config["COMMUNICATION"])
+
+        self.message_queue = dict()
+
+        self.barrier = set()
+
+        self.participated = 0
+
+        self.init_sharing(config["SHARING"])
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        parameter_server_uid=-1,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        parameter_server_uid: int
+            The parameter server's uid
+        args : optional
+            Other arguments
+
+        """
+
+        total_threads = os.cpu_count()
+        self.threads_per_proc = max(
+            math.floor(total_threads / mapping.procs_per_machine), 1
+        )
+        torch.set_num_threads(self.threads_per_proc)
+        torch.set_num_interop_threads(1)
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            log_level,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+            *args
+        )
+        logging.info(
+            "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
+        )
+
+        self.message_queue["PEERS"] = deque()
+
+        self.parameter_server_uid = parameter_server_uid
+        self.connect_neighbor(self.parameter_server_uid)
+        self.wait_for_hello(self.parameter_server_uid)
+
+        self.run()
diff --git a/src/decentralizepy/node/DPSGDNodeWithParameterServer.py b/src/decentralizepy/node/DPSGDNodeWithParameterServer.py
new file mode 100644
index 0000000000000000000000000000000000000000..85fa862e0c61438533c3c764b566bdbd25e1cb46
--- /dev/null
+++ b/src/decentralizepy/node/DPSGDNodeWithParameterServer.py
@@ -0,0 +1,526 @@
+import importlib
+import json
+import logging
+import math
+import os
+from collections import deque
+
+import torch
+from matplotlib import pyplot as plt
+
+from decentralizepy import utils
+from decentralizepy.communication.TCP import TCP
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.graphs.Star import Star
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+from decentralizepy.train_test_evaluation import TrainTestHelper
+
+
+class DPSGDNodeWithParameterServer(Node):
+    """
+    This class defines the node for DPSGD
+
+    """
+
+    def save_plot(self, l, label, title, xlabel, filename):
+        """
+        Save Matplotlib plot. Clears previous plots.
+
+        Parameters
+        ----------
+        l : dict
+            dict of x -> y. `x` must be castable to int.
+        label : str
+            label of the plot. Used for legend.
+        title : str
+            Header
+        xlabel : str
+            x-axis label
+        filename : str
+            Name of file to save the plot as.
+
+        """
+        plt.clf()
+        y_axis = [l[key] for key in l.keys()]
+        x_axis = list(map(int, l.keys()))
+        plt.plot(x_axis, y_axis, label=label)
+        plt.xlabel(xlabel)
+        plt.title(title)
+        plt.savefig(filename)
+
+    def run(self):
+        """
+        Start the decentralized learning
+
+        """
+        self.testset = self.dataset.get_testset()
+        rounds_to_test = self.test_after
+        rounds_to_train_evaluate = self.train_evaluate_after
+        global_epoch = 1
+        change = 1
+        if self.uid == 0:
+            dataset = self.dataset
+            if self.centralized_train_eval:
+                dataset_params_copy = self.dataset_params.copy()
+                if "sizes" in dataset_params_copy:
+                    del dataset_params_copy["sizes"]
+                self.whole_dataset = self.dataset_class(
+                    self.rank,
+                    self.machine_id,
+                    self.mapping,
+                    sizes=[1.0],
+                    **dataset_params_copy
+                )
+                dataset = self.whole_dataset
+            if self.centralized_test_eval:
+                tthelper = TrainTestHelper(
+                    dataset,  # self.whole_dataset,
+                    # self.model_test, # todo: this only works if eval_train is set to false
+                    self.model,
+                    self.loss,
+                    self.weights_store_dir,
+                    self.mapping.get_n_procs(),
+                    self.trainer,
+                    self.testing_comm,
+                    self.star,
+                    self.threads_per_proc,
+                    eval_train=self.centralized_train_eval,
+                )
+
+        for iteration in range(self.iterations):
+            logging.info("Starting training iteration: %d", iteration)
+            self.iteration = iteration
+            self.trainer.train(self.dataset)
+
+            to_send = self.sharing.get_data_to_send()
+            to_send["CHANNEL"] = "DPSGD"
+
+            self.communication.send(self.parameter_server_uid, to_send)
+
+            sender, data = self.receive_channel("GRADS")
+            del data["CHANNEL"]
+
+            self.model.load_state_dict(data)
+            self.sharing._post_step()
+            self.sharing.communication_round += 1
+
+            if self.reset_optimizer:
+                self.optimizer = self.optimizer_class(
+                    self.model.parameters(), **self.optimizer_params
+                )  # Reset optimizer state
+                self.trainer.reset_optimizer(self.optimizer)
+
+            if iteration:
+                with open(
+                    os.path.join(
+                        self.log_dir, "{}_results.json".format(self.rank)),
+                    "r",
+                ) as inf:
+                    results_dict = json.load(inf)
+            else:
+                results_dict = {
+                    "train_loss": {},
+                    "test_loss": {},
+                    "test_acc": {},
+                    "total_bytes": {},
+                    "total_meta": {},
+                    "total_data_per_n": {},
+                    "grad_mean": {},
+                    "grad_std": {},
+                }
+
+            results_dict["total_bytes"][iteration
+                                        + 1] = self.communication.total_bytes
+
+            if hasattr(self.communication, "total_meta"):
+                results_dict["total_meta"][
+                    iteration + 1
+                ] = self.communication.total_meta
+            if hasattr(self.communication, "total_data"):
+                results_dict["total_data_per_n"][
+                    iteration + 1
+                ] = self.communication.total_data
+            if hasattr(self.sharing, "mean"):
+                results_dict["grad_mean"][iteration + 1] = self.sharing.mean
+            if hasattr(self.sharing, "std"):
+                results_dict["grad_std"][iteration + 1] = self.sharing.std
+
+            rounds_to_train_evaluate -= 1
+
+            if rounds_to_train_evaluate == 0 and not self.centralized_train_eval:
+                logging.info("Evaluating on train set.")
+                rounds_to_train_evaluate = self.train_evaluate_after * change
+                loss_after_sharing = self.trainer.eval_loss(self.dataset)
+                results_dict["train_loss"][iteration + 1] = loss_after_sharing
+                self.save_plot(
+                    results_dict["train_loss"],
+                    "train_loss",
+                    "Training Loss",
+                    "Communication Rounds",
+                    os.path.join(
+                        self.log_dir, "{}_train_loss.png".format(self.rank)),
+                )
+
+            rounds_to_test -= 1
+
+            if self.dataset.__testing__ and rounds_to_test == 0:
+                rounds_to_test = self.test_after * change
+                if self.centralized_test_eval:
+                    if self.uid == 0:
+                        ta, tl, trl = tthelper.train_test_evaluation(iteration)
+                        results_dict["test_acc"][iteration + 1] = ta
+                        results_dict["test_loss"][iteration + 1] = tl
+                        if trl is not None:
+                            results_dict["train_loss"][iteration + 1] = trl
+                    else:
+                        self.testing_comm.send(0, self.model.get_weights())
+                        sender, data = self.testing_comm.receive()
+                        assert sender == 0 and data == "finished"
+                else:
+                    logging.info("Evaluating on test set.")
+                    ta, tl = self.dataset.test(self.model, self.loss)
+                    results_dict["test_acc"][iteration + 1] = ta
+                    results_dict["test_loss"][iteration + 1] = tl
+
+                if global_epoch == 49:
+                    change *= 2
+
+                global_epoch += change
+
+            with open(
+                os.path.join(
+                    self.log_dir, "{}_results.json".format(self.rank)), "w"
+            ) as of:
+                json.dump(results_dict, of)
+        if self.model.shared_parameters_counter is not None:
+            logging.info("Saving the shared parameter counts")
+            with open(
+                os.path.join(
+                    self.log_dir, "{}_shared_parameters.json".format(self.rank)
+                ),
+                "w",
+            ) as of:
+                json.dump(
+                    self.model.shared_parameters_counter.numpy().tolist(), of)
+        self.disconnect_parameter_server()
+        logging.info("Storing final weight")
+        self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
+        logging.info("Server disconnected. Process complete!")
+
+    def cache_fields(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+        weights_store_dir,
+        test_after,
+        train_evaluate_after,
+        reset_optimizer,
+        centralized_train_eval,
+        centralized_test_eval,
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        centralized_train_eval : bool
+            If set the train set evaluation happens at the node with uid 0
+        centralized_test_eval : bool
+            If set the train set evaluation happens at the node with uid 0
+        """
+        self.rank = rank
+        self.machine_id = machine_id
+        self.graph = graph
+        self.mapping = mapping
+        self.uid = self.mapping.get_uid(rank, machine_id)
+        self.log_dir = log_dir
+        self.weights_store_dir = weights_store_dir
+        self.iterations = iterations
+        self.test_after = test_after
+        self.train_evaluate_after = train_evaluate_after
+        self.reset_optimizer = reset_optimizer
+        self.centralized_train_eval = centralized_train_eval
+        self.centralized_test_eval = centralized_test_eval
+        self.sent_disconnections = False
+
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
+
+        if centralized_test_eval or centralized_train_eval:
+            self.star = Star(self.mapping.get_n_procs())
+
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
+        comm_module = importlib.import_module(comm_configs["comm_package"])
+        comm_class = getattr(comm_module, comm_configs["comm_class"])
+        comm_params = utils.remove_keys(
+            comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
+        if self.centralized_test_eval:
+            self.testing_comm = TCP(
+                self.rank,
+                self.machine_id,
+                self.mapping,
+                self.star.n_procs,
+                self.addresses_filepath,
+                offset=self.star.n_procs,
+            )
+            self.testing_comm.connect_neighbors(self.star.neighbors(self.uid))
+
+        self.communication = comm_class(
+            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
+        )
+
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        centralized_train_eval=False,
+        centralized_test_eval=True,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        centralized_train_eval : bool
+            If set the train set evaluation happens at the node with uid 0
+        centralized_test_eval : bool
+            If set the train set evaluation happens at the node with uid 0
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.init_log(log_dir, rank, log_level)
+
+        self.cache_fields(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+            centralized_train_eval,
+            centralized_test_eval,
+        )
+        self.init_dataset_model(config["DATASET"])
+        self.init_optimizer(config["OPTIMIZER_PARAMS"])
+        self.init_trainer(config["TRAIN_PARAMS"])
+        self.init_comm(config["COMMUNICATION"])
+
+        self.message_queue = dict()
+
+        self.barrier = set()
+        self.my_neighbors = self.graph.neighbors(self.uid)
+
+        self.init_sharing(config["SHARING"])
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        reset_optimizer=1,
+        centralized_train_eval=0,
+        centralized_test_eval=1,
+        parameter_server_uid=-1,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
+        centralized_train_eval : int
+            If set then the train set evaluation happens at the node with uid 0.
+            Note: If it is True then centralized_test_eval needs to be true as well!
+        centralized_test_eval : int
+            If set then the trainset evaluation happens at the node with uid 0
+        parameter_server_uid: int
+            The parameter server's uid
+        args : optional
+            Other arguments
+
+        """
+        centralized_train_eval = centralized_train_eval == 1
+        centralized_test_eval = centralized_test_eval == 1
+        # If centralized_train_eval is True then centralized_test_eval needs to be true as well!
+        assert not centralized_train_eval or centralized_test_eval
+
+        total_threads = os.cpu_count()
+        self.threads_per_proc = max(
+            math.floor(total_threads / mapping.procs_per_machine), 1
+        )
+        torch.set_num_threads(self.threads_per_proc)
+        torch.set_num_interop_threads(1)
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            log_level,
+            test_after,
+            train_evaluate_after,
+            reset_optimizer,
+            centralized_train_eval == 1,
+            centralized_test_eval == 1,
+            *args
+        )
+        logging.info(
+            "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
+        )
+
+        self.message_queue["PEERS"] = deque()
+
+        self.parameter_server_uid = parameter_server_uid
+        self.connect_neighbor(self.parameter_server_uid)
+        self.wait_for_hello(self.parameter_server_uid)
+
+        self.run()
+
+    def disconnect_parameter_server(self):
+        """
+        Disconnects from the parameter server. Sends BYE.
+
+        Raises
+        ------
+        RuntimeError
+            If received another message while waiting for BYEs
+
+        """
+        if not self.sent_disconnections:
+            logging.info("Disconnecting parameter server.")
+            self.communication.send(
+                self.parameter_server_uid, {
+                    "BYE": self.uid, "CHANNEL": "SERVER_REQUEST"}
+            )
+            self.sent_disconnections = True
+
+            self.barrier.remove(self.parameter_server_uid)
+
+            while len(self.barrier):
+                sender, _ = self.receive_disconnect()
+                self.barrier.remove(sender)
diff --git a/src/decentralizepy/node/FederatedParameterServer.py b/src/decentralizepy/node/FederatedParameterServer.py
new file mode 100644
index 0000000000000000000000000000000000000000..312fa6d79187e7cec8b1433d46fb84fc7d296554
--- /dev/null
+++ b/src/decentralizepy/node/FederatedParameterServer.py
@@ -0,0 +1,516 @@
+import importlib
+import json
+import logging
+import math
+import os
+import random
+from collections import deque
+from matplotlib import pyplot as plt
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+
+
+class FederatedParameterServer(Node):
+    """
+    This class defines the parameter serving service
+
+    """
+
+    def save_plot(self, l, label, title, xlabel, filename):
+        """
+        Save Matplotlib plot. Clears previous plots.
+
+        Parameters
+        ----------
+        l : dict
+            dict of x -> y. `x` must be castable to int.
+        label : str
+            label of the plot. Used for legend.
+        title : str
+            Header
+        xlabel : str
+            x-axis label
+        filename : str
+            Name of file to save the plot as.
+
+        """
+        plt.clf()
+        y_axis = [l[key] for key in l.keys()]
+        x_axis = list(map(int, l.keys()))
+        plt.plot(x_axis, y_axis, label=label)
+        plt.xlabel(xlabel)
+        plt.title(title)
+        plt.savefig(filename)
+
+    def init_log(self, log_dir, log_level, force=True):
+        """
+        Instantiate Logging.
+
+        Parameters
+        ----------
+        log_dir : str
+            Logging directory
+        rank : rank : int
+            Rank of process local to the machine
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        force : bool
+            Argument to logging.basicConfig()
+
+        """
+        log_file = os.path.join(log_dir, "ParameterServer.log")
+        logging.basicConfig(
+            filename=log_file,
+            format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s",
+            level=log_level,
+            force=force,
+        )
+
+    def cache_fields(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+        weights_store_dir,
+        test_after,
+        train_evaluate_after,
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+
+        """
+        self.rank = rank
+        self.machine_id = machine_id
+        self.graph = graph
+        self.mapping = mapping
+        self.uid = self.mapping.get_uid(rank, machine_id)
+        self.log_dir = log_dir
+        self.iterations = iterations
+        self.sent_disconnections = False
+        self.weights_store_dir = weights_store_dir
+        self.test_after = test_after
+        self.train_evaluate_after = train_evaluate_after
+
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
+
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
+        comm_module = importlib.import_module(comm_configs["comm_package"])
+        comm_class = getattr(comm_module, comm_configs["comm_class"])
+        comm_params = utils.remove_keys(
+            comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
+        self.communication = comm_class(
+            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
+        )
+
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.init_log(log_dir, log_level)
+
+        self.cache_fields(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            test_after,
+            train_evaluate_after,
+        )
+
+        self.message_queue = dict()
+
+        self.barrier = set()
+
+        self.peer_deques = dict()
+
+        self.init_dataset_model(config["DATASET"])
+        self.init_comm(config["COMMUNICATION"])
+        self.init_optimizer(config["OPTIMIZER_PARAMS"])
+        self.init_trainer(config["TRAIN_PARAMS"])
+
+        self.my_neighbors = self.graph.get_all_nodes()
+        self.connect_neighbors()
+
+        self.init_sharing(config["SHARING"])
+
+    def received_from_all(self):
+        """
+        Check if all current workers have sent the current iteration
+
+        Returns
+        -------
+        bool
+            True if required data has been received, False otherwise
+
+        """
+        for k in self.current_workers:
+            if (k not in self.peer_deques) or len(self.peer_deques[k]) == 0:
+                return False
+        return True
+
+    def disconnect_neighbors(self):
+        """
+        Disconnects all neighbors.
+
+        Raises
+        ------
+        RuntimeError
+            If received another message while waiting for BYEs
+
+        """
+        if not self.sent_disconnections:
+            logging.info("Disconnecting neighbors")
+
+            for neighbor in self.my_neighbors:
+                self.communication.send(
+                    neighbor, {"BYE": self.uid, "CHANNEL": "WORKER_REQUEST"}
+                )
+                self.barrier.remove(neighbor)
+
+            self.sent_disconnections = True
+
+    def get_working_nodes(self):
+        """
+        Randomly select set of clients for the current iteration
+
+        """
+        k = int(math.ceil(len(self.my_neighbors) * self.working_fraction))
+        return random.sample(self.my_neighbors, k)
+
+    def run(self):
+        """
+        Start the federated parameter-serving service.
+
+        """
+        self.testset = self.dataset.get_testset()
+        rounds_to_test = self.test_after
+        rounds_to_train_evaluate = self.train_evaluate_after
+        global_epoch = 1
+        change = 1
+
+        to_send = dict()
+
+        for iteration in range(self.iterations):
+            self.iteration = iteration
+            # reset deques after each iteration
+            self.peer_deques = dict()
+
+            # Get workers for this iteration
+            self.current_workers = self.get_working_nodes()
+
+            # Params to send to workers
+            # if this is the first iteration, use the init parameters, else use averaged params from last iteration
+            if iteration == 0:
+                to_send = self.sharing.get_data_to_send()
+
+            to_send["CHANNEL"] = "WORKER_REQUEST"
+            to_send["iteration"] = iteration
+
+            # Notify workers
+            for worker in self.current_workers:
+                self.communication.send(
+                    worker, to_send
+                )
+
+            # Receive updates from current workers
+            while not self.received_from_all():
+                sender, data = self.receive_channel("DPSGD")
+                if sender not in self.peer_deques:
+                    self.peer_deques[sender] = deque()
+                self.peer_deques[sender].append(data)
+
+            logging.info("Received from all current workers")
+
+            # Average received updates
+            averaging_deque = dict()
+            total = dict()
+            for worker in self.current_workers:
+                averaging_deque[worker] = self.peer_deques[worker]
+
+            for i, n in enumerate(averaging_deque):
+                data = averaging_deque[n].popleft()
+                del data["degree"]
+                del data["iteration"]
+                del data["CHANNEL"]
+                data = self.sharing.deserialized_model(data)
+                for key, value in data.items():
+                    if key in total:
+                        total[key] += value
+                    else:
+                        total[key] = value
+
+            for key, value in total.items():
+                total[key] = total[key] / len(averaging_deque)
+
+            self.model.load_state_dict(total)
+
+            to_send = total
+
+            if iteration:
+                with open(
+                    os.path.join(
+                        self.log_dir, "{}_results.json".format(self.rank)),
+                    "r",
+                ) as inf:
+                    results_dict = json.load(inf)
+            else:
+                results_dict = {
+                    "train_loss": {},
+                    "test_loss": {},
+                    "test_acc": {},
+                    "total_bytes": {},
+                    "total_meta": {},
+                    "total_data_per_n": {},
+                    "grad_mean": {},
+                    "grad_std": {},
+                }
+
+            results_dict["total_bytes"][iteration
+                                        + 1] = self.communication.total_bytes
+
+            if hasattr(self.communication, "total_meta"):
+                results_dict["total_meta"][
+                    iteration + 1
+                ] = self.communication.total_meta
+            if hasattr(self.communication, "total_data"):
+                results_dict["total_data_per_n"][
+                    iteration + 1
+                ] = self.communication.total_data
+            if hasattr(self.sharing, "mean"):
+                results_dict["grad_mean"][iteration + 1] = self.sharing.mean
+            if hasattr(self.sharing, "std"):
+                results_dict["grad_std"][iteration + 1] = self.sharing.std
+
+            rounds_to_train_evaluate -= 1
+
+            if rounds_to_train_evaluate == 0:
+                logging.info("Evaluating on train set.")
+                rounds_to_train_evaluate = self.train_evaluate_after * change
+                loss_after_sharing = self.trainer.eval_loss(self.dataset)
+                results_dict["train_loss"][iteration + 1] = loss_after_sharing
+                self.save_plot(
+                    results_dict["train_loss"],
+                    "train_loss",
+                    "Training Loss",
+                    "Communication Rounds",
+                    os.path.join(
+                        self.log_dir, "{}_train_loss.png".format(self.rank)),
+                )
+
+            rounds_to_test -= 1
+
+            if self.dataset.__testing__ and rounds_to_test == 0:
+                rounds_to_test = self.test_after * change
+                logging.info("Evaluating on test set.")
+                ta, tl = self.dataset.test(self.model, self.loss)
+                results_dict["test_acc"][iteration + 1] = ta
+                results_dict["test_loss"][iteration + 1] = tl
+
+                if global_epoch == 49:
+                    change *= 2
+
+                global_epoch += change
+
+            with open(
+                os.path.join(
+                    self.log_dir, "{}_results.json".format(self.rank)), "w"
+            ) as of:
+                json.dump(results_dict, of)
+
+        if self.model.shared_parameters_counter is not None:
+            logging.info("Saving the shared parameter counts")
+            with open(
+                os.path.join(
+                    self.log_dir, "{}_shared_parameters.json".format(self.rank)
+                ),
+                "w",
+            ) as of:
+                json.dump(
+                    self.model.shared_parameters_counter.numpy().tolist(), of)
+
+        self.disconnect_neighbors()
+        logging.info("Storing final weight")
+        self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
+        logging.info("All neighbors disconnected. Process complete!")
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        weights_store_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        train_evaluate_after=1,
+        working_fraction=1.0,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
+        working_fraction : float
+            Percentage of nodes participating in one global iteration
+        args : optional
+            Other arguments
+
+        """
+        super().__init__(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            log_level,
+            *args
+        )
+
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            weights_store_dir,
+            log_level,
+            test_after,
+            train_evaluate_after,
+            *args
+        )
+
+        self.working_fraction = working_fraction
+
+        random.seed(self.mapping.get_uid(self.rank, self.machine_id))
+
+        self.run()
+
+        logging.info("Parameter Server exiting")
diff --git a/src/decentralizepy/node/ParameterServer.py b/src/decentralizepy/node/ParameterServer.py
new file mode 100644
index 0000000000000000000000000000000000000000..616138ad581cb1f82251dc144dc8ea90d7639447
--- /dev/null
+++ b/src/decentralizepy/node/ParameterServer.py
@@ -0,0 +1,308 @@
+import importlib
+import logging
+import os
+from collections import deque
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+
+
+class ParameterServer(Node):
+    """
+    This class defines the parameter serving service
+
+    """
+
+    def init_log(self, log_dir, log_level, force=True):
+        """
+        Instantiate Logging.
+
+        Parameters
+        ----------
+        log_dir : str
+            Logging directory
+        rank : rank : int
+            Rank of process local to the machine
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        force : bool
+            Argument to logging.basicConfig()
+
+        """
+        log_file = os.path.join(log_dir, "ParameterServer.log")
+        logging.basicConfig(
+            filename=log_file,
+            format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s",
+            level=log_level,
+            force=force,
+        )
+
+    def cache_fields(
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+
+        """
+        self.rank = rank
+        self.machine_id = machine_id
+        self.graph = graph
+        self.mapping = mapping
+        self.uid = self.mapping.get_uid(rank, machine_id)
+        self.log_dir = log_dir
+        self.iterations = iterations
+        self.sent_disconnections = False
+
+        logging.info("Rank: %d", self.rank)
+        logging.info("type(graph): %s", str(type(self.rank)))
+        logging.info("type(mapping): %s", str(type(self.mapping)))
+
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
+        comm_module = importlib.import_module(comm_configs["comm_package"])
+        comm_class = getattr(comm_module, comm_configs["comm_class"])
+        comm_params = utils.remove_keys(
+            comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
+        self.communication = comm_class(
+            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
+        )
+
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        log_level=logging.INFO,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.init_log(log_dir, log_level)
+
+        self.cache_fields(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+        )
+
+        self.message_queue = dict()
+
+        self.barrier = set()
+
+        self.peer_deques = dict()
+
+        self.init_dataset_model(config["DATASET"])
+        self.init_comm(config["COMMUNICATION"])
+        self.my_neighbors = self.graph.get_all_nodes()
+        self.connect_neighbors()
+        self.init_sharing(config["SHARING"])
+
+    def receive_server_request(self):
+        return self.receive_channel("SERVER_REQUEST")
+
+    def received_from_all(self):
+        """
+        Check if all neighbors have sent the current iteration
+
+        Returns
+        -------
+        bool
+            True if required data has been received, False otherwise
+
+        """
+        for k in self.my_neighbors:
+            if (k not in self.peer_deques) or len(self.peer_deques[k]) == 0:
+                return False
+        return True
+
+    def run(self):
+        """
+        Start the parameter-serving service.
+
+        """
+        for iteration in range(self.iterations):
+            self.iteration = iteration
+            # reset deques after each iteration
+            self.peer_deques = dict()
+
+            while not self.received_from_all():
+                sender, data = self.receive_channel("DPSGD")
+                if sender not in self.peer_deques:
+                    self.peer_deques[sender] = deque()
+                self.peer_deques[sender].append(data)
+
+            logging.info("Received from everybody")
+
+            averaging_deque = dict()
+            total = dict()
+            for neighbor in self.my_neighbors:
+                averaging_deque[neighbor] = self.peer_deques[neighbor]
+
+            for i, n in enumerate(averaging_deque):
+                data = averaging_deque[n].popleft()
+                degree, iteration = data["degree"], data["iteration"]
+                del data["degree"]
+                del data["iteration"]
+                del data["CHANNEL"]
+                data = self.sharing.deserialized_model(data)
+                for key, value in data.items():
+                    if key in total:
+                        total[key] += value
+                    else:
+                        total[key] = value
+
+            for key, value in total.items():
+                total[key] = total[key] / len(averaging_deque)
+
+            to_send = total
+            to_send["CHANNEL"] = "GRADS"
+
+            for neighbor in self.my_neighbors:
+                self.communication.send(neighbor, to_send)
+
+        while len(self.barrier):
+            sender, data = self.receive_server_request()
+            if "BYE" in data:
+                logging.debug("Received {} from {}".format("BYE", sender))
+                self.barrier.remove(sender)
+
+    def __init__(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        log_level=logging.INFO,
+        *args
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations. Must contain the following:
+            [DATASET]
+                dataset_package
+                dataset_class
+                model_class
+            [OPTIMIZER_PARAMS]
+                optimizer_package
+                optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        iterations : int
+            Number of iterations (communication steps) for which the model should be trained
+        log_dir : str
+            Logging directory
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        args : optional
+            Other arguments
+
+        """
+        super().__init__(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            log_level,
+            *args
+        )
+
+        self.instantiate(
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            config,
+            iterations,
+            log_dir,
+            log_level,
+            *args
+        )
+
+        self.run()
+
+        logging.info("Parameter Server exiting")