From a21933e82a18bd5cf50c021020bb36bef3dc6ddd Mon Sep 17 00:00:00 2001
From: Milos Vujasinovic <milos.vujasinovic@epfl.ch>
Date: Tue, 8 Nov 2022 17:17:43 +0100
Subject: [PATCH] Added prototype for secure aggregation

---
 eval/testing_secure.py                        |  79 +++++
 src/decentralizepy/graphs/Graph.py            |  28 ++
 .../node/SecureCompressedAggregatopn.py       | 330 ++++++++++++++++++
 3 files changed, 437 insertions(+)
 create mode 100644 eval/testing_secure.py
 create mode 100644 src/decentralizepy/node/SecureCompressedAggregatopn.py

diff --git a/eval/testing_secure.py b/eval/testing_secure.py
new file mode 100644
index 0000000..6d03d31
--- /dev/null
+++ b/eval/testing_secure.py
@@ -0,0 +1,79 @@
+import logging
+from pathlib import Path
+from shutil import copy
+
+from localconfig import LocalConfig
+from torch import multiprocessing as mp
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Linear import Linear
+from decentralizepy.node.SecureCompressedAggregatopn import SecureCompressedAggregatopn
+
+
+def read_ini(file_path):
+    config = LocalConfig(file_path)
+    for section in config:
+        print("Section: ", section)
+        for key, value in config.items(section):
+            print((key, value))
+    print(dict(config.items("DATASET")))
+    return config
+
+
+if __name__ == "__main__":
+    args = utils.get_args()
+
+    Path(args.log_dir).mkdir(parents=True, exist_ok=True)
+
+    log_level = {
+        "INFO": logging.INFO,
+        "DEBUG": logging.DEBUG,
+        "WARNING": logging.WARNING,
+        "ERROR": logging.ERROR,
+        "CRITICAL": logging.CRITICAL,
+    }
+
+    config = read_ini(args.config_file)
+    my_config = dict()
+    for section in config:
+        my_config[section] = dict(config.items(section))
+
+    copy(args.config_file, args.log_dir)
+    copy(args.graph_file, args.log_dir)
+    utils.write_args(args, args.log_dir)
+
+    g = Graph()
+    g.read_graph_from_file(args.graph_file, args.graph_type)
+    n_machines = args.machines
+    procs_per_machine = args.procs_per_machine
+    l = Linear(n_machines, procs_per_machine)
+    m_id = args.machine_id
+
+    processes = []
+    for r in range(procs_per_machine):
+        processes.append(
+            mp.Process(
+                target=SecureCompressedAggregatopn,
+                args=[
+                    r,
+                    m_id,
+                    l,
+                    g,
+                    my_config,
+                    args.iterations,
+                    args.log_dir,
+                    args.weights_store_dir,
+                    log_level[args.log_level],
+                    args.test_after,
+                    args.train_evaluate_after,
+                    args.reset_optimizer,
+                ],
+            )
+        )
+
+    for p in processes:
+        p.start()
+
+    for p in processes:
+        p.join()
diff --git a/src/decentralizepy/graphs/Graph.py b/src/decentralizepy/graphs/Graph.py
index dc66eef..0889052 100644
--- a/src/decentralizepy/graphs/Graph.py
+++ b/src/decentralizepy/graphs/Graph.py
@@ -1,3 +1,4 @@
+from multiprocessing.sharedctypes import Value
 import networkx as nx
 import numpy as np
 
@@ -152,6 +153,33 @@ class Graph:
         """
         return self.adj_list[uid]
 
+    def nodes_at_distance(self, uid, distance):
+        """
+        Returns the set of all vertices at the given
+        distance from a node.
+        
+        Parameters
+        ----------
+        uid : int
+            globally unique identifier of the node
+            
+        Returns
+        -------
+        set(int)
+            a set of nodes at the given distance from uid
+        
+        """
+        if distance == 1:
+            return self.adj_list[uid]
+        elif distance == 0:
+            return set(uid)
+        elif distance < 0:
+            raise ValueError('Negative distance')
+        vertex_set = set()
+        for neighbor in self.adj_list[uid]:
+            vertex_set |= self.nodes_at_distance(neighbor, distance - 1)
+        return vertex_set
+
     def centr(self):
         my_adj = {x: list(adj) for x, adj in enumerate(self.adj_list)}
         nxGraph = nx.Graph(my_adj)
diff --git a/src/decentralizepy/node/SecureCompressedAggregatopn.py b/src/decentralizepy/node/SecureCompressedAggregatopn.py
new file mode 100644
index 0000000..465be84
--- /dev/null
+++ b/src/decentralizepy/node/SecureCompressedAggregatopn.py
@@ -0,0 +1,330 @@
+import importlib
+import json
+import logging
+import math
+import os
+from collections import deque
+
+import torch
+from matplotlib import pyplot as plt
+from collections import OrderedDict
+
+import numpy as np
+import contextlib
+
+from decentralizepy import utils
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy.node.Node import Node
+from decentralizepy.node.DPSGDNode import DPSGDNode
+
+def flatten_state_dict(state_dict):
+    """
+    Transforms state dictionary into a flat tensor
+    by flattening and concatenating tensors of the 
+    state dictionary. 
+
+    Note: changes made to the result won't affect state dictionary
+
+    Parameters
+    ----------
+    state_dict : OrderedDict[str, torch.tensor]
+        A state dictionary to flatten
+
+    """
+    return torch.cat([
+        tensor.flatten()\
+        for tensor in state_dict.values()
+    ], axis=0)
+
+def unflatten_state_dict(flat_tensor, reference_state_dict):
+    """
+    Transforms a falt tensor into a state dictionary
+    by using another state dictionary as a reference
+    for size and names of the tensors. Assumes
+    that the number of elements of the flat tensor
+    is the same as the number of elements in the
+    reference state dictionary.
+
+    This operation is inverse operation to flatten_state_dict
+
+    Note: changes made to the result will affect the flat tensor
+
+    Parameters
+    ----------
+    flat_tensor : torch.tensor
+        A 1-dim tensor
+    reference_state_dict : OrderedDict[str, torch.tensor]
+        A state dictionary used as a reference for tensor names
+    and shapes of the result
+
+    """
+    result = OrderedDict()
+    start_index = 0
+    for tensor_name, tensor in reference_state_dict.items():
+        end_index = start_index + tensor.numel()
+        result[tensor_name] = flat_tensor[start_index:end_index].reshape(
+            tensor.shape)
+        start_index = end_index
+    return result
+
+def top_k(state_dict, alpha):
+  flat_sd = flatten_state_dict(state_dict)
+  num_el_to_keep = int(flat_sd.numel() * alpha)  
+  parameters, indices = torch.topk(flat_sd, num_el_to_keep, largest=True)
+  return parameters, indices
+
+
+@contextlib.contextmanager
+def temp_seed(seed):
+    state = np.random.get_state()
+    np.random.seed(seed)
+    try:
+        yield
+    finally:
+        np.random.set_state(state)
+
+class SecureCompressedAggregatopn(DPSGDNode):
+    """
+    This class defines the node for secure compressed aggregation
+
+    """
+
+    def get_neighbors(self, node=None):
+        if node is None:
+            node = self.uid
+        return  self.graph.neighbors(node)
+
+    def get_distance2_neighbors(self, start_node=None):
+        """
+        Get all nodes 2 hops away from the start node exluding itself
+        
+        """ 
+        if start_node is None:
+            start_node = self.uid
+        nodes = self.graph.nodes_at_distance(start_node, 2)
+        nodes.remove(start_node)
+        return nodes
+
+    def receive_DPSGD(self):
+        return self.receive_channel("DPSGD")
+
+    def connect_to_nodes(self, set_of_nodes):
+        """
+        Connects all neighbors. Sends HELLO. Waits for HELLO.
+        Caches any data received while waiting for HELLOs.
+
+        Raises
+        ------
+        RuntimeError
+            If received BYE while waiting for HELLO
+
+        """
+        logging.info("Sending connection request to all neighbors")
+        wait_acknowledgements = []
+        for node in set_of_nodes:
+            if not self.communication.already_connected(node):
+                self.connect_neighbor(node)
+                wait_acknowledgements.append(node)
+        for node in wait_acknowledgements:
+            self.wait_for_hello(node)
+
+    def aggregate_models(self, parameters, indices):
+        distance2_nodes = self.get_distance2_neighbors()
+        self.connect_to_nodes(distance2_nodes)
+        compressed_indices = self.sharing.compressor.compress(indices.numpy())
+        # Generating and sending pairwise masks
+        sent_masks = {}
+        for node in distance2_nodes:
+            mask_seed = np.random.randint(1000000)
+            sent_masks[node] = mask_seed
+            self.communication.send(node, {
+                "seed": mask_seed,
+                "indices": compressed_indices,
+                "CHANNEL": "PRE-SECURE-AGG-STEP"
+            })
+        # Receiving pairwise masks and indices
+        received_data = {}
+        waiting_mask_from = distance2_nodes.copy()
+        while waiting_mask_from:
+            sender, data = self.receive_channel("PRE-SECURE-AGG-STEP")
+            if sender in waiting_mask_from:
+                # print('Seed from', sender, 'is', data["seed"])
+                del data["CHANNEL"]
+                received_data[sender] = data
+                received_data[sender]["indices"] = torch.tensor(
+                    self.sharing.compressor.decompress(received_data[sender]["indices"]), dtype=torch.long)
+                waiting_mask_from.remove(sender)
+        # Building masks
+        pairwise_mask_difference = {}
+        for node, data in received_data.items():
+            # sortednp.intersect supports intersection of sorted array (make sure to cast tensor to nparray)
+            _, my_indices_pos, _ = np.intersect1d(indices, data["indices"], return_indices=True)
+            mask_shape = my_indices_pos.shape
+            
+            pairwise_mask_difference[node] = {
+                "value": (self.generate_mask(sent_masks[node], mask_shape) - self.generate_mask(data["seed"], mask_shape)).double(),
+                "indices": my_indices_pos
+            }
+        # Sending models to neighbors
+        self.my_neighbors = self.get_neighbors()
+        self.connect_to_nodes(self.my_neighbors)
+        for neighbor in self.my_neighbors:
+            neighbors_neighbors = self.get_neighbors(neighbor)
+            perturbated_model = parameters.clone()
+            masking_count = torch.zeros_like(indices)
+            # Add masks from pairs
+            for pairing_node in neighbors_neighbors:
+                if self.uid == pairing_node:
+                    continue
+                pair_mask = pairwise_mask_difference[pairing_node]["value"]
+                pair_indices = pairwise_mask_difference[pairing_node]["indices"]
+                # print(perturbated_model.shape, pair_indices.shape, pair_mask.shape)
+                # print(perturbated_model.dtype, pair_mask.dtype)
+                perturbated_model[pair_indices] += pair_mask
+                masking_count[pair_indices] += 1
+            non_zero_indices = masking_count.nonzero(as_tuple=True)[0]
+            indices_to_send = indices[non_zero_indices]
+            parameters_to_send = parameters[non_zero_indices]
+
+
+            compressed_parameters = self.sharing.compressor.compress_float(parameters_to_send.numpy())
+            compressed_indices = self.sharing.compressor.compress(indices_to_send.numpy())
+
+            self.communication.send(neighbor, {
+                "params": compressed_parameters,
+                "indices": compressed_indices,
+                "CHANNEL": "SECURE_MODEL_CHANNEL"
+            })
+        # Receiving models from neighbors
+        received_models = {}
+        waiting_models_from = self.my_neighbors.copy()
+        while waiting_models_from:
+            sender, data = self.receive_channel("SECURE_MODEL_CHANNEL")
+            if sender in waiting_models_from:
+                # print('Seed from', sender, 'is', data["seed"])
+                del data["CHANNEL"]
+                received_models[sender] = data
+                received_models[sender]["indices"] = torch.tensor(
+                    self.sharing.compressor.decompress(received_models[sender]["indices"]), dtype=torch.long)
+                received_models[sender]["params"] = torch.tensor(
+                    self.sharing.compressor.decompress_float(received_models[sender]["params"]))
+                waiting_models_from.remove(sender)
+        # Averaging
+        weight = 1 / (len(self.my_neighbors) + 1)
+        preshare_model = flatten_state_dict(self.model.state_dict())
+        new_flat_model = weight * preshare_model.clone()
+        for data in received_models.values():
+            params = data["params"]
+            indices = data["indices"]
+            recovered_model = preshare_model.clone()
+            recovered_model[indices] = params
+            new_flat_model += weight * recovered_model
+        # Loading new state state dictionary
+        new_state_dict = unflatten_state_dict(new_flat_model, self.model.state_dict())
+        self.model.load_state_dict(new_state_dict)
+
+    def generate_mask(self, seed, size):
+        with temp_seed(seed):
+            return torch.Tensor(np.random.uniform(1, 10, size=size))
+
+    def run(self):
+        """
+        Start the decentralized learning
+
+        """
+        torch.manual_seed(self.uid)
+        np.random.seed(self.uid)
+
+        self.testset = self.dataset.get_testset()
+        rounds_to_test = self.test_after
+        rounds_to_train_evaluate = self.train_evaluate_after
+        global_epoch = 1
+        change = 1
+
+        for iteration in range(self.iterations):
+            logging.info("Starting training iteration: %d", iteration)
+            print("Starting training iteration:", iteration)
+            rounds_to_train_evaluate -= 1
+            rounds_to_test -= 1
+
+            self.iteration = iteration
+            self.trainer.train(self.dataset)
+
+            self.aggregate_models(*top_k(self.model.state_dict(), 0.3))
+
+            if self.reset_optimizer:
+                self.optimizer = self.optimizer_class(
+                    self.model.parameters(), **self.optimizer_params
+                )  # Reset optimizer state
+                self.trainer.reset_optimizer(self.optimizer)
+
+            if iteration:
+                with open(
+                    os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
+                    "r",
+                ) as inf:
+                    results_dict = json.load(inf)
+            else:
+                results_dict = {
+                    "train_loss": {},
+                    "test_loss": {},
+                    "test_acc": {},
+                    "total_bytes": {},
+                    "total_meta": {},
+                    "total_data_per_n": {},
+                }
+
+            results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
+
+            if hasattr(self.communication, "total_meta"):
+                results_dict["total_meta"][
+                    iteration + 1
+                ] = self.communication.total_meta
+            if hasattr(self.communication, "total_data"):
+                results_dict["total_data_per_n"][
+                    iteration + 1
+                ] = self.communication.total_data
+
+            if rounds_to_train_evaluate == 0:
+                logging.info("Evaluating on train set.")
+                rounds_to_train_evaluate = self.train_evaluate_after * change
+                loss_after_sharing = self.trainer.eval_loss(self.dataset)
+                results_dict["train_loss"][iteration + 1] = loss_after_sharing
+                self.save_plot(
+                    results_dict["train_loss"],
+                    "train_loss",
+                    "Training Loss",
+                    "Communication Rounds",
+                    os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
+                )
+
+            if self.dataset.__testing__ and rounds_to_test == 0:
+                rounds_to_test = self.test_after * change
+                logging.info("Evaluating on test set.")
+                ta, tl = self.dataset.test(self.model, self.loss)
+                results_dict["test_acc"][iteration + 1] = ta
+                results_dict["test_loss"][iteration + 1] = tl
+
+                if global_epoch == 49:
+                    change *= 2
+
+                global_epoch += change
+
+            with open(
+                os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
+            ) as of:
+                json.dump(results_dict, of)
+        if self.model.shared_parameters_counter is not None:
+            logging.info("Saving the shared parameter counts")
+            with open(
+                os.path.join(
+                    self.log_dir, "{}_shared_parameters.json".format(self.rank)
+                ),
+                "w",
+            ) as of:
+                json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
+        self.disconnect_neighbors()
+        logging.info("Storing final weight")
+        self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
+        logging.info("All neighbors disconnected. Process complete!")
-- 
GitLab