From 786dcb98dfa7ad5d080561c4c4fab77447f4bd4f Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Mon, 9 May 2022 16:17:50 +0200
Subject: [PATCH] removing total_data and total_meta from sharing classes
 Fixing subsampling with compression

---
 src/decentralizepy/communication/TCP.py         | 2 ++
 src/decentralizepy/sharing/FFT.py               | 6 ------
 src/decentralizepy/sharing/PartialModel.py      | 1 -
 src/decentralizepy/sharing/Sharing.py           | 8 ++++----
 src/decentralizepy/sharing/SharingCentrality.py | 2 --
 src/decentralizepy/sharing/SubSampling.py       | 7 -------
 src/decentralizepy/sharing/Synchronous.py       | 3 ---
 src/decentralizepy/sharing/TopKParams.py        | 5 -----
 src/decentralizepy/sharing/Wavelet.py           | 1 -
 9 files changed, 6 insertions(+), 29 deletions(-)

diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py
index c609699..58cc55d 100644
--- a/src/decentralizepy/communication/TCP.py
+++ b/src/decentralizepy/communication/TCP.py
@@ -134,6 +134,8 @@ class TCP(Communication):
                 meta_len = len(
                     pickle.dumps(data["indices"])
                 )  # ONLY necessary for the statistics
+            else:
+                meta_len = 0
             if "params" in data:
                 data["params"] = self.compressor.compress_float(data["params"])
             output = pickle.dumps(data)
diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
index ba7b841..17650c1 100644
--- a/src/decentralizepy/sharing/FFT.py
+++ b/src/decentralizepy/sharing/FFT.py
@@ -159,7 +159,6 @@ class FFT(PartialModel):
         if self.alpha >= self.metadata_cap:  # Share fully
             data = self.pre_share_model_transformed
             m["params"] = data.numpy()
-            self.total_data += len(self.communication.encrypt(m["params"]))
             if self.model.accumulated_changes is not None:
                 self.model.accumulated_changes = torch.zeros_like(
                     self.model.accumulated_changes
@@ -200,11 +199,6 @@ class FFT(PartialModel):
             m["indices"] = indices.numpy().astype(np.int32)
             m["send_partial"] = True
 
-            self.total_data += len(self.communication.encrypt(m["params"]))
-            self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
-                self.communication.encrypt(m["alpha"])
-            )
-
         return m
 
     def deserialized_model(self, m):
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 1f1feca..3111e82 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -82,7 +82,6 @@ class PartialModel(Sharing):
         self.dict_ordered = dict_ordered
         self.save_shared = save_shared
         self.metadata_cap = metadata_cap
-        self.total_meta = 0
         self.accumulation = accumulation
         self.save_accumulated = conditional_value(save_accumulated, "", False)
         self.change_transformer = change_transformer
diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py
index 22b4de9..7dc8852 100644
--- a/src/decentralizepy/sharing/Sharing.py
+++ b/src/decentralizepy/sharing/Sharing.py
@@ -46,7 +46,6 @@ class Sharing:
         self.dataset = dataset
         self.communication_round = 0
         self.log_dir = log_dir
-        self.total_data = 0
 
         self.peer_deques = dict()
         self.my_neighbors = self.graph.neighbors(self.uid)
@@ -99,8 +98,9 @@ class Sharing:
         m = dict()
         for key, val in self.model.state_dict().items():
             m[key] = val.numpy()
-            self.total_data += len(self.communication.encrypt(m[key]))
-        return m
+        data = dict()
+        data["params"] = m
+        return data
 
     def deserialized_model(self, m):
         """
@@ -118,7 +118,7 @@ class Sharing:
 
         """
         state_dict = dict()
-        for key, value in m.items():
+        for key, value in m["params"].items():
             state_dict[key] = torch.from_numpy(value)
         return state_dict
 
diff --git a/src/decentralizepy/sharing/SharingCentrality.py b/src/decentralizepy/sharing/SharingCentrality.py
index 580ce2a..f933a0e 100644
--- a/src/decentralizepy/sharing/SharingCentrality.py
+++ b/src/decentralizepy/sharing/SharingCentrality.py
@@ -46,7 +46,6 @@ class Sharing:
         self.dataset = dataset
         self.communication_round = 0
         self.log_dir = log_dir
-        self.total_data = 0
 
         self.peer_deques = dict()
         my_neighbors = self.graph.neighbors(self.uid)
@@ -101,7 +100,6 @@ class Sharing:
         m = dict()
         for key, val in self.model.state_dict().items():
             m[key] = val.numpy()
-            self.total_data += len(self.communication.encrypt(m[key]))
         return m
 
     def deserialized_model(self, m):
diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py
index f8c8f50..b51cb07 100644
--- a/src/decentralizepy/sharing/SubSampling.py
+++ b/src/decentralizepy/sharing/SubSampling.py
@@ -72,7 +72,6 @@ class SubSampling(Sharing):
         self.dict_ordered = dict_ordered
         self.save_shared = save_shared
         self.metadata_cap = metadata_cap
-        self.total_meta = 0
 
         # self.random_seed_generator = torch.Generator()
         # # Will use the random device if supported by CPU, else uses the system time
@@ -216,12 +215,6 @@ class SubSampling(Sharing):
             m["alpha"] = alpha
             m["params"] = subsample.numpy()
 
-            # logging.info("Converted dictionary to json")
-            self.total_data += len(self.communication.encrypt(m["params"]))
-            self.total_meta += len(self.communication.encrypt(m["seed"])) + len(
-                self.communication.encrypt(m["alpha"])
-            )
-
             return m
 
     def deserialized_model(self, m):
diff --git a/src/decentralizepy/sharing/Synchronous.py b/src/decentralizepy/sharing/Synchronous.py
index 29d7f62..2c2d5e7 100644
--- a/src/decentralizepy/sharing/Synchronous.py
+++ b/src/decentralizepy/sharing/Synchronous.py
@@ -46,7 +46,6 @@ class Synchronous:
         self.dataset = dataset
         self.communication_round = 0
         self.log_dir = log_dir
-        self.total_data = 0
 
         self.peer_deques = dict()
         self.my_neighbors = self.graph.neighbors(self.uid)
@@ -104,7 +103,6 @@ class Synchronous:
         m = dict()
         for key, val in self.model.state_dict().items():
             m[key] = val - self.init_model[key]  # this is -lr*gradient
-        self.total_data += len(self.communication.encrypt(m))
         return m
 
     def serialized_model(self):
@@ -120,7 +118,6 @@ class Synchronous:
         m = dict()
         for key, val in self.model.state_dict().items():
             m[key] = val.clone().detach()
-        self.total_data += len(self.communication.encrypt(m))
         return m
 
     def deserialized_model(self, m):
diff --git a/src/decentralizepy/sharing/TopKParams.py b/src/decentralizepy/sharing/TopKParams.py
index 02531f1..f188179 100644
--- a/src/decentralizepy/sharing/TopKParams.py
+++ b/src/decentralizepy/sharing/TopKParams.py
@@ -68,7 +68,6 @@ class TopKParams(Sharing):
         self.dict_ordered = dict_ordered
         self.save_shared = save_shared
         self.metadata_cap = metadata_cap
-        self.total_meta = 0
 
         if self.save_shared:
             # Only save for 2 procs: Save space
@@ -171,10 +170,6 @@ class TopKParams(Sharing):
             #    m[key] = json.dumps(m[key])
 
             logging.info("Converted dictionary to json")
-            self.total_data += len(self.communication.encrypt(m["params"]))
-            self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
-                self.communication.encrypt(m["offsets"])
-            )
 
             return m
 
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index b864f1f..91c97d0 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -181,7 +181,6 @@ class Wavelet(PartialModel):
         if self.alpha >= self.metadata_cap:  # Share fully
             data = self.pre_share_model_transformed
             m["params"] = data.numpy()
-            self.total_data += len(self.communication.encrypt(m["params"]))
             if self.model.accumulated_changes is not None:
                 self.model.accumulated_changes = torch.zeros_like(
                     self.model.accumulated_changes
-- 
GitLab