From f1d5035f3e83c1c11571216f6b304155ecbcb9a4 Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Fri, 22 Jul 2022 16:55:43 +0200 Subject: [PATCH] Add peer sampler, refactor everything --- eval/main.ipynb | 50 +- eval/plot.py | 2 +- eval/plotting_from_csv.py | 20 +- eval/testing.py | 4 +- setup.cfg | 2 +- src/decentralizepy/communication/TCP.py | 151 +----- src/decentralizepy/compression/Compression.py | 3 - src/decentralizepy/compression/Elias.py | 1 - src/decentralizepy/compression/EliasFpzip.py | 1 - .../compression/EliasFpzipLossy.py | 1 - src/decentralizepy/datasets/MovieLens.py | 1 - src/decentralizepy/datasets/Shakespeare.py | 1 - src/decentralizepy/graphs/Graph.py | 3 + src/decentralizepy/mappings/Linear.py | 7 +- src/decentralizepy/node/DPSGDNode.py | 513 ++++++++++++++++++ src/decentralizepy/node/Node.py | 384 ++++--------- src/decentralizepy/node/PeerSampler.py | 221 ++++++++ src/decentralizepy/sharing/FFT.py | 27 +- src/decentralizepy/sharing/GrowingAlpha.py | 11 +- src/decentralizepy/sharing/LowerBoundTopK.py | 20 +- src/decentralizepy/sharing/ManualAdapt.py | 13 +- src/decentralizepy/sharing/PartialModel.py | 36 +- src/decentralizepy/sharing/RandomAlpha.py | 10 +- .../sharing/RandomAlphaIncremental.py | 11 +- .../sharing/RandomAlphaWavelet.py | 10 +- .../sharing/RoundRobinPartial.py | 19 +- src/decentralizepy/sharing/Sharing.py | 127 ++--- .../sharing/SharingCentrality.py | 1 + src/decentralizepy/sharing/SubSampling.py | 19 +- src/decentralizepy/sharing/Synchronous.py | 1 + src/decentralizepy/sharing/TopKNormalized.py | 6 + src/decentralizepy/sharing/TopKParams.py | 19 +- src/decentralizepy/sharing/TopKPlusRandom.py | 6 + src/decentralizepy/sharing/Wavelet.py | 26 +- src/decentralizepy/train_test_evaluation.py | 1 - 35 files changed, 1183 insertions(+), 545 deletions(-) create mode 100644 src/decentralizepy/node/DPSGDNode.py create mode 100644 src/decentralizepy/node/PeerSampler.py diff --git a/eval/main.ipynb b/eval/main.ipynb index 80daae6..0873005 100644 --- a/eval/main.ipynb +++ b/eval/main.ipynb @@ -5709,6 +5709,41 @@ "print(i)" ] }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from torch import multiprocessing as mp\n", + "from decentralizepy.node.PeerSampler import PeerSampler\n", + "from decentralizepy.node.Node import Node\n", + "from decentralizepy.mappings.Linear import Linear\n", + "from decentralizepy.graphs.Regular import Regular\n", + "\n", + "l = Linear(1, 6)\n", + "g = Regular(6, 2)\n", + "processes = [mp.Process(target = PeerSampler, args=[-1, 0, l, g, None]),\n", + " mp.Process(target = Node, args=[1, 0, l, g, None]),\n", + " mp.Process(target = Node, args=[2, 0, l, g, None]),\n", + " mp.Process(target = Node, args=[3, 0, l, g, None]),\n", + " mp.Process(target = Node, args=[4, 0, l, g, None]),\n", + " mp.Process(target = Node, args=[5, 0, l, g, None]),\n", + " mp.Process(target = Node, args=[6, 0, l, g, None]),\n", + " ]\n", + "\n", + "for p in processes:\n", + " p.start()\n", + "\n", + "for p in processes:\n", + " p.join()\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, { "cell_type": "code", "execution_count": null, @@ -5718,11 +5753,9 @@ } ], "metadata": { - "interpreter": { - "hash": "996934296aa9d79be6c3d800a38d8fdb7dfa8fe7bb07df178f1397cde2cb8742" - }, "kernelspec": { - "display_name": "Python 3.9.7 64-bit ('tff': conda)", + "display_name": "Python 3.9.7 ('decpy')", + "language": "python", "name": "python3" }, "language_info": { @@ -5735,9 +5768,14 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.9.12" }, - "orig_nbformat": 4 + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "70be49349d3cda3718db277e01495433e35b5db6f514174958763e3b43682235" + } + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/eval/plot.py b/eval/plot.py index 601d8e8..8a9bcc1 100644 --- a/eval/plot.py +++ b/eval/plot.py @@ -38,7 +38,7 @@ def plot(means, stdevs, mins, maxs, title, label, loc): def plot_results(path, centralized, data_machine="machine0", data_node=0): folders = os.listdir(path) - if centralized.lower() in ['true', '1', 't', 'y', 'yes']: + if centralized.lower() in ["true", "1", "t", "y", "yes"]: centralized = True print("Centralized") else: diff --git a/eval/plotting_from_csv.py b/eval/plotting_from_csv.py index b8d4320..dbd1c1a 100644 --- a/eval/plotting_from_csv.py +++ b/eval/plotting_from_csv.py @@ -23,7 +23,7 @@ def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel): def plot_results(path, epochs, global_epochs="True"): - if global_epochs.lower() in ['true', '1', 't', 'y', 'yes']: + if global_epochs.lower() in ["true", "1", "t", "y", "yes"]: global_epochs = True else: global_epochs = False @@ -52,10 +52,12 @@ def plot_results(path, epochs, global_epochs="True"): if global_epochs: rounds = results_csv["rounds"].iloc[0] print("Rounds: ", rounds) - results_cr = results_csv[results_csv.rounds <= epochs*rounds] + results_cr = results_csv[results_csv.rounds <= epochs * rounds] means = results_cr["mean"].to_numpy() stdevs = results_cr["std"].to_numpy() - x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1)) + x_axis = ( + results_cr["rounds"].to_numpy() / rounds + ) # list(np.arange(0, len(means), 1)) x_label = "global epochs" else: results_cr = results_csv[results_csv.rounds <= epochs] @@ -85,10 +87,12 @@ def plot_results(path, epochs, global_epochs="True"): if global_epochs: rounds = results_csv["rounds"].iloc[0] print("Rounds: ", rounds) - results_cr = results_csv[results_csv.rounds <= epochs*rounds] + results_cr = results_csv[results_csv.rounds <= epochs * rounds] means = results_cr["mean"].to_numpy() stdevs = results_cr["std"].to_numpy() - x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1)) + x_axis = ( + results_cr["rounds"].to_numpy() / rounds + ) # list(np.arange(0, len(means), 1)) x_label = "global epochs" else: results_cr = results_csv[results_csv.rounds <= epochs] @@ -120,10 +124,12 @@ def plot_results(path, epochs, global_epochs="True"): if global_epochs: rounds = results_csv["rounds"].iloc[0] print("Rounds: ", rounds) - results_cr = results_csv[results_csv.rounds <= epochs*rounds] + results_cr = results_csv[results_csv.rounds <= epochs * rounds] means = results_cr["mean"].to_numpy() stdevs = results_cr["std"].to_numpy() - x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1)) + x_axis = ( + results_cr["rounds"].to_numpy() / rounds + ) # list(np.arange(0, len(means), 1)) x_label = "global epochs" else: results_cr = results_csv[results_csv.rounds <= epochs] diff --git a/eval/testing.py b/eval/testing.py index 9125828..9d67b28 100644 --- a/eval/testing.py +++ b/eval/testing.py @@ -8,7 +8,7 @@ from torch import multiprocessing as mp from decentralizepy import utils from decentralizepy.graphs.Graph import Graph from decentralizepy.mappings.Linear import Linear -from decentralizepy.node.Node import Node +from decentralizepy.node.DPSGDNode import DPSGDNode def read_ini(file_path): @@ -51,7 +51,7 @@ if __name__ == "__main__": m_id = args.machine_id mp.spawn( - fn=Node, + fn=DPSGDNode, nprocs=procs_per_machine, args=[ m_id, diff --git a/setup.cfg b/setup.cfg index 1b3f6c7..2df457a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,7 @@ python_requires = >=3.6 where = src [options.extras_require] dev = - black + black>22.3.0 coverage isort pytest diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py index 6d72e1c..8c8587b 100644 --- a/src/decentralizepy/communication/TCP.py +++ b/src/decentralizepy/communication/TCP.py @@ -1,4 +1,3 @@ -import importlib import json import logging import pickle @@ -36,7 +35,8 @@ class TCP(Communication): """ machine_addr = self.ip_addrs[str(machine_id)] - port = rank + self.offset + port = (2 * rank + 1) + self.offset + assert port > 0 return "tcp://{}:{}".format(machine_addr, port) def __init__( @@ -46,10 +46,7 @@ class TCP(Communication): mapping, total_procs, addresses_filepath, - compress=False, - offset=20000, - compression_package=None, - compression_class=None, + offset=9000, ): """ Constructor @@ -81,30 +78,19 @@ class TCP(Communication): self.rank = rank self.machine_id = machine_id self.mapping = mapping - self.offset = 20000 + offset + self.offset = offset self.uid = mapping.get_uid(rank, machine_id) self.identity = str(self.uid).encode() self.context = zmq.Context() self.router = self.context.socket(zmq.ROUTER) self.router.setsockopt(zmq.IDENTITY, self.identity) self.router.bind(self.addr(rank, machine_id)) - self.sent_disconnections = False - self.compress = compress - - if compression_package and compression_class: - compressor_module = importlib.import_module(compression_package) - compressor_class = getattr(compressor_module, compression_class) - self.compressor = compressor_class() - logging.info(f"Using the {compressor_class} to compress the data") - else: - assert not self.compress self.total_data = 0 self.total_meta = 0 self.peer_deque = deque() self.peer_sockets = dict() - self.barrier = set() def __del__(self): """ @@ -128,26 +114,12 @@ class TCP(Communication): Encoded data """ - if self.compress: - if "indices" in data: - data["indices"] = self.compressor.compress(data["indices"]) - - assert "params" in data - data["params"] = self.compressor.compress_float(data["params"]) + data_len = 0 + if "params" in data: data_len = len(pickle.dumps(data["params"])) - output = pickle.dumps(data) - - # the compressed meta data gets only a few bytes smaller after pickling - self.total_meta += len(output) - data_len - self.total_data += data_len - else: - output = pickle.dumps(data) - # centralized testing uses its own instance - if type(data) == dict: - assert "params" in data - data_len = len(pickle.dumps(data["params"])) - self.total_meta += len(output) - data_len - self.total_data += data_len + output = pickle.dumps(data) + self.total_meta += len(output) - data_len + self.total_data += data_len return output def decrypt(self, sender, data): @@ -168,63 +140,25 @@ class TCP(Communication): """ sender = int(sender.decode()) - if self.compress: - data = pickle.loads(data) - if "indices" in data: - data["indices"] = self.compressor.decompress(data["indices"]) - if "params" in data: - data["params"] = self.compressor.decompress_float(data["params"]) - else: - data = pickle.loads(data) + data = pickle.loads(data) return sender, data - def connect_neighbors(self, neighbors): + def init_connection(self, neighbor): """ - Connects all neighbors. Sends HELLO. Waits for HELLO. - Caches any data received while waiting for HELLOs. + Initiates a socket to a given node. Parameters ---------- - neighbors : list(int) - List of neighbors - - Raises - ------ - RuntimeError - If received BYE while waiting for HELLO + neighbor : int + neighbor to connect to """ - logging.info("Sending connection request to neighbors") - for uid in neighbors: - logging.debug("Connecting to my neighbour: {}".format(uid)) - id = str(uid).encode() - req = self.context.socket(zmq.DEALER) - req.setsockopt(zmq.IDENTITY, self.identity) - req.connect(self.addr(*self.mapping.get_machine_and_rank(uid))) - self.peer_sockets[id] = req - req.send(HELLO) - - num_neighbors = len(neighbors) - while len(self.barrier) < num_neighbors: - sender, recv = self.router.recv_multipart() - - if recv == HELLO: - logging.debug("Received {} from {}".format(HELLO, sender)) - self.barrier.add(sender) - elif recv == BYE: - logging.debug("Received {} from {}".format(BYE, sender)) - raise RuntimeError( - "A neighbour wants to disconnect before training started!" - ) - else: - logging.debug( - "Received message from {} @ connect_neighbors".format(sender) - ) - - self.peer_deque.append(self.decrypt(sender, recv)) - - logging.info("Connected to all neighbors") - self.initialized = True + logging.debug("Connecting to my neighbour: {}".format(neighbor)) + id = str(neighbor).encode() + req = self.context.socket(zmq.DEALER) + req.setsockopt(zmq.IDENTITY, self.identity) + req.connect(self.addr(*self.mapping.get_machine_and_rank(neighbor))) + self.peer_sockets[id] = req def receive(self): """ @@ -241,25 +175,9 @@ class TCP(Communication): If received HELLO """ - assert self.initialized == True - if len(self.peer_deque) != 0: - resp = self.peer_deque.popleft() - return resp sender, recv = self.router.recv_multipart() - - if recv == HELLO: - logging.debug("Received {} from {}".format(HELLO, sender)) - raise RuntimeError( - "A neighbour wants to connect when everyone is connected!" - ) - elif recv == BYE: - logging.debug("Received {} from {}".format(BYE, sender)) - self.barrier.remove(sender) - return self.receive() - else: - logging.debug("Received message from {}".format(sender)) - return self.decrypt(sender, recv) + return self.decrypt(sender, recv) def send(self, uid, data, encrypt=True): """ @@ -273,7 +191,6 @@ class TCP(Communication): Message as a Python dictionary """ - assert self.initialized == True if encrypt: to_send = self.encrypt(data) else: @@ -283,28 +200,4 @@ class TCP(Communication): id = str(uid).encode() self.peer_sockets[id].send(to_send) logging.debug("{} sent the message to {}.".format(self.uid, uid)) - logging.info("Sent this round: {}".format(data_size)) - - def disconnect_neighbors(self): - """ - Disconnects all neighbors. - - """ - assert self.initialized == True - if not self.sent_disconnections: - logging.info("Disconnecting neighbors") - for sock in self.peer_sockets.values(): - sock.send(BYE) - self.sent_disconnections = True - while len(self.barrier): - sender, recv = self.router.recv_multipart() - if recv == BYE: - logging.debug("Received {} from {}".format(BYE, sender)) - self.barrier.remove(sender) - else: - logging.critical( - "Received unexpected {} from {}".format(recv, sender) - ) - raise RuntimeError( - "Received a message when expecting BYE from {}".format(sender) - ) + logging.info("Sent message size: {}".format(data_size)) diff --git a/src/decentralizepy/compression/Compression.py b/src/decentralizepy/compression/Compression.py index 0924caf..b45e641 100644 --- a/src/decentralizepy/compression/Compression.py +++ b/src/decentralizepy/compression/Compression.py @@ -1,6 +1,3 @@ -import numpy as np - - class Compression: """ Compression API diff --git a/src/decentralizepy/compression/Elias.py b/src/decentralizepy/compression/Elias.py index 235cf00..0d408d8 100644 --- a/src/decentralizepy/compression/Elias.py +++ b/src/decentralizepy/compression/Elias.py @@ -1,6 +1,5 @@ # elias implementation: taken from this stack overflow post: # https://stackoverflow.com/questions/62843156/python-fast-compression-of-large-amount-of-numbers-with-elias-gamma -import fpzip import numpy as np from decentralizepy.compression.Compression import Compression diff --git a/src/decentralizepy/compression/EliasFpzip.py b/src/decentralizepy/compression/EliasFpzip.py index 0c82560..0142dd9 100644 --- a/src/decentralizepy/compression/EliasFpzip.py +++ b/src/decentralizepy/compression/EliasFpzip.py @@ -1,7 +1,6 @@ # elias implementation: taken from this stack overflow post: # https://stackoverflow.com/questions/62843156/python-fast-compression-of-large-amount-of-numbers-with-elias-gamma import fpzip -import numpy as np from decentralizepy.compression.Elias import Elias diff --git a/src/decentralizepy/compression/EliasFpzipLossy.py b/src/decentralizepy/compression/EliasFpzipLossy.py index 617a78b..0b60307 100644 --- a/src/decentralizepy/compression/EliasFpzipLossy.py +++ b/src/decentralizepy/compression/EliasFpzipLossy.py @@ -1,7 +1,6 @@ # elias implementation: taken from this stack overflow post: # https://stackoverflow.com/questions/62843156/python-fast-compression-of-large-amount-of-numbers-with-elias-gamma import fpzip -import numpy as np from decentralizepy.compression.Elias import Elias diff --git a/src/decentralizepy/datasets/MovieLens.py b/src/decentralizepy/datasets/MovieLens.py index dafb4ce..95e55cc 100644 --- a/src/decentralizepy/datasets/MovieLens.py +++ b/src/decentralizepy/datasets/MovieLens.py @@ -3,7 +3,6 @@ import math import os import zipfile -import numpy as np import pandas as pd import requests import torch diff --git a/src/decentralizepy/datasets/Shakespeare.py b/src/decentralizepy/datasets/Shakespeare.py index 0c02932..c7ede74 100644 --- a/src/decentralizepy/datasets/Shakespeare.py +++ b/src/decentralizepy/datasets/Shakespeare.py @@ -1,7 +1,6 @@ import json import logging import os -import re from collections import defaultdict import numpy as np diff --git a/src/decentralizepy/graphs/Graph.py b/src/decentralizepy/graphs/Graph.py index 689d2dc..dc66eef 100644 --- a/src/decentralizepy/graphs/Graph.py +++ b/src/decentralizepy/graphs/Graph.py @@ -22,6 +22,9 @@ class Graph: self.n_procs = n_procs self.adj_list = [set() for i in range(self.n_procs)] + def get_all_nodes(self): + return [i for i in range(self.n_procs)] + def __insert_adj__(self, node, neighbours): """ Inserts `neighbours` into the adjacency list of `node` diff --git a/src/decentralizepy/mappings/Linear.py b/src/decentralizepy/mappings/Linear.py index 9419fbd..f166dc9 100644 --- a/src/decentralizepy/mappings/Linear.py +++ b/src/decentralizepy/mappings/Linear.py @@ -8,7 +8,7 @@ class Linear(Mapping): """ - def __init__(self, n_machines, procs_per_machine): + def __init__(self, n_machines, procs_per_machine, global_service_machine=0): """ Constructor @@ -23,6 +23,7 @@ class Linear(Mapping): super().__init__(n_machines * procs_per_machine) self.n_machines = n_machines self.procs_per_machine = procs_per_machine + self.global_service_machine = global_service_machine def get_uid(self, rank: int, machine_id: int): """ @@ -41,6 +42,8 @@ class Linear(Mapping): the unique identifier """ + if rank < 0: + return rank return machine_id * self.procs_per_machine + rank def get_machine_and_rank(self, uid: int): @@ -58,6 +61,8 @@ class Linear(Mapping): a tuple of rank and machine_id """ + if uid < 0: + return uid, self.global_service_machine return (uid % self.procs_per_machine), (uid // self.procs_per_machine) def get_local_procs_count(self): diff --git a/src/decentralizepy/node/DPSGDNode.py b/src/decentralizepy/node/DPSGDNode.py new file mode 100644 index 0000000..964f103 --- /dev/null +++ b/src/decentralizepy/node/DPSGDNode.py @@ -0,0 +1,513 @@ +import importlib +import json +import logging +import math +import os +from collections import deque + +import torch +from matplotlib import pyplot as plt + +from decentralizepy import utils +from decentralizepy.communication.TCP import TCP +from decentralizepy.graphs.Graph import Graph +from decentralizepy.graphs.Star import Star +from decentralizepy.mappings.Mapping import Mapping +from decentralizepy.node.Node import Node +from decentralizepy.train_test_evaluation import TrainTestHelper + + +class DPSGDNode(Node): + """ + This class defines the node for DPSGD + + """ + + def save_plot(self, l, label, title, xlabel, filename): + """ + Save Matplotlib plot. Clears previous plots. + + Parameters + ---------- + l : dict + dict of x -> y. `x` must be castable to int. + label : str + label of the plot. Used for legend. + title : str + Header + xlabel : str + x-axis label + filename : str + Name of file to save the plot as. + + """ + plt.clf() + y_axis = [l[key] for key in l.keys()] + x_axis = list(map(int, l.keys())) + plt.plot(x_axis, y_axis, label=label) + plt.xlabel(xlabel) + plt.title(title) + plt.savefig(filename) + + def run(self): + """ + Start the decentralized learning + + """ + self.testset = self.dataset.get_testset() + rounds_to_test = self.test_after + rounds_to_train_evaluate = self.train_evaluate_after + global_epoch = 1 + change = 1 + if self.uid == 0: + dataset = self.dataset + if self.centralized_train_eval: + dataset_params_copy = self.dataset_params.copy() + if "sizes" in dataset_params_copy: + del dataset_params_copy["sizes"] + self.whole_dataset = self.dataset_class( + self.rank, + self.machine_id, + self.mapping, + sizes=[1.0], + **dataset_params_copy + ) + dataset = self.whole_dataset + if self.centralized_test_eval: + tthelper = TrainTestHelper( + dataset, # self.whole_dataset, + # self.model_test, # todo: this only works if eval_train is set to false + self.model, + self.loss, + self.weights_store_dir, + self.mapping.get_n_procs(), + self.trainer, + self.testing_comm, + self.star, + self.threads_per_proc, + eval_train=self.centralized_train_eval, + ) + + for iteration in range(self.iterations): + logging.info("Starting training iteration: %d", iteration) + self.trainer.train(self.dataset) + to_send = self.sharing.get_data_to_send() + + for neighbor in self.my_neighbors: + self.communication.send(neighbor, to_send) + + while not self.received_from_all(): + sender, data = self.receive() + + if "HELLO" in data: + logging.critical( + "Received unexpected {} from {}".format("HELLO", sender) + ) + raise RuntimeError("A neighbour wants to connect during training!") + elif "BYE" in data: + logging.debug("Received {} from {}".format("BYE", sender)) + self.barrier.remove(sender) + else: + logging.debug("Received message from {}".format(sender)) + self.peer_deques[sender].append(data) + + self.sharing._averaging(self.peer_deques) + + if self.reset_optimizer: + self.optimizer = self.optimizer_class( + self.model.parameters(), **self.optimizer_params + ) # Reset optimizer state + self.trainer.reset_optimizer(self.optimizer) + + if iteration: + with open( + os.path.join(self.log_dir, "{}_results.json".format(self.rank)), + "r", + ) as inf: + results_dict = json.load(inf) + else: + results_dict = { + "train_loss": {}, + "test_loss": {}, + "test_acc": {}, + "total_bytes": {}, + "total_meta": {}, + "total_data_per_n": {}, + "grad_mean": {}, + "grad_std": {}, + } + + results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes + + if hasattr(self.communication, "total_meta"): + results_dict["total_meta"][ + iteration + 1 + ] = self.communication.total_meta + if hasattr(self.communication, "total_data"): + results_dict["total_data_per_n"][ + iteration + 1 + ] = self.communication.total_data + if hasattr(self.sharing, "mean"): + results_dict["grad_mean"][iteration + 1] = self.sharing.mean + if hasattr(self.sharing, "std"): + results_dict["grad_std"][iteration + 1] = self.sharing.std + + rounds_to_train_evaluate -= 1 + + if rounds_to_train_evaluate == 0 and not self.centralized_train_eval: + logging.info("Evaluating on train set.") + rounds_to_train_evaluate = self.train_evaluate_after * change + loss_after_sharing = self.trainer.eval_loss(self.dataset) + results_dict["train_loss"][iteration + 1] = loss_after_sharing + self.save_plot( + results_dict["train_loss"], + "train_loss", + "Training Loss", + "Communication Rounds", + os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)), + ) + + rounds_to_test -= 1 + + if self.dataset.__testing__ and rounds_to_test == 0: + rounds_to_test = self.test_after * change + if self.centralized_test_eval: + if self.uid == 0: + ta, tl, trl = tthelper.train_test_evaluation(iteration) + results_dict["test_acc"][iteration + 1] = ta + results_dict["test_loss"][iteration + 1] = tl + if trl is not None: + results_dict["train_loss"][iteration + 1] = trl + else: + self.testing_comm.send(0, self.model.get_weights()) + sender, data = self.testing_comm.receive() + assert sender == 0 and data == "finished" + else: + logging.info("Evaluating on test set.") + ta, tl = self.dataset.test(self.model, self.loss) + results_dict["test_acc"][iteration + 1] = ta + results_dict["test_loss"][iteration + 1] = tl + + if global_epoch == 49: + change *= 2 + + global_epoch += change + + with open( + os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w" + ) as of: + json.dump(results_dict, of) + if self.model.shared_parameters_counter is not None: + logging.info("Saving the shared parameter counts") + with open( + os.path.join( + self.log_dir, "{}_shared_parameters.json".format(self.rank) + ), + "w", + ) as of: + json.dump(self.model.shared_parameters_counter.numpy().tolist(), of) + self.disconnect_neighbors() + logging.info("Storing final weight") + self.model.dump_weights(self.weights_store_dir, self.uid, iteration) + logging.info("All neighbors disconnected. Process complete!") + + def cache_fields( + self, + rank, + machine_id, + mapping, + graph, + iterations, + log_dir, + weights_store_dir, + test_after, + train_evaluate_after, + reset_optimizer, + centralized_train_eval, + centralized_test_eval, + ): + """ + Instantiate object field with arguments. + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + weights_store_dir : str + Directory in which to store model weights + test_after : int + Number of iterations after which the test loss and accuracy arecalculated + train_evaluate_after : int + Number of iterations after which the train loss is calculated + reset_optimizer : int + 1 if optimizer should be reset every communication round, else 0 + centralized_train_eval : bool + If set the train set evaluation happens at the node with uid 0 + centralized_test_eval : bool + If set the train set evaluation happens at the node with uid 0 + """ + self.rank = rank + self.machine_id = machine_id + self.graph = graph + self.mapping = mapping + self.uid = self.mapping.get_uid(rank, machine_id) + self.log_dir = log_dir + self.weights_store_dir = weights_store_dir + self.iterations = iterations + self.test_after = test_after + self.train_evaluate_after = train_evaluate_after + self.reset_optimizer = reset_optimizer + self.centralized_train_eval = centralized_train_eval + self.centralized_test_eval = centralized_test_eval + self.sent_disconnections = False + + logging.info("Rank: %d", self.rank) + logging.info("type(graph): %s", str(type(self.rank))) + logging.info("type(mapping): %s", str(type(self.mapping))) + + if centralized_test_eval or centralized_train_eval: + self.star = Star(self.mapping.get_n_procs()) + + def init_comm(self, comm_configs): + """ + Instantiate communication module from config. + + Parameters + ---------- + comm_configs : dict + Python dict containing communication config params + + """ + comm_module = importlib.import_module(comm_configs["comm_package"]) + comm_class = getattr(comm_module, comm_configs["comm_class"]) + comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"]) + self.addresses_filepath = comm_params.get("addresses_filepath", None) + if self.centralized_test_eval: + self.testing_comm = TCP( + self.rank, + self.machine_id, + self.mapping, + self.star.n_procs, + self.addresses_filepath, + offset=self.star.n_procs, + ) + self.testing_comm.connect_neighbors(self.star.neighbors(self.uid)) + + self.communication = comm_class( + self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params + ) + + def instantiate( + self, + rank: int, + machine_id: int, + mapping: Mapping, + graph: Graph, + config, + iterations=1, + log_dir=".", + weights_store_dir=".", + log_level=logging.INFO, + test_after=5, + train_evaluate_after=1, + reset_optimizer=1, + centralized_train_eval=False, + centralized_test_eval=True, + *args + ): + """ + Construct objects. + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + config : dict + A dictionary of configurations. + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + weights_store_dir : str + Directory in which to store model weights + log_level : logging.Level + One of DEBUG, INFO, WARNING, ERROR, CRITICAL + test_after : int + Number of iterations after which the test loss and accuracy arecalculated + train_evaluate_after : int + Number of iterations after which the train loss is calculated + reset_optimizer : int + 1 if optimizer should be reset every communication round, else 0 + centralized_train_eval : bool + If set the train set evaluation happens at the node with uid 0 + centralized_test_eval : bool + If set the train set evaluation happens at the node with uid 0 + args : optional + Other arguments + + """ + logging.info("Started process.") + + self.init_log(log_dir, rank, log_level) + + self.cache_fields( + rank, + machine_id, + mapping, + graph, + iterations, + log_dir, + weights_store_dir, + test_after, + train_evaluate_after, + reset_optimizer, + centralized_train_eval, + centralized_test_eval, + ) + self.init_dataset_model(config["DATASET"]) + self.init_optimizer(config["OPTIMIZER_PARAMS"]) + self.init_trainer(config["TRAIN_PARAMS"]) + self.init_comm(config["COMMUNICATION"]) + + self.message_queue = deque() + self.barrier = set() + self.my_neighbors = self.graph.neighbors(self.uid) + + self.init_sharing(config["SHARING"]) + self.peer_deques = dict() + for n in self.my_neighbors: + self.peer_deques[n] = deque() + + self.connect_neighbors() + + def received_from_all(self): + """ + Check if all neighbors have sent the current iteration + + Returns + ------- + bool + True if required data has been received, False otherwise + + """ + for k in self.my_neighbors: + if len(self.peer_deques[k]) == 0: + return False + return True + + def __init__( + self, + rank: int, + machine_id: int, + mapping: Mapping, + graph: Graph, + config, + iterations=1, + log_dir=".", + weights_store_dir=".", + log_level=logging.INFO, + test_after=5, + train_evaluate_after=1, + reset_optimizer=1, + centralized_train_eval=0, + centralized_test_eval=1, + *args + ): + """ + Constructor + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + config : dict + A dictionary of configurations. Must contain the following: + [DATASET] + dataset_package + dataset_class + model_class + [OPTIMIZER_PARAMS] + optimizer_package + optimizer_class + [TRAIN_PARAMS] + training_package = decentralizepy.training.Training + training_class = Training + epochs_per_round = 25 + batch_size = 64 + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + weights_store_dir : str + Directory in which to store model weights + log_level : logging.Level + One of DEBUG, INFO, WARNING, ERROR, CRITICAL + test_after : int + Number of iterations after which the test loss and accuracy arecalculated + train_evaluate_after : int + Number of iterations after which the train loss is calculated + reset_optimizer : int + 1 if optimizer should be reset every communication round, else 0 + centralized_train_eval : int + If set then the train set evaluation happens at the node with uid 0. + Note: If it is True then centralized_test_eval needs to be true as well! + centralized_test_eval : int + If set then the trainset evaluation happens at the node with uid 0 + args : optional + Other arguments + + """ + centralized_train_eval = centralized_train_eval == 1 + centralized_test_eval = centralized_test_eval == 1 + # If centralized_train_eval is True then centralized_test_eval needs to be true as well! + assert not centralized_train_eval or centralized_test_eval + + total_threads = os.cpu_count() + self.threads_per_proc = max( + math.floor(total_threads / mapping.procs_per_machine), 1 + ) + torch.set_num_threads(self.threads_per_proc) + torch.set_num_interop_threads(1) + self.instantiate( + rank, + machine_id, + mapping, + graph, + config, + iterations, + log_dir, + weights_store_dir, + log_level, + test_after, + train_evaluate_after, + reset_optimizer, + centralized_train_eval == 1, + centralized_test_eval == 1, + *args + ) + logging.info( + "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads + ) + self.run() diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index 91f34e5..67ee659 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -1,18 +1,14 @@ import importlib -import json import logging import math import os +from collections import deque import torch -from matplotlib import pyplot as plt from decentralizepy import utils -from decentralizepy.communication.TCP import TCP from decentralizepy.graphs.Graph import Graph -from decentralizepy.graphs.Star import Star from decentralizepy.mappings.Mapping import Mapping -from decentralizepy.train_test_evaluation import TrainTestHelper class Node: @@ -21,31 +17,96 @@ class Node: """ - def save_plot(self, l, label, title, xlabel, filename): + def connect_neighbor(self, neighbor): """ - Save Matplotlib plot. Clears previous plots. + Connects given neighbor. Sends HELLO. - Parameters - ---------- - l : dict - dict of x -> y. `x` must be castable to int. - label : str - label of the plot. Used for legend. - title : str - Header - xlabel : str - x-axis label - filename : str - Name of file to save the plot as. + """ + logging.info("Sending connection request to {}".format(neighbor)) + self.communication.init_connection(neighbor) + self.communication.send(neighbor, {"HELLO": self.uid}) + + def wait_for_hello(self, neighbor): + """ + Waits for HELLO. + Caches any data received while waiting for HELLOs. + + Raises + ------ + RuntimeError + If received BYE while waiting for HELLO + + """ + + while neighbor not in self.barrier: + sender, recv = self.communication.receive() + + if "HELLO" in recv: + logging.debug("Received {} from {}".format("HELLO", sender)) + self.barrier.add(sender) + elif "BYE" in recv: + logging.debug("Received {} from {}".format("BYE", sender)) + raise RuntimeError( + "A neighbour wants to disconnect before training started!" + ) + else: + logging.debug( + "Received message from {} @ connect_neighbors".format(sender) + ) + self.message_queue.append((sender, recv)) + def receive(self): + if len(self.message_queue) > 0: + resp = self.message_queue.popleft() + else: + resp = self.communication.receive() + return resp + + def connect_neighbors(self): """ - plt.clf() - y_axis = [l[key] for key in l.keys()] - x_axis = list(map(int, l.keys())) - plt.plot(x_axis, y_axis, label=label) - plt.xlabel(xlabel) - plt.title(title) - plt.savefig(filename) + Connects all neighbors. Sends HELLO. Waits for HELLO. + Caches any data received while waiting for HELLOs. + + Raises + ------ + RuntimeError + If received BYE while waiting for HELLO + + """ + logging.info("Sending connection request to all neighbors") + for neighbor in self.my_neighbors: + self.connect_neighbor(neighbor) + + for neighbor in self.my_neighbors: + self.wait_for_hello(neighbor) + + def disconnect_neighbors(self): + """ + Disconnects all neighbors. + + Raises + ------ + RuntimeError + If received another message while waiting for BYEs + + """ + if not self.sent_disconnections: + logging.info("Disconnecting neighbors") + for uid in self.my_neighbors: + self.communication.send(uid, {"BYE": self.uid}) + self.sent_disconnections = True + while len(self.barrier): + sender, recv = self.receive() + if "BYE" in recv: + logging.debug("Received {} from {}".format("BYE", sender)) + self.barrier.remove(sender) + else: + logging.critical( + "Received unexpected {} from {}".format(recv, sender) + ) + raise RuntimeError( + "Received a message when expecting BYE from {}".format(sender) + ) def init_log(self, log_dir, rank, log_level, force=True): """ @@ -68,7 +129,7 @@ class Node: filename=log_file, format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s", level=log_level, - force=True, + force=force, ) def cache_fields( @@ -79,12 +140,6 @@ class Node: graph, iterations, log_dir, - weights_store_dir, - test_after, - train_evaluate_after, - reset_optimizer, - centralized_train_eval, - centralized_test_eval, ): """ Instantiate object field with arguments. @@ -103,18 +158,6 @@ class Node: Number of iterations (communication steps) for which the model should be trained log_dir : str Logging directory - weights_store_dir : str - Directory in which to store model weights - test_after : int - Number of iterations after which the test loss and accuracy arecalculated - train_evaluate_after : int - Number of iterations after which the train loss is calculated - reset_optimizer : int - 1 if optimizer should be reset every communication round, else 0 - centralized_train_eval : bool - If set the train set evaluation happens at the node with uid 0 - centralized_test_eval : bool - If set the train set evaluation happens at the node with uid 0 """ self.rank = rank self.machine_id = machine_id @@ -122,19 +165,12 @@ class Node: self.mapping = mapping self.uid = self.mapping.get_uid(rank, machine_id) self.log_dir = log_dir - self.weights_store_dir = weights_store_dir self.iterations = iterations - self.test_after = test_after - self.train_evaluate_after = train_evaluate_after - self.reset_optimizer = reset_optimizer - self.centralized_train_eval = centralized_train_eval - self.centralized_test_eval = centralized_test_eval + self.sent_disconnections = False - logging.debug("Rank: %d", self.rank) - logging.debug("type(graph): %s", str(type(self.rank))) - logging.debug("type(mapping): %s", str(type(self.mapping))) - - self.star = Star(self.mapping.get_n_procs()) + logging.info("Rank: %d", self.rank) + logging.info("type(graph): %s", str(type(self.rank))) + logging.info("type(mapping): %s", str(type(self.mapping))) def init_dataset_model(self, dataset_configs): """ @@ -243,17 +279,6 @@ class Node: comm_class = getattr(comm_module, comm_configs["comm_class"]) comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"]) self.addresses_filepath = comm_params.get("addresses_filepath", None) - if self.centralized_test_eval: - self.testing_comm = TCP( - self.rank, - self.machine_id, - self.mapping, - self.star.n_procs, - self.addresses_filepath, - offset=self.star.n_procs, - ) - self.testing_comm.connect_neighbors(self.star.neighbors(self.uid)) - self.communication = comm_class( self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params ) @@ -294,13 +319,7 @@ class Node: config, iterations=1, log_dir=".", - weights_store_dir=".", log_level=logging.INFO, - test_after=5, - train_evaluate_after=1, - reset_optimizer=1, - centralized_train_eval=False, - centralized_test_eval=True, *args ): """ @@ -322,26 +341,16 @@ class Node: Number of iterations (communication steps) for which the model should be trained log_dir : str Logging directory - weights_store_dir : str - Directory in which to store model weights log_level : logging.Level One of DEBUG, INFO, WARNING, ERROR, CRITICAL - test_after : int - Number of iterations after which the test loss and accuracy arecalculated - train_evaluate_after : int - Number of iterations after which the train loss is calculated - reset_optimizer : int - 1 if optimizer should be reset every communication round, else 0 - centralized_train_eval : bool - If set the train set evaluation happens at the node with uid 0 - centralized_test_eval : bool - If set the train set evaluation happens at the node with uid 0 args : optional Other arguments """ logging.info("Started process.") + self.init_log(log_dir, rank, log_level) + self.cache_fields( rank, machine_id, @@ -349,18 +358,16 @@ class Node: graph, iterations, log_dir, - weights_store_dir, - test_after, - train_evaluate_after, - reset_optimizer, - centralized_train_eval, - centralized_test_eval, ) - self.init_log(log_dir, rank, log_level) self.init_dataset_model(config["DATASET"]) self.init_optimizer(config["OPTIMIZER_PARAMS"]) self.init_trainer(config["TRAIN_PARAMS"]) self.init_comm(config["COMMUNICATION"]) + + self.message_queue = deque() + self.barrier = set() + self.my_neighbors = self.graph.neighbors(self.uid) + self.init_sharing(config["SHARING"]) def run(self): @@ -368,146 +375,7 @@ class Node: Start the decentralized learning """ - self.testset = self.dataset.get_testset() - self.communication.connect_neighbors(self.graph.neighbors(self.uid)) - rounds_to_test = self.test_after - rounds_to_train_evaluate = self.train_evaluate_after - global_epoch = 1 - change = 1 - if self.uid == 0: - dataset = self.dataset - if self.centralized_train_eval: - dataset_params_copy = self.dataset_params.copy() - if "sizes" in dataset_params_copy: - del dataset_params_copy["sizes"] - self.whole_dataset = self.dataset_class( - self.rank, - self.machine_id, - self.mapping, - sizes=[1.0], - **dataset_params_copy - ) - dataset = self.whole_dataset - if self.centralized_test_eval: - tthelper = TrainTestHelper( - dataset, # self.whole_dataset, - # self.model_test, # todo: this only works if eval_train is set to false - self.model, - self.loss, - self.weights_store_dir, - self.mapping.get_n_procs(), - self.trainer, - self.testing_comm, - self.star, - self.threads_per_proc, - eval_train=self.centralized_train_eval, - ) - - for iteration in range(self.iterations): - logging.info("Starting training iteration: %d", iteration) - self.trainer.train(self.dataset) - - self.sharing.step() - - if self.reset_optimizer: - self.optimizer = self.optimizer_class( - self.model.parameters(), **self.optimizer_params - ) # Reset optimizer state - self.trainer.reset_optimizer(self.optimizer) - - if iteration: - with open( - os.path.join(self.log_dir, "{}_results.json".format(self.rank)), - "r", - ) as inf: - results_dict = json.load(inf) - else: - results_dict = { - "train_loss": {}, - "test_loss": {}, - "test_acc": {}, - "total_bytes": {}, - "total_meta": {}, - "total_data_per_n": {}, - "grad_mean": {}, - "grad_std": {}, - } - - results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes - - if hasattr(self.communication, "total_meta"): - results_dict["total_meta"][ - iteration + 1 - ] = self.communication.total_meta - if hasattr(self.communication, "total_data"): - results_dict["total_data_per_n"][ - iteration + 1 - ] = self.communication.total_data - if hasattr(self.sharing, "mean"): - results_dict["grad_mean"][iteration + 1] = self.sharing.mean - if hasattr(self.sharing, "std"): - results_dict["grad_std"][iteration + 1] = self.sharing.std - - rounds_to_train_evaluate -= 1 - - if rounds_to_train_evaluate == 0 and not self.centralized_train_eval: - logging.info("Evaluating on train set.") - rounds_to_train_evaluate = self.train_evaluate_after * change - loss_after_sharing = self.trainer.eval_loss(self.dataset) - results_dict["train_loss"][iteration + 1] = loss_after_sharing - self.save_plot( - results_dict["train_loss"], - "train_loss", - "Training Loss", - "Communication Rounds", - os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)), - ) - - rounds_to_test -= 1 - - if self.dataset.__testing__ and rounds_to_test == 0: - rounds_to_test = self.test_after * change - # ta, tl = self.dataset.test(self.model, self.loss) - # self.model.dump_weights(self.weights_store_dir, self.uid, iteration) - if self.centralized_test_eval: - if self.uid == 0: - ta, tl, trl = tthelper.train_test_evaluation(iteration) - results_dict["test_acc"][iteration + 1] = ta - results_dict["test_loss"][iteration + 1] = tl - if trl is not None: - results_dict["train_loss"][iteration + 1] = trl - else: - self.testing_comm.send(0, self.model.get_weights()) - sender, data = self.testing_comm.receive() - assert sender == 0 and data == "finished" - else: - logging.info("Evaluating on test set.") - ta, tl = self.dataset.test(self.model, self.loss) - results_dict["test_acc"][iteration + 1] = ta - results_dict["test_loss"][iteration + 1] = tl - - if global_epoch == 49: - change *= 2 - - global_epoch += change - - with open( - os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w" - ) as of: - json.dump(results_dict, of) - if self.model.shared_parameters_counter is not None: - logging.info("Saving the shared parameter counts") - with open( - os.path.join( - self.log_dir, "{}_shared_parameters.json".format(self.rank) - ), - "w", - ) as of: - json.dump(self.model.shared_parameters_counter.numpy().tolist(), of) - self.communication.disconnect_neighbors() - logging.info("Storing final weight") - self.model.dump_weights(self.weights_store_dir, self.uid, iteration) - logging.info("All neighbors disconnected. Process complete!") + raise NotImplementedError def __init__( self, @@ -518,13 +386,7 @@ class Node: config, iterations=1, log_dir=".", - weights_store_dir=".", log_level=logging.INFO, - test_after=5, - train_evaluate_after=1, - reset_optimizer=1, - centralized_train_eval=0, - centralized_test_eval=1, *args ): """ @@ -559,28 +421,12 @@ class Node: log_dir : str Logging directory weights_store_dir : str - Directory in which to store model weights log_level : logging.Level One of DEBUG, INFO, WARNING, ERROR, CRITICAL - test_after : int - Number of iterations after which the test loss and accuracy arecalculated - train_evaluate_after : int - Number of iterations after which the train loss is calculated - reset_optimizer : int - 1 if optimizer should be reset every communication round, else 0 - centralized_train_eval : int - If set then the train set evaluation happens at the node with uid 0. - Note: If it is True then centralized_test_eval needs to be true as well! - centralized_test_eval : int - If set then the trainset evaluation happens at the node with uid 0 args : optional Other arguments """ - centralized_train_eval = centralized_train_eval == 1 - centralized_test_eval = centralized_test_eval == 1 - # If centralized_train_eval is True then centralized_test_eval needs to be true as well! - assert not centralized_train_eval or centralized_test_eval total_threads = os.cpu_count() self.threads_per_proc = max( @@ -588,25 +434,17 @@ class Node: ) torch.set_num_threads(self.threads_per_proc) torch.set_num_interop_threads(1) - self.instantiate( - rank, - machine_id, - mapping, - graph, - config, - iterations, - log_dir, - weights_store_dir, - log_level, - test_after, - train_evaluate_after, - reset_optimizer, - centralized_train_eval == 1, - centralized_test_eval == 1, - *args - ) + # self.instantiate( + # rank, + # machine_id, + # mapping, + # graph, + # config, + # iterations, + # log_dir, + # log_level, + # *args + # ) logging.info( "Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads ) - - self.run() diff --git a/src/decentralizepy/node/PeerSampler.py b/src/decentralizepy/node/PeerSampler.py new file mode 100644 index 0000000..8f8db6f --- /dev/null +++ b/src/decentralizepy/node/PeerSampler.py @@ -0,0 +1,221 @@ +import importlib +import logging +from collections import deque + +from decentralizepy import utils +from decentralizepy.graphs.Graph import Graph +from decentralizepy.mappings.Mapping import Mapping +from decentralizepy.node.Node import Node + + +class PeerSampler(Node): + """ + This class defines the peer sampling service + + """ + + def cache_fields( + self, + rank, + machine_id, + mapping, + graph, + iterations, + log_dir, + ): + """ + Instantiate object field with arguments. + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + + """ + self.rank = rank + self.machine_id = machine_id + self.graph = graph + self.mapping = mapping + self.uid = self.mapping.get_uid(rank, machine_id) + self.log_dir = log_dir + self.iterations = iterations + self.sent_disconnections = False + + logging.info("Rank: %d", self.rank) + logging.info("type(graph): %s", str(type(self.rank))) + logging.info("type(mapping): %s", str(type(self.mapping))) + + def init_comm(self, comm_configs): + """ + Instantiate communication module from config. + + Parameters + ---------- + comm_configs : dict + Python dict containing communication config params + + """ + comm_module = importlib.import_module(comm_configs["comm_package"]) + comm_class = getattr(comm_module, comm_configs["comm_class"]) + comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"]) + self.addresses_filepath = comm_params.get("addresses_filepath", None) + self.communication = comm_class( + self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params + ) + + def instantiate( + self, + rank: int, + machine_id: int, + mapping: Mapping, + graph: Graph, + config, + iterations=1, + log_dir=".", + log_level=logging.INFO, + *args + ): + """ + Construct objects. + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + config : dict + A dictionary of configurations. + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + log_level : logging.Level + One of DEBUG, INFO, WARNING, ERROR, CRITICAL + args : optional + Other arguments + + """ + logging.info("Started process.") + + self.init_log(log_dir, log_level) + + self.cache_fields( + rank, + machine_id, + mapping, + graph, + iterations, + log_dir, + ) + + self.message_queue = deque() + self.barrier = set() + + self.init_comm(config["COMMUNICATION"]) + self.my_neighbors = self.graph.get_all_nodes() + self.connect_neighbours() + + def run(self): + """ + Start the peer-sampling service. + + """ + while len(self.barrier) > 0: + sender, data = self.receive() + if "BYE" in data: + logging.debug("Received {} from {}".format("BYE", sender)) + self.barrier.remove(sender) + else: + logging.debug("Received {} from {}".format("Request", sender)) + resp = {"neighbors": self.get_neighbors(sender)} + self.communication.send(sender, resp) + + def __init__( + self, + rank: int, + machine_id: int, + mapping: Mapping, + graph: Graph, + config, + iterations=1, + log_dir=".", + log_level=logging.INFO, + *args + ): + """ + Constructor + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + config : dict + A dictionary of configurations. Must contain the following: + [DATASET] + dataset_package + dataset_class + model_class + [OPTIMIZER_PARAMS] + optimizer_package + optimizer_class + [TRAIN_PARAMS] + training_package = decentralizepy.training.Training + training_class = Training + epochs_per_round = 25 + batch_size = 64 + iterations : int + Number of iterations (communication steps) for which the model should be trained + log_dir : str + Logging directory + log_level : logging.Level + One of DEBUG, INFO, WARNING, ERROR, CRITICAL + args : optional + Other arguments + + """ + super().__init__( + rank, + machine_id, + mapping, + graph, + config, + iterations, + log_dir, + log_level, + *args + ) + + self.instantiate( + rank, + machine_id, + mapping, + graph, + config, + iterations, + log_dir, + log_level, + *args + ) + + self.run() diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py index 17650c1..60c60ec 100644 --- a/src/decentralizepy/sharing/FFT.py +++ b/src/decentralizepy/sharing/FFT.py @@ -1,8 +1,6 @@ import json import logging import os -from pathlib import Path -from time import time import numpy as np import torch @@ -53,6 +51,9 @@ class FFT(PartialModel): save_accumulated="", accumulation=True, accumulate_averaging_changes=False, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -111,6 +112,9 @@ class FFT(PartialModel): save_accumulated, change_transformer_fft, accumulate_averaging_changes, + compress, + compression_package, + compression_class, ) self.change_based_selection = change_based_selection @@ -163,7 +167,7 @@ class FFT(PartialModel): self.model.accumulated_changes = torch.zeros_like( self.model.accumulated_changes ) - return m + return self.compress_data(m) with torch.no_grad(): topk, indices = self.apply_fft() @@ -199,7 +203,7 @@ class FFT(PartialModel): m["indices"] = indices.numpy().astype(np.int32) m["send_partial"] = True - return m + return self.compress_data(m) def deserialized_model(self, m): """ @@ -216,6 +220,8 @@ class FFT(PartialModel): state_dict of received """ + m = self.decompress_data(m) + ret = dict() if "send_partial" not in m: params = m["params"] @@ -237,7 +243,7 @@ class FFT(PartialModel): ret["send_partial"] = True return ret - def _averaging(self): + def _averaging(self, peer_deques): """ Averages the received model with the local model @@ -251,8 +257,11 @@ class FFT(PartialModel): pre_share_model = torch.cat(tensors_to_cat, dim=0) flat_fft = self.change_transformer(pre_share_model) - for i, n in enumerate(self.peer_deques): - degree, iteration, data = self.peer_deques[n].popleft() + for i, n in enumerate(peer_deques): + data = peer_deques[n].popleft() + degree, iteration = data["degree"], data["iteration"] + del data["degree"] + del data["iteration"] logging.debug( "Averaging model from neighbor {} of iteration {}".format( n, iteration @@ -268,7 +277,7 @@ class FFT(PartialModel): else: topkf = params - weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings + weight = 1 / (max(len(peer_deques), degree) + 1) # Metro-Hastings weight_total += weight if total is None: total = weight * topkf @@ -289,3 +298,5 @@ class FFT(PartialModel): start_index = end_index self.model.load_state_dict(std_dict) + self._post_step() + self.communication_round += 1 diff --git a/src/decentralizepy/sharing/GrowingAlpha.py b/src/decentralizepy/sharing/GrowingAlpha.py index 7fe7bf5..a13a869 100644 --- a/src/decentralizepy/sharing/GrowingAlpha.py +++ b/src/decentralizepy/sharing/GrowingAlpha.py @@ -1,3 +1,4 @@ +# Deprecated import logging from decentralizepy.sharing.PartialModel import PartialModel @@ -25,6 +26,9 @@ class GrowingAlpha(PartialModel): dict_ordered=True, save_shared=False, metadata_cap=1.0, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -74,12 +78,15 @@ class GrowingAlpha(PartialModel): dict_ordered, save_shared, metadata_cap, + compress, + compression_package, + compression_class, ) self.init_alpha = init_alpha self.max_alpha = max_alpha self.k = k - def step(self): + def get_data_to_send(self): """ Perform a sharing step. Implements D-PSGD with alpha increasing as a linear function. @@ -93,4 +100,4 @@ class GrowingAlpha(PartialModel): self.communication_round += 1 return - super().step() + return super().get_data_to_send() diff --git a/src/decentralizepy/sharing/LowerBoundTopK.py b/src/decentralizepy/sharing/LowerBoundTopK.py index 6ac5329..86b9c3b 100644 --- a/src/decentralizepy/sharing/LowerBoundTopK.py +++ b/src/decentralizepy/sharing/LowerBoundTopK.py @@ -24,6 +24,9 @@ class LowerBoundTopK(PartialModel): log_dir, lower_bound=0.1, metro_hastings=True, + compress=False, + compression_package=None, + compression_class=None, **kwargs, ): """ @@ -81,7 +84,9 @@ class LowerBoundTopK(PartialModel): model, dataset, log_dir, - **kwargs, + compress, + compression_package, + compression_class**kwargs, ) self.lower_bound = lower_bound self.metro_hastings = metro_hastings @@ -154,6 +159,8 @@ class LowerBoundTopK(PartialModel): if "send_partial" not in m: return super().deserialized_model(m) + m = self.decompress_data(m) + with torch.no_grad(): state_dict = self.model.state_dict() @@ -169,7 +176,7 @@ class LowerBoundTopK(PartialModel): return T, index_tensor - def _averaging(self): + def _averaging(self, peer_deques): """ Averages the received model with the local model @@ -187,8 +194,11 @@ class LowerBoundTopK(PartialModel): weight_total = 0 weight_vector = torch.ones_like(self.init_model) datas = [] - for i, n in enumerate(self.peer_deques): - degree, iteration, data = self.peer_deques[n].popleft() + for i, n in enumerate(peer_deques): + data = peer_deques[n].popleft() + degree, iteration = data["degree"], data["iteration"] + del data["degree"] + del data["iteration"] logging.debug( "Averaging model from neighbor {} of iteration {}".format( n, iteration @@ -215,3 +225,5 @@ class LowerBoundTopK(PartialModel): logging.info("new averaging") self.model.load_state_dict(total) + self._post_step() + self.communication_round += 1 diff --git a/src/decentralizepy/sharing/ManualAdapt.py b/src/decentralizepy/sharing/ManualAdapt.py index dcb94cf..9a54eb7 100644 --- a/src/decentralizepy/sharing/ManualAdapt.py +++ b/src/decentralizepy/sharing/ManualAdapt.py @@ -1,3 +1,4 @@ +# Deprecated import logging from decentralizepy.sharing.PartialModel import PartialModel @@ -24,6 +25,9 @@ class ManualAdapt(PartialModel): dict_ordered=True, save_shared=False, metadata_cap=1.0, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -81,11 +85,14 @@ class ManualAdapt(PartialModel): dict_ordered, save_shared, metadata_cap, + compress, + compression_package, + compression_class, ) self.change_alpha = change_alpha[1:] self.change_rounds = change_rounds - def step(self): + def get_data_to_send(self): """ Perform a sharing step. Implements D-PSGD with alpha manually given. @@ -101,6 +108,6 @@ class ManualAdapt(PartialModel): if self.alpha == 0.0: logging.info("Not sending/receiving data (alpha=0.0)") self.communication_round += 1 - return + return dict() - super().step() + return super().get_data_to_send() diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 3111e82..b302f5d 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -34,6 +34,9 @@ class PartialModel(Sharing): save_accumulated="", change_transformer=identity, accumulate_averaging_changes=False, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -76,7 +79,17 @@ class PartialModel(Sharing): """ super().__init__( - rank, machine_id, communication, mapping, graph, model, dataset, log_dir + rank, + machine_id, + communication, + mapping, + graph, + model, + dataset, + log_dir, + compress, + compression_package, + compression_class, ) self.alpha = alpha self.dict_ordered = dict_ordered @@ -129,6 +142,23 @@ class PartialModel(Sharing): self.change_transformer(self.init_model).shape[0], dtype=torch.int32 ) + def compress_data(self, data): + result = dict(data) + if self.compress: + if "indices" in result: + result["indices"] = self.compressor.compress(result["indices"]) + if "params" in result: + result["params"] = self.compressor.compress_float(result["params"]) + return result + + def decompress_data(self, data): + if self.compress: + if "indices" in data: + data["indices"] = self.compressor.decompress(data["indices"]) + if "params" in data: + data["params"] = self.compressor.decompress_float(data["params"]) + return data + def extract_top_gradients(self): """ Extract the indices and values of the topK gradients. @@ -220,7 +250,7 @@ class PartialModel(Sharing): logging.info("Converted dictionary to pickle") - return m + return self.compress_data(m) def deserialized_model(self, m): """ @@ -241,6 +271,8 @@ class PartialModel(Sharing): return super().deserialized_model(m) with torch.no_grad(): + m = self.decompress_data(m) + state_dict = self.model.state_dict() if not self.dict_ordered: diff --git a/src/decentralizepy/sharing/RandomAlpha.py b/src/decentralizepy/sharing/RandomAlpha.py index 1956c29..3bac634 100644 --- a/src/decentralizepy/sharing/RandomAlpha.py +++ b/src/decentralizepy/sharing/RandomAlpha.py @@ -28,6 +28,9 @@ class RandomAlpha(PartialModel): save_accumulated="", change_transformer=identity, accumulate_averaging_changes=False, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -75,14 +78,17 @@ class RandomAlpha(PartialModel): save_accumulated, change_transformer, accumulate_averaging_changes, + compress, + compression_package, + compression_class, ) self.alpha_list = eval(alpha_list) random.seed(self.mapping.get_uid(self.rank, self.machine_id)) - def step(self): + def get_data_to_send(self): """ Perform a sharing step. Implements D-PSGD with alpha randomly chosen. """ self.alpha = random.choice(self.alpha_list) - super().step() + return super().get_data_to_send() diff --git a/src/decentralizepy/sharing/RandomAlphaIncremental.py b/src/decentralizepy/sharing/RandomAlphaIncremental.py index c3b7c0d..96ead3d 100644 --- a/src/decentralizepy/sharing/RandomAlphaIncremental.py +++ b/src/decentralizepy/sharing/RandomAlphaIncremental.py @@ -1,3 +1,4 @@ +# Deprecated import random from decentralizepy.sharing.PartialModel import PartialModel @@ -24,6 +25,9 @@ class RandomAlphaIncremental(PartialModel): metadata_cap=1.0, range_start=0.1, range_end=0.2, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -67,16 +71,19 @@ class RandomAlphaIncremental(PartialModel): dict_ordered, save_shared, metadata_cap, + compress, + compression_package, + compression_class, ) random.seed(self.mapping.get_uid(self.rank, self.machine_id)) self.range_start = range_start self.range_end = range_end - def step(self): + def get_data_to_send(self): """ Perform a sharing step. Implements D-PSGD with alpha randomly chosen from an increasing range. """ self.alpha = round(random.uniform(self.range_start, self.range_end), 2) self.range_end = min(1.0, self.range_end + round(random.uniform(0.0, 0.1), 2)) - super().step() + return super().get_data_to_send() diff --git a/src/decentralizepy/sharing/RandomAlphaWavelet.py b/src/decentralizepy/sharing/RandomAlphaWavelet.py index 44ea336..de2a5e6 100644 --- a/src/decentralizepy/sharing/RandomAlphaWavelet.py +++ b/src/decentralizepy/sharing/RandomAlphaWavelet.py @@ -29,6 +29,9 @@ class RandomAlpha(Wavelet): save_accumulated="", accumulation=False, accumulate_averaging_changes=False, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -78,14 +81,17 @@ class RandomAlpha(Wavelet): save_accumulated, accumulation, accumulate_averaging_changes, + compress, + compression_package, + compression_class, ) self.alpha_list = eval(alpha_list) random.seed(self.mapping.get_uid(self.rank, self.machine_id)) - def step(self): + def get_data_to_send(self): """ Perform a sharing step. Implements D-PSGD with alpha randomly chosen. """ self.alpha = random.choice(self.alpha_list) - super().step() + return super().get_data_to_send() diff --git a/src/decentralizepy/sharing/RoundRobinPartial.py b/src/decentralizepy/sharing/RoundRobinPartial.py index c5288a5..fbe0179 100644 --- a/src/decentralizepy/sharing/RoundRobinPartial.py +++ b/src/decentralizepy/sharing/RoundRobinPartial.py @@ -25,6 +25,9 @@ class RoundRobinPartial(Sharing): dataset, log_dir, alpha=1.0, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -52,7 +55,17 @@ class RoundRobinPartial(Sharing): """ super().__init__( - rank, machine_id, communication, mapping, graph, model, dataset, log_dir + rank, + machine_id, + communication, + mapping, + graph, + model, + dataset, + log_dir, + compress, + compression_package, + compression_class, ) self.alpha = alpha random.seed(self.mapping.get_uid(rank, machine_id)) @@ -104,7 +117,7 @@ class RoundRobinPartial(Sharing): logging.info("Converted dictionary to json") self.total_data += len(self.communication.encrypt(m["params"])) - return m + return self.compress_data(m) def deserialized_model(self, m): """ @@ -121,9 +134,9 @@ class RoundRobinPartial(Sharing): state_dict of received """ + m = self.decompress_data(m) with torch.no_grad(): state_dict = self.model.state_dict() - shapes = [] lens = [] tensors_to_cat = [] diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index 0ad3927..7ef18cc 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -1,5 +1,5 @@ +import importlib import logging -from collections import deque import torch @@ -11,7 +11,18 @@ class Sharing: """ def __init__( - self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir + self, + rank, + machine_id, + communication, + mapping, + graph, + model, + dataset, + log_dir, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -47,11 +58,6 @@ class Sharing: self.communication_round = 0 self.log_dir = log_dir - self.peer_deques = dict() - self.my_neighbors = self.graph.neighbors(self.uid) - for n in self.my_neighbors: - self.peer_deques[n] = deque() - self.shapes = [] self.lens = [] with torch.no_grad(): @@ -60,38 +66,28 @@ class Sharing: t = v.flatten().numpy() self.lens.append(t.shape[0]) - def received_from_all(self): - """ - Check if all neighbors have sent the current iteration - - Returns - ------- - bool - True if required data has been received, False otherwise - - """ - for _, i in self.peer_deques.items(): - if len(i) == 0: - return False - return True - - def get_neighbors(self, neighbors): - """ - Choose which neighbors to share with - - Parameters - ---------- - neighbors : list(int) - List of all neighbors - - Returns - ------- - list(int) - Neighbors to share with - - """ - # modify neighbors here - return neighbors + self.compress = compress + + if compression_package and compression_class: + compressor_module = importlib.import_module(compression_package) + compressor_class = getattr(compressor_module, compression_class) + self.compressor = compressor_class() + logging.info(f"Using the {compressor_class} to compress the data") + else: + assert not self.compress + + def compress_data(self, data): + result = dict(data) + if self.compress: + if "params" in result: + result["params"] = self.compressor.compress_float(result["params"]) + return result + + def decompress_data(self, data): + if self.compress: + if "params" in data: + data["params"] = self.compressor.decompress_float(data["params"]) + return data def serialized_model(self): """ @@ -111,7 +107,7 @@ class Sharing: flat = torch.cat(to_cat) data = dict() data["params"] = flat.numpy() - return data + return self.compress_data(data) def deserialized_model(self, m): """ @@ -129,11 +125,14 @@ class Sharing: """ state_dict = dict() + m = self.decompress_data(m) T = m["params"] start_index = 0 for i, key in enumerate(self.model.state_dict()): end_index = start_index + self.lens[i] - state_dict[key] = torch.from_numpy(T[start_index:end_index].reshape(self.shapes[i])) + state_dict[key] = torch.from_numpy( + T[start_index:end_index].reshape(self.shapes[i]) + ) start_index = end_index return state_dict @@ -152,7 +151,7 @@ class Sharing: """ pass - def _averaging(self): + def _averaging(self, peer_deques): """ Averages the received model with the local model @@ -160,15 +159,18 @@ class Sharing: with torch.no_grad(): total = dict() weight_total = 0 - for i, n in enumerate(self.peer_deques): - degree, iteration, data = self.peer_deques[n].popleft() + for i, n in enumerate(peer_deques): + data = peer_deques[n].popleft() + degree, iteration = data["degree"], data["iteration"] + del data["degree"] + del data["iteration"] logging.debug( "Averaging model from neighbor {} of iteration {}".format( n, iteration ) ) data = self.deserialized_model(data) - weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings + weight = 1 / (max(len(peer_deques), degree) + 1) # Metro-Hastings weight_total += weight for key, value in data.items(): if key in total: @@ -180,41 +182,14 @@ class Sharing: total[key] += (1 - weight_total) * value # Metro-Hastings self.model.load_state_dict(total) + self._post_step() + self.communication_round += 1 - def step(self): - """ - Perform a sharing step. Implements D-PSGD. - - """ + def get_data_to_send(self): self._pre_step() data = self.serialized_model() my_uid = self.mapping.get_uid(self.rank, self.machine_id) all_neighbors = self.graph.neighbors(my_uid) - iter_neighbors = self.get_neighbors(all_neighbors) data["degree"] = len(all_neighbors) data["iteration"] = self.communication_round - encrypted = self.communication.encrypt(data) - for neighbor in iter_neighbors: - self.communication.send(neighbor, encrypted, encrypt=False) - - logging.info("Waiting for messages from neighbors") - while not self.received_from_all(): - sender, data = self.communication.receive() - logging.debug("Received model from {}".format(sender)) - degree = data["degree"] - iteration = data["iteration"] - del data["degree"] - del data["iteration"] - self.peer_deques[sender].append((degree, iteration, data)) - logging.info( - "Deserialized received model from {} of iteration {}".format( - sender, iteration - ) - ) - - logging.info("Starting model averaging after receiving from all neighbors") - self._averaging() - logging.info("Model averaging complete") - - self.communication_round += 1 - self._post_step() + return data diff --git a/src/decentralizepy/sharing/SharingCentrality.py b/src/decentralizepy/sharing/SharingCentrality.py index f933a0e..05986ac 100644 --- a/src/decentralizepy/sharing/SharingCentrality.py +++ b/src/decentralizepy/sharing/SharingCentrality.py @@ -1,3 +1,4 @@ +# Deprecated import logging from collections import deque diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py index b51cb07..7201d33 100644 --- a/src/decentralizepy/sharing/SubSampling.py +++ b/src/decentralizepy/sharing/SubSampling.py @@ -31,6 +31,9 @@ class SubSampling(Sharing): metadata_cap=1.0, pickle=True, layerwise=False, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -66,7 +69,17 @@ class SubSampling(Sharing): """ super().__init__( - rank, machine_id, communication, mapping, graph, model, dataset, log_dir + rank, + machine_id, + communication, + mapping, + graph, + model, + dataset, + log_dir, + compress, + compression_package, + compression_class, ) self.alpha = alpha self.dict_ordered = dict_ordered @@ -215,7 +228,7 @@ class SubSampling(Sharing): m["alpha"] = alpha m["params"] = subsample.numpy() - return m + return self.compress_data(m) def deserialized_model(self, m): """ @@ -235,6 +248,8 @@ class SubSampling(Sharing): if self.alpha > self.metadata_cap: # Share fully return super().deserialized_model(m) + m = self.decompress_data(m) + with torch.no_grad(): state_dict = self.model.state_dict() diff --git a/src/decentralizepy/sharing/Synchronous.py b/src/decentralizepy/sharing/Synchronous.py index 2c2d5e7..7fc1c35 100644 --- a/src/decentralizepy/sharing/Synchronous.py +++ b/src/decentralizepy/sharing/Synchronous.py @@ -1,3 +1,4 @@ +# Deprecated import logging from collections import deque diff --git a/src/decentralizepy/sharing/TopKNormalized.py b/src/decentralizepy/sharing/TopKNormalized.py index 15a3caf..b281294 100644 --- a/src/decentralizepy/sharing/TopKNormalized.py +++ b/src/decentralizepy/sharing/TopKNormalized.py @@ -31,6 +31,9 @@ class TopKNormalized(PartialModel): change_transformer=identity, accumulate_averaging_changes=False, epsilon=0.01, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -91,6 +94,9 @@ class TopKNormalized(PartialModel): save_accumulated, change_transformer, accumulate_averaging_changes, + compress, + compression_package, + compression_class, ) self.epsilon = epsilon diff --git a/src/decentralizepy/sharing/TopKParams.py b/src/decentralizepy/sharing/TopKParams.py index f188179..c2b0e3f 100644 --- a/src/decentralizepy/sharing/TopKParams.py +++ b/src/decentralizepy/sharing/TopKParams.py @@ -29,6 +29,9 @@ class TopKParams(Sharing): dict_ordered=True, save_shared=False, metadata_cap=1.0, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -62,7 +65,17 @@ class TopKParams(Sharing): """ super().__init__( - rank, machine_id, communication, mapping, graph, model, dataset, log_dir + rank, + machine_id, + communication, + mapping, + graph, + model, + dataset, + log_dir, + compress, + compression_package, + compression_class, ) self.alpha = alpha self.dict_ordered = dict_ordered @@ -171,7 +184,7 @@ class TopKParams(Sharing): logging.info("Converted dictionary to json") - return m + return self.compress_data(m) def deserialized_model(self, m): """ @@ -191,6 +204,8 @@ class TopKParams(Sharing): if self.alpha > self.metadata_cap: # Share fully return super().deserialized_model(m) + m = self.decompress_data(m) + with torch.no_grad(): state_dict = self.model.state_dict() diff --git a/src/decentralizepy/sharing/TopKPlusRandom.py b/src/decentralizepy/sharing/TopKPlusRandom.py index 728d5bf..8962933 100644 --- a/src/decentralizepy/sharing/TopKPlusRandom.py +++ b/src/decentralizepy/sharing/TopKPlusRandom.py @@ -26,6 +26,9 @@ class TopKPlusRandom(PartialModel): dict_ordered=True, save_shared=False, metadata_cap=1.0, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -71,6 +74,9 @@ class TopKPlusRandom(PartialModel): dict_ordered, save_shared, metadata_cap, + compress, + compression_package, + compression_class, ) def extract_top_gradients(self): diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py index 91c97d0..ffc2558 100644 --- a/src/decentralizepy/sharing/Wavelet.py +++ b/src/decentralizepy/sharing/Wavelet.py @@ -1,8 +1,6 @@ import json import logging import os -from pathlib import Path -from time import time import numpy as np import pywt @@ -61,6 +59,9 @@ class Wavelet(PartialModel): save_accumulated="", accumulation=False, accumulate_averaging_changes=False, + compress=False, + compression_package=None, + compression_class=None, ): """ Constructor @@ -125,6 +126,9 @@ class Wavelet(PartialModel): save_accumulated, lambda x: change_transformer_wavelet(x, wavelet, level), accumulate_averaging_changes, + compress, + compression_package, + compression_class, ) self.change_based_selection = change_based_selection @@ -185,7 +189,7 @@ class Wavelet(PartialModel): self.model.accumulated_changes = torch.zeros_like( self.model.accumulated_changes ) - return m + return self.compress_data(m) with torch.no_grad(): topk, indices = self.apply_wavelet() @@ -223,7 +227,7 @@ class Wavelet(PartialModel): m["send_partial"] = True - return m + return self.compress_data(m) def deserialized_model(self, m): """ @@ -240,6 +244,7 @@ class Wavelet(PartialModel): state_dict of received """ + m = self.decompress_data(m) ret = dict() if "send_partial" not in m: params = m["params"] @@ -260,7 +265,7 @@ class Wavelet(PartialModel): ret["send_partial"] = True return ret - def _averaging(self): + def _averaging(self, peer_deques): """ Averages the received model with the local model @@ -269,8 +274,11 @@ class Wavelet(PartialModel): total = None weight_total = 0 wt_params = self.pre_share_model_transformed - for i, n in enumerate(self.peer_deques): - degree, iteration, data = self.peer_deques[n].popleft() + for i, n in enumerate(peer_deques): + data = peer_deques[n].popleft() + degree, iteration = data["degree"], data["iteration"] + del data["degree"] + del data["iteration"] logging.debug( "Averaging model from neighbor {} of iteration {}".format( n, iteration @@ -287,7 +295,7 @@ class Wavelet(PartialModel): else: topkwf = params.reshape(self.wt_shape) - weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings + weight = 1 / (max(len(peer_deques), degree) + 1) # Metro-Hastings weight_total += weight if total is None: total = weight * topkwf @@ -314,3 +322,5 @@ class Wavelet(PartialModel): start_index = end_index self.model.load_state_dict(std_dict) + self._post_step() + self.communication_round += 1 diff --git a/src/decentralizepy/train_test_evaluation.py b/src/decentralizepy/train_test_evaluation.py index 319d308..95f407c 100644 --- a/src/decentralizepy/train_test_evaluation.py +++ b/src/decentralizepy/train_test_evaluation.py @@ -1,6 +1,5 @@ import logging import os -import pickle from pathlib import Path import numpy as np -- GitLab