diff --git a/eval/testingKNN.py b/eval/testingKNN.py new file mode 100644 index 0000000000000000000000000000000000000000..a706afd99eafda36aca051538a15fd6446fe85de --- /dev/null +++ b/eval/testingKNN.py @@ -0,0 +1,79 @@ +import logging +from pathlib import Path +from shutil import copy + +from localconfig import LocalConfig +from torch import multiprocessing as mp + +from decentralizepy import utils +from decentralizepy.graphs.Graph import Graph +from decentralizepy.mappings.Linear import Linear +from decentralizepy.node.KNN import KNN + + +def read_ini(file_path): + config = LocalConfig(file_path) + for section in config: + print("Section: ", section) + for key, value in config.items(section): + print((key, value)) + print(dict(config.items("DATASET"))) + return config + + +if __name__ == "__main__": + args = utils.get_args() + + Path(args.log_dir).mkdir(parents=True, exist_ok=True) + + log_level = { + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + + config = read_ini(args.config_file) + my_config = dict() + for section in config: + my_config[section] = dict(config.items(section)) + + copy(args.config_file, args.log_dir) + copy(args.graph_file, args.log_dir) + utils.write_args(args, args.log_dir) + + g = Graph() + g.read_graph_from_file(args.graph_file, args.graph_type) + n_machines = args.machines + procs_per_machine = args.procs_per_machine + l = Linear(n_machines, procs_per_machine) + m_id = args.machine_id + + processes = [] + for r in range(procs_per_machine): + processes.append( + mp.Process( + target=KNN, + args=[ + r, + m_id, + l, + g, + my_config, + args.iterations, + args.log_dir, + args.weights_store_dir, + log_level[args.log_level], + args.test_after, + args.train_evaluate_after, + args.reset_optimizer, + ], + ) + ) + + for p in processes: + p.start() + + for p in processes: + p.join() diff --git a/src/decentralizepy/node/KNN.py b/src/decentralizepy/node/KNN.py new file mode 100644 index 0000000000000000000000000000000000000000..6480aefd20e593de83fa7dac7f0f1972a80115ad --- /dev/null +++ b/src/decentralizepy/node/KNN.py @@ -0,0 +1,407 @@ +import logging +import math +import os +import queue +from random import Random +from threading import Lock, Thread + +import numpy as np +import torch +from numpy.linalg import norm + +from decentralizepy import utils +from decentralizepy.graphs.Graph import Graph +from decentralizepy.mappings.Mapping import Mapping +from decentralizepy.node.OverlayNode import OverlayNode + + +class KNN(OverlayNode): + """ + This class defines the node for KNN Learning Node + + """ + + def similarityMetric(self, candidate): + logging.debug("A: {}".format(self.othersInfo[self.uid])) + logging.debug("B: {}".format(self.othersInfo[candidate])) + A = np.array(self.othersInfo[self.uid]) + B = np.array(self.othersInfo[candidate]) + return np.dot(A, B) / (norm(A) * norm(B)) + + def get_most_similar(self, candidates, to_keep=4): + if len(candidates) <= to_keep: + return candidates + + cur_candidates = dict() + for i in candidates: + simil = round(self.similarityMetric(i), 3) + if simil not in cur_candidates: + cur_candidates[simil] = [] + cur_candidates[simil].append(i) + + similarity_scores = list(cur_candidates.keys()) + similarity_scores.sort() + + left_to_keep = to_keep + return_result = set() + for i in similarity_scores: + if left_to_keep >= len(cur_candidates[i]): + return_result.update(cur_candidates[i]) + left_to_keep -= len(cur_candidates[i]) + elif left_to_keep > 0: + return_result.update( + list(self.rng.sample(cur_candidates[i], left_to_keep)) + ) + left_to_keep = 0 + break + else: + break + + return return_result + + def create_message_to_send( + self, + channel="KNNConstr", + boolean_flags=[], + add_my_info=False, + add_neighbor_info=False, + ): + message = {"CHANNEL": channel, "KNNRound": self.knn_round} + for x in boolean_flags: + message[x] = True + if add_my_info: + message[self.uid] = self.othersInfo[self.uid] + if add_neighbor_info: + for neighbors in self.out_edges: + if neighbors in self.othersInfo: + message[neighbors] = self.othersInfo[neighbors] + return message + + def receive_KNN_message(self): + return self.receive_channel("KNNConstr", block=False) + + def process_init_receive(self, message): + self.mutex.acquire() + if "RESPONSE" in message[1]: + self.num_initializations += 1 + else: + self.communication.send( + message[0], + self.create_message_to_send( + boolean_flags=["INIT", "RESPONSE"], add_my_info=True + ), + ) + x = ( + message[0], + utils.remove_keys(message[1], ["RESPONSE", "INIT", "CHANNEL", "KNNRound"]), + ) + self.othersInfo.update(x[1]) + self.mutex.release() + + def remove_meta_from_message(self, message): + return ( + message[0], + utils.remove_keys(message[1], ["RESPONSE", "INIT", "CHANNEL", "KNNRound"]), + ) + + def process_candidates_without_lock(self, current_candidates, message): + if not self.exit_receiver: + message = ( + message[0], + utils.remove_keys( + message[1], ["CHANNEL", "RESPONSE", "INIT", "KNNRound"] + ), + ) + self.othersInfo.update(message[1]) + new_candidates = set(message[1].keys()) + current_candidates = current_candidates.union(new_candidates) + if self.uid in current_candidates: + current_candidates.remove(self.uid) + self.out_edges = self.get_most_similar(current_candidates) + + def send_response(self, message, add_neighbor_info=False, process_candidates=False): + self.mutex.acquire() + logging.debug("Responding to {}".format(message[0])) + self.communication.send( + message[0], + self.create_message_to_send( + boolean_flags=["RESPONSE"], + add_my_info=True, + add_neighbor_info=add_neighbor_info, + ), + ) + if process_candidates: + self.process_candidates_without_lock(set(self.out_edges), message) + self.mutex.release() + + def receiver_thread(self): + knnBYEs = set() + self.num_initializations = 0 + waiting_queue = queue.Queue() + while True: + if len(knnBYEs) == self.mapping.get_n_procs() - 1: + self.mutex.acquire() + if self.exit_receiver: + self.mutex.release() + logging.debug("Exiting thread") + return + self.mutex.release() + + if self.num_initializations < self.initial_neighbors: + x = self.receive_KNN_message() + if x == None: + continue + elif "INIT" in x[1]: + self.process_init_receive(x) + else: + waiting_queue.put(x) + else: + logging.debug("Waiting for messages") + if waiting_queue.empty(): + x = self.receive_KNN_message() + if x == None: + continue + else: + x = waiting_queue.get() + + if "INIT" in x[1]: + logging.debug("A past INIT Message received from {}".format(x[0])) + self.process_init_receive(x) + elif "RESPONSE" in x[1]: + logging.debug( + "A response message received from {} from KNNRound {}".format( + x[0], x[1]["KNNRound"] + ) + ) + x = self.remove_meta_from_message(x) + self.responseQueue.put(x) + elif "RANDOM_DISCOVERY" in x[1]: + logging.debug( + "A Random Discovery message received from {} from KNNRound {}".format( + x[0], x[1]["KNNRound"] + ) + ) + self.send_response( + x, add_neighbor_info=False, process_candidates=False + ) + elif "KNNBYE" in x[1]: + self.mutex.acquire() + knnBYEs.add(x[0]) + logging.debug("{} KNN Byes received".format(knnBYEs)) + if self.uid in x[1]["CLOSE"]: + self.in_edges.add(x[0]) + self.mutex.release() + else: + logging.debug( + "A KNN sharing message received from {} from KNNRound {}".format( + x[0], x[1]["KNNRound"] + ) + ) + self.send_response( + x, add_neighbor_info=True, process_candidates=True + ) + + def build_topology(self, rounds=30, random_nodes=4): + self.knn_round = 0 + self.exit_receiver = False + t = Thread(target=self.receiver_thread) + + t.start() + + # Initializations : Send my dataset info to others + + self.mutex.acquire() + initial_KNN_message = self.create_message_to_send( + boolean_flags=["INIT"], add_my_info=True + ) + for x in self.out_edges: + self.communication.send(x, initial_KNN_message) + self.mutex.release() + + for round in range(rounds): + self.knn_round = round + logging.info("Starting KNN Round {}".format(round)) + self.mutex.acquire() + rand_neighbor = self.rng.choice(list(self.out_edges)) + logging.debug("Random neighbor: {}".format(rand_neighbor)) + self.communication.send( + rand_neighbor, + self.create_message_to_send(add_my_info=True, add_neighbor_info=True), + ) + self.mutex.release() + + logging.debug("Waiting for knn response from {}".format(rand_neighbor)) + + response = self.responseQueue.get(block=True) + + logging.debug("Got response from random neighbor") + + self.mutex.acquire() + random_candidates = set( + self.rng.sample(list(range(self.mapping.get_n_procs())), random_nodes) + ) + + req_responses = 0 + for rc in random_candidates: + logging.debug("Current random discovery: {}".format(rc)) + if rc not in self.othersInfo and rc != self.uid: + logging.debug("Sending discovery request to {}".format(rc)) + self.communication.send( + rc, + self.create_message_to_send(boolean_flags=["RANDOM_DISCOVERY"]), + ) + req_responses += 1 + self.mutex.release() + + while req_responses > 0: + logging.debug( + "Waiting for {} random discovery responses.".format(req_responses) + ) + req_responses -= 1 + random_discovery_response = self.responseQueue.get(block=True) + logging.debug( + "Received discovery response from {}".format( + random_discovery_response[0] + ) + ) + self.mutex.acquire() + self.othersInfo.update(random_discovery_response[1]) + self.mutex.release() + + self.mutex.acquire() + self.process_candidates_without_lock( + random_candidates.union(self.out_edges), response + ) + self.mutex.release() + + logging.info("Completed KNN Round {}".format(round)) + + logging.debug("OutNodes: {}".format(self.out_edges)) + + # Send out_edges and BYE to all + + to_send = self.create_message_to_send(boolean_flags=["KNNBYE"]) + logging.info("Sending KNNByes") + self.mutex.acquire() + self.exit_receiver = True + to_send["CLOSE"] = list(self.out_edges) # Optimize to only send Yes/No + for receiver in range(self.mapping.get_n_procs()): + if receiver != self.uid: + self.communication.send(receiver, to_send) + self.mutex.release() + logging.info("KNNByes Sent") + t.join() + logging.info("Receiver Thread Returned") + + def __init__( + self, + rank: int, + machine_id: int, + mapping: Mapping, + graph: Graph, + config, + iterations=1, + log_dir=".", + weights_store_dir=".", + log_level=logging.INFO, + test_after=5, + train_evaluate_after=1, + reset_optimizer=1, + initial_neighbors=4, + *args + ): + """ + Constructor + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + config : dict + A dictionary of configurations. Must contain the following: + [DATASET] + dataset_package + dataset_class + model_class + [OPTIMIZER_PARAMS] + optimizer_package + optimizer_class + [TRAIN_PARAMS] + training_package = decentralizepy.training.Training + training_class = Training + epochs_per_round = 25 + batch_size = 64 + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + weights_store_dir : str + Directory in which to store model weights + log_level : logging.Level + One of DEBUG, INFO, WARNING, ERROR, CRITICAL + test_after : int + Number of iterations after which the test loss and accuracy arecalculated + train_evaluate_after : int + Number of iterations after which the train loss is calculated + reset_optimizer : int + 1 if optimizer should be reset every communication round, else 0 + args : optional + Other arguments + + """ + + total_threads = os.cpu_count() + self.threads_per_proc = max( + math.floor(total_threads / mapping.procs_per_machine), 1 + ) + torch.set_num_threads(self.threads_per_proc) + torch.set_num_interop_threads(1) + self.instantiate( + rank, + machine_id, + mapping, + graph, + config, + iterations, + log_dir, + weights_store_dir, + log_level, + test_after, + train_evaluate_after, + reset_optimizer, + *args + ) + + self.rng = Random() + self.rng.seed(self.uid + 100) + + self.initial_neighbors = initial_neighbors + self.in_edges = set() + self.out_edges = set( + self.rng.sample( + list(self.graph.neighbors(self.uid)), self.initial_neighbors + ) + ) + self.responseQueue = queue.Queue() + self.mutex = Lock() + self.othersInfo = {self.uid: list(self.dataset.get_label_distribution())} + # ld = self.dataset.get_label_distribution() + # ld_keys = sorted(list(ld.keys())) + # self.othersInfo = {self.uid: []} + # for key in range(max(ld_keys) + 1): + # if key in ld: + # self.othersInfo[self.uid].append(ld[key]) + # else: + # self.othersInfo[self.uid].append(0) + logging.info("Label Distributions: {}".format(self.othersInfo)) + + logging.info( + "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads + ) + self.run() diff --git a/src/decentralizepy/node/OverlayNode.py b/src/decentralizepy/node/OverlayNode.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0768154cd33617d6bc5d462494636e2419a4b0 --- /dev/null +++ b/src/decentralizepy/node/OverlayNode.py @@ -0,0 +1,454 @@ +import importlib +import json +import logging +import math +import os +from collections import deque + +import torch +from matplotlib import pyplot as plt + +from decentralizepy import utils +from decentralizepy.graphs.Graph import Graph +from decentralizepy.mappings.Mapping import Mapping +from decentralizepy.node.Node import Node + + +class OverlayNode(Node): + """ + This class defines the node on overlay graph + + """ + + def save_plot(self, l, label, title, xlabel, filename): + """ + Save Matplotlib plot. Clears previous plots. + + Parameters + ---------- + l : dict + dict of x -> y. `x` must be castable to int. + label : str + label of the plot. Used for legend. + title : str + Header + xlabel : str + x-axis label + filename : str + Name of file to save the plot as. + + """ + plt.clf() + y_axis = [l[key] for key in l.keys()] + x_axis = list(map(int, l.keys())) + plt.plot(x_axis, y_axis, label=label) + plt.xlabel(xlabel) + plt.title(title) + plt.savefig(filename) + + def get_neighbors(self, node=None): + return self.my_neighbors + + def receive_DPSGD(self): + return self.receive_channel("DPSGD") + + def build_topology(self): + pass + + def run(self): + """ + Start the decentralized learning + + """ + self.testset = self.dataset.get_testset() + rounds_to_test = self.test_after + rounds_to_train_evaluate = self.train_evaluate_after + global_epoch = 1 + change = 1 + + self.connect_neighbors() + logging.info("Connected to all neighbors") + + self.build_topology() + + logging.info("OutNodes: {}".format(self.out_edges)) + + logging.info("InNodes: {}".format(self.in_edges)) + + logging.info("Unifying edges") + + self.out_edges = self.out_edges.union(self.in_edges) + self.my_neighbors = self.in_edges = set(self.out_edges) + + logging.info("Total number of neighbor: {}".format(len(self.my_neighbors))) + + for iteration in range(self.iterations): + logging.info("Starting training iteration: %d", iteration) + rounds_to_train_evaluate -= 1 + rounds_to_test -= 1 + + self.iteration = iteration + self.trainer.train(self.dataset) + + to_send = self.sharing.get_data_to_send() + to_send["CHANNEL"] = "DPSGD" + to_send["degree"] = len(self.in_edges) + + assert len(self.out_edges) != 0 + assert len(self.in_edges) != 0 + + for neighbor in self.out_edges: + self.communication.send(neighbor, to_send) + + while not self.received_from_all(): + sender, data = self.receive_DPSGD() + logging.info( + "Received Model from {} of iteration {}".format( + sender, data["iteration"] + ) + ) + if sender not in self.peer_deques: + self.peer_deques[sender] = deque() + self.peer_deques[sender].append(data) + + averaging_deque = dict() + for neighbor in self.in_edges: + averaging_deque[neighbor] = self.peer_deques[neighbor] + + self.sharing._averaging(averaging_deque) + + if self.reset_optimizer: + self.optimizer = self.optimizer_class( + self.model.parameters(), **self.optimizer_params + ) # Reset optimizer state + self.trainer.reset_optimizer(self.optimizer) + + if iteration: + with open( + os.path.join(self.log_dir, "{}_results.json".format(self.rank)), + "r", + ) as inf: + results_dict = json.load(inf) + else: + results_dict = { + "train_loss": {}, + "test_loss": {}, + "test_acc": {}, + "total_bytes": {}, + "total_meta": {}, + "total_data_per_n": {}, + } + + results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes + + if hasattr(self.communication, "total_meta"): + results_dict["total_meta"][ + iteration + 1 + ] = self.communication.total_meta + if hasattr(self.communication, "total_data"): + results_dict["total_data_per_n"][ + iteration + 1 + ] = self.communication.total_data + + if rounds_to_train_evaluate == 0: + logging.info("Evaluating on train set.") + rounds_to_train_evaluate = self.train_evaluate_after * change + loss_after_sharing = self.trainer.eval_loss(self.dataset) + results_dict["train_loss"][iteration + 1] = loss_after_sharing + self.save_plot( + results_dict["train_loss"], + "train_loss", + "Training Loss", + "Communication Rounds", + os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)), + ) + + if self.dataset.__testing__ and rounds_to_test == 0: + rounds_to_test = self.test_after * change + logging.info("Evaluating on test set.") + ta, tl = self.dataset.test(self.model, self.loss) + results_dict["test_acc"][iteration + 1] = ta + results_dict["test_loss"][iteration + 1] = tl + + if global_epoch == 49: + change *= 2 + + global_epoch += change + + with open( + os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w" + ) as of: + json.dump(results_dict, of) + # if self.model.shared_parameters_counter is not None: + # logging.info("Saving the shared parameter counts") + # with open( + # os.path.join( + # self.log_dir, "{}_shared_parameters.json".format(self.rank) + # ), + # "w", + # ) as of: + # json.dump(self.model.shared_parameters_counter.numpy().tolist(), of) + self.disconnect_neighbors() + # logging.info("Storing final weight") + # self.model.dump_weights(self.weights_store_dir, self.uid, iteration) + logging.info("All neighbors disconnected. Process complete!") + + def cache_fields( + self, + rank, + machine_id, + mapping, + graph, + iterations, + log_dir, + weights_store_dir, + test_after, + train_evaluate_after, + reset_optimizer, + ): + """ + Instantiate object field with arguments. + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + weights_store_dir : str + Directory in which to store model weights + test_after : int + Number of iterations after which the test loss and accuracy arecalculated + train_evaluate_after : int + Number of iterations after which the train loss is calculated + reset_optimizer : int + 1 if optimizer should be reset every communication round, else 0 + """ + self.rank = rank + self.machine_id = machine_id + self.graph = graph + self.mapping = mapping + self.uid = self.mapping.get_uid(rank, machine_id) + self.log_dir = log_dir + self.weights_store_dir = weights_store_dir + self.iterations = iterations + self.test_after = test_after + self.train_evaluate_after = train_evaluate_after + self.reset_optimizer = reset_optimizer + self.sent_disconnections = False + + logging.info("Rank: %d", self.rank) + logging.info("type(graph): %s", str(type(self.rank))) + logging.info("type(mapping): %s", str(type(self.mapping))) + + def init_comm(self, comm_configs): + """ + Instantiate communication module from config. + + Parameters + ---------- + comm_configs : dict + Python dict containing communication config params + + """ + comm_module = importlib.import_module(comm_configs["comm_package"]) + comm_class = getattr(comm_module, comm_configs["comm_class"]) + comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"]) + self.addresses_filepath = comm_params.get("addresses_filepath", None) + self.communication = comm_class( + self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params + ) + + def instantiate( + self, + rank: int, + machine_id: int, + mapping: Mapping, + graph: Graph, + config, + iterations=1, + log_dir=".", + weights_store_dir=".", + log_level=logging.INFO, + test_after=5, + train_evaluate_after=1, + reset_optimizer=1, + *args + ): + """ + Construct objects. + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + config : dict + A dictionary of configurations. + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + weights_store_dir : str + Directory in which to store model weights + log_level : logging.Level + One of DEBUG, INFO, WARNING, ERROR, CRITICAL + test_after : int + Number of iterations after which the test loss and accuracy arecalculated + train_evaluate_after : int + Number of iterations after which the train loss is calculated + reset_optimizer : int + 1 if optimizer should be reset every communication round, else 0 + args : optional + Other arguments + + """ + logging.info("Started process.") + + self.init_log(log_dir, rank, log_level) + + self.cache_fields( + rank, + machine_id, + mapping, + graph, + iterations, + log_dir, + weights_store_dir, + test_after, + train_evaluate_after, + reset_optimizer, + ) + self.init_dataset_model(config["DATASET"]) + self.init_optimizer(config["OPTIMIZER_PARAMS"]) + self.init_trainer(config["TRAIN_PARAMS"]) + self.init_comm(config["COMMUNICATION"]) + + self.message_queue = dict() + + self.barrier = set() + self.my_neighbors = self.graph.neighbors(self.uid) + + self.init_sharing(config["SHARING"]) + self.peer_deques = dict() + self.connect_neighbors() + + def received_from_all(self): + """ + Check if all neighbors have sent the current iteration + + Returns + ------- + bool + True if required data has been received, False otherwise + + """ + for k in self.in_edges: + if (k not in self.peer_deques) or len(self.peer_deques[k]) == 0: + return False + return True + + def __init__( + self, + rank: int, + machine_id: int, + mapping: Mapping, + graph: Graph, + config, + iterations=1, + log_dir=".", + weights_store_dir=".", + log_level=logging.INFO, + test_after=5, + train_evaluate_after=1, + reset_optimizer=1, + *args + ): + """ + Constructor + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + config : dict + A dictionary of configurations. Must contain the following: + [DATASET] + dataset_package + dataset_class + model_class + [OPTIMIZER_PARAMS] + optimizer_package + optimizer_class + [TRAIN_PARAMS] + training_package = decentralizepy.training.Training + training_class = Training + epochs_per_round = 25 + batch_size = 64 + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + weights_store_dir : str + Directory in which to store model weights + log_level : logging.Level + One of DEBUG, INFO, WARNING, ERROR, CRITICAL + test_after : int + Number of iterations after which the test loss and accuracy arecalculated + train_evaluate_after : int + Number of iterations after which the train loss is calculated + reset_optimizer : int + 1 if optimizer should be reset every communication round, else 0 + args : optional + Other arguments + + """ + + total_threads = os.cpu_count() + self.threads_per_proc = max( + math.floor(total_threads / mapping.procs_per_machine), 1 + ) + torch.set_num_threads(self.threads_per_proc) + torch.set_num_interop_threads(1) + self.instantiate( + rank, + machine_id, + mapping, + graph, + config, + iterations, + log_dir, + weights_store_dir, + log_level, + test_after, + train_evaluate_after, + reset_optimizer, + *args + ) + + self.in_edges = set() + self.out_edges = set() + + logging.info( + "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads + ) + self.run()