From 3dc5b1745dfd55cbd6a7e5e5456dfe7264c313b3 Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Mon, 9 May 2022 17:14:17 +0200 Subject: [PATCH] sharing works now with data compression --- src/decentralizepy/compression/EliasFpzip.py | 2 +- .../compression/EliasFpzipLossy.py | 2 +- src/decentralizepy/sharing/Sharing.py | 28 +++++++++++++++---- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/decentralizepy/compression/EliasFpzip.py b/src/decentralizepy/compression/EliasFpzip.py index dc1413a..0c82560 100644 --- a/src/decentralizepy/compression/EliasFpzip.py +++ b/src/decentralizepy/compression/EliasFpzip.py @@ -49,4 +49,4 @@ class EliasFpzip(Elias): decompressed data as array """ - return fpzip.decompress(bytes, order="C") + return fpzip.decompress(bytes, order="C").squeeze() diff --git a/src/decentralizepy/compression/EliasFpzipLossy.py b/src/decentralizepy/compression/EliasFpzipLossy.py index 30e0111..617a78b 100644 --- a/src/decentralizepy/compression/EliasFpzipLossy.py +++ b/src/decentralizepy/compression/EliasFpzipLossy.py @@ -49,4 +49,4 @@ class EliasFpzipLossy(Elias): decompressed data as array """ - return fpzip.decompress(bytes, order="C") + return fpzip.decompress(bytes, order="C").squeeze() diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index 7dc8852..0ad3927 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -52,6 +52,14 @@ class Sharing: for n in self.my_neighbors: self.peer_deques[n] = deque() + self.shapes = [] + self.lens = [] + with torch.no_grad(): + for _, v in self.model.state_dict().items(): + self.shapes.append(v.shape) + t = v.flatten().numpy() + self.lens.append(t.shape[0]) + def received_from_all(self): """ Check if all neighbors have sent the current iteration @@ -95,11 +103,14 @@ class Sharing: Model converted to dict """ - m = dict() - for key, val in self.model.state_dict().items(): - m[key] = val.numpy() + to_cat = [] + with torch.no_grad(): + for _, v in self.model.state_dict().items(): + t = v.flatten() + to_cat.append(t) + flat = torch.cat(to_cat) data = dict() - data["params"] = m + data["params"] = flat.numpy() return data def deserialized_model(self, m): @@ -118,8 +129,13 @@ class Sharing: """ state_dict = dict() - for key, value in m["params"].items(): - state_dict[key] = torch.from_numpy(value) + T = m["params"] + start_index = 0 + for i, key in enumerate(self.model.state_dict()): + end_index = start_index + self.lens[i] + state_dict[key] = torch.from_numpy(T[start_index:end_index].reshape(self.shapes[i])) + start_index = end_index + return state_dict def _pre_step(self): -- GitLab