From b132f11767bab7539bf5fb01b69b11e838c48f49 Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Thu, 18 Nov 2021 19:02:22 +0100 Subject: [PATCH] Sharing --- config.ini | 9 +- graph.adj | 7 ++ main.ipynb | 103 +++++++++++++--- .../communication/Communication.py | 96 ++------------- src/decentralizepy/communication/TCP.py | 112 ++++++++++++++++++ src/decentralizepy/graphs/Graph.py | 2 +- src/decentralizepy/node/Node.py | 23 +++- src/decentralizepy/sharing/Sharing.py | 80 ++++++++++++- testing.py | 35 ++++-- 9 files changed, 342 insertions(+), 125 deletions(-) create mode 100644 graph.adj create mode 100644 src/decentralizepy/communication/TCP.py diff --git a/config.ini b/config.ini index 5c0f3b3..e80c1db 100644 --- a/config.ini +++ b/config.ini @@ -6,7 +6,7 @@ graph_class = SmallWorld dataset_package = decentralizepy.datasets.Femnist dataset_class = Femnist model_class = CNN -n_procs = 1 +n_procs = 6 train_dir = leaf/data/femnist/data/train test_dir = leaf/data/femnist/data/test ; python list of fractions below @@ -20,17 +20,16 @@ lr = 0.01 [TRAIN_PARAMS] training_package = decentralizepy.training.Training training_class = Training -epochs_per_round = 25 +epochs_per_round = 5 batch_size = 512 shuffle = True loss_package = torch.nn loss_class = CrossEntropyLoss [COMMUNICATION] -comm_package = decentralizepy.communication.Communication -comm_class = Communication +comm_package = decentralizepy.communication.TCP +comm_class = TCP addresses_filepath = ip_addr.json -total_procs = 4 [SHARING] sharing_package = decentralizepy.sharing.Sharing diff --git a/graph.adj b/graph.adj new file mode 100644 index 0000000..f5d1331 --- /dev/null +++ b/graph.adj @@ -0,0 +1,7 @@ +6 +1 +0 3 4 +3 5 +1 2 5 +1 +2 3 \ No newline at end of file diff --git a/main.ipynb b/main.ipynb index a45d562..8d3f9e1 100644 --- a/main.ipynb +++ b/main.ipynb @@ -346,14 +346,81 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Section: GRAPH\n", + "('package', 'decentralizepy.graphs.SmallWorld')\n", + "('graph_class', 'SmallWorld')\n", + "Section: DATASET\n", + "('dataset_package', 'decentralizepy.datasets.Femnist')\n", + "('dataset_class', 'Femnist')\n", + "('model_class', 'CNN')\n", + "('n_procs', 2)\n", + "('train_dir', 'leaf/data/femnist/data/train')\n", + "('test_dir', 'leaf/data/femnist/data/test')\n", + "('sizes', '')\n", + "Section: OPTIMIZER_PARAMS\n", + "('optimizer_package', 'torch.optim')\n", + "('optimizer_class', 'Adam')\n", + "('lr', 0.01)\n", + "Section: TRAIN_PARAMS\n", + "('training_package', 'decentralizepy.training.Training')\n", + "('training_class', 'Training')\n", + "('epochs_per_round', 1)\n", + "('batch_size', 512)\n", + "('shuffle', True)\n", + "('loss_package', 'torch.nn')\n", + "('loss_class', 'CrossEntropyLoss')\n", + "Section: COMMUNICATION\n", + "('comm_package', 'decentralizepy.communication.TCP')\n", + "('comm_class', 'TCP')\n", + "('addresses_filepath', 'ip_addr.json')\n", + "Section: SHARING\n", + "('sharing_package', 'decentralizepy.sharing.Sharing')\n", + "('sharing_class', 'Sharing')\n", + "{'dataset_package': 'decentralizepy.datasets.Femnist', 'dataset_class': 'Femnist', 'model_class': 'CNN', 'n_procs': 2, 'train_dir': 'leaf/data/femnist/data/train', 'test_dir': 'leaf/data/femnist/data/test', 'sizes': ''}\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 0\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 1\n", + "n: <class 'int'> 0\n" + ] + } + ], "source": [ "%matplotlib inline\n", "\n", "from decentralizepy.node.Node import Node\n", "from decentralizepy.graphs.SmallWorld import SmallWorld\n", + "from decentralizepy.graphs.Graph import Graph\n", "from decentralizepy.mappings.Linear import Linear\n", "from torch import multiprocessing as mp\n", "import torch\n", @@ -376,31 +443,37 @@ " my_config[section] = dict(config.items(section))\n", "\n", "#f = Femnist(2, 'leaf/data/femnist/data/train', sizes=[0.6, 0.4])\n", - "g = SmallWorld(4, 1, 0.5)\n", - "print(g)\n", - "l = Linear(2, 2)\n", + "g = Graph()\n", + "g.read_graph_from_file(\"graph.adj\", \"adjacency\")\n", + "l = Linear(1, 6)\n", "\n", "#Node(0, 0, l, g, my_config, 20, \"results\", logging.DEBUG)\n", "\n", - "#mp.spawn(fn = Node, nprocs = 1, args=[0,l,g,my_config,20,\"results\",logging.DEBUG])\n", + "mp.spawn(fn = Node, nprocs = 6, args=[0,l,g,my_config,20,\"results\",logging.DEBUG])\n", "\n", "# mp.spawn(fn = Node, args = [l, g, config, 10, \"results\", logging.DEBUG], nprocs=2)\n" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": {}, "outputs": [ { - "ename": "NameError", - "evalue": "name 'mp' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_1457289/353106489.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mmp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspawn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnprocs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"ip_addr.json\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'mp' is not defined" + "name": "stdout", + "output_type": "stream", + "text": [ + "Message sent\n", + "Message sent\n", + "1 (0, {'message': 'Hi I am rank 0'})\n", + "0 (1, {'message': 'Hi I am rank 1'})\n" ] } ], diff --git a/src/decentralizepy/communication/Communication.py b/src/decentralizepy/communication/Communication.py index 188af12..bd3004e 100644 --- a/src/decentralizepy/communication/Communication.py +++ b/src/decentralizepy/communication/Communication.py @@ -1,105 +1,25 @@ -import json -import logging -from collections import deque - -import zmq - -HELLO = b"HELLO" -BYE = b"BYE" - - class Communication: """ Communcation API """ - - def addr(self, rank, machine_id): - machine_addr = self.ip_addrs[str(machine_id)] - port = rank + 20000 - return "tcp://{}:{}".format(machine_addr, port) - - def __init__(self, rank, machine_id, total_procs, addresses_filepath, mapping): - with open(addresses_filepath) as addrs: - self.ip_addrs = json.load(addrs) - + def __init__(self, rank, machine_id, mapping, total_procs): self.total_procs = total_procs self.rank = rank self.machine_id = machine_id self.mapping = mapping self.uid = mapping.get_uid(rank, machine_id) - self.identity = str(self.uid).encode() - self.context = zmq.Context() - self.router = self.context.socket(zmq.ROUTER) - self.router.setsockopt(zmq.IDENTITY, self.identity) - self.router.bind(self.addr(rank, machine_id)) - self.sent_disconnections = False - - self.peer_deque = deque() - self.peer_sockets = dict() - self.barrier = set() def encrypt(self, data): - return json.dumps(data).encode("utf8") + raise NotImplementedError def decrypt(self, sender, data): - sender = int(sender.decode()) - data = json.loads(data.decode("utf8")) - return sender, data - - def connect_neighbours(self, neighbours): - for uid in neighbours: - id = str(uid).encode() - req = self.context.socket(zmq.DEALER) - req.setsockopt(zmq.IDENTITY, self.identity) - req.connect(self.addr(*self.mapping.get_machine_and_rank(uid))) - self.peer_sockets[id] = req - req.send(HELLO) - - num_neighbours = len(neighbours) - while len(self.barrier) < num_neighbours: - sender, recv = self.router.recv_multipart() - - if recv == HELLO: - logging.info("Recieved {} from {}".format(HELLO, sender)) - self.barrier.add(sender) - elif recv == BYE: - logging.info("Recieved {} from {}".format(BYE, sender)) - raise RuntimeError( - "A neighbour wants to disconnect before training started!" - ) - else: - logging.info( - "Recieved message from {} @ connect_neighbours".format(sender) - ) - - self.peer_deque.append(self.decrypt(sender, recv)) + raise NotImplementedError + def connect_neighbors(self, neighbors): + raise NotImplementedError + def receive(self): - if len(self.peer_deque) != 0: - resp = self.peer_deque[0] - self.peer_deque.popleft() - return resp - - sender, recv = self.router.recv_multipart() - - if recv == HELLO: - logging.info("Recieved {} from {}".format(HELLO, sender)) - raise RuntimeError( - "A neighbour wants to connect when everyone is connected!" - ) - elif recv == BYE: - logging.info("Recieved {} from {}".format(BYE, sender)) - self.barrier.remove(sender) - if not self.sent_disconnections: - for sock in self.peer_sockets.values(): - sock.send(BYE) - self.sent_disconnections = True - else: - logging.info("Recieved message from {}".format(sender)) - return self.decrypt(sender, recv) + raise NotImplementedError def send(self, uid, data): - to_send = self.encrypt(data) - id = str(uid).encode() - self.peer_sockets[id].send(to_send) - print("{} sent the message to {}.".format(self.uid, uid)) + raise NotImplementedError diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py new file mode 100644 index 0000000..898038b --- /dev/null +++ b/src/decentralizepy/communication/TCP.py @@ -0,0 +1,112 @@ +import json +import logging +from collections import deque + +import zmq + +from decentralizepy.communication.Communication import Communication + +HELLO = b"HELLO" +BYE = b"BYE" + + +class TCP(Communication): + """ + TCP Communication API + """ + + def addr(self, rank, machine_id): + machine_addr = self.ip_addrs[str(machine_id)] + port = rank + 20000 + return "tcp://{}:{}".format(machine_addr, port) + + def __init__(self, rank, machine_id, mapping, total_procs, addresses_filepath): + super().__init__(rank, machine_id, mapping, total_procs) + + with open(addresses_filepath) as addrs: + self.ip_addrs = json.load(addrs) + + self.total_procs = total_procs + self.rank = rank + self.machine_id = machine_id + self.mapping = mapping + self.uid = mapping.get_uid(rank, machine_id) + self.identity = str(self.uid).encode() + self.context = zmq.Context() + self.router = self.context.socket(zmq.ROUTER) + self.router.setsockopt(zmq.IDENTITY, self.identity) + self.router.bind(self.addr(rank, machine_id)) + self.sent_disconnections = False + + self.peer_deque = deque() + self.peer_sockets = dict() + self.barrier = set() + + def __del__(self): + self.context.destroy(linger=0) + + def encrypt(self, data): + return json.dumps(data).encode("utf8") + + def decrypt(self, sender, data): + sender = int(sender.decode()) + data = json.loads(data.decode("utf8")) + return sender, data + + def connect_neighbors(self, neighbors): + for uid in neighbors: + id = str(uid).encode() + req = self.context.socket(zmq.DEALER) + req.setsockopt(zmq.IDENTITY, self.identity) + req.connect(self.addr(*self.mapping.get_machine_and_rank(uid))) + self.peer_sockets[id] = req + req.send(HELLO) + + num_neighbors = len(neighbors) + while len(self.barrier) < num_neighbors: + sender, recv = self.router.recv_multipart() + + if recv == HELLO: + logging.info("Recieved {} from {}".format(HELLO, sender)) + self.barrier.add(sender) + elif recv == BYE: + logging.info("Recieved {} from {}".format(BYE, sender)) + raise RuntimeError( + "A neighbour wants to disconnect before training started!" + ) + else: + logging.info( + "Recieved message from {} @ connect_neighbors".format(sender) + ) + + self.peer_deque.append(self.decrypt(sender, recv)) + + def receive(self): + if len(self.peer_deque) != 0: + resp = self.peer_deque[0] + self.peer_deque.popleft() + return resp + + sender, recv = self.router.recv_multipart() + + if recv == HELLO: + logging.info("Recieved {} from {}".format(HELLO, sender)) + raise RuntimeError( + "A neighbour wants to connect when everyone is connected!" + ) + elif recv == BYE: + logging.info("Recieved {} from {}".format(BYE, sender)) + self.barrier.remove(sender) + if not self.sent_disconnections: + for sock in self.peer_sockets.values(): + sock.send(BYE) + self.sent_disconnections = True + else: + logging.info("Recieved message from {}".format(sender)) + return self.decrypt(sender, recv) + + def send(self, uid, data): + to_send = self.encrypt(data) + id = str(uid).encode() + self.peer_sockets[id].send(to_send) + logging.info("{} sent the message to {}.".format(self.uid, uid)) diff --git a/src/decentralizepy/graphs/Graph.py b/src/decentralizepy/graphs/Graph.py index e812c76..d21f97b 100644 --- a/src/decentralizepy/graphs/Graph.py +++ b/src/decentralizepy/graphs/Graph.py @@ -73,7 +73,7 @@ class Graph: elif type == "adjacency": node_id = 0 for line in lines: - neighbours = line.strip().split() + neighbours = map(int, line.strip().split()) self.__insert_adj__(node_id, neighbours) node_id += 1 else: diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index 167cf2e..9a4e9cc 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -3,6 +3,7 @@ import logging import os from decentralizepy import utils +from decentralizepy.communication.Communication import Communication from decentralizepy.graphs.Graph import Graph from decentralizepy.mappings.Mapping import Mapping @@ -70,8 +71,10 @@ class Node: logging.info("Started process.") self.rank = rank + self.machine_id = machine_id self.graph = graph self.mapping = mapping + self.uid = self.mapping.get_uid(rank, machine_id) logging.debug("Rank: %d", self.rank) logging.debug("type(graph): %s", str(type(self.rank))) @@ -125,10 +128,28 @@ class Node: ) self.trainer = train_class(self.model, self.optimizer, loss, **train_params) - self.testset = self.dataset.get_trainset() + + comm_configs = config["COMMUNICATION"] + comm_module = importlib.import_module(comm_configs["comm_package"]) + comm_class = getattr(comm_module, comm_configs["comm_class"]) + comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"]) + self.communication = comm_class(self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params) + self.communication.connect_neighbors(self.graph.neighbors(self.uid)) + + sharing_configs = config["SHARING"] + sharing_package = importlib.import_module(sharing_configs["sharing_package"]) + sharing_class = getattr(sharing_package, sharing_configs["sharing_class"]) + self.sharing = sharing_class(self.rank, self.machine_id, self.communication, self.mapping, self.graph, self.model, self.dataset) + + + + self.testset = self.dataset.get_testset() for iteration in range(iterations): logging.info("Starting training iteration: %d", iteration) self.trainer.train(self.dataset) + + self.sharing.step() + if self.dataset.__testing__: self.dataset.test(self.model) diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index 5c90e5f..8f09689 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -1,7 +1,83 @@ +from collections import deque +import json +import logging +import torch +import numpy + class Sharing: """ API defining who to share with and what, and what to do on receiving """ - def __init__(): - raise NotImplementedError + def __init__(self, rank, machine_id, communication, mapping, graph, model, dataset): + self.rank = rank + self.machine_id = machine_id + self.uid = mapping.get_uid(rank, machine_id) + self.communication = communication + self.mapping = mapping + self.graph = graph + self.model = model + self.dataset = dataset + + self.peer_deques = dict() + my_neighbors = self.graph.neighbors(self.uid) + for n in my_neighbors: + self.peer_deques[n] = deque() + + def received_from_all(self): + for _, i in self.peer_deques.items(): + if len(i) == 0: + return False + return True + + + def get_neighbors(self, neighbors): + # modify neighbors here + return neighbors + + def serialized_model(self): + m = dict() + for key, val in self.model.state_dict().items(): + m[key] = json.dumps(val.numpy().tolist()) + return m + + def deserialized_model(self, m): + state_dict = dict() + for key, value in m.items(): + state_dict[key] = torch.from_numpy(numpy.array(json.loads(value))) + return state_dict + + + def step(self): + data = self.serialized_model() + my_uid = self.mapping.get_uid(self.rank, self.machine_id) + all_neighbors = self.graph.neighbors(my_uid) + iter_neighbors = self.get_neighbors(all_neighbors) + data['degree'] = len(all_neighbors) + for neighbor in iter_neighbors: + self.communication.send(neighbor, data) + + while not self.received_from_all(): + sender, data = self.communication.receive() + logging.info("Received model from {}".format(sender)) + degree = data["degree"] + del data["degree"] + self.peer_deques[sender].append((degree, self.deserialized_model(data))) + + total = dict() + weight_total = 0 + for n in self.peer_deques: + degree, data = self.peer_deques[n].popleft() + #logging.info("top element: {}".format(d)) + weight = 1/(max(len(self.peer_deques), degree) + 1) # Metro-Hastings + weight_total += weight + for key, value in data.items(): + if key in total: + total[key] += value * weight + else: + total[key] = value * weight + + for key, value in self.model.state_dict().items(): + total[key] += (1 - weight_total) * value # Metro-Hastings + + self.model.load_state_dict(total) diff --git a/testing.py b/testing.py index 6058184..c2cb432 100644 --- a/testing.py +++ b/testing.py @@ -1,20 +1,29 @@ +from decentralizepy.node.Node import Node +from decentralizepy.graphs.Graph import Graph +from decentralizepy.mappings.Linear import Linear from torch import multiprocessing as mp +import logging -from decentralizepy.communication.Communication import Communication -from decentralizepy.mappings.Linear import Linear +from localconfig import LocalConfig +def read_ini(file_path): + config = LocalConfig(file_path) + for section in config: + print("Section: ", section) + for key, value in config.items(section): + print((key, value)) + print(dict(config.items('DATASET'))) + return config -def f(rank, m_id, total_procs, filePath, mapping): - c = Communication(rank, m_id, total_procs, filePath, mapping) - c.connect_neighbours([i for i in range(total_procs) if i != c.uid]) - send = {} - send["message"] = "Hi I am rank {}".format(rank) - c.send((c.uid + 1) % total_procs, send) - print(c.uid, c.receive()) +if __name__ == "__main__": + config = read_ini("config.ini") + my_config = dict() + for section in config: + my_config[section] = dict(config.items(section)) + g = Graph() + g.read_graph_from_file("graph.adj", "adjacency") + l = Linear(1, 6) -if __name__ == "__main__": - l = Linear(2, 2) - m_id = int(input()) - mp.spawn(fn=f, nprocs=2, args=[m_id, 4, "ip_addr.json", l]) + mp.spawn(fn = Node, nprocs = 6, args=[0,l,g,my_config,20,"results",logging.DEBUG]) -- GitLab