Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sacs/decentralizepy
  • mvujas/decentralizepy
  • randl/decentralizepy
3 results
Show changes
Commits on Source (6)
import logging
from pathlib import Path
from shutil import copy
from localconfig import LocalConfig
from torch import multiprocessing as mp
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.SecureCompressedAggregation import SecureCompressedAggregation
def read_ini(file_path):
config = LocalConfig(file_path)
for section in config:
print("Section: ", section)
for key, value in config.items(section):
print((key, value))
print(dict(config.items("DATASET")))
return config
if __name__ == "__main__":
args = utils.get_args()
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
log_level = {
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
config = read_ini(args.config_file)
my_config = dict()
for section in config:
my_config[section] = dict(config.items(section))
copy(args.config_file, args.log_dir)
copy(args.graph_file, args.log_dir)
utils.write_args(args, args.log_dir)
g = Graph()
g.read_graph_from_file(args.graph_file, args.graph_type)
n_machines = args.machines
procs_per_machine = args.procs_per_machine
l = Linear(n_machines, procs_per_machine)
m_id = args.machine_id
processes = []
for r in range(procs_per_machine):
processes.append(
mp.Process(
target=SecureCompressedAggregation,
args=[
r,
m_id,
l,
g,
my_config,
args.iterations,
args.log_dir,
args.weights_store_dir,
log_level[args.log_level],
args.test_after,
args.train_evaluate_after,
args.reset_optimizer,
],
)
)
for p in processes:
p.start()
for p in processes:
p.join()
from multiprocessing.sharedctypes import Value
import networkx as nx
import numpy as np
......@@ -152,6 +153,33 @@ class Graph:
"""
return self.adj_list[uid]
def nodes_at_distance(self, uid, distance):
"""
Returns the set of all vertices at the given
distance from a node.
Parameters
----------
uid : int
globally unique identifier of the node
Returns
-------
set(int)
a set of nodes at the given distance from uid
"""
if distance == 1:
return self.adj_list[uid]
elif distance == 0:
return set(uid)
elif distance < 0:
raise ValueError('Negative distance')
vertex_set = set()
for neighbor in self.adj_list[uid]:
vertex_set |= self.nodes_at_distance(neighbor, distance - 1)
return vertex_set
def centr(self):
my_adj = {x: list(adj) for x, adj in enumerate(self.adj_list)}
nxGraph = nx.Graph(my_adj)
......
import importlib
import json
import logging
import math
import os
from collections import deque
import torch
from matplotlib import pyplot as plt
from collections import OrderedDict
import numpy as np
from decentralizepy.random import RandomState, temp_seed
from decentralizepy.node.DPSGDNode import DPSGDNode
def flatten_state_dict(state_dict):
"""
Transforms state dictionary into a flat tensor
by flattening and concatenating tensors of the
state dictionary.
Note: changes made to the result won't affect state dictionary
Parameters
----------
state_dict : OrderedDict[str, torch.tensor]
A state dictionary to flatten
"""
return torch.cat([
tensor.flatten()\
for tensor in state_dict.values()
], axis=0)
def get_number_of_elements(state_dict):
"""
Returns the number of parameters in the state dictionary
of a model.
Parameters
----------
state_dict : OrderedDict[str, torch.tensor]
The state dictionary of model
"""
return sum([v.numel() for v in state_dict.values()])
def unflatten_state_dict(flat_tensor, reference_state_dict):
"""
Transforms a falt tensor into a state dictionary
by using another state dictionary as a reference
for size and names of the tensors. Assumes
that the number of elements of the flat tensor
is the same as the number of elements in the
reference state dictionary.
This operation is inverse operation to flatten_state_dict
Note: changes made to the result will affect the flat tensor
Parameters
----------
flat_tensor : torch.tensor
A 1-dim tensor
reference_state_dict : OrderedDict[str, torch.tensor]
A state dictionary used as a reference for tensor names
and shapes of the result
"""
result = OrderedDict()
start_index = 0
for tensor_name, tensor in reference_state_dict.items():
end_index = start_index + tensor.numel()
result[tensor_name] = flat_tensor[start_index:end_index].reshape(
tensor.shape)
start_index = end_index
return result
def top_k(state_dict, alpha):
flat_sd = flatten_state_dict(state_dict)
num_el_to_keep = int(flat_sd.numel() * alpha)
_, indices = torch.topk(flat_sd.abs(), num_el_to_keep, largest=True)
return flat_sd, indices
def layerwise_topk(state_dict, alpha):
indice_list, params_list = [], []
numel_so_far = 0
for _, v in state_dict.items():
flat_tensor = v.flatten()
num_el_to_keep = int(flat_tensor.numel() * alpha)
_, indices = torch.topk(flat_tensor, num_el_to_keep, largest=True, sorted=True)
indices, _ = torch.sort(indices)
# print(indices)
indices += numel_so_far
indice_list.append(indices)
params_list.append(flat_tensor)
numel_so_far += flat_tensor.numel()
selected_indices = torch.cat(indice_list)
flat_params = torch.cat(params_list)
return flat_params, selected_indices
class SecureCompressedAggregation(DPSGDNode):
"""
This class defines the node for secure compressed aggregation
"""
def get_neighbors(self, node=None):
if node is None:
node = self.uid
return self.graph.neighbors(node)
def get_distance2_neighbors(self, start_node=None):
"""
Get all nodes 2 hops away from the start node exluding itself
"""
if start_node is None:
start_node = self.uid
nodes = self.graph.nodes_at_distance(start_node, 2)
nodes.remove(start_node)
return nodes
def connect_to_nodes(self, set_of_nodes):
"""
Connects all neighbors. Sends HELLO. Waits for HELLO.
Caches any data received while waiting for HELLOs.
Raises
------
RuntimeError
If received BYE while waiting for HELLO
"""
logging.info("Sending connection request to all neighbors")
wait_acknowledgements = []
for node in set_of_nodes:
if not self.communication.already_connected(node):
self.connect_neighbor(node)
wait_acknowledgements.append(node)
for node in wait_acknowledgements:
self.wait_for_hello(node)
def _pseudo_pre_step(self):
pre_share_model = flatten_state_dict(self.model.state_dict()).clone()
change = pre_share_model - self.init_model
self.model.accumulated_changes += change
change = self.model.accumulated_changes.clone().detach()
self.model.model_change = change
def _pseudo_post_step(self):
post_share_model = flatten_state_dict(self.model.state_dict()).clone()
self.init_model = post_share_model
self.model.accumulated_changes += self.init_model - self.prev
self.prev = self.init_model
self.model.model_change = None
def top_k_changed(self, state_dict, alpha):
flat_sd = flatten_state_dict(state_dict)
flat_changes = torch.abs(self.model.model_change)
num_el_to_keep = int(flat_sd.numel() * alpha)
_, indices = torch.topk(flat_changes, num_el_to_keep, largest=True)
return flat_sd, indices
def random_subsampling(self, state_dict, alpha):
flat_sd = flatten_state_dict(state_dict)
logging.info("Subsampling mask seed: %d", torch.seed())
keep_mask = torch.rand(flat_sd.shape) < alpha
indices = keep_mask.nonzero(as_tuple=True)[0]
return flat_sd, indices
def aggregate_models(self, parameters, indices, iteration):
# return None
distance2_nodes = self.get_distance2_neighbors()
logging.info("Neighbors: {}".format(self.get_neighbors()))
logging.info("Distance 2 nodes: {}".format(distance2_nodes))
self.connect_to_nodes(distance2_nodes)
compressed_indices = self.sharing.compressor.compress(indices.numpy())
# Generating and sending pairwise masks
sent_masks = {}
for node in distance2_nodes:
mask_seed = np.random.randint(1000000)
sent_masks[node] = mask_seed
self.communication.send(node, {
"seed": mask_seed,
"indices": compressed_indices,
"iteration": iteration,
"CHANNEL": "PRE-SECURE-AGG-STEP"
})
logging.info("Sent mask to %d", node)
# Receiving pairwise masks and indices
received_data = {}
waiting_mask_from = distance2_nodes.copy()
# Processing masks received before the given round
for sender, mask_data in self.masks_received_early:
if mask_data["iteration"] != iteration:
raise ValueError("Mask iterations don't match")
del mask_data["iteration"]
received_data[sender] = mask_data
received_data[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_data[sender]["indices"]), dtype=torch.long)
waiting_mask_from.remove(sender)
self.masks_received_early = []
# Waiting for other masks
while waiting_mask_from:
sender, data = self.receive_channel("PRE-SECURE-AGG-STEP")
del data["CHANNEL"]
if sender in waiting_mask_from:
if data["iteration"] != iteration:
raise ValueError("Mask iterations don't match")
del data["iteration"]
received_data[sender] = data
received_data[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_data[sender]["indices"]), dtype=torch.long)
waiting_mask_from.remove(sender)
else:
self.masks_received_early.append((sender, data))
# Building masks
pairwise_mask_difference = {}
indices_size = indices.size()[0]
logging.info("Indices intended to share: %s", indices_size)
for node, data in received_data.items():
# sortednp.intersect supports intersection of sorted array (make sure to cast tensor to nparray)
# torch.topk doesn't return indices sorted...
_, my_indices_pos, _ = np.intersect1d(indices, data["indices"], return_indices=True)
mask_shape = my_indices_pos.shape
# logging.info("My indices %d, neighbors indices %d, intersect %d", indices.size()[0], data["indices"].size()[0], my_indices_pos.size)
logging.info("Indice intersects: %s", my_indices_pos.size)
pairwise_mask_difference[node] = {
"value": (self.generate_mask(sent_masks[node], mask_shape) - self.generate_mask(data["seed"], mask_shape)).double(),
"indices": my_indices_pos
}
# Sending models to neighbors
# print(indices)
self.my_neighbors = self.get_neighbors()
self.connect_to_nodes(self.my_neighbors)
for neighbor in self.my_neighbors:
neighbors_neighbors = self.get_neighbors(neighbor)
perturbated_model = parameters.clone()
masking_count = torch.zeros_like(indices)
# Add masks from pairs
for pairing_node in neighbors_neighbors:
if self.uid == pairing_node:
continue
pair_mask = pairwise_mask_difference[pairing_node]["value"]
pair_indices_pos = pairwise_mask_difference[pairing_node]["indices"]
pair_indices = indices[pair_indices_pos]
# print(perturbated_model.shape, pair_indices.shape, pair_mask.shape)
# print(perturbated_model.dtype, pair_mask.dtype)
perturbated_model[pair_indices] += pair_mask
masking_count[pair_indices_pos] += 1
non_zero_indices = masking_count.nonzero(as_tuple=True)[0]
indices_to_send = indices[non_zero_indices]
parameters_to_send = parameters[indices_to_send]
# Debug to 'skip' protocol (delete later)
# parameters_to_send = parameters[indices]
# indices_to_send = indices
logging.info('Sending indices: %d', indices_to_send.shape[0])
compressed_parameters = self.sharing.compressor.compress_float(parameters_to_send.numpy())
compressed_indices = self.sharing.compressor.compress(indices_to_send.numpy())
self.communication.send(neighbor, {
"params": compressed_parameters,
"indices": compressed_indices,
"iteration": iteration,
"CHANNEL": "SECURE_MODEL_CHANNEL"
})
logging.info("Sent model to %d", neighbor)
# Receiving models from neighbors
received_models = {}
waiting_models_from = self.my_neighbors.copy()
for sender, model_data in self.models_received_early:
if model_data["iteration"] != iteration:
raise ValueError("Model iterations don't match")
del model_data["iteration"]
received_models[sender] = model_data
received_models[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_models[sender]["indices"]), dtype=torch.long)
received_models[sender]["params"] = torch.tensor(
self.sharing.compressor.decompress_float(received_models[sender]["params"]))
waiting_models_from.remove(sender)
self.models_received_early = []
while waiting_models_from:
# print(self.uid, "Waiting models from:", waiting_models_from)
sender, data = self.receive_channel("SECURE_MODEL_CHANNEL")
del data["CHANNEL"]
if sender in waiting_models_from:
if data["iteration"] != iteration:
raise ValueError("Model iterations don't match")
del data["iteration"]
received_models[sender] = data
received_models[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_models[sender]["indices"]), dtype=torch.long)
received_models[sender]["params"] = torch.tensor(
self.sharing.compressor.decompress_float(received_models[sender]["params"]))
waiting_models_from.remove(sender)
else:
self.models_received_early.append((sender, data))
# Averaging
weight = 1 / (len(self.my_neighbors) + 1)
preshare_model = flatten_state_dict(self.model.state_dict())
new_flat_model = weight * preshare_model.clone()
for data in received_models.values():
params = data["params"]
indices = data["indices"]
recovered_model = preshare_model.clone()
recovered_model[indices] = params
new_flat_model += weight * recovered_model
# Loading new state state dictionary
logging.info('L0=' + str((parameters-new_flat_model).abs().sum()))
logging.info('model_L0=' + str((parameters).abs().sum()))
new_state_dict = unflatten_state_dict(new_flat_model, self.model.state_dict())
self.model.load_state_dict(new_state_dict)
def generate_mask(self, seed, size):
with temp_seed(seed):
# Figure out best distribution to add
return torch.Tensor(np.random.uniform(-10000000, 20000000, size=size))
def extract_top_gradients(self):
"""
Extract the indices and values of the topK gradients.
The gradients must have been accumulated.
Returns
-------
tuple
(a,b). a: The magnitudes of the topK gradients, b: Their indices.
"""
logging.info("Returning topk gradients")
G_topk = torch.abs(self.model.model_change)
std, mean = torch.std_mean(G_topk, unbiased=False)
self.std = std.item()
self.mean = mean.item()
_, index = torch.topk(
G_topk, round(self.alpha * G_topk.shape[0]), dim=0, sorted=True
)
index, _ = torch.sort(index)
return _, index
def run(self):
"""
Start the decentralized learning
"""
# logging.info("Start, Np num: %f, torch num: %f", np.random.random(), torch.rand((1,))[0])
with torch.no_grad():
self.init_model = flatten_state_dict(self.model.state_dict())
self.model.accumulated_changes = torch.zeros_like(
self.init_model)
self.prev = self.init_model
self.sec_agg_state = RandomState(self.uid)
self.testset = self.dataset.get_testset()
rounds_to_test = self.test_after
rounds_to_train_evaluate = self.train_evaluate_after
global_epoch = 1
change = 1
self.old_model_holder = flatten_state_dict(self.model.state_dict()).clone()
self.model.accumulated_changes = torch.zeros_like(
self.old_model_holder)
self.masks_received_early = []
self.models_received_early = []
# logging.info("Before iter, Np num: %f, torch num: %f", np.random.random(), torch.rand((1,))[0])
logging.info("Number of parameters in model: %d",
get_number_of_elements(self.model.state_dict()))
for iteration in range(self.iterations):
if self.uid == 0:
print("Iteration", iteration)
logging.info("Starting training iteration: %d", iteration)
rounds_to_train_evaluate -= 1
rounds_to_test -= 1
# logging.info("Iteration %d before train, NP state: %d, torch state: %d",
# iteration,
# np.random.get_state()[1].sum(),
# torch.random.get_rng_state().sum())
self.iteration = iteration
self.trainer.train(self.dataset)
# logging.info("Iteration %d before share, NP state: %d, torch state: %d",
# iteration,
# np.random.get_state()[1].sum(),
# torch.random.get_rng_state().sum())
self._pseudo_pre_step()
# self.aggregate_models(*top_k(self.model.state_dict(), 0.3), iteration)
# self.aggregate_models(*self.random_subsampling(self.model.state_dict(), 0.3), iteration)
# flat_model, indices_to_share = self.top_k_changed(self.model.state_dict(), 0.3)
flat_model, indices_to_share = self.top_k_changed(self.model.state_dict(), 1)
self.model.shared_parameters_counter[indices_to_share] += 1
self.model.rewind_accumulation(indices_to_share)
with self.sec_agg_state.activate():
self.aggregate_models(flat_model, indices_to_share, iteration)
self._pseudo_post_step()
# logging.info("Iteration %d, NP state: %d, torch state: %d",
# iteration,
# np.random.get_state()[1].sum(),
# torch.random.get_rng_state().sum())
if self.reset_optimizer:
self.optimizer = self.optimizer_class(
self.model.parameters(), **self.optimizer_params
) # Reset optimizer state
self.trainer.reset_optimizer(self.optimizer)
if iteration:
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
"r",
) as inf:
results_dict = json.load(inf)
else:
results_dict = {
"train_loss": {},
"test_loss": {},
"test_acc": {},
"total_bytes": {},
"total_meta": {},
"total_data_per_n": {},
}
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
if hasattr(self.communication, "total_meta"):
results_dict["total_meta"][
iteration + 1
] = self.communication.total_meta
if hasattr(self.communication, "total_data"):
results_dict["total_data_per_n"][
iteration + 1
] = self.communication.total_data
if rounds_to_train_evaluate == 0:
logging.info("Evaluating on train set.")
rounds_to_train_evaluate = self.train_evaluate_after * change
loss_after_sharing = self.trainer.eval_loss(self.dataset)
results_dict["train_loss"][iteration + 1] = loss_after_sharing
self.save_plot(
results_dict["train_loss"],
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
)
if self.dataset.__testing__ and rounds_to_test == 0:
rounds_to_test = self.test_after * change
logging.info("Evaluating on test set.")
ta, tl = self.dataset.test(self.model, self.loss)
results_dict["test_acc"][iteration + 1] = ta
results_dict["test_loss"][iteration + 1] = tl
if global_epoch == 49:
change *= 2
global_epoch += change
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
) as of:
json.dump(results_dict, of)
if self.model.shared_parameters_counter is not None:
logging.info("Saving the shared parameter counts")
with open(
os.path.join(
self.log_dir, "{}_shared_parameters.json".format(self.rank)
),
"w",
) as of:
json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
self.disconnect_neighbors()
logging.info("Storing final weight")
self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
logging.info("All neighbors disconnected. Process complete!")
import random
import contextlib
import torch
import numpy as np
@contextlib.contextmanager
def temp_seed(seed):
"""
Creates a context with seeds set to given value. Returns to the
previous seed afterwards.
Note: Based on torch implementation there might be issues with CUDA
causing troubles with the correctness of this function. Function
torch.rand() work fine from testing as their results are generated
on CPU regardless if CUDA is used for other things.
"""
random_state = random.getstate()
np_old_state = np.random.get_state()
torch_old_state = torch.random.get_rng_state()
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
try:
yield
finally:
random.setstate(random_state)
np.random.set_state(np_old_state)
torch.random.set_rng_state(torch_old_state)
class RandomState:
"""
Creates a state that affects random number generation on
torch and numpy and whose context can be activated at will
"""
def __init__(self, seed):
with temp_seed(seed):
self.__refresh_states()
def __refresh_states(self):
self.__random_state = random.getstate()
self.__np_state = np.random.get_state()
self.__torch_state = torch.random.get_rng_state()
def __set_states(self):
random.setstate(self.__random_state)
np.random.set_state(self.__np_state)
torch.random.set_rng_state(self.__torch_state)
@contextlib.contextmanager
def activate(self):
"""
Activates this state in the given context for torch and
numpy. The previous state is restored when the context
is finished
"""
random_state = random.getstate()
np_old_state = np.random.get_state()
torch_old_state = torch.random.get_rng_state()
self.__set_states()
try:
yield
finally:
self.__refresh_states()
random.setstate(random_state)
np.random.set_state(np_old_state)
torch.random.set_rng_state(torch_old_state)
\ No newline at end of file
......@@ -310,9 +310,9 @@ class Choco(Sharing):
model,
dataset,
log_dir,
compress=False,
compression_package=None,
compression_class=None
compress,
compression_package,
compression_class
)
self.step_size = step_size
self.alpha = alpha
......