From e36ca70d26b054d9305711a0bafe58587bfd5df0 Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Wed, 23 Feb 2022 21:39:23 +0100 Subject: [PATCH] Add ManualAdapt --- src/decentralizepy/sharing/ManualAdapt.py | 96 +++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 src/decentralizepy/sharing/ManualAdapt.py diff --git a/src/decentralizepy/sharing/ManualAdapt.py b/src/decentralizepy/sharing/ManualAdapt.py new file mode 100644 index 0000000..8c653a6 --- /dev/null +++ b/src/decentralizepy/sharing/ManualAdapt.py @@ -0,0 +1,96 @@ +import logging + +from decentralizepy.sharing.PartialModel import PartialModel + + +class ManualAdapt(PartialModel): + """ + This class implements the basic growing partial model sharing provided when and what alpha to set. + + """ + + def __init__( + self, + rank, + machine_id, + communication, + mapping, + graph, + model, + dataset, + log_dir, + change_alpha: list, + change_rounds: list, + dict_ordered=True, + save_shared=False, + metadata_cap=1.0, + ): + """ + Constructor + + Parameters + ---------- + rank : int + Local rank + machine_id : int + Global machine id + communication : decentralizepy.communication.Communication + Communication module used to send and receive messages + mapping : decentralizepy.mappings.Mapping + Mapping (rank, machine_id) -> uid + graph : decentralizepy.graphs.Graph + Graph reprensenting neighbors + model : decentralizepy.models.Model + Model to train + dataset : decentralizepy.datasets.Dataset + Dataset for sharing data. Not implemented yet! TODO + log_dir : str + Location to write shared_params (only writing for 2 procs per machine) + change_alpha : list + List of alphas to set. change_alpha[0] must be initial alpha. + change_rounds : list + List of iterations to change alpha. len(change_alpha) = len(change_rounds) + 1. + dict_ordered : bool + Specifies if the python dict maintains the order of insertion + save_shared : bool + Specifies if the indices of shared parameters should be logged + metadata_cap : float + Share full model when self.alpha > metadata_cap + + """ + super().__init__( + rank, + machine_id, + communication, + mapping, + graph, + model, + dataset, + log_dir, + change_alpha[0], + dict_ordered, + save_shared, + metadata_cap, + ) + self.change_alpha = change_alpha[1:] + self.change_rounds = change_rounds + + def step(self): + """ + Perform a sharing step. Implements D-PSGD with alpha manually given. + + """ + if ( + len(self.change_rounds) + and (self.communication_round + 1) == self.change_rounds[0] + ): + self.alpha = min(self.change_alpha[0], 1.00) + self.change_alpha = self.change_alpha[1:] + self.change_rounds = self.change_rounds[1:] + + if self.alpha == 0.0: + logging.info("Not sending/receiving data (alpha=0.0)") + self.communication_round += 1 + return + + super().step() -- GitLab