diff --git a/src/decentralizepy/compression/EliasFpzip.py b/src/decentralizepy/compression/EliasFpzip.py index dc1413a22f98f2378401ecf173bc0404a51bc4a5..0c82560aae28efb64bdba2a9cf0abf81d7fcdda7 100644 --- a/src/decentralizepy/compression/EliasFpzip.py +++ b/src/decentralizepy/compression/EliasFpzip.py @@ -49,4 +49,4 @@ class EliasFpzip(Elias): decompressed data as array """ - return fpzip.decompress(bytes, order="C") + return fpzip.decompress(bytes, order="C").squeeze() diff --git a/src/decentralizepy/compression/EliasFpzipLossy.py b/src/decentralizepy/compression/EliasFpzipLossy.py index 30e0111be49451386a760dfc379c0f2f2e1a7c7e..617a78b2b27ff88bd57db29a3a65d71ad1e0a843 100644 --- a/src/decentralizepy/compression/EliasFpzipLossy.py +++ b/src/decentralizepy/compression/EliasFpzipLossy.py @@ -49,4 +49,4 @@ class EliasFpzipLossy(Elias): decompressed data as array """ - return fpzip.decompress(bytes, order="C") + return fpzip.decompress(bytes, order="C").squeeze() diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index 7dc8852797af138fb13d0c02ef0dd1e0d1dd6f19..0ad3927a6d7fb80acfa01e120fde1ef8db4a8d77 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -52,6 +52,14 @@ class Sharing: for n in self.my_neighbors: self.peer_deques[n] = deque() + self.shapes = [] + self.lens = [] + with torch.no_grad(): + for _, v in self.model.state_dict().items(): + self.shapes.append(v.shape) + t = v.flatten().numpy() + self.lens.append(t.shape[0]) + def received_from_all(self): """ Check if all neighbors have sent the current iteration @@ -95,11 +103,14 @@ class Sharing: Model converted to dict """ - m = dict() - for key, val in self.model.state_dict().items(): - m[key] = val.numpy() + to_cat = [] + with torch.no_grad(): + for _, v in self.model.state_dict().items(): + t = v.flatten() + to_cat.append(t) + flat = torch.cat(to_cat) data = dict() - data["params"] = m + data["params"] = flat.numpy() return data def deserialized_model(self, m): @@ -118,8 +129,13 @@ class Sharing: """ state_dict = dict() - for key, value in m["params"].items(): - state_dict[key] = torch.from_numpy(value) + T = m["params"] + start_index = 0 + for i, key in enumerate(self.model.state_dict()): + end_index = start_index + self.lens[i] + state_dict[key] = torch.from_numpy(T[start_index:end_index].reshape(self.shapes[i])) + start_index = end_index + return state_dict def _pre_step(self):