From 0fa9ba103b62257693b6f0b8c8f61bc5f283f27f Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Wed, 9 Mar 2022 15:57:16 +0100
Subject: [PATCH] encoding indices as np.int32

---
 src/decentralizepy/sharing/FFT.py         | 18 +++---------------
 src/decentralizepy/sharing/SubSampling.py |  5 ++---
 src/decentralizepy/sharing/TopK.py        |  5 +++--
 src/decentralizepy/sharing/TopKParams.py  |  5 +++--
 src/decentralizepy/sharing/Wavelet.py     |  7 +++----
 5 files changed, 14 insertions(+), 26 deletions(-)

diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
index 4a3ee36..1cc8382 100644
--- a/src/decentralizepy/sharing/FFT.py
+++ b/src/decentralizepy/sharing/FFT.py
@@ -1,13 +1,12 @@
-import base64
 import json
 import logging
 import os
-import pickle
 from pathlib import Path
 from time import time
 
 import torch
 import torch.fft as fft
+import numpy as np
 
 from decentralizepy.sharing.Sharing import Sharing
 
@@ -182,7 +181,7 @@ class FFT(Sharing):
 
             m["alpha"] = self.alpha
             m["params"] = topk.numpy()
-            m["indices"] = indices.numpy()
+            m["indices"] = indices.numpy().astype(np.int32)
 
             self.total_data += len(self.communication.encrypt(m["params"]))
             self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
@@ -215,23 +214,12 @@ class FFT(Sharing):
             if not self.dict_ordered:
                 raise NotImplementedError
 
-            shapes = []
-            lens = []
-            tensors_to_cat = []
-            for _, v in state_dict.items():
-                shapes.append(v.shape)
-                t = v.flatten()
-                lens.append(t.shape[0])
-                tensors_to_cat.append(t)
-
-            T = torch.cat(tensors_to_cat, dim=0)
-
             indices = m["indices"]
             alpha = m["alpha"]
             params = m["params"]
 
             params_tensor = torch.tensor(params)
-            indices_tensor = torch.tensor(indices)
+            indices_tensor = torch.tensor(indices, dtype=torch.long)
             ret = dict()
             ret["indices"] = indices_tensor
             ret["params"] = params_tensor
diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py
index 6fe3f93..1e956cd 100644
--- a/src/decentralizepy/sharing/SubSampling.py
+++ b/src/decentralizepy/sharing/SubSampling.py
@@ -1,11 +1,10 @@
-import base64
 import json
 import logging
 import os
-import pickle
 from pathlib import Path
 
 import torch
+import numpy as np
 
 from decentralizepy.sharing.Sharing import Sharing
 
@@ -203,7 +202,7 @@ class SubSampling(Sharing):
 
             m["seed"] = seed
             m["alpha"] = alpha
-            m["params"] = subsample.numpy()
+            m["params"] = subsample.numpy().astype(np.int32)
 
             # logging.info("Converted dictionary to json")
             self.total_data += len(self.communication.encrypt(m["params"]))
diff --git a/src/decentralizepy/sharing/TopK.py b/src/decentralizepy/sharing/TopK.py
index 47b4151..f50ba7e 100644
--- a/src/decentralizepy/sharing/TopK.py
+++ b/src/decentralizepy/sharing/TopK.py
@@ -3,6 +3,7 @@ import logging
 import os
 from pathlib import Path
 
+import numpy as np
 import torch
 
 from decentralizepy.sharing.Sharing import Sharing
@@ -166,7 +167,7 @@ class TopK(Sharing):
             if not self.dict_ordered:
                 raise NotImplementedError
 
-            m["indices"] = G_topk.numpy()
+            m["indices"] = G_topk.numpy().astype(np.int32)
             m["params"] = T_topk.numpy()
 
             assert len(m["indices"]) == len(m["params"])
@@ -214,7 +215,7 @@ class TopK(Sharing):
                 tensors_to_cat.append(t)
 
             T = torch.cat(tensors_to_cat, dim=0)
-            index_tensor = torch.tensor(m["indices"])
+            index_tensor = torch.tensor(m["indices"], dtype=torch.long)
             logging.debug("Original tensor: {}".format(T[index_tensor]))
             T[index_tensor] = torch.tensor(m["params"])
             logging.debug("Final tensor: {}".format(T[index_tensor]))
diff --git a/src/decentralizepy/sharing/TopKParams.py b/src/decentralizepy/sharing/TopKParams.py
index 3beb10f..c6535ce 100644
--- a/src/decentralizepy/sharing/TopKParams.py
+++ b/src/decentralizepy/sharing/TopKParams.py
@@ -3,6 +3,7 @@ import logging
 import os
 from pathlib import Path
 
+import numpy as np
 import torch
 
 from decentralizepy.sharing.Sharing import Sharing
@@ -157,7 +158,7 @@ class TopKParams(Sharing):
             if not self.dict_ordered:
                 raise NotImplementedError
 
-            m["indices"] = index.numpy()
+            m["indices"] = index.numpy().astype(np.int32)
             m["params"] = values.numpy()
             m["offsets"] = offsets
 
@@ -206,7 +207,7 @@ class TopKParams(Sharing):
             tensors_to_cat = []
             offsets = m["offsets"]
             params = torch.tensor(m["params"])
-            indices = torch.tensor(m["indices"])
+            indices = torch.tensor(m["indices"], dtype=torch.long)
 
             for i, (_, v) in enumerate(state_dict.items()):
                 shapes.append(v.shape)
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index a6cccaf..774dfe0 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -1,11 +1,10 @@
-import base64
 import json
 import logging
 import os
-import pickle
 from pathlib import Path
 from time import time
 
+import numpy as np
 import pywt
 import torch
 
@@ -206,7 +205,7 @@ class Wavelet(Sharing):
 
             m["params"] = topk.numpy()
 
-            m["indices"] = indices.numpy()
+            m["indices"] = indices.numpy().astype(np.int32)
 
             self.total_data += len(self.communication.encrypt(m["params"]))
             self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
@@ -255,7 +254,7 @@ class Wavelet(Sharing):
             params = m["params"]
 
             params_tensor = torch.tensor(params)
-            indices_tensor = torch.tensor(indices)
+            indices_tensor = torch.tensor(indices, dtype=torch.long)
             ret = dict()
             ret["indices"] = indices_tensor
             ret["params"] = params_tensor
-- 
GitLab