Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sacs/decentralizepy
  • mvujas/decentralizepy
  • randl/decentralizepy
3 results
Show changes
......@@ -26,14 +26,19 @@ class Node:
self.communication.init_connection(neighbor)
self.communication.send(neighbor, {"HELLO": self.uid, "CHANNEL": "CONNECT"})
def receive_channel(self, channel):
def receive_channel(self, channel, block=True):
if channel not in self.message_queue:
self.message_queue[channel] = deque()
if len(self.message_queue[channel]) > 0:
return self.message_queue[channel].popleft()
else:
sender, recv = self.communication.receive()
x = self.communication.receive(block=block)
if x == None:
assert not block
return None
sender, recv = x
logging.info(
"Received some message from {} with CHANNEL: {}".format(
sender, recv["CHANNEL"]
......@@ -44,7 +49,11 @@ class Node:
if recv["CHANNEL"] not in self.message_queue:
self.message_queue[recv["CHANNEL"]] = deque()
self.message_queue[recv["CHANNEL"]].append((sender, recv))
sender, recv = self.communication.receive()
x = self.communication.receive(block=block)
if x == None:
assert not block
return None
sender, recv = x
logging.info(
"Received some message from {} with CHANNEL: {}".format(
sender, recv["CHANNEL"]
......
......@@ -9,17 +9,14 @@ import torch
from matplotlib import pyplot as plt
from decentralizepy import utils
from decentralizepy.communication.TCP import TCP
from decentralizepy.graphs.Graph import Graph
from decentralizepy.graphs.Star import Star
from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.node.Node import Node
from decentralizepy.train_test_evaluation import TrainTestHelper
class DPSGDNodeWithParameterServer(Node):
class OverlayNode(Node):
"""
This class defines the node for DPSGD
This class defines the node on overlay graph
"""
......@@ -49,6 +46,15 @@ class DPSGDNodeWithParameterServer(Node):
plt.title(title)
plt.savefig(filename)
def get_neighbors(self, node=None):
return self.my_neighbors
def receive_DPSGD(self):
return self.receive_channel("DPSGD")
def build_topology(self):
pass
def run(self):
"""
Start the decentralized learning
......@@ -59,51 +65,57 @@ class DPSGDNodeWithParameterServer(Node):
rounds_to_train_evaluate = self.train_evaluate_after
global_epoch = 1
change = 1
if self.uid == 0:
dataset = self.dataset
if self.centralized_train_eval:
dataset_params_copy = self.dataset_params.copy()
if "sizes" in dataset_params_copy:
del dataset_params_copy["sizes"]
self.whole_dataset = self.dataset_class(
self.rank,
self.machine_id,
self.mapping,
sizes=[1.0],
**dataset_params_copy
)
dataset = self.whole_dataset
if self.centralized_test_eval:
tthelper = TrainTestHelper(
dataset, # self.whole_dataset,
# self.model_test, # todo: this only works if eval_train is set to false
self.model,
self.loss,
self.weights_store_dir,
self.mapping.get_n_procs(),
self.trainer,
self.testing_comm,
self.star,
self.threads_per_proc,
eval_train=self.centralized_train_eval,
)
self.connect_neighbors()
logging.info("Connected to all neighbors")
self.build_topology()
logging.info("OutNodes: {}".format(self.out_edges))
logging.info("InNodes: {}".format(self.in_edges))
logging.info("Unifying edges")
self.out_edges = self.out_edges.union(self.in_edges)
self.my_neighbors = self.in_edges = set(self.out_edges)
logging.info("Total number of neighbor: {}".format(len(self.my_neighbors)))
for iteration in range(self.iterations):
logging.info("Starting training iteration: %d", iteration)
rounds_to_train_evaluate -= 1
rounds_to_test -= 1
self.iteration = iteration
self.trainer.train(self.dataset)
to_send = self.sharing.get_data_to_send()
to_send["CHANNEL"] = "DPSGD"
to_send["degree"] = len(self.in_edges)
assert len(self.out_edges) != 0
assert len(self.in_edges) != 0
self.communication.send(self.parameter_server_uid, to_send)
for neighbor in self.out_edges:
self.communication.send(neighbor, to_send)
sender, data = self.receive_channel("GRADS")
del data["CHANNEL"]
while not self.received_from_all():
sender, data = self.receive_DPSGD()
logging.info(
"Received Model from {} of iteration {}".format(
sender, data["iteration"]
)
)
if sender not in self.peer_deques:
self.peer_deques[sender] = deque()
self.peer_deques[sender].append(data)
averaging_deque = dict()
for neighbor in self.in_edges:
averaging_deque[neighbor] = self.peer_deques[neighbor]
self.model.load_state_dict(data)
self.sharing._post_step()
self.sharing.communication_round += 1
self.sharing._averaging(averaging_deque)
if self.reset_optimizer:
self.optimizer = self.optimizer_class(
......@@ -113,8 +125,7 @@ class DPSGDNodeWithParameterServer(Node):
if iteration:
with open(
os.path.join(
self.log_dir, "{}_results.json".format(self.rank)),
os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
"r",
) as inf:
results_dict = json.load(inf)
......@@ -126,12 +137,9 @@ class DPSGDNodeWithParameterServer(Node):
"total_bytes": {},
"total_meta": {},
"total_data_per_n": {},
"grad_mean": {},
"grad_std": {},
}
results_dict["total_bytes"][iteration
+ 1] = self.communication.total_bytes
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
if hasattr(self.communication, "total_meta"):
results_dict["total_meta"][
......@@ -141,14 +149,8 @@ class DPSGDNodeWithParameterServer(Node):
results_dict["total_data_per_n"][
iteration + 1
] = self.communication.total_data
if hasattr(self.sharing, "mean"):
results_dict["grad_mean"][iteration + 1] = self.sharing.mean
if hasattr(self.sharing, "std"):
results_dict["grad_std"][iteration + 1] = self.sharing.std
rounds_to_train_evaluate -= 1
if rounds_to_train_evaluate == 0 and not self.centralized_train_eval:
if rounds_to_train_evaluate == 0:
logging.info("Evaluating on train set.")
rounds_to_train_evaluate = self.train_evaluate_after * change
loss_after_sharing = self.trainer.eval_loss(self.dataset)
......@@ -158,30 +160,15 @@ class DPSGDNodeWithParameterServer(Node):
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(
self.log_dir, "{}_train_loss.png".format(self.rank)),
os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
)
rounds_to_test -= 1
if self.dataset.__testing__ and rounds_to_test == 0:
rounds_to_test = self.test_after * change
if self.centralized_test_eval:
if self.uid == 0:
ta, tl, trl = tthelper.train_test_evaluation(iteration)
results_dict["test_acc"][iteration + 1] = ta
results_dict["test_loss"][iteration + 1] = tl
if trl is not None:
results_dict["train_loss"][iteration + 1] = trl
else:
self.testing_comm.send(0, self.model.get_weights())
sender, data = self.testing_comm.receive()
assert sender == 0 and data == "finished"
else:
logging.info("Evaluating on test set.")
ta, tl = self.dataset.test(self.model, self.loss)
results_dict["test_acc"][iteration + 1] = ta
results_dict["test_loss"][iteration + 1] = tl
logging.info("Evaluating on test set.")
ta, tl = self.dataset.test(self.model, self.loss)
results_dict["test_acc"][iteration + 1] = ta
results_dict["test_loss"][iteration + 1] = tl
if global_epoch == 49:
change *= 2
......@@ -189,24 +176,22 @@ class DPSGDNodeWithParameterServer(Node):
global_epoch += change
with open(
os.path.join(
self.log_dir, "{}_results.json".format(self.rank)), "w"
os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
) as of:
json.dump(results_dict, of)
if self.model.shared_parameters_counter is not None:
logging.info("Saving the shared parameter counts")
with open(
os.path.join(
self.log_dir, "{}_shared_parameters.json".format(self.rank)
),
"w",
) as of:
json.dump(
self.model.shared_parameters_counter.numpy().tolist(), of)
self.disconnect_parameter_server()
logging.info("Storing final weight")
self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
logging.info("Server disconnected. Process complete!")
# if self.model.shared_parameters_counter is not None:
# logging.info("Saving the shared parameter counts")
# with open(
# os.path.join(
# self.log_dir, "{}_shared_parameters.json".format(self.rank)
# ),
# "w",
# ) as of:
# json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
self.disconnect_neighbors()
# logging.info("Storing final weight")
# self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
logging.info("All neighbors disconnected. Process complete!")
def cache_fields(
self,
......@@ -220,8 +205,6 @@ class DPSGDNodeWithParameterServer(Node):
test_after,
train_evaluate_after,
reset_optimizer,
centralized_train_eval,
centralized_test_eval,
):
"""
Instantiate object field with arguments.
......@@ -248,10 +231,6 @@ class DPSGDNodeWithParameterServer(Node):
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
centralized_train_eval : bool
If set the train set evaluation happens at the node with uid 0
centralized_test_eval : bool
If set the train set evaluation happens at the node with uid 0
"""
self.rank = rank
self.machine_id = machine_id
......@@ -264,17 +243,12 @@ class DPSGDNodeWithParameterServer(Node):
self.test_after = test_after
self.train_evaluate_after = train_evaluate_after
self.reset_optimizer = reset_optimizer
self.centralized_train_eval = centralized_train_eval
self.centralized_test_eval = centralized_test_eval
self.sent_disconnections = False
logging.info("Rank: %d", self.rank)
logging.info("type(graph): %s", str(type(self.rank)))
logging.info("type(mapping): %s", str(type(self.mapping)))
if centralized_test_eval or centralized_train_eval:
self.star = Star(self.mapping.get_n_procs())
def init_comm(self, comm_configs):
"""
Instantiate communication module from config.
......@@ -287,20 +261,8 @@ class DPSGDNodeWithParameterServer(Node):
"""
comm_module = importlib.import_module(comm_configs["comm_package"])
comm_class = getattr(comm_module, comm_configs["comm_class"])
comm_params = utils.remove_keys(
comm_configs, ["comm_package", "comm_class"])
comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
self.addresses_filepath = comm_params.get("addresses_filepath", None)
if self.centralized_test_eval:
self.testing_comm = TCP(
self.rank,
self.machine_id,
self.mapping,
self.star.n_procs,
self.addresses_filepath,
offset=self.star.n_procs,
)
self.testing_comm.connect_neighbors(self.star.neighbors(self.uid))
self.communication = comm_class(
self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
)
......@@ -319,8 +281,6 @@ class DPSGDNodeWithParameterServer(Node):
test_after=5,
train_evaluate_after=1,
reset_optimizer=1,
centralized_train_eval=False,
centralized_test_eval=True,
*args
):
"""
......@@ -352,10 +312,6 @@ class DPSGDNodeWithParameterServer(Node):
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
centralized_train_eval : bool
If set the train set evaluation happens at the node with uid 0
centralized_test_eval : bool
If set the train set evaluation happens at the node with uid 0
args : optional
Other arguments
......@@ -375,8 +331,6 @@ class DPSGDNodeWithParameterServer(Node):
test_after,
train_evaluate_after,
reset_optimizer,
centralized_train_eval,
centralized_test_eval,
)
self.init_dataset_model(config["DATASET"])
self.init_optimizer(config["OPTIMIZER_PARAMS"])
......@@ -389,6 +343,23 @@ class DPSGDNodeWithParameterServer(Node):
self.my_neighbors = self.graph.neighbors(self.uid)
self.init_sharing(config["SHARING"])
self.peer_deques = dict()
self.connect_neighbors()
def received_from_all(self):
"""
Check if all neighbors have sent the current iteration
Returns
-------
bool
True if required data has been received, False otherwise
"""
for k in self.in_edges:
if (k not in self.peer_deques) or len(self.peer_deques[k]) == 0:
return False
return True
def __init__(
self,
......@@ -404,9 +375,6 @@ class DPSGDNodeWithParameterServer(Node):
test_after=5,
train_evaluate_after=1,
reset_optimizer=1,
centralized_train_eval=0,
centralized_test_eval=1,
parameter_server_uid=-1,
*args
):
"""
......@@ -450,21 +418,10 @@ class DPSGDNodeWithParameterServer(Node):
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
centralized_train_eval : int
If set then the train set evaluation happens at the node with uid 0.
Note: If it is True then centralized_test_eval needs to be true as well!
centralized_test_eval : int
If set then the trainset evaluation happens at the node with uid 0
parameter_server_uid: int
The parameter server's uid
args : optional
Other arguments
"""
centralized_train_eval = centralized_train_eval == 1
centralized_test_eval = centralized_test_eval == 1
# If centralized_train_eval is True then centralized_test_eval needs to be true as well!
assert not centralized_train_eval or centralized_test_eval
total_threads = os.cpu_count()
self.threads_per_proc = max(
......@@ -485,42 +442,13 @@ class DPSGDNodeWithParameterServer(Node):
test_after,
train_evaluate_after,
reset_optimizer,
centralized_train_eval == 1,
centralized_test_eval == 1,
*args
)
self.in_edges = set()
self.out_edges = set()
logging.info(
"Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
)
self.message_queue["PEERS"] = deque()
self.parameter_server_uid = parameter_server_uid
self.connect_neighbor(self.parameter_server_uid)
self.wait_for_hello(self.parameter_server_uid)
self.run()
def disconnect_parameter_server(self):
"""
Disconnects from the parameter server. Sends BYE.
Raises
------
RuntimeError
If received another message while waiting for BYEs
"""
if not self.sent_disconnections:
logging.info("Disconnecting parameter server.")
self.communication.send(
self.parameter_server_uid, {
"BYE": self.uid, "CHANNEL": "SERVER_REQUEST"}
)
self.sent_disconnections = True
self.barrier.remove(self.parameter_server_uid)
while len(self.barrier):
sender, _ = self.receive_disconnect()
self.barrier.remove(sender)
import importlib
import logging
import os
from collections import deque
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.node.Node import Node
class ParameterServer(Node):
"""
This class defines the parameter serving service
"""
def init_log(self, log_dir, log_level, force=True):
"""
Instantiate Logging.
Parameters
----------
log_dir : str
Logging directory
rank : rank : int
Rank of process local to the machine
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
force : bool
Argument to logging.basicConfig()
"""
log_file = os.path.join(log_dir, "ParameterServer.log")
logging.basicConfig(
filename=log_file,
format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s",
level=log_level,
force=force,
)
def cache_fields(
self,
rank,
machine_id,
mapping,
graph,
iterations,
log_dir,
):
"""
Instantiate object field with arguments.
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
iterations : int
Number of iterations (communication steps) for which the model should be trained
log_dir : str
Logging directory
"""
self.rank = rank
self.machine_id = machine_id
self.graph = graph
self.mapping = mapping
self.uid = self.mapping.get_uid(rank, machine_id)
self.log_dir = log_dir
self.iterations = iterations
self.sent_disconnections = False
logging.info("Rank: %d", self.rank)
logging.info("type(graph): %s", str(type(self.rank)))
logging.info("type(mapping): %s", str(type(self.mapping)))
def init_comm(self, comm_configs):
"""
Instantiate communication module from config.
Parameters
----------
comm_configs : dict
Python dict containing communication config params
"""
comm_module = importlib.import_module(comm_configs["comm_package"])
comm_class = getattr(comm_module, comm_configs["comm_class"])
comm_params = utils.remove_keys(
comm_configs, ["comm_package", "comm_class"])
self.addresses_filepath = comm_params.get("addresses_filepath", None)
self.communication = comm_class(
self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
)
def instantiate(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
log_level=logging.INFO,
*args
):
"""
Construct objects.
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations.
iterations : int
Number of iterations (communication steps) for which the model should be trained
log_dir : str
Logging directory
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
args : optional
Other arguments
"""
logging.info("Started process.")
self.init_log(log_dir, log_level)
self.cache_fields(
rank,
machine_id,
mapping,
graph,
iterations,
log_dir,
)
self.message_queue = dict()
self.barrier = set()
self.peer_deques = dict()
self.init_dataset_model(config["DATASET"])
self.init_comm(config["COMMUNICATION"])
self.my_neighbors = self.graph.get_all_nodes()
self.connect_neighbors()
self.init_sharing(config["SHARING"])
def receive_server_request(self):
return self.receive_channel("SERVER_REQUEST")
def received_from_all(self):
"""
Check if all neighbors have sent the current iteration
Returns
-------
bool
True if required data has been received, False otherwise
"""
for k in self.my_neighbors:
if (k not in self.peer_deques) or len(self.peer_deques[k]) == 0:
return False
return True
def run(self):
"""
Start the parameter-serving service.
"""
for iteration in range(self.iterations):
self.iteration = iteration
# reset deques after each iteration
self.peer_deques = dict()
while not self.received_from_all():
sender, data = self.receive_channel("DPSGD")
if sender not in self.peer_deques:
self.peer_deques[sender] = deque()
self.peer_deques[sender].append(data)
logging.info("Received from everybody")
averaging_deque = dict()
total = dict()
for neighbor in self.my_neighbors:
averaging_deque[neighbor] = self.peer_deques[neighbor]
for i, n in enumerate(averaging_deque):
data = averaging_deque[n].popleft()
degree, iteration = data["degree"], data["iteration"]
del data["degree"]
del data["iteration"]
del data["CHANNEL"]
data = self.sharing.deserialized_model(data)
for key, value in data.items():
if key in total:
total[key] += value
else:
total[key] = value
for key, value in total.items():
total[key] = total[key] / len(averaging_deque)
to_send = total
to_send["CHANNEL"] = "GRADS"
for neighbor in self.my_neighbors:
self.communication.send(neighbor, to_send)
while len(self.barrier):
sender, data = self.receive_server_request()
if "BYE" in data:
logging.debug("Received {} from {}".format("BYE", sender))
self.barrier.remove(sender)
def __init__(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
log_level=logging.INFO,
*args
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations. Must contain the following:
[DATASET]
dataset_package
dataset_class
model_class
[OPTIMIZER_PARAMS]
optimizer_package
optimizer_class
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
epochs_per_round = 25
batch_size = 64
iterations : int
Number of iterations (communication steps) for which the model should be trained
log_dir : str
Logging directory
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
args : optional
Other arguments
"""
super().__init__(
rank,
machine_id,
mapping,
graph,
config,
iterations,
log_dir,
log_level,
*args
)
self.instantiate(
rank,
machine_id,
mapping,
graph,
config,
iterations,
log_dir,
log_level,
*args
)
self.run()
logging.info("Parameter Server exiting")
import logging
from collections import OrderedDict
import torch
from decentralizepy.sharing.Sharing import Sharing
def zeros_like_state_dict(state_dict):
"""
Creates a new state dictionary such that it has same
layers (name and size) as the input state dictionary, but all values
are zero
Parameters
----------
state_dict: dict[str, torch.Tensor]
"""
result_dict = OrderedDict()
for tensor_name, tensor_values in state_dict.items():
result_dict[tensor_name] = torch.zeros_like(tensor_values)
return result_dict
def get_dict_keys_and_check_matching(dict_1, dict_2):
"""
Checks if keys of the two dictionaries match and
reutrns them if they do, otherwise raises ValueError
Parameters
----------
dict_1: dict
dict_2: dict
Raises
------
ValueError
If the keys of the dictionaries don't match
"""
keys = dict_1.keys()
if set(keys).difference(set(dict_2.keys())):
raise ValueError("Dictionaries must have matching keys")
return keys
def subtract_state_dicts(_1, _2):
"""
Subtracts one state dictionary from another
Parameters
----------
_1: dict[str, torch.Tensor]
Minuend state dictionary
_2: dict[str, torch.Tensor]
Subtrahend state dictionary
Raises
------
ValueError
If the keys of the state dictionaries don't match
"""
keys = get_dict_keys_and_check_matching(_1, _2)
result_dict = OrderedDict()
for key in keys:
# Size checking is done by torch during the subtraction
result_dict[key] = _1[key] - _2[key]
return result_dict
def self_add_state_dict(_1, _2, constant=1.0):
"""
Scales one state dictionary by a constant and
adds it directly to another minimizing copies
created. Equivalent to operation `_1 += constant * _2`
Parameters
----------
_1: dict[str, torch.Tensor]
State dictionary
_2: dict[str, torch.Tensor]
State dictionary
constant: float
Constant to scale _2 with
Raises
------
ValueError
If the keys of the state dictionaries don't match
"""
keys = get_dict_keys_and_check_matching(_1, _2)
for key in keys:
# Size checking is done by torch during the subtraction
_1[key] += constant * _2[key]
def flatten_state_dict(state_dict):
"""
Transforms state dictionary into a flat tensor
by flattening and concatenating tensors of the
state dictionary.
Note: changes made to the result won't affect state dictionary
Parameters
----------
state_dict : OrderedDict[str, torch.tensor]
A state dictionary to flatten
"""
return torch.cat([tensor.flatten() for tensor in state_dict.values()], axis=0)
def unflatten_state_dict(flat_tensor, reference_state_dict):
"""
Transforms a falt tensor into a state dictionary
by using another state dictionary as a reference
for size and names of the tensors. Assumes
that the number of elements of the flat tensor
is the same as the number of elements in the
reference state dictionary.
This operation is inverse operation to flatten_state_dict
Note: changes made to the result will affect the flat tensor
Parameters
----------
flat_tensor : torch.tensor
A 1-dim tensor
reference_state_dict : OrderedDict[str, torch.tensor]
A state dictionary used as a reference for tensor names
and shapes of the result
"""
result = OrderedDict()
start_index = 0
for tensor_name, tensor in reference_state_dict.items():
end_index = start_index + tensor.numel()
result[tensor_name] = flat_tensor[start_index:end_index].reshape(tensor.shape)
start_index = end_index
return result
def serialize_sparse_tensor(tensor):
"""
Serializes sparse tensor by flattening it and
returning values and indices of it that are not 0
Parameters
----------
tensor: torch.Tensor
"""
flat = tensor.flatten()
indices = flat.nonzero(as_tuple=True)[0]
values = flat[indices]
return values, indices
def deserialize_sparse_tensor(values, indices, shape):
"""
Deserializes tensor from its non-zero values and indices
in flattened form and original shape of the tensor.
Parameters
----------
values: torch.Tensor
Non-zero entries of flattened original tensor
indices: torch.Tensor
Respective indices of non-zero entries of flattened original tensor
shape: torch.Size or tuple[*int]
Shape of the original tensor
"""
result = torch.zeros(size=shape)
if len(indices):
flat_result = result.flatten()
flat_result[indices] = values
return result
def topk_sparsification_tensor(tensor, alpha):
"""
Performs topk sparsification of a tensor and returns
the same tensor from the input but transformed.
Note: no copies are created, but input vector is directly changed
Parameters
----------
tensor : torch.tensor
A tensor to perform the sparsification on
alpha : float
Percentage of topk values to keep
"""
tensor_abs = tensor.abs()
flat_abs_tensor = tensor_abs.flatten()
numel_to_keep = round(alpha * flat_abs_tensor.numel())
if numel_to_keep > 0:
cutoff_value, _ = torch.kthvalue(-flat_abs_tensor, numel_to_keep)
tensor[tensor_abs < -cutoff_value] = 0
return tensor
def topk_sparsification(state_dict, alpha):
"""
Performs topk sparsification of a state_dict
applying it over all elements together.
Note: the changes made to the result won't affect
the input state dictionary
Parameters
----------
state_dict : OrderedDict[str, torch.tensor]
A state dictionary to perform the sparsification on
alpha : float
Percentage of topk values to keep
"""
flat_tensor = flatten_state_dict(state_dict)
return unflatten_state_dict(
topk_sparsification_tensor(flat_tensor, alpha), state_dict
)
def serialize_sparse_state_dict(state_dict):
with torch.no_grad():
concatted_tensors = torch.cat(
[tensor.flatten() for tensor in state_dict.values()], axis=0
)
return serialize_sparse_tensor(concatted_tensors)
def deserialize_sparse_state_dict(values, indices, reference_state_dict):
with torch.no_grad():
keys = []
lens = []
shapes = []
for k, v in reference_state_dict.items():
keys.append(k)
shapes.append(v.shape)
lens.append(v.numel())
total_num_el = sum(lens)
T = deserialize_sparse_tensor(values, indices, (total_num_el,))
result_state_dict = OrderedDict()
start_index = 0
for i, k in enumerate(keys):
end_index = start_index + lens[i]
result_state_dict[k] = T[start_index:end_index].reshape(shapes[i])
start_index = end_index
return result_state_dict
class Choco(Sharing):
"""
API defining who to share with and what, and what to do on receiving
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
step_size,
alpha,
compress=False,
compression_package=None,
compression_class=None,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
step_size : float
Step size from the formulation of Choco
alpha : float
Percentage of elements to keep during topk sparsification
"""
super().__init__(
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
compress,
compression_package,
compression_class,
)
self.step_size = step_size
self.alpha = alpha
logging.info(
"type(step_size): %s, value: %s",
str(type(self.step_size)),
str(self.step_size),
)
logging.info(
"type(alpha): %s, value: %s", str(type(self.alpha)), str(self.alpha)
)
model_state_dict = model.state_dict()
self.model_hat = zeros_like_state_dict(model_state_dict)
self.s = zeros_like_state_dict(model_state_dict)
self.my_q = None
def compress_data(self, data):
result = dict(data)
if self.compress:
if "indices" in result:
result["indices"] = self.compressor.compress(result["indices"])
if "params" in result:
result["params"] = self.compressor.compress_float(result["params"])
return result
def decompress_data(self, data):
if self.compress:
if "indices" in data:
data["indices"] = self.compressor.decompress(data["indices"])
if "params" in data:
data["params"] = self.compressor.decompress_float(data["params"])
return data
def _compress(self, x):
return topk_sparsification(x, self.alpha)
def _pre_step(self):
"""
Called at the beginning of step.
"""
with torch.no_grad():
self.my_q = self._compress(
subtract_state_dicts(self.model.state_dict(), self.model_hat)
)
def serialized_model(self):
"""
Convert self q to a dictionary. Here we can choose how much to share
Returns
-------
dict
Model converted to dict
"""
values, indices = serialize_sparse_state_dict(self.my_q)
data = dict()
data["params"] = values.numpy()
data["indices"] = indices.numpy()
data["send_partial"] = True
return self.compress_data(data)
def deserialized_model(self, m):
"""
Convert received dict to state_dict.
Parameters
----------
m : dict
received dict
Returns
-------
state_dict
state_dict of received
"""
if "send_partial" not in m:
return super().deserialized_model(m)
with torch.no_grad():
m = self.decompress_data(m)
indices = torch.tensor(m["indices"], dtype=torch.long)
values = torch.tensor(m["params"])
return deserialize_sparse_state_dict(
values, indices, self.model.state_dict()
)
def _averaging(self, peer_deques):
"""
Averages the received model with the local model
"""
with torch.no_grad():
self_add_state_dict(self.model_hat, self.my_q) # x_hat = q_self + x_hat
weight_total = 0
for i, n in enumerate(peer_deques):
data = peer_deques[n].popleft()
degree, iteration = data["degree"], data["iteration"]
del data["degree"]
del data["iteration"]
del data["CHANNEL"]
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
data = self.deserialized_model(data)
# Metro-Hastings
weight = 1 / (max(len(peer_deques), degree) + 1)
weight_total += weight
for key, value in data.items():
if key in self.s:
self.s[key] += value * weight
# else:
# self.s[key] = value * weight
for key, value in self.my_q.items():
self.s[key] += (1 - weight_total) * value # Metro-Hastings
total = self.model.state_dict().copy()
self_add_state_dict(
total,
subtract_state_dicts(self.s, self.model_hat),
constant=self.step_size,
) # x = x + gamma * (s - x_hat)
self.model.load_state_dict(total)
self._post_step()
self.communication_round += 1
def _averaging_server(self, peer_deques):
"""
Averages the received models of all working nodes
"""
raise NotImplementedError()
......@@ -49,7 +49,7 @@ class FFT(PartialModel):
metadata_cap=1.0,
change_based_selection=True,
save_accumulated="",
accumulation=True,
accumulation=False,
accumulate_averaging_changes=False,
compress=False,
compression_package=None,
......
......@@ -171,7 +171,8 @@ class Sharing:
)
)
data = self.deserialized_model(data)
weight = 1 / (max(len(peer_deques), degree) + 1) # Metro-Hastings
# Metro-Hastings
weight = 1 / (max(len(peer_deques), degree) + 1)
weight_total += weight
for key, value in data.items():
if key in total:
......@@ -194,3 +195,34 @@ class Sharing:
data["degree"] = len(all_neighbors)
data["iteration"] = self.communication_round
return data
def _averaging_server(self, peer_deques):
"""
Averages the received models of all working nodes
"""
with torch.no_grad():
total = dict()
for i, n in enumerate(peer_deques):
data = peer_deques[n].popleft()
degree, iteration = data["degree"], data["iteration"]
del data["degree"]
del data["iteration"]
del data["CHANNEL"]
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
data = self.deserialized_model(data)
weight = 1 / len(peer_deques)
for key, value in data.items():
if key in total:
total[key] += weight * value
else:
total[key] = weight * value
self.model.load_state_dict(total)
self._post_step()
self.communication_round += 1
return total
......@@ -203,7 +203,8 @@ class Wavelet(PartialModel):
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = indices.tolist() # is slow
# is slow
shared_params[self.communication_round] = indices.tolist()
shared_params["alpha"] = self.alpha
......@@ -296,7 +297,8 @@ class Wavelet(PartialModel):
else:
topkwf = params.reshape(self.wt_shape)
weight = 1 / (max(len(peer_deques), degree) + 1) # Metro-Hastings
# Metro-Hastings
weight = 1 / (max(len(peer_deques), degree) + 1)
weight_total += weight
if total is None:
total = weight * topkwf
......@@ -325,3 +327,59 @@ class Wavelet(PartialModel):
self.model.load_state_dict(std_dict)
self._post_step()
self.communication_round += 1
def _averaging_server(self, peer_deques):
"""
Averages the received models of all working nodes
"""
with torch.no_grad():
total = None
wt_params = self.pre_share_model_transformed
for i, n in enumerate(peer_deques):
data = peer_deques[n].popleft()
degree, iteration = data["degree"], data["iteration"]
del data["degree"]
del data["iteration"]
del data["CHANNEL"]
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
data = self.deserialized_model(data)
params = data["params"]
if "indices" in data:
indices = data["indices"]
# use local data to complement
topkwf = wt_params.clone().detach()
topkwf[indices] = params
topkwf = topkwf.reshape(self.wt_shape)
else:
topkwf = params.reshape(self.wt_shape)
weight = 1 / len(peer_deques)
if total is None:
total = weight * topkwf
else:
total += weight * topkwf
avg_wf_params = pywt.array_to_coeffs(
total.numpy(), self.coeff_slices, output_format="wavedec"
)
reverse_total = torch.from_numpy(
pywt.waverec(avg_wf_params, wavelet=self.wavelet)
)
start_index = 0
std_dict = {}
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
std_dict[key] = reverse_total[start_index:end_index].reshape(
self.shapes[i]
)
start_index = end_index
self.model.load_state_dict(std_dict)
self._post_step()
self.communication_round += 1
import logging
import os
from pathlib import Path
import numpy as np
import torch
from decentralizepy.graphs import Graph
class TrainTestHelper:
def __init__(
self,
dataset,
model,
loss,
dir,
n_procs,
trainer,
comm,
graph: Graph,
threads_per_proc,
eval_train=False,
):
self.dataset = dataset
self.model = model
self.loss = loss
self.dir = Path(dir)
self.n_procs = n_procs
self.trainer = trainer
self.comm = comm
self.star = graph
self.threads_per_proc = threads_per_proc
self.eval_train = eval_train
def train_test_evaluation(self, iteration):
with torch.no_grad():
self.model.eval()
total_threads = os.cpu_count()
torch.set_num_threads(total_threads)
neighbors = self.star.neighbors(0)
state_dict_copy = {}
shapes = []
lens = []
to_cat = []
for key, val in self.model.state_dict().items():
shapes.append(val.shape)
clone_val = val.clone().detach()
state_dict_copy[key] = clone_val
flat = clone_val.flatten()
to_cat.append(flat)
lens.append(flat.shape[0])
my_weight = torch.cat(to_cat)
weights = [my_weight]
# TODO: add weight of node 0
for i in neighbors:
sender, data = self.comm.receive()
logging.info(f"Received weight from {sender}")
weights.append(data)
# averaging
average_weight = np.mean([w.numpy() for w in weights], axis=0)
start_index = 0
average_weight_dict = {}
for i, key in enumerate(state_dict_copy):
end_index = start_index + lens[i]
average_weight_dict[key] = torch.from_numpy(
average_weight[start_index:end_index].reshape(shapes[i])
)
start_index = end_index
self.model.load_state_dict(average_weight_dict)
if self.eval_train:
logging.info("Evaluating on train set.")
trl = self.trainer.eval_loss(self.dataset)
else:
trl = None
logging.info("Evaluating on test set.")
ta, tl = self.dataset.test(self.model, self.loss)
# reload old weight
self.model.load_state_dict(state_dict_copy)
if trl is not None:
print(iteration, ta, tl, trl)
else:
print(iteration, ta, tl)
torch.set_num_threads(self.threads_per_proc)
for neighbor in neighbors:
self.comm.send(neighbor, "finished")
self.model.train()
return ta, tl, trl
......@@ -83,10 +83,9 @@ def get_args():
parser.add_argument("-ta", "--test_after", type=int, default=5)
parser.add_argument("-tea", "--train_evaluate_after", type=int, default=1)
parser.add_argument("-ro", "--reset_optimizer", type=int, default=1)
parser.add_argument("-ctr", "--centralized_train_eval", type=int, default=0)
parser.add_argument("-cte", "--centralized_test_eval", type=int, default=0)
parser.add_argument("-sm", "--server_machine", type=int, default=0)
parser.add_argument("-sr", "--server_rank", type=int, default=-1)
parser.add_argument("-wr", "--working_rate", type=float, default=1.0)
args = parser.parse_args()
return args
......@@ -118,8 +117,7 @@ def write_args(args, path):
"test_after": args.test_after,
"train_evaluate_after": args.train_evaluate_after,
"reset_optimizer": args.reset_optimizer,
"centralized_train_eval": args.centralized_train_eval,
"centralized_test_eval": args.centralized_test_eval,
"working_rate": args.working_rate,
}
with open(os.path.join(path, "args.json"), "w") as of:
json.dump(data, of)
......
[DATASET]
dataset_package = decentralizepy.datasets.Celeba
dataset_class = Celeba
model_class = CNN
images_dir = /mnt/nfs/shared/leaf/data/celeba/data/raw/img_align_celeba
train_dir = /mnt/nfs/shared/leaf/data/celeba/per_user_data/train
test_dir = /mnt/nfs/shared/leaf/data/celeba/data/test
dataset_package = decentralizepy.datasets.CIFAR10
dataset_class = CIFAR10
model_class = LeNet
train_dir = /mnt/nfs/shared/CIFAR
test_dir = /mnt/nfs/shared/CIFAR
; python list of fractions below
sizes =
random_seed = 90
partition_niid = True
shards = 4
[OPTIMIZER_PARAMS]
optimizer_package = torch.optim
optimizer_class = SGD
lr = 0.001
lr = 0.01
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
rounds = 4
rounds = 3
full_epochs = False
batch_size = 16
batch_size = 8
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
......@@ -26,14 +28,8 @@ loss_class = CrossEntropyLoss
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
comm_class = TCP
addresses_filepath = /mnt/nfs/kirsten/Gitlab/tutorial/ip.json
addresses_filepath = /mnt/nfs/risharma/Gitlab/tutorial/ip.json
[SHARING]
sharing_package = decentralizepy.sharing.Sharing
sharing_class = Sharing
;sharing_package = decentralizepy.sharing.PartialModel
;sharing_class = PartialModel
;alpha = 0.1
;accumulation = True
;accumulate_averaging_changes = True
{
"0": "127.0.0.1"
}
\ No newline at end of file
16
0 12
0 14
0 15
1 8
1 3
1 6
2 9
2 10
2 5
3 1
3 11
3 9
4 9
4 12
4 13
5 2
5 6
5 7
6 1
6 5
6 7
7 5
7 6
7 14
8 1
8 13
8 14
9 2
9 3
9 4
10 2
10 11
10 13
11 10
11 3
11 15
12 0
12 4
12 15
13 8
13 10
13 4
14 0
14 8
14 7
15 0
15 11
15 12
#!/bin/bash
decpy_path=/mnt/nfs/risharma/Gitlab/decentralizepy/eval
cd $decpy_path
env_python=~/miniconda3/envs/decpy/bin/python3
graph=/mnt/nfs/risharma/Gitlab/tutorial/96_regular.edges
original_config=/mnt/nfs/risharma/Gitlab/tutorial/config_celeba_sharing.ini
config_file=~/tmp/config.ini
procs_per_machine=16
machines=1
iterations=80
test_after=20
eval_file=testingPeerSampler.py
log_level=INFO
m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
echo M is $m
log_dir=$(date '+%Y-%m-%dT%H:%M')/machine$m
mkdir -p $log_dir
cp $original_config $config_file
# echo "alpha = 0.10" >> $config_file
$env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -wsd $log_dir
\ No newline at end of file
#!/bin/bash
decpy_path=/mnt/nfs/risharma/Gitlab/decentralizepy/eval
cd $decpy_path
env_python=~/miniconda3/envs/decpy/bin/python3
graph=/mnt/nfs/risharma/Gitlab/tutorial/96_regular.edges
original_config=/mnt/nfs/risharma/Gitlab/tutorial/config_celeba_sharing.ini
config_file=~/tmp/config.ini
procs_per_machine=16
machines=1
iterations=80
test_after=20
eval_file=testingFederated.py
log_level=INFO
server_rank=-1
server_machine=0
working_rate=0.5
m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
echo M is $m
log_dir=$(date '+%Y-%m-%dT%H:%M')/machine$m
mkdir -p $log_dir
cp $original_config $config_file
# echo "alpha = 0.10" >> $config_file
$env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -ctr 0 -cte 0 -wsd $log_dir -sm $server_machine -sr $server_rank -wr $working_rate
\ No newline at end of file