From 44f0d32d16079c02366189e60a347f70734d5386 Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Thu, 17 Feb 2022 22:58:51 +0100 Subject: [PATCH] Add Roundrobin, fix formatting --- src/decentralizepy/sharing/RoundRobinPartial.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/decentralizepy/sharing/RoundRobinPartial.py b/src/decentralizepy/sharing/RoundRobinPartial.py index beca088..6b4f517 100644 --- a/src/decentralizepy/sharing/RoundRobinPartial.py +++ b/src/decentralizepy/sharing/RoundRobinPartial.py @@ -57,8 +57,11 @@ class RoundRobinPartial(Sharing): self.alpha = alpha random.seed(self.mapping.get_uid(rank, machine_id)) n_params = self.model.count_params() + logging.info("Total number of parameters: {}".format(n_params)) self.block_size = math.ceil(self.alpha * n_params) - self.num_blocks = n_params // self.block_size + logging.info("Block_size: {}".format(self.block_size)) + self.num_blocks = math.ceil(n_params / self.block_size) + logging.info("Total number of blocks: {}".format(n_params)) self.current_block = random.randint(0, self.num_blocks - 1) def serialized_model(self): @@ -81,7 +84,7 @@ class RoundRobinPartial(Sharing): block_end = min(T.shape[0], (self.current_block + 1) * self.block_size) self.current_block = (self.current_block + 1) % self.num_blocks T_send = T[block_start:block_end] - + logging.info("Range sending: {}-{}".format(block_start, block_end)) logging.info("Generating dictionary to send") m = dict() -- GitLab