Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sacs/decentralizepy
  • mvujas/decentralizepy
  • randl/decentralizepy
3 results
Show changes
Commits on Source (5)
import logging
from pathlib import Path
from shutil import copy
from localconfig import LocalConfig
from torch import multiprocessing as mp
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.SecureCompressedAggregation import SecureCompressedAggregation
def read_ini(file_path):
config = LocalConfig(file_path)
for section in config:
print("Section: ", section)
for key, value in config.items(section):
print((key, value))
print(dict(config.items("DATASET")))
return config
if __name__ == "__main__":
args = utils.get_args()
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
log_level = {
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
config = read_ini(args.config_file)
my_config = dict()
for section in config:
my_config[section] = dict(config.items(section))
copy(args.config_file, args.log_dir)
copy(args.graph_file, args.log_dir)
utils.write_args(args, args.log_dir)
g = Graph()
g.read_graph_from_file(args.graph_file, args.graph_type)
n_machines = args.machines
procs_per_machine = args.procs_per_machine
l = Linear(n_machines, procs_per_machine)
m_id = args.machine_id
processes = []
for r in range(procs_per_machine):
processes.append(
mp.Process(
target=SecureCompressedAggregation,
args=[
r,
m_id,
l,
g,
my_config,
args.iterations,
args.log_dir,
args.weights_store_dir,
log_level[args.log_level],
args.test_after,
args.train_evaluate_after,
args.reset_optimizer,
],
)
)
for p in processes:
p.start()
for p in processes:
p.join()
from multiprocessing.sharedctypes import Value
import networkx as nx
import numpy as np
......@@ -152,6 +153,33 @@ class Graph:
"""
return self.adj_list[uid]
def nodes_at_distance(self, uid, distance):
"""
Returns the set of all vertices at the given
distance from a node.
Parameters
----------
uid : int
globally unique identifier of the node
Returns
-------
set(int)
a set of nodes at the given distance from uid
"""
if distance == 1:
return self.adj_list[uid]
elif distance == 0:
return set(uid)
elif distance < 0:
raise ValueError('Negative distance')
vertex_set = set()
for neighbor in self.adj_list[uid]:
vertex_set |= self.nodes_at_distance(neighbor, distance - 1)
return vertex_set
def centr(self):
my_adj = {x: list(adj) for x, adj in enumerate(self.adj_list)}
nxGraph = nx.Graph(my_adj)
......
import importlib
import json
import logging
import math
import os
from collections import deque
import torch
from matplotlib import pyplot as plt
from collections import OrderedDict
import numpy as np
import contextlib
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.node.Node import Node
from decentralizepy.node.DPSGDNode import DPSGDNode
def flatten_state_dict(state_dict):
"""
Transforms state dictionary into a flat tensor
by flattening and concatenating tensors of the
state dictionary.
Note: changes made to the result won't affect state dictionary
Parameters
----------
state_dict : OrderedDict[str, torch.tensor]
A state dictionary to flatten
"""
return torch.cat([
tensor.flatten()\
for tensor in state_dict.values()
], axis=0)
def unflatten_state_dict(flat_tensor, reference_state_dict):
"""
Transforms a falt tensor into a state dictionary
by using another state dictionary as a reference
for size and names of the tensors. Assumes
that the number of elements of the flat tensor
is the same as the number of elements in the
reference state dictionary.
This operation is inverse operation to flatten_state_dict
Note: changes made to the result will affect the flat tensor
Parameters
----------
flat_tensor : torch.tensor
A 1-dim tensor
reference_state_dict : OrderedDict[str, torch.tensor]
A state dictionary used as a reference for tensor names
and shapes of the result
"""
result = OrderedDict()
start_index = 0
for tensor_name, tensor in reference_state_dict.items():
end_index = start_index + tensor.numel()
result[tensor_name] = flat_tensor[start_index:end_index].reshape(
tensor.shape)
start_index = end_index
return result
def top_k(state_dict, alpha):
flat_sd = flatten_state_dict(state_dict)
num_el_to_keep = int(flat_sd.numel() * alpha)
parameters, indices = torch.topk(flat_sd, num_el_to_keep, largest=True)
return parameters, indices
@contextlib.contextmanager
def temp_seed(seed):
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
class SecureCompressedAggregation(DPSGDNode):
"""
This class defines the node for secure compressed aggregation
"""
def get_neighbors(self, node=None):
if node is None:
node = self.uid
return self.graph.neighbors(node)
def get_distance2_neighbors(self, start_node=None):
"""
Get all nodes 2 hops away from the start node exluding itself
"""
if start_node is None:
start_node = self.uid
nodes = self.graph.nodes_at_distance(start_node, 2)
nodes.remove(start_node)
return nodes
def receive_DPSGD(self):
return self.receive_channel("DPSGD")
def connect_to_nodes(self, set_of_nodes):
"""
Connects all neighbors. Sends HELLO. Waits for HELLO.
Caches any data received while waiting for HELLOs.
Raises
------
RuntimeError
If received BYE while waiting for HELLO
"""
logging.info("Sending connection request to all neighbors")
wait_acknowledgements = []
for node in set_of_nodes:
if not self.communication.already_connected(node):
self.connect_neighbor(node)
wait_acknowledgements.append(node)
for node in wait_acknowledgements:
self.wait_for_hello(node)
def aggregate_models(self, parameters, indices):
distance2_nodes = self.get_distance2_neighbors()
self.connect_to_nodes(distance2_nodes)
compressed_indices = self.sharing.compressor.compress(indices.numpy())
# Generating and sending pairwise masks
sent_masks = {}
for node in distance2_nodes:
mask_seed = np.random.randint(1000000)
sent_masks[node] = mask_seed
self.communication.send(node, {
"seed": mask_seed,
"indices": compressed_indices,
"CHANNEL": "PRE-SECURE-AGG-STEP"
})
# Receiving pairwise masks and indices
received_data = {}
waiting_mask_from = distance2_nodes.copy()
while waiting_mask_from:
sender, data = self.receive_channel("PRE-SECURE-AGG-STEP")
if sender in waiting_mask_from:
# print('Seed from', sender, 'is', data["seed"])
del data["CHANNEL"]
received_data[sender] = data
received_data[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_data[sender]["indices"]), dtype=torch.long)
waiting_mask_from.remove(sender)
# Building masks
pairwise_mask_difference = {}
for node, data in received_data.items():
# sortednp.intersect supports intersection of sorted array (make sure to cast tensor to nparray)
_, my_indices_pos, _ = np.intersect1d(indices, data["indices"], return_indices=True)
mask_shape = my_indices_pos.shape
pairwise_mask_difference[node] = {
"value": (self.generate_mask(sent_masks[node], mask_shape) - self.generate_mask(data["seed"], mask_shape)).double(),
"indices": my_indices_pos
}
# Sending models to neighbors
self.my_neighbors = self.get_neighbors()
self.connect_to_nodes(self.my_neighbors)
for neighbor in self.my_neighbors:
neighbors_neighbors = self.get_neighbors(neighbor)
perturbated_model = parameters.clone()
masking_count = torch.zeros_like(indices)
# Add masks from pairs
for pairing_node in neighbors_neighbors:
if self.uid == pairing_node:
continue
pair_mask = pairwise_mask_difference[pairing_node]["value"]
pair_indices = pairwise_mask_difference[pairing_node]["indices"]
# print(perturbated_model.shape, pair_indices.shape, pair_mask.shape)
# print(perturbated_model.dtype, pair_mask.dtype)
perturbated_model[pair_indices] += pair_mask
masking_count[pair_indices] += 1
non_zero_indices = masking_count.nonzero(as_tuple=True)[0]
indices_to_send = indices[non_zero_indices]
parameters_to_send = parameters[non_zero_indices]
compressed_parameters = self.sharing.compressor.compress_float(parameters_to_send.numpy())
compressed_indices = self.sharing.compressor.compress(indices_to_send.numpy())
self.communication.send(neighbor, {
"params": compressed_parameters,
"indices": compressed_indices,
"CHANNEL": "SECURE_MODEL_CHANNEL"
})
# Receiving models from neighbors
received_models = {}
waiting_models_from = self.my_neighbors.copy()
while waiting_models_from:
sender, data = self.receive_channel("SECURE_MODEL_CHANNEL")
if sender in waiting_models_from:
# print('Seed from', sender, 'is', data["seed"])
del data["CHANNEL"]
received_models[sender] = data
received_models[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_models[sender]["indices"]), dtype=torch.long)
received_models[sender]["params"] = torch.tensor(
self.sharing.compressor.decompress_float(received_models[sender]["params"]))
waiting_models_from.remove(sender)
# Averaging
weight = 1 / (len(self.my_neighbors) + 1)
preshare_model = flatten_state_dict(self.model.state_dict())
new_flat_model = weight * preshare_model.clone()
for data in received_models.values():
params = data["params"]
indices = data["indices"]
recovered_model = preshare_model.clone()
recovered_model[indices] = params
new_flat_model += weight * recovered_model
# Loading new state state dictionary
new_state_dict = unflatten_state_dict(new_flat_model, self.model.state_dict())
self.model.load_state_dict(new_state_dict)
def generate_mask(self, seed, size):
with temp_seed(seed):
# Figure out best distribution to add
return torch.Tensor(np.random.normal(0, 100000, size=size))
def run(self):
"""
Start the decentralized learning
"""
torch.manual_seed(self.uid)
np.random.seed(self.uid)
self.testset = self.dataset.get_testset()
rounds_to_test = self.test_after
rounds_to_train_evaluate = self.train_evaluate_after
global_epoch = 1
change = 1
for iteration in range(self.iterations):
logging.info("Starting training iteration: %d", iteration)
rounds_to_train_evaluate -= 1
rounds_to_test -= 1
self.iteration = iteration
self.trainer.train(self.dataset)
self.aggregate_models(*top_k(self.model.state_dict(), 0.3))
if self.reset_optimizer:
self.optimizer = self.optimizer_class(
self.model.parameters(), **self.optimizer_params
) # Reset optimizer state
self.trainer.reset_optimizer(self.optimizer)
if iteration:
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
"r",
) as inf:
results_dict = json.load(inf)
else:
results_dict = {
"train_loss": {},
"test_loss": {},
"test_acc": {},
"total_bytes": {},
"total_meta": {},
"total_data_per_n": {},
}
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
if hasattr(self.communication, "total_meta"):
results_dict["total_meta"][
iteration + 1
] = self.communication.total_meta
if hasattr(self.communication, "total_data"):
results_dict["total_data_per_n"][
iteration + 1
] = self.communication.total_data
if rounds_to_train_evaluate == 0:
logging.info("Evaluating on train set.")
rounds_to_train_evaluate = self.train_evaluate_after * change
loss_after_sharing = self.trainer.eval_loss(self.dataset)
results_dict["train_loss"][iteration + 1] = loss_after_sharing
self.save_plot(
results_dict["train_loss"],
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
)
if self.dataset.__testing__ and rounds_to_test == 0:
rounds_to_test = self.test_after * change
logging.info("Evaluating on test set.")
ta, tl = self.dataset.test(self.model, self.loss)
results_dict["test_acc"][iteration + 1] = ta
results_dict["test_loss"][iteration + 1] = tl
if global_epoch == 49:
change *= 2
global_epoch += change
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
) as of:
json.dump(results_dict, of)
if self.model.shared_parameters_counter is not None:
logging.info("Saving the shared parameter counts")
with open(
os.path.join(
self.log_dir, "{}_shared_parameters.json".format(self.rank)
),
"w",
) as of:
json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
self.disconnect_neighbors()
logging.info("Storing final weight")
self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
logging.info("All neighbors disconnected. Process complete!")
import random
import contextlib
import torch
import numpy as np
@contextlib.contextmanager
def temp_seed(seed):
"""
Creates a context with seeds set to given value. Returns to the
previous seed afterwards.
Note: Based on torch implementation there might be issues with CUDA
causing troubles with the correctness of this function. Function
torch.rand() work fine from testing as their results are generated
on CPU regardless if CUDA is used for other things.
"""
random_state = random.getstate()
np_old_state = np.random.get_state()
torch_old_state = torch.random.get_rng_state()
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
try:
yield
finally:
random.setstate(random_state)
np.random.set_state(np_old_state)
torch.random.set_rng_state(torch_old_state)
class RandomState:
"""
Creates a state that affects random number generation on
torch and numpy and whose context can be activated at will
"""
def __init__(self, seed):
with temp_seed(seed):
self.__refresh_states()
def __refresh_states(self):
self.__random_state = random.getstate()
self.__np_state = np.random.get_state()
self.__torch_state = torch.random.get_rng_state()
def __set_states(self):
random.setstate(self.__random_state)
np.random.set_state(self.__np_state)
torch.random.set_rng_state(self.__torch_state)
@contextlib.contextmanager
def activate(self):
"""
Activates this state in the given context for torch and
numpy. The previous state is restored when the context
is finished
"""
random_state = random.getstate()
np_old_state = np.random.get_state()
torch_old_state = torch.random.get_rng_state()
self.__set_states()
try:
yield
finally:
self.__refresh_states()
random.setstate(random_state)
np.random.set_state(np_old_state)
torch.random.set_rng_state(torch_old_state)
\ No newline at end of file
......@@ -310,9 +310,9 @@ class Choco(Sharing):
model,
dataset,
log_dir,
compress=False,
compression_package=None,
compression_class=None
compress,
compression_package,
compression_class
)
self.step_size = step_size
self.alpha = alpha
......