Skip to content
Snippets Groups Projects
Commit b0cfe04e authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Improve Node

parent 922fdb47
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,7 @@ Setting up decentralizepy
pip3 install --upgrade pip
pip install --upgrade pip
* Install decentralizepy for development/ ::
* Install decentralizepy for development. ::
pip3 install --editable .\[dev\]
......
This diff is collapsed.
......@@ -15,6 +15,7 @@ class TCP(Communication):
TCP Communication API
"""
def addr(self, rank, machine_id):
"""
Returns TCP address of the process.
......
......@@ -138,10 +138,13 @@ class Celeba(Dataset):
os.path.join(self.train_dir, cur_file)
)
for cur_client in clients:
logging.debug("Got data of client: {}".format(cur_client))
self.clients.append(cur_client)
my_train_data["x"].extend(self.process_x(train_data[cur_client]["x"]))
my_train_data["y"].extend(train_data[cur_client]["y"])
self.num_samples.append(len(train_data[cur_client]["y"]))
logging.debug("Initial shape of x: {}".format(np.array(my_train_data["x"], dtype=np.dtype("float32")).shape))
self.train_x = (
np.array(my_train_data["x"], dtype=np.dtype("float32"))
.reshape(-1, IMAGE_DIM, IMAGE_DIM, CHANNELS)
......@@ -409,6 +412,7 @@ class CNN(Model):
Class for a CNN Model for Celeba
"""
def __init__(self):
"""
Constructor. Instantiates the CNN Model
......
......@@ -413,6 +413,7 @@ class CNN(Model):
Class for a CNN Model for FEMNIST
"""
def __init__(self):
"""
Constructor. Instantiates the CNN Model
......
......@@ -7,6 +7,7 @@ class Model(nn.Module):
More fields can be added here
"""
def __init__(self):
"""
Constructor
......
......@@ -42,42 +42,20 @@ class Node:
plt.title(title)
plt.savefig(filename)
def instantiate(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
log_level=logging.INFO,
test_after=5,
*args
):
def init_log(self, log_dir, rank, log_level, force=True):
"""
Construct objects.
Instantiate Logging.
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
n_procs_local : int
Number of processes on current machine
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations.
log_dir : str
Logging directory
rank : rank : int
Rank of process local to the machine
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
args : optional
Other arguments
force : bool
Argument to logging.basicConfig()
"""
log_file = os.path.join(log_dir, str(rank) + ".log")
......@@ -88,20 +66,49 @@ class Node:
force=True,
)
logging.info("Started process.")
def cache_fields(
self, rank, machine_id, mapping, graph, iterations, log_dir, test_after
):
"""
Instantiate object field with arguments.
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
log_dir : str
Logging directory
"""
self.rank = rank
self.machine_id = machine_id
self.graph = graph
self.mapping = mapping
self.uid = self.mapping.get_uid(rank, machine_id)
self.log_dir = log_dir
self.iterations = iterations
self.test_after = test_after
logging.debug("Rank: %d", self.rank)
logging.debug("type(graph): %s", str(type(self.rank)))
logging.debug("type(mapping): %s", str(type(self.mapping)))
dataset_configs = config["DATASET"]
def init_dataset_model(self, dataset_configs):
"""
Instantiate dataset and model from config.
Parameters
----------
dataset_configs : dict
Python dict containing dataset config params
"""
dataset_module = importlib.import_module(dataset_configs["dataset_package"])
self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
self.dataset_params = utils.remove_keys(
......@@ -116,7 +123,16 @@ class Node:
self.model_class = getattr(dataset_module, dataset_configs["model_class"])
self.model = self.model_class()
optimizer_configs = config["OPTIMIZER_PARAMS"]
def init_optimizer(self, optimizer_configs):
"""
Instantiate optimizer from config.
Parameters
----------
optimizer_configs : dict
Python dict containing optimizer config params
"""
optimizer_module = importlib.import_module(
optimizer_configs["optimizer_package"]
)
......@@ -130,7 +146,16 @@ class Node:
self.model.parameters(), **self.optimizer_params
)
train_configs = config["TRAIN_PARAMS"]
def init_trainer(self, train_configs):
"""
Instantiate training module and loss from config.
Parameters
----------
train_configs : dict
Python dict containing training config params
"""
train_module = importlib.import_module(train_configs["training_package"])
train_class = getattr(train_module, train_configs["training_class"])
......@@ -155,7 +180,16 @@ class Node:
self.model, self.optimizer, self.loss, **train_params
)
comm_configs = config["COMMUNICATION"]
def init_comm(self, comm_configs):
"""
Instantiate communication module from config.
Parameters
----------
comm_configs : dict
Python dict containing communication config params
"""
comm_module = importlib.import_module(comm_configs["comm_package"])
comm_class = getattr(comm_module, comm_configs["comm_class"])
comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
......@@ -163,7 +197,16 @@ class Node:
self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
)
sharing_configs = config["SHARING"]
def init_sharing(self, sharing_configs):
"""
Instantiate sharing module from config.
Parameters
----------
sharing_configs : dict
Python dict containing sharing config params
"""
sharing_package = importlib.import_module(sharing_configs["sharing_package"])
sharing_class = getattr(sharing_package, sharing_configs["sharing_class"])
sharing_params = utils.remove_keys(
......@@ -181,9 +224,53 @@ class Node:
**sharing_params
)
self.iterations = iterations
self.test_after = test_after
self.log_dir = log_dir
def instantiate(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
log_level=logging.INFO,
test_after=5,
*args
):
"""
Construct objects.
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations.
log_dir : str
Logging directory
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
args : optional
Other arguments
"""
logging.info("Started process.")
self.cache_fields(
rank, machine_id, mapping, graph, iterations, log_dir, test_after
)
self.init_log(log_dir, rank, log_level)
self.init_dataset_model(config["DATASET"])
self.init_optimizer(config["OPTIMIZER_PARAMS"])
self.init_trainer(config["TRAIN_PARAMS"])
self.init_comm(config["COMMUNICATION"])
self.init_sharing(config["SHARING"])
def run(self):
"""
......
......@@ -8,6 +8,7 @@ class GrowingAlpha(PartialModel):
This class implements the basic growing partial model sharing using a linear function.
"""
def __init__(
self,
rank,
......
......@@ -13,6 +13,7 @@ class PartialModel(Sharing):
This class implements the vanilla version of partial model sharing.
"""
def __init__(
self,
rank,
......
......@@ -8,6 +8,7 @@ class GradientAccumulator(Training):
This class implements the training module which also accumulates gradients of steps in a list.
"""
def __init__(
self,
model,
......
......@@ -109,7 +109,7 @@ class Training:
self.optimizer.step()
return loss_val.item()
def train_full(self, trainset):
def train_full(self, dataset):
"""
One training iteration, goes through the entire dataset
......@@ -120,9 +120,12 @@ class Training:
"""
for epoch in range(self.rounds):
trainset = dataset.get_trainset(self.batch_size, self.shuffle)
epoch_loss = 0.0
count = 0
for data, target in trainset:
logging.info("Starting minibatch {} with num_samples: {}".format(count, len(data)))
logging.info("Classes: {}".format(target))
epoch_loss += self.trainstep(data, target)
count += 1
logging.info("Epoch: {} loss: {}".format(epoch, epoch_loss / count))
......@@ -137,13 +140,13 @@ class Training:
The training dataset. Should implement get_trainset(batch_size, shuffle)
"""
trainset = dataset.get_trainset(self.batch_size, self.shuffle)
if self.full_epochs:
self.train_full(trainset)
self.train_full(dataset)
else:
iter_loss = 0.0
count = 0
trainset = dataset.get_trainset(self.batch_size, self.shuffle)
while count < self.rounds:
for data, target in trainset:
iter_loss += self.trainstep(data, target)
......
......@@ -16,7 +16,7 @@ def conditional_value(var, nul, default):
The null value. Assigns default if var == nul
default : any
The default value
Returns
-------
type(var)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment