Skip to content
Snippets Groups Projects
Commit 1b7936b1 authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

wavelet and fft fix

parent 2efaec4a
No related branches found
No related tags found
No related merge requests found
......@@ -34,5 +34,5 @@ sharing_class = Wavelet
change_based_selection = True
alpha = 0.1
wavelet=sym2
level= None
level= 4
accumulation = True
......@@ -31,3 +31,4 @@ addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.PartialModel
sharing_class = PartialModel
alpha=0.1
......@@ -35,5 +35,5 @@ sharing_class = Wavelet
change_based_selection = True
alpha = 0.1
wavelet=sym2
level= None
level= 4
accumulation = True
......@@ -224,8 +224,11 @@ class FFT(PartialModel):
with torch.no_grad():
total = None
weight_total = 0
flat_fft = self.change_transformer(self.init_model)
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
pre_share_model = torch.cat(tensors_to_cat, dim=0)
flat_fft = self.change_transformer(pre_share_model)
for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft()
......
......@@ -155,7 +155,7 @@ class PartialModel(Sharing):
Model converted to a dict
"""
if self.alpha > self.metadata_cap: # Share fully
if self.alpha >= self.metadata_cap: # Share fully
return super().serialized_model()
with torch.no_grad():
......
......@@ -257,7 +257,11 @@ class Wavelet(PartialModel):
with torch.no_grad():
total = None
weight_total = 0
wt_params = self.change_transformer(self.init_model)
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
pre_share_model = torch.cat(tensors_to_cat, dim=0)
wt_params = self.change_transformer(pre_share_model)
for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft()
logging.debug(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment