From 0fa9ba103b62257693b6f0b8c8f61bc5f283f27f Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Wed, 9 Mar 2022 15:57:16 +0100 Subject: [PATCH] encoding indices as np.int32 --- src/decentralizepy/sharing/FFT.py | 18 +++--------------- src/decentralizepy/sharing/SubSampling.py | 5 ++--- src/decentralizepy/sharing/TopK.py | 5 +++-- src/decentralizepy/sharing/TopKParams.py | 5 +++-- src/decentralizepy/sharing/Wavelet.py | 7 +++---- 5 files changed, 14 insertions(+), 26 deletions(-) diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py index 4a3ee36..1cc8382 100644 --- a/src/decentralizepy/sharing/FFT.py +++ b/src/decentralizepy/sharing/FFT.py @@ -1,13 +1,12 @@ -import base64 import json import logging import os -import pickle from pathlib import Path from time import time import torch import torch.fft as fft +import numpy as np from decentralizepy.sharing.Sharing import Sharing @@ -182,7 +181,7 @@ class FFT(Sharing): m["alpha"] = self.alpha m["params"] = topk.numpy() - m["indices"] = indices.numpy() + m["indices"] = indices.numpy().astype(np.int32) self.total_data += len(self.communication.encrypt(m["params"])) self.total_meta += len(self.communication.encrypt(m["indices"])) + len( @@ -215,23 +214,12 @@ class FFT(Sharing): if not self.dict_ordered: raise NotImplementedError - shapes = [] - lens = [] - tensors_to_cat = [] - for _, v in state_dict.items(): - shapes.append(v.shape) - t = v.flatten() - lens.append(t.shape[0]) - tensors_to_cat.append(t) - - T = torch.cat(tensors_to_cat, dim=0) - indices = m["indices"] alpha = m["alpha"] params = m["params"] params_tensor = torch.tensor(params) - indices_tensor = torch.tensor(indices) + indices_tensor = torch.tensor(indices, dtype=torch.long) ret = dict() ret["indices"] = indices_tensor ret["params"] = params_tensor diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py index 6fe3f93..1e956cd 100644 --- a/src/decentralizepy/sharing/SubSampling.py +++ b/src/decentralizepy/sharing/SubSampling.py @@ -1,11 +1,10 @@ -import base64 import json import logging import os -import pickle from pathlib import Path import torch +import numpy as np from decentralizepy.sharing.Sharing import Sharing @@ -203,7 +202,7 @@ class SubSampling(Sharing): m["seed"] = seed m["alpha"] = alpha - m["params"] = subsample.numpy() + m["params"] = subsample.numpy().astype(np.int32) # logging.info("Converted dictionary to json") self.total_data += len(self.communication.encrypt(m["params"])) diff --git a/src/decentralizepy/sharing/TopK.py b/src/decentralizepy/sharing/TopK.py index 47b4151..f50ba7e 100644 --- a/src/decentralizepy/sharing/TopK.py +++ b/src/decentralizepy/sharing/TopK.py @@ -3,6 +3,7 @@ import logging import os from pathlib import Path +import numpy as np import torch from decentralizepy.sharing.Sharing import Sharing @@ -166,7 +167,7 @@ class TopK(Sharing): if not self.dict_ordered: raise NotImplementedError - m["indices"] = G_topk.numpy() + m["indices"] = G_topk.numpy().astype(np.int32) m["params"] = T_topk.numpy() assert len(m["indices"]) == len(m["params"]) @@ -214,7 +215,7 @@ class TopK(Sharing): tensors_to_cat.append(t) T = torch.cat(tensors_to_cat, dim=0) - index_tensor = torch.tensor(m["indices"]) + index_tensor = torch.tensor(m["indices"], dtype=torch.long) logging.debug("Original tensor: {}".format(T[index_tensor])) T[index_tensor] = torch.tensor(m["params"]) logging.debug("Final tensor: {}".format(T[index_tensor])) diff --git a/src/decentralizepy/sharing/TopKParams.py b/src/decentralizepy/sharing/TopKParams.py index 3beb10f..c6535ce 100644 --- a/src/decentralizepy/sharing/TopKParams.py +++ b/src/decentralizepy/sharing/TopKParams.py @@ -3,6 +3,7 @@ import logging import os from pathlib import Path +import numpy as np import torch from decentralizepy.sharing.Sharing import Sharing @@ -157,7 +158,7 @@ class TopKParams(Sharing): if not self.dict_ordered: raise NotImplementedError - m["indices"] = index.numpy() + m["indices"] = index.numpy().astype(np.int32) m["params"] = values.numpy() m["offsets"] = offsets @@ -206,7 +207,7 @@ class TopKParams(Sharing): tensors_to_cat = [] offsets = m["offsets"] params = torch.tensor(m["params"]) - indices = torch.tensor(m["indices"]) + indices = torch.tensor(m["indices"], dtype=torch.long) for i, (_, v) in enumerate(state_dict.items()): shapes.append(v.shape) diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py index a6cccaf..774dfe0 100644 --- a/src/decentralizepy/sharing/Wavelet.py +++ b/src/decentralizepy/sharing/Wavelet.py @@ -1,11 +1,10 @@ -import base64 import json import logging import os -import pickle from pathlib import Path from time import time +import numpy as np import pywt import torch @@ -206,7 +205,7 @@ class Wavelet(Sharing): m["params"] = topk.numpy() - m["indices"] = indices.numpy() + m["indices"] = indices.numpy().astype(np.int32) self.total_data += len(self.communication.encrypt(m["params"])) self.total_meta += len(self.communication.encrypt(m["indices"])) + len( @@ -255,7 +254,7 @@ class Wavelet(Sharing): params = m["params"] params_tensor = torch.tensor(params) - indices_tensor = torch.tensor(indices) + indices_tensor = torch.tensor(indices, dtype=torch.long) ret = dict() ret["indices"] = indices_tensor ret["params"] = params_tensor -- GitLab