Skip to content
Snippets Groups Projects
Commit 1d5c53ae authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Add growing model sharing

parent 59d47662
No related branches found
No related tags found
No related merge requests found
[DATASET]
dataset_package = decentralizepy.datasets.Celeba
dataset_class = Celeba
model_class = CNN
n_procs = 96
images_dir = /home/risharma/leaf/data/celeba/data/raw/img_align_celeba
train_dir = /home/risharma/leaf/data/celeba/per_user_data/train
test_dir = /home/risharma/leaf/data/celeba/data/test
; python list of fractions below
sizes =
[OPTIMIZER_PARAMS]
optimizer_package = torch.optim
optimizer_class = Adam
lr = 0.001
[TRAIN_PARAMS]
training_package = decentralizepy.training.GradientAccumulator
training_class = GradientAccumulator
epochs_per_round = 5
batch_size = 512
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
comm_class = TCP
addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.GrowingAlpha
sharing_class = GrowingAlpha
init_alpha=0.0
max_alpha=1.0
k=8
metadata_cap=0.6
[DATASET]
dataset_package = decentralizepy.datasets.Femnist
dataset_class = Femnist
model_class = CNN
n_procs = 16
train_dir = /home/risharma/leaf/data/femnist/per_user_data/train
test_dir = /home/risharma/leaf/data/femnist/data/test
; python list of fractions below
sizes =
[OPTIMIZER_PARAMS]
optimizer_package = torch.optim
optimizer_class = Adam
lr = 0.001
[TRAIN_PARAMS]
training_package = decentralizepy.training.GradientAccumulator
training_class = GradientAccumulator
epochs_per_round = 5
batch_size = 1024
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
comm_class = TCP
addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.GrowingAlpha
sharing_class = GrowingAlpha
init_alpha=0.0
max_alpha=1.0
k=10
metadata_cap=0.6
\ No newline at end of file
......@@ -320,6 +320,7 @@ class Celeba(Dataset):
class CNN(Model):
def __init__(self):
super().__init__()
# 2.8k parameters
self.conv1 = nn.Conv2d(CHANNELS, 32, 3, padding="same")
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 32, 3, padding="same")
......
......@@ -337,6 +337,7 @@ class LogisticRegression(Model):
class CNN(Model):
def __init__(self):
super().__init__()
# 1.6 million params
self.conv1 = nn.Conv2d(1, 32, 5, padding=2)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
......
import logging
from decentralizepy.sharing.PartialModel import PartialModel
from decentralizepy.sharing.Sharing import Sharing
class GrowingAlpha(PartialModel):
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
init_alpha=0.0,
max_alpha=1.0,
k=10,
metadata_cap=0.6,
dict_ordered=True,
save_shared=False,
):
super().__init__(
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
init_alpha,
dict_ordered,
save_shared,
)
self.init_alpha = init_alpha
self.max_alpha = max_alpha
self.k = k
self.metadata_cap = metadata_cap
self.base = None
def step(self):
if (self.communication_round + 1) % self.k == 0:
self.alpha += (self.max_alpha - self.init_alpha) / self.k
if self.alpha == 0.0:
logging.info("Not sending/receiving data (alpha=0.0)")
self.communication_round += 1
return
if self.alpha > self.metadata_cap:
if self.base == None:
self.base = Sharing(
self.rank,
self.machine_id,
self.communication,
self.mapping,
self.graph,
self.model,
self.dataset,
self.log_dir,
)
self.base.communication_round = self.communication_round
self.base.step()
else:
super().step()
......@@ -29,7 +29,6 @@ class PartialModel(Sharing):
)
self.alpha = alpha
self.dict_ordered = dict_ordered
self.communication_round = 0
self.save_shared = save_shared
# Only save for 2 procs
......@@ -106,8 +105,6 @@ class PartialModel(Sharing):
logging.info("Converted dictionary to json")
self.communication_round += 1
return m
def deserialized_model(self, m):
......
......@@ -23,6 +23,7 @@ class Sharing:
self.model = model
self.dataset = dataset
self.log_dir = log_dir
self.communication_round = 0
self.peer_deques = dict()
my_neighbors = self.graph.neighbors(self.uid)
......@@ -89,3 +90,5 @@ class Sharing:
self.model.load_state_dict(total)
logging.info("Model averaging complete")
self.communication_round += 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment