diff --git a/src/decentralizepy/node/DPSGDNodeFederated.py b/src/decentralizepy/node/DPSGDNodeFederated.py
index d77e35d409d86fa66ffc210e6e0ff0bd62f9f865..7d8f5b1efdd4977413610cf114f27fa06cd83bab 100644
--- a/src/decentralizepy/node/DPSGDNodeFederated.py
+++ b/src/decentralizepy/node/DPSGDNodeFederated.py
@@ -35,10 +35,6 @@ class DPSGDNodeFederated(Node):
             del data["iteration"]
             del data["CHANNEL"]
 
-            if iteration == 0:
-                del data["degree"]
-                data = self.sharing.deserialized_model(data)
-
             self.model.load_state_dict(data)
             self.sharing._post_step()
             self.sharing.communication_round += 1
diff --git a/src/decentralizepy/node/FederatedParameterServer.py b/src/decentralizepy/node/FederatedParameterServer.py
index 312fa6d79187e7cec8b1433d46fb84fc7d296554..ee919b73abf34934029b9cfdac874ee183def6ab 100644
--- a/src/decentralizepy/node/FederatedParameterServer.py
+++ b/src/decentralizepy/node/FederatedParameterServer.py
@@ -285,9 +285,7 @@ class FederatedParameterServer(Node):
             self.current_workers = self.get_working_nodes()
 
             # Params to send to workers
-            # if this is the first iteration, use the init parameters, else use averaged params from last iteration
-            if iteration == 0:
-                to_send = self.sharing.get_data_to_send()
+            to_send = self.model.state_dict()
 
             to_send["CHANNEL"] = "WORKER_REQUEST"
             to_send["iteration"] = iteration
@@ -309,28 +307,11 @@ class FederatedParameterServer(Node):
 
             # Average received updates
             averaging_deque = dict()
-            total = dict()
             for worker in self.current_workers:
                 averaging_deque[worker] = self.peer_deques[worker]
 
-            for i, n in enumerate(averaging_deque):
-                data = averaging_deque[n].popleft()
-                del data["degree"]
-                del data["iteration"]
-                del data["CHANNEL"]
-                data = self.sharing.deserialized_model(data)
-                for key, value in data.items():
-                    if key in total:
-                        total[key] += value
-                    else:
-                        total[key] = value
-
-            for key, value in total.items():
-                total[key] = total[key] / len(averaging_deque)
-
-            self.model.load_state_dict(total)
-
-            to_send = total
+            self.sharing._pre_step()
+            self.sharing._averaging_server(averaging_deque)
 
             if iteration:
                 with open(
diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py
index a41de33203d72cafd4aecca10a8b14b347625d72..49419214a36260574d8c617d187273dc0501fd25 100644
--- a/src/decentralizepy/sharing/Sharing.py
+++ b/src/decentralizepy/sharing/Sharing.py
@@ -80,13 +80,15 @@ class Sharing:
         result = dict(data)
         if self.compress:
             if "params" in result:
-                result["params"] = self.compressor.compress_float(result["params"])
+                result["params"] = self.compressor.compress_float(
+                    result["params"])
         return result
 
     def decompress_data(self, data):
         if self.compress:
             if "params" in data:
-                data["params"] = self.compressor.decompress_float(data["params"])
+                data["params"] = self.compressor.decompress_float(
+                    data["params"])
         return data
 
     def serialized_model(self):
@@ -171,7 +173,8 @@ class Sharing:
                     )
                 )
                 data = self.deserialized_model(data)
-                weight = 1 / (max(len(peer_deques), degree) + 1)  # Metro-Hastings
+                # Metro-Hastings
+                weight = 1 / (max(len(peer_deques), degree) + 1)
                 weight_total += weight
                 for key, value in data.items():
                     if key in total:
@@ -194,3 +197,34 @@ class Sharing:
         data["degree"] = len(all_neighbors)
         data["iteration"] = self.communication_round
         return data
+
+    def _averaging_server(self, peer_deques):
+        """
+        Averages the received models of all working nodes
+
+        """
+        with torch.no_grad():
+            total = dict()
+            for i, n in enumerate(peer_deques):
+                data = peer_deques[n].popleft()
+                degree, iteration = data["degree"], data["iteration"]
+                del data["degree"]
+                del data["iteration"]
+                del data["CHANNEL"]
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(
+                        n, iteration
+                    )
+                )
+                data = self.deserialized_model(data)
+                weight = 1 / len(peer_deques)
+                for key, value in data.items():
+                    if key in total:
+                        total[key] += weight * value
+                    else:
+                        total[key] = weight * value
+
+        self.model.load_state_dict(total)
+        self._post_step()
+        self.communication_round += 1
+        return total
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index 24db77afea84b7a8424303f2239121423f5f054c..b0a4796fa909cd7cfc3d333bfb4bab45236851a9 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -134,7 +134,8 @@ class Wavelet(PartialModel):
         self.change_based_selection = change_based_selection
 
         # Do a dummy transform to get the shape and coefficents slices
-        coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level)
+        coeff = pywt.wavedec(self.init_model.numpy(),
+                             self.wavelet, level=self.level)
         data, coeff_slices = pywt.coeffs_to_array(coeff)
         self.wt_shape = data.shape
         self.coeff_slices = coeff_slices
@@ -203,14 +204,16 @@ class Wavelet(PartialModel):
                     shapes[k] = list(v.shape)
                 shared_params["shapes"] = shapes
 
-                shared_params[self.communication_round] = indices.tolist()  # is slow
+                # is slow
+                shared_params[self.communication_round] = indices.tolist()
 
                 shared_params["alpha"] = self.alpha
 
                 with open(
                     os.path.join(
                         self.folder_path,
-                        "{}_shared_params.json".format(self.communication_round + 1),
+                        "{}_shared_params.json".format(
+                            self.communication_round + 1),
                     ),
                     "w",
                 ) as of:
@@ -296,7 +299,8 @@ class Wavelet(PartialModel):
                 else:
                     topkwf = params.reshape(self.wt_shape)
 
-                weight = 1 / (max(len(peer_deques), degree) + 1)  # Metro-Hastings
+                # Metro-Hastings
+                weight = 1 / (max(len(peer_deques), degree) + 1)
                 weight_total += weight
                 if total is None:
                     total = weight * topkwf
@@ -325,3 +329,59 @@ class Wavelet(PartialModel):
         self.model.load_state_dict(std_dict)
         self._post_step()
         self.communication_round += 1
+
+    def _averaging_server(self, peer_deques):
+        """
+        Averages the received models of all working nodes
+
+        """
+        with torch.no_grad():
+            total = None
+            wt_params = self.pre_share_model_transformed
+            for i, n in enumerate(peer_deques):
+                data = peer_deques[n].popleft()
+                degree, iteration = data["degree"], data["iteration"]
+                del data["degree"]
+                del data["iteration"]
+                del data["CHANNEL"]
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(
+                        n, iteration
+                    )
+                )
+                data = self.deserialized_model(data)
+                params = data["params"]
+                if "indices" in data:
+                    indices = data["indices"]
+                    # use local data to complement
+                    topkwf = wt_params.clone().detach()
+                    topkwf[indices] = params
+                    topkwf = topkwf.reshape(self.wt_shape)
+                else:
+                    topkwf = params.reshape(self.wt_shape)
+
+                weight = 1 / len(peer_deques)
+                if total is None:
+                    total = weight * topkwf
+                else:
+                    total += weight * topkwf
+
+            avg_wf_params = pywt.array_to_coeffs(
+                total.numpy(), self.coeff_slices, output_format="wavedec"
+            )
+            reverse_total = torch.from_numpy(
+                pywt.waverec(avg_wf_params, wavelet=self.wavelet)
+            )
+
+            start_index = 0
+            std_dict = {}
+            for i, key in enumerate(self.model.state_dict()):
+                end_index = start_index + self.lens[i]
+                std_dict[key] = reverse_total[start_index:end_index].reshape(
+                    self.shapes[i]
+                )
+                start_index = end_index
+
+        self.model.load_state_dict(std_dict)
+        self._post_step()
+        self.communication_round += 1