Skip to content
Snippets Groups Projects
Commit 5c5bfb64 authored by Elisabeth Kirsten's avatar Elisabeth Kirsten
Browse files

refactors averaging at server

parent 833cfbe3
No related branches found
No related tags found
1 merge request!15Refactor and add federated + parameter server + central peer sampling
......@@ -35,10 +35,6 @@ class DPSGDNodeFederated(Node):
del data["iteration"]
del data["CHANNEL"]
if iteration == 0:
del data["degree"]
data = self.sharing.deserialized_model(data)
self.model.load_state_dict(data)
self.sharing._post_step()
self.sharing.communication_round += 1
......
......@@ -285,9 +285,7 @@ class FederatedParameterServer(Node):
self.current_workers = self.get_working_nodes()
# Params to send to workers
# if this is the first iteration, use the init parameters, else use averaged params from last iteration
if iteration == 0:
to_send = self.sharing.get_data_to_send()
to_send = self.model.state_dict()
to_send["CHANNEL"] = "WORKER_REQUEST"
to_send["iteration"] = iteration
......@@ -309,28 +307,11 @@ class FederatedParameterServer(Node):
# Average received updates
averaging_deque = dict()
total = dict()
for worker in self.current_workers:
averaging_deque[worker] = self.peer_deques[worker]
for i, n in enumerate(averaging_deque):
data = averaging_deque[n].popleft()
del data["degree"]
del data["iteration"]
del data["CHANNEL"]
data = self.sharing.deserialized_model(data)
for key, value in data.items():
if key in total:
total[key] += value
else:
total[key] = value
for key, value in total.items():
total[key] = total[key] / len(averaging_deque)
self.model.load_state_dict(total)
to_send = total
self.sharing._pre_step()
self.sharing._averaging_server(averaging_deque)
if iteration:
with open(
......
......@@ -80,13 +80,15 @@ class Sharing:
result = dict(data)
if self.compress:
if "params" in result:
result["params"] = self.compressor.compress_float(result["params"])
result["params"] = self.compressor.compress_float(
result["params"])
return result
def decompress_data(self, data):
if self.compress:
if "params" in data:
data["params"] = self.compressor.decompress_float(data["params"])
data["params"] = self.compressor.decompress_float(
data["params"])
return data
def serialized_model(self):
......@@ -171,7 +173,8 @@ class Sharing:
)
)
data = self.deserialized_model(data)
weight = 1 / (max(len(peer_deques), degree) + 1) # Metro-Hastings
# Metro-Hastings
weight = 1 / (max(len(peer_deques), degree) + 1)
weight_total += weight
for key, value in data.items():
if key in total:
......@@ -194,3 +197,34 @@ class Sharing:
data["degree"] = len(all_neighbors)
data["iteration"] = self.communication_round
return data
def _averaging_server(self, peer_deques):
"""
Averages the received models of all working nodes
"""
with torch.no_grad():
total = dict()
for i, n in enumerate(peer_deques):
data = peer_deques[n].popleft()
degree, iteration = data["degree"], data["iteration"]
del data["degree"]
del data["iteration"]
del data["CHANNEL"]
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
data = self.deserialized_model(data)
weight = 1 / len(peer_deques)
for key, value in data.items():
if key in total:
total[key] += weight * value
else:
total[key] = weight * value
self.model.load_state_dict(total)
self._post_step()
self.communication_round += 1
return total
......@@ -134,7 +134,8 @@ class Wavelet(PartialModel):
self.change_based_selection = change_based_selection
# Do a dummy transform to get the shape and coefficents slices
coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level)
coeff = pywt.wavedec(self.init_model.numpy(),
self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
self.wt_shape = data.shape
self.coeff_slices = coeff_slices
......@@ -203,14 +204,16 @@ class Wavelet(PartialModel):
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = indices.tolist() # is slow
# is slow
shared_params[self.communication_round] = indices.tolist()
shared_params["alpha"] = self.alpha
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
"{}_shared_params.json".format(
self.communication_round + 1),
),
"w",
) as of:
......@@ -296,7 +299,8 @@ class Wavelet(PartialModel):
else:
topkwf = params.reshape(self.wt_shape)
weight = 1 / (max(len(peer_deques), degree) + 1) # Metro-Hastings
# Metro-Hastings
weight = 1 / (max(len(peer_deques), degree) + 1)
weight_total += weight
if total is None:
total = weight * topkwf
......@@ -325,3 +329,59 @@ class Wavelet(PartialModel):
self.model.load_state_dict(std_dict)
self._post_step()
self.communication_round += 1
def _averaging_server(self, peer_deques):
"""
Averages the received models of all working nodes
"""
with torch.no_grad():
total = None
wt_params = self.pre_share_model_transformed
for i, n in enumerate(peer_deques):
data = peer_deques[n].popleft()
degree, iteration = data["degree"], data["iteration"]
del data["degree"]
del data["iteration"]
del data["CHANNEL"]
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
data = self.deserialized_model(data)
params = data["params"]
if "indices" in data:
indices = data["indices"]
# use local data to complement
topkwf = wt_params.clone().detach()
topkwf[indices] = params
topkwf = topkwf.reshape(self.wt_shape)
else:
topkwf = params.reshape(self.wt_shape)
weight = 1 / len(peer_deques)
if total is None:
total = weight * topkwf
else:
total += weight * topkwf
avg_wf_params = pywt.array_to_coeffs(
total.numpy(), self.coeff_slices, output_format="wavedec"
)
reverse_total = torch.from_numpy(
pywt.waverec(avg_wf_params, wavelet=self.wavelet)
)
start_index = 0
std_dict = {}
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
std_dict[key] = reverse_total[start_index:end_index].reshape(
self.shapes[i]
)
start_index = end_index
self.model.load_state_dict(std_dict)
self._post_step()
self.communication_round += 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment