diff --git a/eval/testing_secure.py b/eval/testing_secure.py new file mode 100644 index 0000000000000000000000000000000000000000..6d03d318d0f48865c9d64c19c624a64660e11012 --- /dev/null +++ b/eval/testing_secure.py @@ -0,0 +1,79 @@ +import logging +from pathlib import Path +from shutil import copy + +from localconfig import LocalConfig +from torch import multiprocessing as mp + +from decentralizepy import utils +from decentralizepy.graphs.Graph import Graph +from decentralizepy.mappings.Linear import Linear +from decentralizepy.node.SecureCompressedAggregatopn import SecureCompressedAggregatopn + + +def read_ini(file_path): + config = LocalConfig(file_path) + for section in config: + print("Section: ", section) + for key, value in config.items(section): + print((key, value)) + print(dict(config.items("DATASET"))) + return config + + +if __name__ == "__main__": + args = utils.get_args() + + Path(args.log_dir).mkdir(parents=True, exist_ok=True) + + log_level = { + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + + config = read_ini(args.config_file) + my_config = dict() + for section in config: + my_config[section] = dict(config.items(section)) + + copy(args.config_file, args.log_dir) + copy(args.graph_file, args.log_dir) + utils.write_args(args, args.log_dir) + + g = Graph() + g.read_graph_from_file(args.graph_file, args.graph_type) + n_machines = args.machines + procs_per_machine = args.procs_per_machine + l = Linear(n_machines, procs_per_machine) + m_id = args.machine_id + + processes = [] + for r in range(procs_per_machine): + processes.append( + mp.Process( + target=SecureCompressedAggregatopn, + args=[ + r, + m_id, + l, + g, + my_config, + args.iterations, + args.log_dir, + args.weights_store_dir, + log_level[args.log_level], + args.test_after, + args.train_evaluate_after, + args.reset_optimizer, + ], + ) + ) + + for p in processes: + p.start() + + for p in processes: + p.join() diff --git a/src/decentralizepy/graphs/Graph.py b/src/decentralizepy/graphs/Graph.py index dc66eef2c3b0f5cefbec3b5d5b05400c965eaf90..08890529168a6f04f16e4d31e5b8f2ff9592b847 100644 --- a/src/decentralizepy/graphs/Graph.py +++ b/src/decentralizepy/graphs/Graph.py @@ -1,3 +1,4 @@ +from multiprocessing.sharedctypes import Value import networkx as nx import numpy as np @@ -152,6 +153,33 @@ class Graph: """ return self.adj_list[uid] + def nodes_at_distance(self, uid, distance): + """ + Returns the set of all vertices at the given + distance from a node. + + Parameters + ---------- + uid : int + globally unique identifier of the node + + Returns + ------- + set(int) + a set of nodes at the given distance from uid + + """ + if distance == 1: + return self.adj_list[uid] + elif distance == 0: + return set(uid) + elif distance < 0: + raise ValueError('Negative distance') + vertex_set = set() + for neighbor in self.adj_list[uid]: + vertex_set |= self.nodes_at_distance(neighbor, distance - 1) + return vertex_set + def centr(self): my_adj = {x: list(adj) for x, adj in enumerate(self.adj_list)} nxGraph = nx.Graph(my_adj) diff --git a/src/decentralizepy/node/SecureCompressedAggregatopn.py b/src/decentralizepy/node/SecureCompressedAggregatopn.py new file mode 100644 index 0000000000000000000000000000000000000000..465be84667608b8944c0b46c060ba68b07c79204 --- /dev/null +++ b/src/decentralizepy/node/SecureCompressedAggregatopn.py @@ -0,0 +1,330 @@ +import importlib +import json +import logging +import math +import os +from collections import deque + +import torch +from matplotlib import pyplot as plt +from collections import OrderedDict + +import numpy as np +import contextlib + +from decentralizepy import utils +from decentralizepy.graphs.Graph import Graph +from decentralizepy.mappings.Mapping import Mapping +from decentralizepy.node.Node import Node +from decentralizepy.node.DPSGDNode import DPSGDNode + +def flatten_state_dict(state_dict): + """ + Transforms state dictionary into a flat tensor + by flattening and concatenating tensors of the + state dictionary. + + Note: changes made to the result won't affect state dictionary + + Parameters + ---------- + state_dict : OrderedDict[str, torch.tensor] + A state dictionary to flatten + + """ + return torch.cat([ + tensor.flatten()\ + for tensor in state_dict.values() + ], axis=0) + +def unflatten_state_dict(flat_tensor, reference_state_dict): + """ + Transforms a falt tensor into a state dictionary + by using another state dictionary as a reference + for size and names of the tensors. Assumes + that the number of elements of the flat tensor + is the same as the number of elements in the + reference state dictionary. + + This operation is inverse operation to flatten_state_dict + + Note: changes made to the result will affect the flat tensor + + Parameters + ---------- + flat_tensor : torch.tensor + A 1-dim tensor + reference_state_dict : OrderedDict[str, torch.tensor] + A state dictionary used as a reference for tensor names + and shapes of the result + + """ + result = OrderedDict() + start_index = 0 + for tensor_name, tensor in reference_state_dict.items(): + end_index = start_index + tensor.numel() + result[tensor_name] = flat_tensor[start_index:end_index].reshape( + tensor.shape) + start_index = end_index + return result + +def top_k(state_dict, alpha): + flat_sd = flatten_state_dict(state_dict) + num_el_to_keep = int(flat_sd.numel() * alpha) + parameters, indices = torch.topk(flat_sd, num_el_to_keep, largest=True) + return parameters, indices + + +@contextlib.contextmanager +def temp_seed(seed): + state = np.random.get_state() + np.random.seed(seed) + try: + yield + finally: + np.random.set_state(state) + +class SecureCompressedAggregatopn(DPSGDNode): + """ + This class defines the node for secure compressed aggregation + + """ + + def get_neighbors(self, node=None): + if node is None: + node = self.uid + return self.graph.neighbors(node) + + def get_distance2_neighbors(self, start_node=None): + """ + Get all nodes 2 hops away from the start node exluding itself + + """ + if start_node is None: + start_node = self.uid + nodes = self.graph.nodes_at_distance(start_node, 2) + nodes.remove(start_node) + return nodes + + def receive_DPSGD(self): + return self.receive_channel("DPSGD") + + def connect_to_nodes(self, set_of_nodes): + """ + Connects all neighbors. Sends HELLO. Waits for HELLO. + Caches any data received while waiting for HELLOs. + + Raises + ------ + RuntimeError + If received BYE while waiting for HELLO + + """ + logging.info("Sending connection request to all neighbors") + wait_acknowledgements = [] + for node in set_of_nodes: + if not self.communication.already_connected(node): + self.connect_neighbor(node) + wait_acknowledgements.append(node) + for node in wait_acknowledgements: + self.wait_for_hello(node) + + def aggregate_models(self, parameters, indices): + distance2_nodes = self.get_distance2_neighbors() + self.connect_to_nodes(distance2_nodes) + compressed_indices = self.sharing.compressor.compress(indices.numpy()) + # Generating and sending pairwise masks + sent_masks = {} + for node in distance2_nodes: + mask_seed = np.random.randint(1000000) + sent_masks[node] = mask_seed + self.communication.send(node, { + "seed": mask_seed, + "indices": compressed_indices, + "CHANNEL": "PRE-SECURE-AGG-STEP" + }) + # Receiving pairwise masks and indices + received_data = {} + waiting_mask_from = distance2_nodes.copy() + while waiting_mask_from: + sender, data = self.receive_channel("PRE-SECURE-AGG-STEP") + if sender in waiting_mask_from: + # print('Seed from', sender, 'is', data["seed"]) + del data["CHANNEL"] + received_data[sender] = data + received_data[sender]["indices"] = torch.tensor( + self.sharing.compressor.decompress(received_data[sender]["indices"]), dtype=torch.long) + waiting_mask_from.remove(sender) + # Building masks + pairwise_mask_difference = {} + for node, data in received_data.items(): + # sortednp.intersect supports intersection of sorted array (make sure to cast tensor to nparray) + _, my_indices_pos, _ = np.intersect1d(indices, data["indices"], return_indices=True) + mask_shape = my_indices_pos.shape + + pairwise_mask_difference[node] = { + "value": (self.generate_mask(sent_masks[node], mask_shape) - self.generate_mask(data["seed"], mask_shape)).double(), + "indices": my_indices_pos + } + # Sending models to neighbors + self.my_neighbors = self.get_neighbors() + self.connect_to_nodes(self.my_neighbors) + for neighbor in self.my_neighbors: + neighbors_neighbors = self.get_neighbors(neighbor) + perturbated_model = parameters.clone() + masking_count = torch.zeros_like(indices) + # Add masks from pairs + for pairing_node in neighbors_neighbors: + if self.uid == pairing_node: + continue + pair_mask = pairwise_mask_difference[pairing_node]["value"] + pair_indices = pairwise_mask_difference[pairing_node]["indices"] + # print(perturbated_model.shape, pair_indices.shape, pair_mask.shape) + # print(perturbated_model.dtype, pair_mask.dtype) + perturbated_model[pair_indices] += pair_mask + masking_count[pair_indices] += 1 + non_zero_indices = masking_count.nonzero(as_tuple=True)[0] + indices_to_send = indices[non_zero_indices] + parameters_to_send = parameters[non_zero_indices] + + + compressed_parameters = self.sharing.compressor.compress_float(parameters_to_send.numpy()) + compressed_indices = self.sharing.compressor.compress(indices_to_send.numpy()) + + self.communication.send(neighbor, { + "params": compressed_parameters, + "indices": compressed_indices, + "CHANNEL": "SECURE_MODEL_CHANNEL" + }) + # Receiving models from neighbors + received_models = {} + waiting_models_from = self.my_neighbors.copy() + while waiting_models_from: + sender, data = self.receive_channel("SECURE_MODEL_CHANNEL") + if sender in waiting_models_from: + # print('Seed from', sender, 'is', data["seed"]) + del data["CHANNEL"] + received_models[sender] = data + received_models[sender]["indices"] = torch.tensor( + self.sharing.compressor.decompress(received_models[sender]["indices"]), dtype=torch.long) + received_models[sender]["params"] = torch.tensor( + self.sharing.compressor.decompress_float(received_models[sender]["params"])) + waiting_models_from.remove(sender) + # Averaging + weight = 1 / (len(self.my_neighbors) + 1) + preshare_model = flatten_state_dict(self.model.state_dict()) + new_flat_model = weight * preshare_model.clone() + for data in received_models.values(): + params = data["params"] + indices = data["indices"] + recovered_model = preshare_model.clone() + recovered_model[indices] = params + new_flat_model += weight * recovered_model + # Loading new state state dictionary + new_state_dict = unflatten_state_dict(new_flat_model, self.model.state_dict()) + self.model.load_state_dict(new_state_dict) + + def generate_mask(self, seed, size): + with temp_seed(seed): + return torch.Tensor(np.random.uniform(1, 10, size=size)) + + def run(self): + """ + Start the decentralized learning + + """ + torch.manual_seed(self.uid) + np.random.seed(self.uid) + + self.testset = self.dataset.get_testset() + rounds_to_test = self.test_after + rounds_to_train_evaluate = self.train_evaluate_after + global_epoch = 1 + change = 1 + + for iteration in range(self.iterations): + logging.info("Starting training iteration: %d", iteration) + print("Starting training iteration:", iteration) + rounds_to_train_evaluate -= 1 + rounds_to_test -= 1 + + self.iteration = iteration + self.trainer.train(self.dataset) + + self.aggregate_models(*top_k(self.model.state_dict(), 0.3)) + + if self.reset_optimizer: + self.optimizer = self.optimizer_class( + self.model.parameters(), **self.optimizer_params + ) # Reset optimizer state + self.trainer.reset_optimizer(self.optimizer) + + if iteration: + with open( + os.path.join(self.log_dir, "{}_results.json".format(self.rank)), + "r", + ) as inf: + results_dict = json.load(inf) + else: + results_dict = { + "train_loss": {}, + "test_loss": {}, + "test_acc": {}, + "total_bytes": {}, + "total_meta": {}, + "total_data_per_n": {}, + } + + results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes + + if hasattr(self.communication, "total_meta"): + results_dict["total_meta"][ + iteration + 1 + ] = self.communication.total_meta + if hasattr(self.communication, "total_data"): + results_dict["total_data_per_n"][ + iteration + 1 + ] = self.communication.total_data + + if rounds_to_train_evaluate == 0: + logging.info("Evaluating on train set.") + rounds_to_train_evaluate = self.train_evaluate_after * change + loss_after_sharing = self.trainer.eval_loss(self.dataset) + results_dict["train_loss"][iteration + 1] = loss_after_sharing + self.save_plot( + results_dict["train_loss"], + "train_loss", + "Training Loss", + "Communication Rounds", + os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)), + ) + + if self.dataset.__testing__ and rounds_to_test == 0: + rounds_to_test = self.test_after * change + logging.info("Evaluating on test set.") + ta, tl = self.dataset.test(self.model, self.loss) + results_dict["test_acc"][iteration + 1] = ta + results_dict["test_loss"][iteration + 1] = tl + + if global_epoch == 49: + change *= 2 + + global_epoch += change + + with open( + os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w" + ) as of: + json.dump(results_dict, of) + if self.model.shared_parameters_counter is not None: + logging.info("Saving the shared parameter counts") + with open( + os.path.join( + self.log_dir, "{}_shared_parameters.json".format(self.rank) + ), + "w", + ) as of: + json.dump(self.model.shared_parameters_counter.numpy().tolist(), of) + self.disconnect_neighbors() + logging.info("Storing final weight") + self.model.dump_weights(self.weights_store_dir, self.uid, iteration) + logging.info("All neighbors disconnected. Process complete!")