diff --git a/eval/plot.py b/eval/plot.py
index 552934fa3231e5b9621acfa33190b86462a570da..9fdccded166afbb51b57a4bc4f37ebda819773a2 100644
--- a/eval/plot.py
+++ b/eval/plot.py
@@ -1,6 +1,7 @@
 import json
 import os
 import sys
+from pathlib import Path
 
 import numpy as np
 import pandas as pd
@@ -35,7 +36,7 @@ def plot(means, stdevs, mins, maxs, title, label, loc):
     plt.legend(loc=loc)
 
 
-def plot_results(path):
+def plot_results(path, data_machine="machine0", data_node=0):
     folders = os.listdir(path)
     folders.sort()
     print("Reading folders from: ", path)
@@ -44,8 +45,8 @@ def plot_results(path):
     meta_means, meta_stdevs = {}, {}
     data_means, data_stdevs = {}, {}
     for folder in folders:
-        folder_path = os.path.join(path, folder)
-        if not os.path.isdir(folder_path):
+        folder_path = Path(os.path.join(path, folder))
+        if not folder_path.is_dir() or "weights" == folder_path.name:
             continue
         results = []
         machine_folders = os.listdir(folder_path)
@@ -59,6 +60,10 @@ def plot_results(path):
                 filepath = os.path.join(mf_path, f)
                 with open(filepath, "r") as inf:
                     results.append(json.load(inf))
+
+        with open(folder_path / data_machine / f"{data_node}_results.json", "r") as f:
+            main_data = json.load(f)
+        main_data = [main_data]
         # Plot Training loss
         plt.figure(1)
         means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
@@ -77,7 +82,7 @@ def plot_results(path):
         )
         # Plot Testing loss
         plt.figure(2)
-        means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
+        means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in main_data])
         plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
         df = pd.DataFrame(
             {
@@ -93,7 +98,7 @@ def plot_results(path):
         )
         # Plot Testing Accuracy
         plt.figure(3)
-        means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
+        means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in main_data])
         plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
         df = pd.DataFrame(
             {
diff --git a/eval/run_xtimes_celeba.sh b/eval/run_xtimes_celeba.sh
index 65c195abf7844eac17cc7bc70eefe78e10a7379f..c6d3e060c4fd60115e40e8b99bb39560b4cd1fe8 100755
--- a/eval/run_xtimes_celeba.sh
+++ b/eval/run_xtimes_celeba.sh
@@ -41,7 +41,7 @@ graph=96_regular.edges
 config_file=~/tmp/config.ini
 procs_per_machine=16
 machines=6
-global_epochs=120
+global_epochs=150
 eval_file=testing.py
 log_level=INFO
 
@@ -91,9 +91,12 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')/machine$m
-    echo results are stored in: $log_dir
+    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    echo results are stored in: $log_dir_base
+    log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
+    weight_store_dir=$log_dir_base/weights
+    mkdir -p $weight_store_dir
     cp $i $config_file
     # changing the config files to reflect the values of the current grid search state
     $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
@@ -101,7 +104,7 @@ do
     $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round
     $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize
     $python_bin/crudini --set $config_file DATASET random_seed $seed
-    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
     echo $i is done
     sleep 200
     echo end of sleep
diff --git a/eval/run_xtimes_cifar.sh b/eval/run_xtimes_cifar.sh
index 1939348b2f0d0e7141b8c0dcdf4af69cc98bd0de..0c545c651e5150f42332ebf81cc2baae4c0f5ef6 100755
--- a/eval/run_xtimes_cifar.sh
+++ b/eval/run_xtimes_cifar.sh
@@ -91,9 +91,12 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')/machine$m
-    echo results are stored in: $log_dir
+    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    echo results are stored in: $log_dir_base
+    log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
+    weight_store_dir=$log_dir_base/weights
+    mkdir -p $weight_store_dir
     cp $i $config_file
     # changing the config files to reflect the values of the current grid search state
     $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
@@ -101,7 +104,7 @@ do
     $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round
     $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize
     $python_bin/crudini --set $config_file DATASET random_seed $seed
-    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
     echo $i is done
     sleep 200
     echo end of sleep
diff --git a/eval/run_xtimes_femnist.sh b/eval/run_xtimes_femnist.sh
index 4b1799b954edc85b602974d4d9dba3b181b3c8e2..c5e18a2ab958997e08b96fc4682c1b64f021cb25 100755
--- a/eval/run_xtimes_femnist.sh
+++ b/eval/run_xtimes_femnist.sh
@@ -41,7 +41,7 @@ graph=96_regular.edges
 config_file=~/tmp/config.ini
 procs_per_machine=16
 machines=6
-global_epochs=70
+global_epochs=80
 eval_file=testing.py
 log_level=INFO
 
@@ -91,9 +91,12 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')/machine$m
-    echo results are stored in: $log_dir
+    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    echo results are stored in: $log_dir_base
+    log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
+    weight_store_dir=$log_dir_base/weights
+    mkdir -p $weight_store_dir
     cp $i $config_file
     # changing the config files to reflect the values of the current grid search state
     $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
@@ -101,7 +104,7 @@ do
     $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round
     $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize
     $python_bin/crudini --set $config_file DATASET random_seed $seed
-    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
     echo $i is done
     sleep 200
     echo end of sleep
diff --git a/eval/run_xtimes_reddit.sh b/eval/run_xtimes_reddit.sh
index 4ecf899810352385dcacac471bf21c85282ec23a..589f52a2978107a3efac300b592048adb09f1717 100755
--- a/eval/run_xtimes_reddit.sh
+++ b/eval/run_xtimes_reddit.sh
@@ -91,9 +91,12 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')/machine$m
-    echo results are stored in: $log_dir
+    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    echo results are stored in: $log_dir_base
+    log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
+    weight_store_dir=$log_dir_base/weights
+    mkdir -p $weight_store_dir
     cp $i $config_file
     # changing the config files to reflect the values of the current grid search state
     $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
@@ -101,7 +104,7 @@ do
     $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round
     $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize
     $python_bin/crudini --set $config_file DATASET random_seed $seed
-    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
     echo $i is done
     sleep 200
     echo end of sleep
diff --git a/eval/run_xtimes_shakespeare.sh b/eval/run_xtimes_shakespeare.sh
index 1fd7b20f383cf96b67f4c19c16076bca18e1c6a7..8e268a18248eb789bc91a6fb68d2183d6a1b96b3 100755
--- a/eval/run_xtimes_shakespeare.sh
+++ b/eval/run_xtimes_shakespeare.sh
@@ -41,7 +41,7 @@ graph=96_regular.edges
 config_file=~/tmp/config.ini
 procs_per_machine=16
 machines=6
-global_epochs=80
+global_epochs=200
 eval_file=testing.py
 log_level=INFO
 
@@ -60,7 +60,7 @@ batchsize="16"
 comm_rounds_per_global_epoch="25"
 procs=`expr $procs_per_machine \* $machines`
 echo procs: $procs
-dataset_size=678696
+dataset_size=97545 # sub96, for sub: 678696
 # Calculating the number of samples that each user/proc will have on average
 samples_per_user=`expr $dataset_size / $procs`
 echo samples per user: $samples_per_user
@@ -91,9 +91,12 @@ do
     echo $i
     IFS='_' read -ra NAMES <<< $i
     IFS='.' read -ra NAME <<< ${NAMES[-1]}
-    log_dir=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')/machine$m
-    echo results are stored in: $log_dir
+    log_dir_base=$nfs_home$logs_subfolder/${NAME[0]}:lr=$lr:r=$comm_rounds_per_global_epoch:b=$batchsize:$(date '+%Y-%m-%dT%H:%M')
+    echo results are stored in: $log_dir_base
+    log_dir=$log_dir_base/machine$m
     mkdir -p $log_dir
+    weight_store_dir=$log_dir_base/weights
+    mkdir -p $weight_store_dir
     cp $i $config_file
     # changing the config files to reflect the values of the current grid search state
     $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
@@ -101,9 +104,9 @@ do
     $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round
     $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize
     $python_bin/crudini --set $config_file DATASET random_seed $seed
-    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
     echo $i is done
-    sleep 500
+    sleep 200
     echo end of sleep
     done
 done
diff --git a/eval/step_configs/config_shakespeare_partialmodel.ini b/eval/step_configs/config_shakespeare_partialmodel.ini
index 453815fd1de984f58a11d930dc5457d5b3b94281..9072903d3a48a50c2822f67711190d862f5852a7 100644
--- a/eval/step_configs/config_shakespeare_partialmodel.ini
+++ b/eval/step_configs/config_shakespeare_partialmodel.ini
@@ -3,8 +3,8 @@ dataset_package = decentralizepy.datasets.Shakespeare
 dataset_class = Shakespeare
 random_seed = 97
 model_class = LSTM
-train_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub/per_user_data/train
-test_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub/data/test
+train_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub96/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub96/data/test
 ; python list of fractions below
 sizes =
 
diff --git a/eval/step_configs/config_shakespeare_sharing.ini b/eval/step_configs/config_shakespeare_sharing.ini
index 525e928c4b967eeabd52c5b15525d8a97ce39595..ee3811ba892adf227e7ea1edb62a9a4b5f5b5da4 100644
--- a/eval/step_configs/config_shakespeare_sharing.ini
+++ b/eval/step_configs/config_shakespeare_sharing.ini
@@ -2,8 +2,8 @@
 dataset_package = decentralizepy.datasets.Shakespeare
 dataset_class = Shakespeare
 model_class = LSTM
-train_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub/per_user_data/train
-test_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub/data/test
+train_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub96/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub96/data/test
 ; python list of fractions below
 sizes =
 
diff --git a/eval/step_configs/config_shakespeare_subsampling.ini b/eval/step_configs/config_shakespeare_subsampling.ini
index 7bdef90365e4a0ea2b963d1b83ba581946cbddf8..b3c145a696cf255e44ba391a28f17c66f4fdd3e8 100644
--- a/eval/step_configs/config_shakespeare_subsampling.ini
+++ b/eval/step_configs/config_shakespeare_subsampling.ini
@@ -3,8 +3,8 @@ dataset_package = decentralizepy.datasets.Shakespeare
 dataset_class = Shakespeare
 random_seed = 97
 model_class = LSTM
-train_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub/per_user_data/train
-test_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub/data/test
+train_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub96/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub96/data/test
 ; python list of fractions below
 sizes =
 
diff --git a/eval/step_configs/config_shakespeare_topkacc.ini b/eval/step_configs/config_shakespeare_topkacc.ini
index 15838fb2259161dd06c5ce7d981b891403b0f837..bd00a6e3f732408346cf8298c10d4b9b402a6f73 100644
--- a/eval/step_configs/config_shakespeare_topkacc.ini
+++ b/eval/step_configs/config_shakespeare_topkacc.ini
@@ -3,8 +3,8 @@ dataset_package = decentralizepy.datasets.Shakespeare
 dataset_class = Shakespeare
 random_seed = 97
 model_class = LSTM
-train_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub/per_user_data/train
-test_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub/data/test
+train_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub96/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub96/data/test
 ; python list of fractions below
 sizes =
 
diff --git a/eval/step_configs/config_shakespeare_wavelet.ini b/eval/step_configs/config_shakespeare_wavelet.ini
index 94cc840d7901f909b56edacd03cf4bd98ec69eab..c48ba7d90c4ab8e849f0bc43cd1c2a94e18b39cc 100644
--- a/eval/step_configs/config_shakespeare_wavelet.ini
+++ b/eval/step_configs/config_shakespeare_wavelet.ini
@@ -3,8 +3,8 @@ dataset_package = decentralizepy.datasets.Shakespeare
 dataset_class = Shakespeare
 random_seed = 97
 model_class = LSTM
-train_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub/per_user_data/train
-test_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub/data/test
+train_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub96/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/shakespeare_sub96/data/test
 ; python list of fractions below
 sizes =
 
diff --git a/eval/testing.py b/eval/testing.py
index efb80dfa107be747ef083888741d3d413c30361c..9125828cc2ffbf5e9ccf7178355c136c77cebdb8 100644
--- a/eval/testing.py
+++ b/eval/testing.py
@@ -60,9 +60,12 @@ if __name__ == "__main__":
             my_config,
             args.iterations,
             args.log_dir,
+            args.weights_store_dir,
             log_level[args.log_level],
             args.test_after,
             args.train_evaluate_after,
             args.reset_optimizer,
+            args.centralized_train_eval,
+            args.centralized_test_eval,
         ],
     )
diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py
index a81e5bff2dd8ec4fa0d06d992de8b01a3e0dac06..fed93011f307ce9676492a6b93659d656f7da6dd 100644
--- a/src/decentralizepy/communication/TCP.py
+++ b/src/decentralizepy/communication/TCP.py
@@ -36,11 +36,18 @@ class TCP(Communication):
 
         """
         machine_addr = self.ip_addrs[str(machine_id)]
-        port = rank + 20000
+        port = rank + self.offset
         return "tcp://{}:{}".format(machine_addr, port)
 
     def __init__(
-        self, rank, machine_id, mapping, total_procs, addresses_filepath, compress=False
+        self,
+        rank,
+        machine_id,
+        mapping,
+        total_procs,
+        addresses_filepath,
+        compress=False,
+        offset=20000,
     ):
         """
         Constructor
@@ -68,6 +75,7 @@ class TCP(Communication):
         self.rank = rank
         self.machine_id = machine_id
         self.mapping = mapping
+        self.offset = 20000 + offset
         self.uid = mapping.get_uid(rank, machine_id)
         self.identity = str(self.uid).encode()
         self.context = zmq.Context()
diff --git a/src/decentralizepy/datasets/Shakespeare.py b/src/decentralizepy/datasets/Shakespeare.py
index 60ba148e27637ea4bdc58d25524d8083f3cce630..0c0293208717385fe69ee41bd8355cc15b6997e1 100644
--- a/src/decentralizepy/datasets/Shakespeare.py
+++ b/src/decentralizepy/datasets/Shakespeare.py
@@ -295,11 +295,8 @@ class Shakespeare(Dataset):
 
         """
         if self.__training__:
-            # Only using a subset of the training set. The full set is too large.
-            thirstiest = torch.arange(0, self.test_x.shape[0], 30)
             return DataLoader(
-                Data(self.test_x[thirstiest], self.test_y[thirstiest]),
-                batch_size=self.test_batch_size,
+                Data(self.train_x, self.train_y), batch_size=batch_size, shuffle=shuffle
             )
         raise RuntimeError("Training set not initialized!")
 
@@ -318,8 +315,10 @@ class Shakespeare(Dataset):
 
         """
         if self.__testing__:
+            thirstiest = torch.arange(0, self.test_x.shape[0], 30)
             return DataLoader(
-                Data(self.test_x, self.test_y), batch_size=self.test_batch_size
+                Data(self.test_x[thirstiest], self.test_y[thirstiest]),
+                batch_size=self.test_batch_size,
             )
         raise RuntimeError("Test set not initialized!")
 
diff --git a/src/decentralizepy/graphs/Graph.py b/src/decentralizepy/graphs/Graph.py
index 7e4c0635bd0ef1762217f594bded826f547b70be..689d2dc62544d603ea492807bc6790b9e90c5f95 100644
--- a/src/decentralizepy/graphs/Graph.py
+++ b/src/decentralizepy/graphs/Graph.py
@@ -1,6 +1,7 @@
 import networkx as nx
 import numpy as np
 
+
 class Graph:
     """
     This class defines the graph topology.
@@ -151,16 +152,16 @@ class Graph:
     def centr(self):
         my_adj = {x: list(adj) for x, adj in enumerate(self.adj_list)}
         nxGraph = nx.Graph(my_adj)
-        a=nx.to_numpy_matrix(nxGraph)
+        a = nx.to_numpy_matrix(nxGraph)
         self.averaging_weights = np.ones((self.n_procs, self.n_procs), dtype=float)
-        centrality= nx.betweenness_centrality(nxGraph)
+        centrality = nx.betweenness_centrality(nxGraph)
         for i in range(len(centrality)):
-            centrality[i]+=0.01
+            centrality[i] += 0.01
         for i in range(self.averaging_weights.shape[0]):
-            s=0
+            s = 0
             for j in range(self.averaging_weights.shape[0]):
-                self.averaging_weights[i,j] = 1.0/centrality[j]
-                s += self.averaging_weights[i,j]
+                self.averaging_weights[i, j] = 1.0 / centrality[j]
+                s += self.averaging_weights[i, j]
             for j in range(self.averaging_weights.shape[0]):
-                self.averaging_weights[i,j]=self.averaging_weights[i,j]/s
+                self.averaging_weights[i, j] = self.averaging_weights[i, j] / s
         return self.averaging_weights
diff --git a/src/decentralizepy/graphs/Star.py b/src/decentralizepy/graphs/Star.py
new file mode 100644
index 0000000000000000000000000000000000000000..968615e559d60cef8fcd7b0b2087bdf4ce3f4e1a
--- /dev/null
+++ b/src/decentralizepy/graphs/Star.py
@@ -0,0 +1,31 @@
+import networkx as nx
+
+from decentralizepy.graphs.Graph import Graph
+
+
+class Star(Graph):
+    """
+    The class for generating a Star topology
+    Adapted from ./Regular.py
+
+    """
+
+    def __init__(self, n_procs):
+        """
+        Constructor. Generates a Ring graph
+
+        Parameters
+        ----------
+        n_procs : int
+            total number of nodes in the graph
+
+        """
+        super().__init__(n_procs)
+        G = nx.star_graph(n_procs - 1)
+        adj = G.adjacency()
+        for i, l in adj:
+            self.adj_list[i] = set()  # new set
+            for k in l:
+                self.adj_list[i].add(k)
+        if not nx.is_connected(G):
+            self.connect_graph()
diff --git a/src/decentralizepy/models/Model.py b/src/decentralizepy/models/Model.py
index 5edba7f75815ffb0fbbd67749ef3e43fcfa2cc28..9518e22c4241b9c59c1a0aa3d011457543a48863 100644
--- a/src/decentralizepy/models/Model.py
+++ b/src/decentralizepy/models/Model.py
@@ -1,3 +1,7 @@
+import pickle
+from pathlib import Path
+
+import torch
 from torch import nn
 
 
@@ -58,3 +62,39 @@ class Model(nn.Module):
         """
         if self.accumulated_changes is not None:
             self.accumulated_changes[indices] = 0.0
+
+    def dump_weights(self, directory, uid, round):
+        """
+        dumps the current model as a pickle file into the specified direcectory
+
+        Parameters
+        ----------
+        directory : str
+            directory in which the weights are dumped
+        uid : int
+            uid of the node, will be used to give the weight a unique name
+        round : int
+            current round, will be used to give the weight a unique name
+
+        """
+        with torch.no_grad():
+            tensors_to_cat = []
+            for _, v in self.state_dict().items():
+                tensors_to_cat.append(v.flatten())
+            flat = torch.cat(tensors_to_cat)
+
+        with open(Path(directory) / f"{round}_weight_{uid}.pk", "wb") as f:
+            pickle.dump(flat, f)
+
+    def get_weights(self):
+        """
+        flattens the current weights
+
+        """
+        with torch.no_grad():
+            tensors_to_cat = []
+            for _, v in self.state_dict().items():
+                tensors_to_cat.append(v.flatten())
+            flat = torch.cat(tensors_to_cat)
+
+        return flat
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 1aa0f04d527be4aa4856dbabf2628c09b97e9586..e387c58c2005c306db679de38cce7ddcbec708c6 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -8,8 +8,11 @@ import torch
 from matplotlib import pyplot as plt
 
 from decentralizepy import utils
+from decentralizepy.communication.TCP import TCP
 from decentralizepy.graphs.Graph import Graph
+from decentralizepy.graphs.Star import Star
 from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.train_test_evaluation import TrainTestHelper
 
 
 class Node:
@@ -76,9 +79,12 @@ class Node:
         graph,
         iterations,
         log_dir,
+        weights_store_dir,
         test_after,
         train_evaluate_after,
         reset_optimizer,
+        centralized_train_eval,
+        centralized_test_eval,
     ):
         """
         Instantiate object field with arguments.
@@ -97,13 +103,18 @@ class Node:
             Number of iterations (communication steps) for which the model should be trained
         log_dir : str
             Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
         test_after : int
             Number of iterations after which the test loss and accuracy arecalculated
         train_evaluate_after : int
             Number of iterations after which the train loss is calculated
         reset_optimizer : int
             1 if optimizer should be reset every communication round, else 0
-
+        centralized_train_eval : bool
+            If set the train set evaluation happens at the node with uid 0
+        centralized_test_eval : bool
+            If set the train set evaluation happens at the node with uid 0
         """
         self.rank = rank
         self.machine_id = machine_id
@@ -111,15 +122,20 @@ class Node:
         self.mapping = mapping
         self.uid = self.mapping.get_uid(rank, machine_id)
         self.log_dir = log_dir
+        self.weights_store_dir = weights_store_dir
         self.iterations = iterations
         self.test_after = test_after
         self.train_evaluate_after = train_evaluate_after
         self.reset_optimizer = reset_optimizer
+        self.centralized_train_eval = centralized_train_eval
+        self.centralized_test_eval = centralized_test_eval
 
         logging.debug("Rank: %d", self.rank)
         logging.debug("type(graph): %s", str(type(self.rank)))
         logging.debug("type(mapping): %s", str(type(self.mapping)))
 
+        self.star = Star(self.mapping.get_n_procs())
+
     def init_dataset_model(self, dataset_configs):
         """
         Instantiate dataset and model from config.
@@ -226,6 +242,7 @@ class Node:
         comm_module = importlib.import_module(comm_configs["comm_package"])
         comm_class = getattr(comm_module, comm_configs["comm_class"])
         comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
+        self.addresses_filepath = comm_params.get("addresses_filepath", None)
         self.communication = comm_class(
             self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
         )
@@ -266,10 +283,13 @@ class Node:
         config,
         iterations=1,
         log_dir=".",
+        weights_store_dir=".",
         log_level=logging.INFO,
         test_after=5,
         train_evaluate_after=1,
         reset_optimizer=1,
+        centralized_train_eval=False,
+        centralized_test_eval=True,
         *args
     ):
         """
@@ -291,6 +311,8 @@ class Node:
             Number of iterations (communication steps) for which the model should be trained
         log_dir : str
             Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
         test_after : int
@@ -299,6 +321,10 @@ class Node:
             Number of iterations after which the train loss is calculated
         reset_optimizer : int
             1 if optimizer should be reset every communication round, else 0
+        centralized_train_eval : bool
+            If set the train set evaluation happens at the node with uid 0
+        centralized_test_eval : bool
+            If set the train set evaluation happens at the node with uid 0
         args : optional
             Other arguments
 
@@ -312,9 +338,12 @@ class Node:
             graph,
             iterations,
             log_dir,
+            weights_store_dir,
             test_after,
             train_evaluate_after,
             reset_optimizer,
+            centralized_train_eval,
+            centralized_test_eval,
         )
         self.init_log(log_dir, rank, log_level)
         self.init_dataset_model(config["DATASET"])
@@ -331,7 +360,46 @@ class Node:
         self.testset = self.dataset.get_testset()
         self.communication.connect_neighbors(self.graph.neighbors(self.uid))
         rounds_to_test = self.test_after
+
+        testing_comm = TCP(
+            self.rank,
+            self.machine_id,
+            self.mapping,
+            self.star.n_procs,
+            self.addresses_filepath,
+            offset=self.star.n_procs,
+        )
+        testing_comm.connect_neighbors(self.star.neighbors(self.uid))
         rounds_to_train_evaluate = self.train_evaluate_after
+        global_epoch = 1
+        change = 1
+        if self.uid == 0:
+            dataset = self.dataset
+            if self.centralized_train_eval:
+                dataset_params_copy = self.dataset_params.copy()
+                if "sizes" in dataset_params_copy:
+                    del dataset_params_copy["sizes"]
+                self.whole_dataset = self.dataset_class(
+                    self.rank,
+                    self.machine_id,
+                    self.mapping,
+                    sizes=[1.0],
+                    **dataset_params_copy
+                )
+                dataset = self.whole_dataset
+            tthelper = TrainTestHelper(
+                dataset,  # self.whole_dataset,
+                # self.model_test, # todo: this only works if eval_train is set to false
+                self.model,
+                self.loss,
+                self.weights_store_dir,
+                self.mapping.get_n_procs(),
+                self.trainer,
+                testing_comm,
+                self.star,
+                self.threads_per_proc,
+                eval_train=self.centralized_train_eval,
+            )
 
         for iteration in range(self.iterations):
             logging.info("Starting training iteration: %d", iteration)
@@ -378,9 +446,9 @@ class Node:
 
             rounds_to_train_evaluate -= 1
 
-            if rounds_to_train_evaluate == 0:
+            if rounds_to_train_evaluate == 0 and not self.centralized_train_eval:
                 logging.info("Evaluating on train set.")
-                rounds_to_train_evaluate = self.train_evaluate_after
+                rounds_to_train_evaluate = self.train_evaluate_after * change
                 loss_after_sharing = self.trainer.eval_loss(self.dataset)
                 results_dict["train_loss"][iteration + 1] = loss_after_sharing
                 self.save_plot(
@@ -394,26 +462,30 @@ class Node:
             rounds_to_test -= 1
 
             if self.dataset.__testing__ and rounds_to_test == 0:
-                logging.info("Evaluating on test set.")
-                rounds_to_test = self.test_after
-                ta, tl = self.dataset.test(self.model, self.loss)
-                results_dict["test_acc"][iteration + 1] = ta
-                results_dict["test_loss"][iteration + 1] = tl
-
-                self.save_plot(
-                    results_dict["test_loss"],
-                    "test_loss",
-                    "Testing Loss",
-                    "Communication Rounds",
-                    os.path.join(self.log_dir, "{}_test_loss.png".format(self.rank)),
-                )
-                self.save_plot(
-                    results_dict["test_acc"],
-                    "test_acc",
-                    "Testing Accuracy",
-                    "Communication Rounds",
-                    os.path.join(self.log_dir, "{}_test_acc.png".format(self.rank)),
-                )
+                rounds_to_test = self.test_after * change
+                # ta, tl = self.dataset.test(self.model, self.loss)
+                # self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
+                if self.centralized_test_eval:
+                    if self.uid == 0:
+                        ta, tl, trl = tthelper.train_test_evaluation(iteration)
+                        results_dict["test_acc"][iteration + 1] = ta
+                        results_dict["test_loss"][iteration + 1] = tl
+                        if trl is not None:
+                            results_dict["train_loss"][iteration + 1] = trl
+                    else:
+                        testing_comm.send(0, self.model.get_weights())
+                        sender, data = testing_comm.receive()
+                        assert sender == 0 and data == "finished"
+                else:
+                    logging.info("Evaluating on test set.")
+                    ta, tl = self.dataset.test(self.model, self.loss)
+                    results_dict["test_acc"][iteration + 1] = ta
+                    results_dict["test_loss"][iteration + 1] = tl
+
+                if global_epoch == 49:
+                    change *= 2
+
+                global_epoch += change
 
             with open(
                 os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
@@ -429,6 +501,8 @@ class Node:
             ) as of:
                 json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
         self.communication.disconnect_neighbors()
+        logging.info("Storing final weight")
+        self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
         logging.info("All neighbors disconnected. Process complete!")
 
     def __init__(
@@ -440,10 +514,13 @@ class Node:
         config,
         iterations=1,
         log_dir=".",
+        weights_store_dir=".",
         log_level=logging.INFO,
         test_after=5,
         train_evaluate_after=1,
         reset_optimizer=1,
+        centralized_train_eval=0,
+        centralized_test_eval=1,
         *args
     ):
         """
@@ -477,6 +554,8 @@ class Node:
             Number of iterations (communication steps) for which the model should be trained
         log_dir : str
             Logging directory
+        weights_store_dir : str
+            Directory in which to store model weights
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
         test_after : int
@@ -485,13 +564,25 @@ class Node:
             Number of iterations after which the train loss is calculated
         reset_optimizer : int
             1 if optimizer should be reset every communication round, else 0
+        centralized_train_eval : int
+            If set then the train set evaluation happens at the node with uid 0.
+            Note: If it is True then centralized_test_eval needs to be true as well!
+        centralized_test_eval : int
+            If set then the trainset evaluation happens at the node with uid 0
         args : optional
             Other arguments
 
         """
+        centralized_train_eval = centralized_train_eval == 1
+        centralized_test_eval = centralized_test_eval == 1
+        # If centralized_train_eval is True then centralized_test_eval needs to be true as well!
+        assert not centralized_train_eval or centralized_test_eval
+
         total_threads = os.cpu_count()
-        threads_per_proc = max(math.floor(total_threads / mapping.procs_per_machine), 1)
-        torch.set_num_threads(threads_per_proc)
+        self.threads_per_proc = max(
+            math.floor(total_threads / mapping.procs_per_machine), 1
+        )
+        torch.set_num_threads(self.threads_per_proc)
         torch.set_num_interop_threads(1)
         self.instantiate(
             rank,
@@ -501,14 +592,17 @@ class Node:
             config,
             iterations,
             log_dir,
+            weights_store_dir,
             log_level,
             test_after,
             train_evaluate_after,
             reset_optimizer,
+            centralized_train_eval == 1,
+            centralized_test_eval == 1,
             *args
         )
         logging.info(
-            "Each proc uses %d threads out of %d.", threads_per_proc, total_threads
+            "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
         )
 
         self.run()
diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
index f86f73c50c34b0492f8278987a78253bc49e5440..cef38732bdd1504e6c896c66056f9dfa2cd62ceb 100644
--- a/src/decentralizepy/sharing/FFT.py
+++ b/src/decentralizepy/sharing/FFT.py
@@ -161,10 +161,11 @@ class FFT(PartialModel):
             m["params"] = data.numpy()
             self.total_data += len(self.communication.encrypt(m["params"]))
             if self.model.accumulated_changes is not None:
-                self.model.accumulated_changes = torch.zeros_like(self.model.accumulated_changes)
+                self.model.accumulated_changes = torch.zeros_like(
+                    self.model.accumulated_changes
+                )
             return m
 
-
         with torch.no_grad():
             topk, indices = self.apply_fft()
             self.model.shared_parameters_counter[indices] += 1
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 75673199b8d3b736533a01e5edaad5d7d9decfc3..018f8951ff5c3ab282a8f641d3f91d1e541d9d23 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -163,7 +163,9 @@ class PartialModel(Sharing):
         """
         if self.alpha >= self.metadata_cap:  # Share fully
             if self.model.accumulated_changes is not None:
-                self.model.accumulated_changes = torch.zeros_like(self.model.accumulated_changes)
+                self.model.accumulated_changes = torch.zeros_like(
+                    self.model.accumulated_changes
+                )
             return super().serialized_model()
 
         with torch.no_grad():
@@ -278,7 +280,9 @@ class PartialModel(Sharing):
             ]
             self.pre_share_model = torch.cat(tensors_to_cat, dim=0)
             # Would only need one of the transforms
-            self.pre_share_model_transformed = self.change_transformer(self.pre_share_model)
+            self.pre_share_model_transformed = self.change_transformer(
+                self.pre_share_model
+            )
             change = self.change_transformer(self.pre_share_model - self.init_model)
             if self.accumulation:
                 if not self.accumulate_averaging_changes:
diff --git a/src/decentralizepy/sharing/RandomAlpha.py b/src/decentralizepy/sharing/RandomAlpha.py
index a91ba1d488bb0c1c63c4af14afdac076b62924fa..1956c29a2b8c81c191bcb46ab2bfcd4e61913262 100644
--- a/src/decentralizepy/sharing/RandomAlpha.py
+++ b/src/decentralizepy/sharing/RandomAlpha.py
@@ -20,7 +20,7 @@ class RandomAlpha(PartialModel):
         model,
         dataset,
         log_dir,
-        alpha_list=[0.1,0.2,0.3,0.4,1.0],
+        alpha_list=[0.1, 0.2, 0.3, 0.4, 1.0],
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
@@ -74,12 +74,10 @@ class RandomAlpha(PartialModel):
             accumulation,
             save_accumulated,
             change_transformer,
-            accumulate_averaging_changes
+            accumulate_averaging_changes,
         )
         self.alpha_list = eval(alpha_list)
-        random.seed(
-            self.mapping.get_uid(self.rank, self.machine_id)
-        )
+        random.seed(self.mapping.get_uid(self.rank, self.machine_id))
 
     def step(self):
         """
diff --git a/src/decentralizepy/sharing/RandomAlphaWavelet.py b/src/decentralizepy/sharing/RandomAlphaWavelet.py
index 61c17bba667f6882341936d596ab26b3e5284a37..62bc51f4513012eb08925b98b582ecca38b12e94 100644
--- a/src/decentralizepy/sharing/RandomAlphaWavelet.py
+++ b/src/decentralizepy/sharing/RandomAlphaWavelet.py
@@ -19,7 +19,7 @@ class RandomAlpha(Wavelet):
         model,
         dataset,
         log_dir,
-        alpha_list=[0.1,0.2,0.3,0.4,1.0],
+        alpha_list=[0.1, 0.2, 0.3, 0.4, 1.0],
         dict_ordered=True,
         save_shared=False,
         metadata_cap=1.0,
@@ -80,9 +80,7 @@ class RandomAlpha(Wavelet):
             accumulate_averaging_changes,
         )
         self.alpha_list = eval(alpha_list)
-        random.seed(
-            self.mapping.get_uid(self.rank, self.machine_id)
-        )
+        random.seed(self.mapping.get_uid(self.rank, self.machine_id))
 
     def step(self):
         """
diff --git a/src/decentralizepy/sharing/SharingCentrality.py b/src/decentralizepy/sharing/SharingCentrality.py
index 52d79fae1c7d9005e8ec7b5d8697d4d502b63572..580ce2aacc6505d7c713cfab763540c7484cd609 100644
--- a/src/decentralizepy/sharing/SharingCentrality.py
+++ b/src/decentralizepy/sharing/SharingCentrality.py
@@ -153,7 +153,7 @@ class Sharing:
                     )
                 )
                 data = self.deserialized_model(data)
-                weight = self.averaging_weights[self.uid,n]
+                weight = self.averaging_weights[self.uid, n]
                 for key, value in data.items():
                     if key in total:
                         total[key] += value * weight
@@ -161,7 +161,9 @@ class Sharing:
                         total[key] = value * weight
 
             for key, value in self.model.state_dict().items():
-                total[key] += self.averaging_weights[self.uid,self.uid] * value  # Metro-Hastings
+                total[key] += (
+                    self.averaging_weights[self.uid, self.uid] * value
+                )  # Metro-Hastings
 
         self.model.load_state_dict(total)
 
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index 81b9a85373e98468f67c81c0ecaf81a443aee22d..407714d794bbd7830e0348ea44c084cf32b300ef 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -183,7 +183,9 @@ class Wavelet(PartialModel):
             m["params"] = data.numpy()
             self.total_data += len(self.communication.encrypt(m["params"]))
             if self.model.accumulated_changes is not None:
-                self.model.accumulated_changes = torch.zeros_like(self.model.accumulated_changes)
+                self.model.accumulated_changes = torch.zeros_like(
+                    self.model.accumulated_changes
+                )
             return m
 
         with torch.no_grad():
diff --git a/src/decentralizepy/train_test_evaluation.py b/src/decentralizepy/train_test_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..316319dfdc355b021c9fbcbd2f7566f0ea780e6f
--- /dev/null
+++ b/src/decentralizepy/train_test_evaluation.py
@@ -0,0 +1,93 @@
+import logging
+import os
+import pickle
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from decentralizepy.graphs import Graph
+
+
+class TrainTestHelper:
+    def __init__(
+        self,
+        dataset,
+        model,
+        loss,
+        dir,
+        n_procs,
+        trainer,
+        comm,
+        graph: Graph,
+        threads_per_proc,
+        eval_train=False,
+    ):
+        self.dataset = dataset
+        self.model = model
+        self.loss = loss
+        self.dir = Path(dir)
+        self.n_procs = n_procs
+        self.trainer = trainer
+        self.comm = comm
+        self.star = graph
+        self.threads_per_proc = threads_per_proc
+        self.eval_train = eval_train
+
+    def train_test_evaluation(self, iteration):
+        with torch.no_grad():
+            total_threads = os.cpu_count()
+            torch.set_num_threads(total_threads)
+
+            neighbors = self.star.neighbors(0)
+            state_dict_copy = {}
+            shapes = []
+            lens = []
+            to_cat = []
+            for key, val in self.model.state_dict().items():
+                shapes.append(val.shape)
+                clone_val = val.clone().detach()
+                state_dict_copy[key] = clone_val
+                flat = clone_val.flatten()
+                to_cat.append(clone_val.flatten())
+                lens.append(flat.shape[0])
+
+            my_weight = torch.cat(to_cat)
+            weights = [my_weight]
+            # TODO: add weight of node 0
+            for i in neighbors:
+                sender, data = self.comm.receive()
+                logging.info(f"Received weight from {sender}")
+                weights.append(data)
+
+            # averaging
+            average_weight = np.mean([w.numpy() for w in weights], axis=0)
+
+            start_index = 0
+            average_weight_dict = {}
+            for i, key in enumerate(state_dict_copy):
+                end_index = start_index + lens[i]
+                average_weight_dict[key] = torch.from_numpy(
+                    average_weight[start_index:end_index].reshape(shapes[i])
+                )
+                start_index = end_index
+            self.model.load_state_dict(average_weight_dict)
+            if self.eval_train:
+                logging.info("Evaluating on train set.")
+                trl = self.trainer.eval_loss(self.dataset)
+            else:
+                trl = None
+            logging.info("Evaluating on test set.")
+            ta, tl = self.dataset.test(self.model, self.loss)
+            # reload old weight
+            self.model.load_state_dict(state_dict_copy)
+
+            if trl is not None:
+                print(iteration, ta, tl, trl)
+            else:
+                print(iteration, ta, tl)
+
+            torch.set_num_threads(self.threads_per_proc)
+            for neighbor in neighbors:
+                self.comm.send(neighbor, "finished")
+            return ta, tl, trl
diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py
index 4298ec37f9e17fd6d94153739776ecbbae9dc323..3c8c8581a4d14144d00280f7c0369030b7688026 100644
--- a/src/decentralizepy/utils.py
+++ b/src/decentralizepy/utils.py
@@ -69,6 +69,12 @@ def get_args():
         type=str,
         default="./{}".format(datetime.datetime.now().isoformat(timespec="minutes")),
     )
+    parser.add_argument(
+        "-wsd",
+        "--weights_store_dir",
+        type=str,
+        default="./{}_ws".format(datetime.datetime.now().isoformat(timespec="minutes")),
+    )
     parser.add_argument("-is", "--iterations", type=int, default=1)
     parser.add_argument("-cf", "--config_file", type=str, default="config.ini")
     parser.add_argument("-ll", "--log_level", type=str, default="INFO")
@@ -77,6 +83,8 @@ def get_args():
     parser.add_argument("-ta", "--test_after", type=int, default=5)
     parser.add_argument("-tea", "--train_evaluate_after", type=int, default=1)
     parser.add_argument("-ro", "--reset_optimizer", type=int, default=1)
+    parser.add_argument("-ctr", "--centralized_train_eval", type=int, default=0)
+    parser.add_argument("-cte", "--centralized_test_eval", type=int, default=1)
 
     args = parser.parse_args()
     return args
@@ -99,6 +107,7 @@ def write_args(args, path):
         "procs_per_machine": args.procs_per_machine,
         "machines": args.machines,
         "log_dir": args.log_dir,
+        "weights_store_dir": args.weights_store_dir,
         "iterations": args.iterations,
         "config_file": args.config_file,
         "log_level": args.log_level,
@@ -107,6 +116,8 @@ def write_args(args, path):
         "test_after": args.test_after,
         "train_evaluate_after": args.train_evaluate_after,
         "reset_optimizer": args.reset_optimizer,
+        "centralized_train_eval": args.centralized_train_eval,
+        "centralized_test_eval": args.centralized_test_eval,
     }
     with open(os.path.join(path, "args.json"), "w") as of:
         json.dump(data, of)