Skip to content
Snippets Groups Projects
Commit 3dc5b174 authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

sharing works now with data compression

parent 786dcb98
No related branches found
No related tags found
1 merge request!12Fixes to the previous PR
...@@ -49,4 +49,4 @@ class EliasFpzip(Elias): ...@@ -49,4 +49,4 @@ class EliasFpzip(Elias):
decompressed data as array decompressed data as array
""" """
return fpzip.decompress(bytes, order="C") return fpzip.decompress(bytes, order="C").squeeze()
...@@ -49,4 +49,4 @@ class EliasFpzipLossy(Elias): ...@@ -49,4 +49,4 @@ class EliasFpzipLossy(Elias):
decompressed data as array decompressed data as array
""" """
return fpzip.decompress(bytes, order="C") return fpzip.decompress(bytes, order="C").squeeze()
...@@ -52,6 +52,14 @@ class Sharing: ...@@ -52,6 +52,14 @@ class Sharing:
for n in self.my_neighbors: for n in self.my_neighbors:
self.peer_deques[n] = deque() self.peer_deques[n] = deque()
self.shapes = []
self.lens = []
with torch.no_grad():
for _, v in self.model.state_dict().items():
self.shapes.append(v.shape)
t = v.flatten().numpy()
self.lens.append(t.shape[0])
def received_from_all(self): def received_from_all(self):
""" """
Check if all neighbors have sent the current iteration Check if all neighbors have sent the current iteration
...@@ -95,11 +103,14 @@ class Sharing: ...@@ -95,11 +103,14 @@ class Sharing:
Model converted to dict Model converted to dict
""" """
m = dict() to_cat = []
for key, val in self.model.state_dict().items(): with torch.no_grad():
m[key] = val.numpy() for _, v in self.model.state_dict().items():
t = v.flatten()
to_cat.append(t)
flat = torch.cat(to_cat)
data = dict() data = dict()
data["params"] = m data["params"] = flat.numpy()
return data return data
def deserialized_model(self, m): def deserialized_model(self, m):
...@@ -118,8 +129,13 @@ class Sharing: ...@@ -118,8 +129,13 @@ class Sharing:
""" """
state_dict = dict() state_dict = dict()
for key, value in m["params"].items(): T = m["params"]
state_dict[key] = torch.from_numpy(value) start_index = 0
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
state_dict[key] = torch.from_numpy(T[start_index:end_index].reshape(self.shapes[i]))
start_index = end_index
return state_dict return state_dict
def _pre_step(self): def _pre_step(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment