diff --git a/eval/config_celeba_grow.ini b/eval/config_celeba_grow.ini new file mode 100644 index 0000000000000000000000000000000000000000..b824b194c868c4473520b16c655b5cb55721d383 --- /dev/null +++ b/eval/config_celeba_grow.ini @@ -0,0 +1,37 @@ +[DATASET] +dataset_package = decentralizepy.datasets.Celeba +dataset_class = Celeba +model_class = CNN +n_procs = 96 +images_dir = /home/risharma/leaf/data/celeba/data/raw/img_align_celeba +train_dir = /home/risharma/leaf/data/celeba/per_user_data/train +test_dir = /home/risharma/leaf/data/celeba/data/test +; python list of fractions below +sizes = + +[OPTIMIZER_PARAMS] +optimizer_package = torch.optim +optimizer_class = Adam +lr = 0.001 + +[TRAIN_PARAMS] +training_package = decentralizepy.training.GradientAccumulator +training_class = GradientAccumulator +epochs_per_round = 5 +batch_size = 512 +shuffle = True +loss_package = torch.nn +loss_class = CrossEntropyLoss + +[COMMUNICATION] +comm_package = decentralizepy.communication.TCP +comm_class = TCP +addresses_filepath = ip_addr_6Machines.json + +[SHARING] +sharing_package = decentralizepy.sharing.GrowingAlpha +sharing_class = GrowingAlpha +init_alpha=0.0 +max_alpha=1.0 +k=8 +metadata_cap=0.6 diff --git a/eval/config_femnist_grow.ini b/eval/config_femnist_grow.ini new file mode 100644 index 0000000000000000000000000000000000000000..c15d98a3081ffc26c1e9c27a88ffc0641c193afd --- /dev/null +++ b/eval/config_femnist_grow.ini @@ -0,0 +1,36 @@ +[DATASET] +dataset_package = decentralizepy.datasets.Femnist +dataset_class = Femnist +model_class = CNN +n_procs = 16 +train_dir = /home/risharma/leaf/data/femnist/per_user_data/train +test_dir = /home/risharma/leaf/data/femnist/data/test +; python list of fractions below +sizes = + +[OPTIMIZER_PARAMS] +optimizer_package = torch.optim +optimizer_class = Adam +lr = 0.001 + +[TRAIN_PARAMS] +training_package = decentralizepy.training.GradientAccumulator +training_class = GradientAccumulator +epochs_per_round = 5 +batch_size = 1024 +shuffle = True +loss_package = torch.nn +loss_class = CrossEntropyLoss + +[COMMUNICATION] +comm_package = decentralizepy.communication.TCP +comm_class = TCP +addresses_filepath = ip_addr_6Machines.json + +[SHARING] +sharing_package = decentralizepy.sharing.GrowingAlpha +sharing_class = GrowingAlpha +init_alpha=0.0 +max_alpha=1.0 +k=10 +metadata_cap=0.6 \ No newline at end of file diff --git a/src/decentralizepy/datasets/Celeba.py b/src/decentralizepy/datasets/Celeba.py index 748f59623e47fa5ce82ceb7968eabccc2651ca03..f648a7fa6c40419b9cbfd3631fdad471bfe5baae 100644 --- a/src/decentralizepy/datasets/Celeba.py +++ b/src/decentralizepy/datasets/Celeba.py @@ -320,6 +320,7 @@ class Celeba(Dataset): class CNN(Model): def __init__(self): super().__init__() + # 2.8k parameters self.conv1 = nn.Conv2d(CHANNELS, 32, 3, padding="same") self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(32, 32, 3, padding="same") diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py index 7370a0d07ab77ce52e6db4dfa1cde8bd7255ba80..ab33880b5d7e2a1ff7e0588e2c27bd0481bf9340 100644 --- a/src/decentralizepy/datasets/Femnist.py +++ b/src/decentralizepy/datasets/Femnist.py @@ -337,6 +337,7 @@ class LogisticRegression(Model): class CNN(Model): def __init__(self): super().__init__() + # 1.6 million params self.conv1 = nn.Conv2d(1, 32, 5, padding=2) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(32, 64, 5, padding=2) diff --git a/src/decentralizepy/sharing/GrowingAlpha.py b/src/decentralizepy/sharing/GrowingAlpha.py new file mode 100644 index 0000000000000000000000000000000000000000..afe86768ededc073de60518373192513184ec03d --- /dev/null +++ b/src/decentralizepy/sharing/GrowingAlpha.py @@ -0,0 +1,68 @@ +import logging + +from decentralizepy.sharing.PartialModel import PartialModel +from decentralizepy.sharing.Sharing import Sharing + + +class GrowingAlpha(PartialModel): + def __init__( + self, + rank, + machine_id, + communication, + mapping, + graph, + model, + dataset, + log_dir, + init_alpha=0.0, + max_alpha=1.0, + k=10, + metadata_cap=0.6, + dict_ordered=True, + save_shared=False, + ): + super().__init__( + rank, + machine_id, + communication, + mapping, + graph, + model, + dataset, + log_dir, + init_alpha, + dict_ordered, + save_shared, + ) + self.init_alpha = init_alpha + self.max_alpha = max_alpha + self.k = k + self.metadata_cap = metadata_cap + self.base = None + + def step(self): + if (self.communication_round + 1) % self.k == 0: + self.alpha += (self.max_alpha - self.init_alpha) / self.k + + if self.alpha == 0.0: + logging.info("Not sending/receiving data (alpha=0.0)") + self.communication_round += 1 + return + + if self.alpha > self.metadata_cap: + if self.base == None: + self.base = Sharing( + self.rank, + self.machine_id, + self.communication, + self.mapping, + self.graph, + self.model, + self.dataset, + self.log_dir, + ) + self.base.communication_round = self.communication_round + self.base.step() + else: + super().step() diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index bd01be3e66e0925e86ce7bb2474f8b6e4b0843fe..19fdfd12d58baaf7ea064d308c1db9ac0bc9e22b 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -29,7 +29,6 @@ class PartialModel(Sharing): ) self.alpha = alpha self.dict_ordered = dict_ordered - self.communication_round = 0 self.save_shared = save_shared # Only save for 2 procs @@ -106,8 +105,6 @@ class PartialModel(Sharing): logging.info("Converted dictionary to json") - self.communication_round += 1 - return m def deserialized_model(self, m): diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index 03e80fd52a8e4b14c381dd3b463a926e94741041..91d25d06badb20500e1dd8c797f4c1d84cca1b8b 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -23,6 +23,7 @@ class Sharing: self.model = model self.dataset = dataset self.log_dir = log_dir + self.communication_round = 0 self.peer_deques = dict() my_neighbors = self.graph.neighbors(self.uid) @@ -89,3 +90,5 @@ class Sharing: self.model.load_state_dict(total) logging.info("Model averaging complete") + + self.communication_round += 1