Skip to content
Snippets Groups Projects
Commit a21933e8 authored by Milos Vujasinovic's avatar Milos Vujasinovic
Browse files

Added prototype for secure aggregation

parent f80fed3c
No related branches found
No related tags found
No related merge requests found
import logging
from pathlib import Path
from shutil import copy
from localconfig import LocalConfig
from torch import multiprocessing as mp
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.SecureCompressedAggregatopn import SecureCompressedAggregatopn
def read_ini(file_path):
config = LocalConfig(file_path)
for section in config:
print("Section: ", section)
for key, value in config.items(section):
print((key, value))
print(dict(config.items("DATASET")))
return config
if __name__ == "__main__":
args = utils.get_args()
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
log_level = {
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
config = read_ini(args.config_file)
my_config = dict()
for section in config:
my_config[section] = dict(config.items(section))
copy(args.config_file, args.log_dir)
copy(args.graph_file, args.log_dir)
utils.write_args(args, args.log_dir)
g = Graph()
g.read_graph_from_file(args.graph_file, args.graph_type)
n_machines = args.machines
procs_per_machine = args.procs_per_machine
l = Linear(n_machines, procs_per_machine)
m_id = args.machine_id
processes = []
for r in range(procs_per_machine):
processes.append(
mp.Process(
target=SecureCompressedAggregatopn,
args=[
r,
m_id,
l,
g,
my_config,
args.iterations,
args.log_dir,
args.weights_store_dir,
log_level[args.log_level],
args.test_after,
args.train_evaluate_after,
args.reset_optimizer,
],
)
)
for p in processes:
p.start()
for p in processes:
p.join()
from multiprocessing.sharedctypes import Value
import networkx as nx
import numpy as np
......@@ -152,6 +153,33 @@ class Graph:
"""
return self.adj_list[uid]
def nodes_at_distance(self, uid, distance):
"""
Returns the set of all vertices at the given
distance from a node.
Parameters
----------
uid : int
globally unique identifier of the node
Returns
-------
set(int)
a set of nodes at the given distance from uid
"""
if distance == 1:
return self.adj_list[uid]
elif distance == 0:
return set(uid)
elif distance < 0:
raise ValueError('Negative distance')
vertex_set = set()
for neighbor in self.adj_list[uid]:
vertex_set |= self.nodes_at_distance(neighbor, distance - 1)
return vertex_set
def centr(self):
my_adj = {x: list(adj) for x, adj in enumerate(self.adj_list)}
nxGraph = nx.Graph(my_adj)
......
import importlib
import json
import logging
import math
import os
from collections import deque
import torch
from matplotlib import pyplot as plt
from collections import OrderedDict
import numpy as np
import contextlib
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.node.Node import Node
from decentralizepy.node.DPSGDNode import DPSGDNode
def flatten_state_dict(state_dict):
"""
Transforms state dictionary into a flat tensor
by flattening and concatenating tensors of the
state dictionary.
Note: changes made to the result won't affect state dictionary
Parameters
----------
state_dict : OrderedDict[str, torch.tensor]
A state dictionary to flatten
"""
return torch.cat([
tensor.flatten()\
for tensor in state_dict.values()
], axis=0)
def unflatten_state_dict(flat_tensor, reference_state_dict):
"""
Transforms a falt tensor into a state dictionary
by using another state dictionary as a reference
for size and names of the tensors. Assumes
that the number of elements of the flat tensor
is the same as the number of elements in the
reference state dictionary.
This operation is inverse operation to flatten_state_dict
Note: changes made to the result will affect the flat tensor
Parameters
----------
flat_tensor : torch.tensor
A 1-dim tensor
reference_state_dict : OrderedDict[str, torch.tensor]
A state dictionary used as a reference for tensor names
and shapes of the result
"""
result = OrderedDict()
start_index = 0
for tensor_name, tensor in reference_state_dict.items():
end_index = start_index + tensor.numel()
result[tensor_name] = flat_tensor[start_index:end_index].reshape(
tensor.shape)
start_index = end_index
return result
def top_k(state_dict, alpha):
flat_sd = flatten_state_dict(state_dict)
num_el_to_keep = int(flat_sd.numel() * alpha)
parameters, indices = torch.topk(flat_sd, num_el_to_keep, largest=True)
return parameters, indices
@contextlib.contextmanager
def temp_seed(seed):
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
class SecureCompressedAggregatopn(DPSGDNode):
"""
This class defines the node for secure compressed aggregation
"""
def get_neighbors(self, node=None):
if node is None:
node = self.uid
return self.graph.neighbors(node)
def get_distance2_neighbors(self, start_node=None):
"""
Get all nodes 2 hops away from the start node exluding itself
"""
if start_node is None:
start_node = self.uid
nodes = self.graph.nodes_at_distance(start_node, 2)
nodes.remove(start_node)
return nodes
def receive_DPSGD(self):
return self.receive_channel("DPSGD")
def connect_to_nodes(self, set_of_nodes):
"""
Connects all neighbors. Sends HELLO. Waits for HELLO.
Caches any data received while waiting for HELLOs.
Raises
------
RuntimeError
If received BYE while waiting for HELLO
"""
logging.info("Sending connection request to all neighbors")
wait_acknowledgements = []
for node in set_of_nodes:
if not self.communication.already_connected(node):
self.connect_neighbor(node)
wait_acknowledgements.append(node)
for node in wait_acknowledgements:
self.wait_for_hello(node)
def aggregate_models(self, parameters, indices):
distance2_nodes = self.get_distance2_neighbors()
self.connect_to_nodes(distance2_nodes)
compressed_indices = self.sharing.compressor.compress(indices.numpy())
# Generating and sending pairwise masks
sent_masks = {}
for node in distance2_nodes:
mask_seed = np.random.randint(1000000)
sent_masks[node] = mask_seed
self.communication.send(node, {
"seed": mask_seed,
"indices": compressed_indices,
"CHANNEL": "PRE-SECURE-AGG-STEP"
})
# Receiving pairwise masks and indices
received_data = {}
waiting_mask_from = distance2_nodes.copy()
while waiting_mask_from:
sender, data = self.receive_channel("PRE-SECURE-AGG-STEP")
if sender in waiting_mask_from:
# print('Seed from', sender, 'is', data["seed"])
del data["CHANNEL"]
received_data[sender] = data
received_data[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_data[sender]["indices"]), dtype=torch.long)
waiting_mask_from.remove(sender)
# Building masks
pairwise_mask_difference = {}
for node, data in received_data.items():
# sortednp.intersect supports intersection of sorted array (make sure to cast tensor to nparray)
_, my_indices_pos, _ = np.intersect1d(indices, data["indices"], return_indices=True)
mask_shape = my_indices_pos.shape
pairwise_mask_difference[node] = {
"value": (self.generate_mask(sent_masks[node], mask_shape) - self.generate_mask(data["seed"], mask_shape)).double(),
"indices": my_indices_pos
}
# Sending models to neighbors
self.my_neighbors = self.get_neighbors()
self.connect_to_nodes(self.my_neighbors)
for neighbor in self.my_neighbors:
neighbors_neighbors = self.get_neighbors(neighbor)
perturbated_model = parameters.clone()
masking_count = torch.zeros_like(indices)
# Add masks from pairs
for pairing_node in neighbors_neighbors:
if self.uid == pairing_node:
continue
pair_mask = pairwise_mask_difference[pairing_node]["value"]
pair_indices = pairwise_mask_difference[pairing_node]["indices"]
# print(perturbated_model.shape, pair_indices.shape, pair_mask.shape)
# print(perturbated_model.dtype, pair_mask.dtype)
perturbated_model[pair_indices] += pair_mask
masking_count[pair_indices] += 1
non_zero_indices = masking_count.nonzero(as_tuple=True)[0]
indices_to_send = indices[non_zero_indices]
parameters_to_send = parameters[non_zero_indices]
compressed_parameters = self.sharing.compressor.compress_float(parameters_to_send.numpy())
compressed_indices = self.sharing.compressor.compress(indices_to_send.numpy())
self.communication.send(neighbor, {
"params": compressed_parameters,
"indices": compressed_indices,
"CHANNEL": "SECURE_MODEL_CHANNEL"
})
# Receiving models from neighbors
received_models = {}
waiting_models_from = self.my_neighbors.copy()
while waiting_models_from:
sender, data = self.receive_channel("SECURE_MODEL_CHANNEL")
if sender in waiting_models_from:
# print('Seed from', sender, 'is', data["seed"])
del data["CHANNEL"]
received_models[sender] = data
received_models[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_models[sender]["indices"]), dtype=torch.long)
received_models[sender]["params"] = torch.tensor(
self.sharing.compressor.decompress_float(received_models[sender]["params"]))
waiting_models_from.remove(sender)
# Averaging
weight = 1 / (len(self.my_neighbors) + 1)
preshare_model = flatten_state_dict(self.model.state_dict())
new_flat_model = weight * preshare_model.clone()
for data in received_models.values():
params = data["params"]
indices = data["indices"]
recovered_model = preshare_model.clone()
recovered_model[indices] = params
new_flat_model += weight * recovered_model
# Loading new state state dictionary
new_state_dict = unflatten_state_dict(new_flat_model, self.model.state_dict())
self.model.load_state_dict(new_state_dict)
def generate_mask(self, seed, size):
with temp_seed(seed):
return torch.Tensor(np.random.uniform(1, 10, size=size))
def run(self):
"""
Start the decentralized learning
"""
torch.manual_seed(self.uid)
np.random.seed(self.uid)
self.testset = self.dataset.get_testset()
rounds_to_test = self.test_after
rounds_to_train_evaluate = self.train_evaluate_after
global_epoch = 1
change = 1
for iteration in range(self.iterations):
logging.info("Starting training iteration: %d", iteration)
print("Starting training iteration:", iteration)
rounds_to_train_evaluate -= 1
rounds_to_test -= 1
self.iteration = iteration
self.trainer.train(self.dataset)
self.aggregate_models(*top_k(self.model.state_dict(), 0.3))
if self.reset_optimizer:
self.optimizer = self.optimizer_class(
self.model.parameters(), **self.optimizer_params
) # Reset optimizer state
self.trainer.reset_optimizer(self.optimizer)
if iteration:
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
"r",
) as inf:
results_dict = json.load(inf)
else:
results_dict = {
"train_loss": {},
"test_loss": {},
"test_acc": {},
"total_bytes": {},
"total_meta": {},
"total_data_per_n": {},
}
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
if hasattr(self.communication, "total_meta"):
results_dict["total_meta"][
iteration + 1
] = self.communication.total_meta
if hasattr(self.communication, "total_data"):
results_dict["total_data_per_n"][
iteration + 1
] = self.communication.total_data
if rounds_to_train_evaluate == 0:
logging.info("Evaluating on train set.")
rounds_to_train_evaluate = self.train_evaluate_after * change
loss_after_sharing = self.trainer.eval_loss(self.dataset)
results_dict["train_loss"][iteration + 1] = loss_after_sharing
self.save_plot(
results_dict["train_loss"],
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
)
if self.dataset.__testing__ and rounds_to_test == 0:
rounds_to_test = self.test_after * change
logging.info("Evaluating on test set.")
ta, tl = self.dataset.test(self.model, self.loss)
results_dict["test_acc"][iteration + 1] = ta
results_dict["test_loss"][iteration + 1] = tl
if global_epoch == 49:
change *= 2
global_epoch += change
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
) as of:
json.dump(results_dict, of)
if self.model.shared_parameters_counter is not None:
logging.info("Saving the shared parameter counts")
with open(
os.path.join(
self.log_dir, "{}_shared_parameters.json".format(self.rank)
),
"w",
) as of:
json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
self.disconnect_neighbors()
logging.info("Storing final weight")
self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
logging.info("All neighbors disconnected. Process complete!")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment