Skip to content
Snippets Groups Projects
Commit 1b7936b1 authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

wavelet and fft fix

parent 2efaec4a
No related branches found
No related tags found
1 merge request!3FFT Wavelets and more
...@@ -34,5 +34,5 @@ sharing_class = Wavelet ...@@ -34,5 +34,5 @@ sharing_class = Wavelet
change_based_selection = True change_based_selection = True
alpha = 0.1 alpha = 0.1
wavelet=sym2 wavelet=sym2
level= None level= 4
accumulation = True accumulation = True
...@@ -31,3 +31,4 @@ addresses_filepath = ip_addr_6Machines.json ...@@ -31,3 +31,4 @@ addresses_filepath = ip_addr_6Machines.json
[SHARING] [SHARING]
sharing_package = decentralizepy.sharing.PartialModel sharing_package = decentralizepy.sharing.PartialModel
sharing_class = PartialModel sharing_class = PartialModel
alpha=0.1
...@@ -35,5 +35,5 @@ sharing_class = Wavelet ...@@ -35,5 +35,5 @@ sharing_class = Wavelet
change_based_selection = True change_based_selection = True
alpha = 0.1 alpha = 0.1
wavelet=sym2 wavelet=sym2
level= None level= 4
accumulation = True accumulation = True
...@@ -224,8 +224,11 @@ class FFT(PartialModel): ...@@ -224,8 +224,11 @@ class FFT(PartialModel):
with torch.no_grad(): with torch.no_grad():
total = None total = None
weight_total = 0 weight_total = 0
tensors_to_cat = [
flat_fft = self.change_transformer(self.init_model) v.data.flatten() for _, v in self.model.state_dict().items()
]
pre_share_model = torch.cat(tensors_to_cat, dim=0)
flat_fft = self.change_transformer(pre_share_model)
for i, n in enumerate(self.peer_deques): for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft() degree, iteration, data = self.peer_deques[n].popleft()
......
...@@ -155,7 +155,7 @@ class PartialModel(Sharing): ...@@ -155,7 +155,7 @@ class PartialModel(Sharing):
Model converted to a dict Model converted to a dict
""" """
if self.alpha > self.metadata_cap: # Share fully if self.alpha >= self.metadata_cap: # Share fully
return super().serialized_model() return super().serialized_model()
with torch.no_grad(): with torch.no_grad():
......
...@@ -257,7 +257,11 @@ class Wavelet(PartialModel): ...@@ -257,7 +257,11 @@ class Wavelet(PartialModel):
with torch.no_grad(): with torch.no_grad():
total = None total = None
weight_total = 0 weight_total = 0
wt_params = self.change_transformer(self.init_model) tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
pre_share_model = torch.cat(tensors_to_cat, dim=0)
wt_params = self.change_transformer(pre_share_model)
for i, n in enumerate(self.peer_deques): for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft() degree, iteration, data = self.peer_deques[n].popleft()
logging.debug( logging.debug(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment