From 3dc5b1745dfd55cbd6a7e5e5456dfe7264c313b3 Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Mon, 9 May 2022 17:14:17 +0200
Subject: [PATCH] sharing works now with data compression

---
 src/decentralizepy/compression/EliasFpzip.py  |  2 +-
 .../compression/EliasFpzipLossy.py            |  2 +-
 src/decentralizepy/sharing/Sharing.py         | 28 +++++++++++++++----
 3 files changed, 24 insertions(+), 8 deletions(-)

diff --git a/src/decentralizepy/compression/EliasFpzip.py b/src/decentralizepy/compression/EliasFpzip.py
index dc1413a..0c82560 100644
--- a/src/decentralizepy/compression/EliasFpzip.py
+++ b/src/decentralizepy/compression/EliasFpzip.py
@@ -49,4 +49,4 @@ class EliasFpzip(Elias):
             decompressed data as array
 
         """
-        return fpzip.decompress(bytes, order="C")
+        return fpzip.decompress(bytes, order="C").squeeze()
diff --git a/src/decentralizepy/compression/EliasFpzipLossy.py b/src/decentralizepy/compression/EliasFpzipLossy.py
index 30e0111..617a78b 100644
--- a/src/decentralizepy/compression/EliasFpzipLossy.py
+++ b/src/decentralizepy/compression/EliasFpzipLossy.py
@@ -49,4 +49,4 @@ class EliasFpzipLossy(Elias):
             decompressed data as array
 
         """
-        return fpzip.decompress(bytes, order="C")
+        return fpzip.decompress(bytes, order="C").squeeze()
diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py
index 7dc8852..0ad3927 100644
--- a/src/decentralizepy/sharing/Sharing.py
+++ b/src/decentralizepy/sharing/Sharing.py
@@ -52,6 +52,14 @@ class Sharing:
         for n in self.my_neighbors:
             self.peer_deques[n] = deque()
 
+        self.shapes = []
+        self.lens = []
+        with torch.no_grad():
+            for _, v in self.model.state_dict().items():
+                self.shapes.append(v.shape)
+                t = v.flatten().numpy()
+                self.lens.append(t.shape[0])
+
     def received_from_all(self):
         """
         Check if all neighbors have sent the current iteration
@@ -95,11 +103,14 @@ class Sharing:
             Model converted to dict
 
         """
-        m = dict()
-        for key, val in self.model.state_dict().items():
-            m[key] = val.numpy()
+        to_cat = []
+        with torch.no_grad():
+            for _, v in self.model.state_dict().items():
+                t = v.flatten()
+                to_cat.append(t)
+        flat = torch.cat(to_cat)
         data = dict()
-        data["params"] = m
+        data["params"] = flat.numpy()
         return data
 
     def deserialized_model(self, m):
@@ -118,8 +129,13 @@ class Sharing:
 
         """
         state_dict = dict()
-        for key, value in m["params"].items():
-            state_dict[key] = torch.from_numpy(value)
+        T = m["params"]
+        start_index = 0
+        for i, key in enumerate(self.model.state_dict()):
+            end_index = start_index + self.lens[i]
+            state_dict[key] = torch.from_numpy(T[start_index:end_index].reshape(self.shapes[i]))
+            start_index = end_index
+
         return state_dict
 
     def _pre_step(self):
-- 
GitLab