diff --git a/config.ini b/config.ini index 0b04725c7169b3b98f8dbcc09d51b2f3ad3d6a1f..5c0f3b3de704f225b8e7ec12f42c61f5d71a39dc 100644 --- a/config.ini +++ b/config.ini @@ -5,30 +5,32 @@ graph_class = SmallWorld [DATASET] dataset_package = decentralizepy.datasets.Femnist dataset_class = Femnist -model_class = LogisticRegression -n_procs = 1.0 +model_class = CNN +n_procs = 1 train_dir = leaf/data/femnist/data/train -test_dir = +test_dir = leaf/data/femnist/data/test ; python list of fractions below -sizes = [0.4, 0.2, 0.3, 0.1] +sizes = [OPTIMIZER_PARAMS] optimizer_package = torch.optim -optimizer_class = SGD -lr = 0.1 +optimizer_class = Adam +lr = 0.01 [TRAIN_PARAMS] training_package = decentralizepy.training.Training training_class = Training epochs_per_round = 25 -batch_size = 64 -shuffle = False -loss_package = torch.nn.functional -loss = nll_loss +batch_size = 512 +shuffle = True +loss_package = torch.nn +loss_class = CrossEntropyLoss [COMMUNICATION] comm_package = decentralizepy.communication.Communication -comm_class = Communcation +comm_class = Communication +addresses_filepath = ip_addr.json +total_procs = 4 [SHARING] sharing_package = decentralizepy.sharing.Sharing diff --git a/ip_addr.json b/ip_addr.json new file mode 100644 index 0000000000000000000000000000000000000000..187f5434557d80e8028197b3edf69fe7f37cfa5c --- /dev/null +++ b/ip_addr.json @@ -0,0 +1,4 @@ +{ + "0": "labostrex131", + "1": "labostrex132" +} \ No newline at end of file diff --git a/main.ipynb b/main.ipynb index 36d80f40f1e80327e53f4bddface4a6c7e77a330..a45d562955bffe3e44ba9d9c6ae02e050f810e7d 100644 --- a/main.ipynb +++ b/main.ipynb @@ -170,8 +170,12 @@ " for key, value in config.items(section):\n", " print((key, value))\n", " print(dict(config.items('DATASET')))\n", + " return config\n", " \n", - "read_ini(\"config.ini\")" + "config = read_ini(\"config.ini\")\n", + "for section in config:\n", + " print(section)\n", + "#d = dict(config.sections())" ] }, { @@ -216,9 +220,12 @@ "metadata": {}, "outputs": [], "source": [ - "f1 = Femnist(1, 'leaf/data/femnist/data/train')\n", - "f1.instantiate_dataset()\n", - "f1.train_x.shape" + "from decentralizepy.datasets.Femnist import Femnist\n", + "f1 = Femnist(0, 1, 'leaf/data/femnist/data/train')\n", + "ts = f1.get_trainset(1)\n", + "for data, target in ts:\n", + " print(data)\n", + " break" ] }, { @@ -252,7 +259,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -261,33 +268,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Linear(in_features=1, out_features=1, bias=True)\n", - "1 OrderedDict([('weight', tensor([[0.9654]])), ('bias', tensor([-0.2141]))])\n", - "1 [{'params': [Parameter containing:\n", - "tensor([[0.9654]], requires_grad=True), Parameter containing:\n", - "tensor([-0.2141], requires_grad=True)], 'lr': 0.6, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]\n", - "1 OrderedDict([('weight', tensor([[0.]])), ('bias', tensor([-0.2141]))])\n", - "1 [{'params': [Parameter containing:\n", - "tensor([[0.]], requires_grad=True), Parameter containing:\n", - "tensor([-0.2141], requires_grad=True)], 'lr': 0.6, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]\n", - "0 OrderedDict([('weight', tensor([[0.]])), ('bias', tensor([-0.2141]))])\n", - "0 [{'params': [Parameter containing:\n", - "tensor([[0.]], requires_grad=True), Parameter containing:\n", - "tensor([-0.2141], requires_grad=True)], 'lr': 0.6, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]\n", - "0 OrderedDict([('weight', tensor([[0.]])), ('bias', tensor([-0.2141]))])\n", - "0 [{'params': [Parameter containing:\n", - "tensor([[0.]], requires_grad=True), Parameter containing:\n", - "tensor([-0.2141], requires_grad=True)], 'lr': 0.6, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]\n" - ] - } - ], + "outputs": [], "source": [ "from torch import multiprocessing as mp\n", "import torch\n", @@ -342,6 +325,94 @@ "m1.state_dict()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "loss = getattr(torch.nn.functional, 'nll_loss')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "from decentralizepy.node.Node import Node\n", + "from decentralizepy.graphs.SmallWorld import SmallWorld\n", + "from decentralizepy.mappings.Linear import Linear\n", + "from torch import multiprocessing as mp\n", + "import torch\n", + "import logging\n", + "\n", + "from localconfig import LocalConfig\n", + "\n", + "def read_ini(file_path):\n", + " config = LocalConfig(file_path)\n", + " for section in config:\n", + " print(\"Section: \", section)\n", + " for key, value in config.items(section):\n", + " print((key, value))\n", + " print(dict(config.items('DATASET')))\n", + " return config\n", + " \n", + "config = read_ini(\"config.ini\")\n", + "my_config = dict()\n", + "for section in config:\n", + " my_config[section] = dict(config.items(section))\n", + "\n", + "#f = Femnist(2, 'leaf/data/femnist/data/train', sizes=[0.6, 0.4])\n", + "g = SmallWorld(4, 1, 0.5)\n", + "print(g)\n", + "l = Linear(2, 2)\n", + "\n", + "#Node(0, 0, l, g, my_config, 20, \"results\", logging.DEBUG)\n", + "\n", + "#mp.spawn(fn = Node, nprocs = 1, args=[0,l,g,my_config,20,\"results\",logging.DEBUG])\n", + "\n", + "# mp.spawn(fn = Node, args = [l, g, config, 10, \"results\", logging.DEBUG], nprocs=2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'mp' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_1457289/353106489.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mmp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspawn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnprocs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"ip_addr.json\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'mp' is not defined" + ] + } + ], + "source": [ + "from decentralizepy.mappings.Linear import Linear\n", + "from testing import f\n", + "from torch import multiprocessing as mp\n", + "\n", + "l = Linear(1, 2)\n", + "mp.spawn(fn = f, nprocs = 2, args = [0, 2, \"ip_addr.json\", l])\n" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/decentralizepy/communication/Communication.py b/src/decentralizepy/communication/Communication.py index 18f7b94b8f4fd9e8e8b97ccc931b2505f38960ff..9240190b7901e3d877b0c4e51bb98b22be36ad0f 100644 --- a/src/decentralizepy/communication/Communication.py +++ b/src/decentralizepy/communication/Communication.py @@ -1,6 +1,105 @@ +import json +import logging +from collections import deque + +import zmq + +HELLO = b"HELLO" +BYE = b"BYE" + + class Communication: """ Communcation API """ - def __init__(): - raise NotImplementedError \ No newline at end of file + + def addr(self, rank, machine_id): + machine_addr = self.ip_addrs[str(machine_id)] + port = rank + 20000 + return "tcp://{}:{}".format(machine_addr, port) + + def __init__(self, rank, machine_id, total_procs, addresses_filepath, mapping): + with open(addresses_filepath) as addrs: + self.ip_addrs = json.load(addrs) + + self.total_procs = total_procs + self.rank = rank + self.machine_id = machine_id + self.mapping = mapping + self.uid = mapping.get_uid(rank, machine_id) + self.identity = str(self.uid).encode() + self.context = zmq.Context() + self.router = self.context.socket(zmq.ROUTER) + self.router.setsockopt(zmq.IDENTITY, self.identity) + self.router.bind(self.addr(rank, machine_id)) + self.sent_disconnections = False + + self.peer_deque = deque() + self.peer_sockets = dict() + self.barrier = set() + + def encrypt(self, data): + return json.dumps(data).encode("utf8") + + def decrypt(self, sender, data): + sender = int(sender.decode()) + data = json.loads(data.decode("utf8")) + return sender, data + + def connect_neighbours(self, neighbours): + for uid in neighbours: + id = str(uid).encode() + req = self.context.socket(zmq.DEALER) + req.setsockopt(zmq.IDENTITY, self.identity) + req.connect(self.addr(*self.mapping.get_machine_and_rank(uid))) + self.peer_sockets[id] = req + req.send(HELLO) + + num_neighbours = len(neighbours) + while len(self.barrier) < num_neighbours: + sender, recv = self.router.recv_multipart() + + if recv == HELLO: + logging.info("Recieved {} from {}".format(HELLO, sender)) + self.barrier.add(sender) + elif recv == BYE: + logging.info("Recieved {} from {}".format(BYE, sender)) + raise RuntimeError( + "A neighbour wants to disconnect before training started!" + ) + else: + logging.info( + "Recieved message from {} @ connect_neighbours".format(sender) + ) + + self.peer_deque.append(self.decrypt(sender, recv)) + + def receive(self): + if len(self.peer_deque) != 0: + resp = self.peer_deque[0] + self.peer_deque.popleft() + return resp + + sender, recv = self.router.recv_multipart() + + if recv == HELLO: + logging.info("Recieved {} from {}".format(HELLO, sender)) + raise RuntimeError( + "A neighbour wants to connect when everyone is connected!" + ) + elif recv == BYE: + logging.info("Recieved {} from {}".format(BYE, sender)) + self.barrier.remove(sender) + if not self.sent_disconnections: + for sock in self.peer_sockets.values(): + sock.send(BYE) + self.sent_disconnections = True + else: + logging.info("Recieved message from {}".format(sender)) + return self.decrypt(sender, recv) + + def send(self, uid, data): + to_send = self.encrypt(data) + id = str(uid).encode() + self.peer_sockets[id].send(to_send) + print("Message sent") diff --git a/src/decentralizepy/datasets/Dataset.py b/src/decentralizepy/datasets/Dataset.py index e4d176ce49341c940a00c2060a62da6bbce53259..065b86cdb065054c83f0d0645482eade756377bf 100644 --- a/src/decentralizepy/datasets/Dataset.py +++ b/src/decentralizepy/datasets/Dataset.py @@ -1,5 +1,6 @@ from decentralizepy import utils + class Dataset: """ This class defines the Dataset API. @@ -33,7 +34,6 @@ class Dataset: if type(self.sizes) == str: self.sizes = eval(self.sizes) - if train_dir: self.__training__ = True else: diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py index d1bbe310411ce62a10962aa00d485ff5c9380da8..bbc067e53ff570c818b463ee9b00673175f9b911 100644 --- a/src/decentralizepy/datasets/Femnist.py +++ b/src/decentralizepy/datasets/Femnist.py @@ -3,10 +3,13 @@ import logging import os from collections import defaultdict +import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F +import torchvision from torch import nn +from torch._C import ParameterDict from torch.utils.data import DataLoader from decentralizepy.datasets.Data import Data @@ -109,17 +112,21 @@ class Femnist(Dataset): logging.debug("train_y.shape: %s", str(self.train_y.shape)) if self.__testing__: - logging.info("Loading training set.") + logging.info("Loading testing set.") _, _, test_data = self.__read_dir__(self.test_dir) - test_data = test_data.values() + test_x = [] + test_y = [] + for test_data in test_data.values(): + for x in test_data["x"]: + test_x.append(x) + for y in test_data["y"]: + test_y.append(y) self.test_x = ( - np.array(test_data["x"], dtype=np.dtype("float32")) + np.array(test_x, dtype=np.dtype("float32")) .reshape(-1, 28, 28, 1) .transpose(0, 3, 1, 2) ) - self.test_y = np.array(test_data["y"], dtype=np.dtype("int64")).reshape( - -1 - ) + self.test_y = np.array(test_y, dtype=np.dtype("int64")).reshape(-1) logging.debug("test_x.shape: %s", str(self.test_x.shape)) logging.debug("test_y.shape: %s", str(self.test_y.shape)) @@ -158,12 +165,12 @@ class Femnist(Dataset): raise IndexError("i is out of bounds!") - def get_trainset(self, batch_size, shuffle = False): + def get_trainset(self, batch_size=1, shuffle=False): """ Function to get the training set Parameters ---------- - batch_size : int + batch_size : int, optional Batch size for learning Returns ------- @@ -174,7 +181,9 @@ class Femnist(Dataset): If the training set was not initialized """ if self.__training__: - return DataLoader(Data(self.train_x, self.train_y), batch_size = batch_size, shuffle = shuffle) + return DataLoader( + Data(self.train_x, self.train_y), batch_size=batch_size, shuffle=shuffle + ) raise RuntimeError("Training set not initialized!") def get_testset(self): @@ -189,9 +198,43 @@ class Femnist(Dataset): If the test set was not initialized """ if self.__testing__: - return Data(self.test_x, self.test_y) + return DataLoader(Data(self.test_x, self.test_y)) raise RuntimeError("Test set not initialized!") + def imshow(self, img): + npimg = img.numpy() + plt.imshow(np.transpose(npimg, (1, 2, 0))) + plt.show() + + def test(self, model): + testloader = self.get_testset() + # dataiter = iter(testloader) + # images, labels = dataiter.next() + # self.imshow(torchvision.utils.make_grid(images)) + # plt.savefig(' '.join('%5s' % j for j in labels) + ".png") + # print(' '.join('%5s' % j for j in labels)) + + correct_pred = [0 for _ in range(NUM_CLASSES)] + total_pred = [0 for _ in range(NUM_CLASSES)] + with torch.no_grad(): + for elems, labels in testloader: + outputs = model(elems) + _, predictions = torch.max(outputs, 1) + for label, prediction in zip(labels, predictions): + if label == prediction: + correct_pred[label] += 1 + total_pred[label] += 1 + + total_correct = 0 + + for key, value in enumerate(correct_pred): + accuracy = 100 * float(value) / total_pred[key] + total_correct += value + logging.debug("Accuracy for class {} is: {:.1f} %".format(key, accuracy)) + + accuracy = 100 * float(total_correct) / testloader.__len__() + logging.info("Overall accuracy is: {:.1f} %".format(accuracy)) + class LogisticRegression(nn.Module): """ @@ -220,4 +263,22 @@ class LogisticRegression(nn.Module): """ x = torch.flatten(x, start_dim=1) x = self.fc1(x) - return F.log_softmax(x, dim=1) + return x + + +class CNN(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 32, 5, padding=2) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(32, 64, 5, padding=2) + self.fc1 = nn.Linear(7 * 7 * 64, 512) + self.fc2 = nn.Linear(512, NUM_CLASSES) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x diff --git a/src/decentralizepy/mappings/Linear.py b/src/decentralizepy/mappings/Linear.py index c70a1de97391b3b8c575424d5394889551651018..b43bf860f7c0835e74336d4f13d3b687bfb5a1d0 100644 --- a/src/decentralizepy/mappings/Linear.py +++ b/src/decentralizepy/mappings/Linear.py @@ -47,6 +47,6 @@ class Linear(Mapping): Returns ------- 2-tuple - a tuple of machine_id and rank + a tuple of rank and machine_id """ - return (uid // self.procs_per_machine), (uid % self.procs_per_machine) + return (uid % self.procs_per_machine), (uid // self.procs_per_machine) diff --git a/src/decentralizepy/mappings/Mapping.py b/src/decentralizepy/mappings/Mapping.py index 3307f49a294414ad87b7f40ffe09f64a33c3566d..cf454e607dbfea265acbf0d6fe551bf99f9a1026 100644 --- a/src/decentralizepy/mappings/Mapping.py +++ b/src/decentralizepy/mappings/Mapping.py @@ -38,7 +38,7 @@ class Mapping: Returns ------- 2-tuple - a tuple of machine_id and rank + a tuple of rank and machine_id """ raise NotImplementedError diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index ee8ff3ebd0bacbce3b64ac8eeb9d47372a3241d9..167cf2e0bfb38d93df52afd7731730d8bd93e451 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -1,19 +1,17 @@ +import importlib import logging import os -from decentralizepy.datasets.Dataset import Dataset +from decentralizepy import utils from decentralizepy.graphs.Graph import Graph from decentralizepy.mappings.Mapping import Mapping -from decentralizepy import utils - -from torch import optim -import importlib class Node: """ This class defines the node (entity that performs learning, sharing and communication). """ + def __init__( self, rank: int, @@ -21,7 +19,7 @@ class Node: mapping: Mapping, graph: Graph, config, - iterations = 1, + iterations=1, log_dir=".", log_level=logging.INFO, *args @@ -78,12 +76,14 @@ class Node: logging.debug("Rank: %d", self.rank) logging.debug("type(graph): %s", str(type(self.rank))) logging.debug("type(mapping): %s", str(type(self.mapping))) - + dataset_configs = config["DATASET"] dataset_module = importlib.import_module(dataset_configs["dataset_package"]) dataset_class = getattr(dataset_module, dataset_configs["dataset_class"]) - dataset_params = utils.remove_keys(dataset_configs, ["dataset_package", "dataset_class", "model_class"]) - self.dataset = dataset_class(rank, **dataset_params) + dataset_params = utils.remove_keys( + dataset_configs, ["dataset_package", "dataset_class", "model_class"] + ) + self.dataset = dataset_class(rank, **dataset_params) logging.info("Dataset instantiation complete.") @@ -91,9 +91,15 @@ class Node: self.model = model_class() optimizer_configs = config["OPTIMIZER_PARAMS"] - optimizer_module = importlib.import_module(optimizer_configs["optimizer_package"]) - optimizer_class = getattr(optimizer_module, optimizer_configs["optimizer_class"]) - optimizer_params = utils.remove_keys(optimizer_configs, ["optimizer_package", "optimizer_class"]) + optimizer_module = importlib.import_module( + optimizer_configs["optimizer_package"] + ) + optimizer_class = getattr( + optimizer_module, optimizer_configs["optimizer_class"] + ) + optimizer_params = utils.remove_keys( + optimizer_configs, ["optimizer_package", "optimizer_class"] + ) self.optimizer = optimizer_class(self.model.parameters(), **optimizer_params) train_configs = config["TRAIN_PARAMS"] @@ -101,11 +107,28 @@ class Node: train_class = getattr(train_module, train_configs["training_class"]) loss_package = importlib.import_module(train_configs["loss_package"]) - loss = getattr(loss_package, train_configs["loss"]) + if "loss_class" in train_configs.keys(): + loss_class = getattr(loss_package, train_configs["loss_class"]) + loss = loss_class() + else: + loss = getattr(loss_package, train_configs["loss"]) - train_params = utils.remove_keys(train_configs, ["training_package", "training_class", "loss", "loss_package"]) + train_params = utils.remove_keys( + train_configs, + [ + "training_package", + "training_class", + "loss", + "loss_package", + "loss_class", + ], + ) self.trainer = train_class(self.model, self.optimizer, loss, **train_params) + self.testset = self.dataset.get_trainset() + for iteration in range(iterations): logging.info("Starting training iteration: %d", iteration) - self.trainer.train(self.dataset) \ No newline at end of file + self.trainer.train(self.dataset) + if self.dataset.__testing__: + self.dataset.test(self.model) diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index 58261b67e349ca2639f05d1b2c22aefb88ef94d4..5c90e5f527c17f0e676a21d636de8d731ce5e4bb 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -2,5 +2,6 @@ class Sharing: """ API defining who to share with and what, and what to do on receiving """ + def __init__(): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py index 48a035e156fdac324f0656f7396260da8784efa6..54de9d5dbc0f99d198446b442ccebbf8e9fc9bcf 100644 --- a/src/decentralizepy/training/Training.py +++ b/src/decentralizepy/training/Training.py @@ -1,11 +1,21 @@ +import logging + +import matplotlib.pyplot as plt +import numpy as np import torch +import torchvision + from decentralizepy import utils -import logging + + class Training: """ This class implements the training module for a single node. """ - def __init__(self, model, optimizer, loss, epochs_per_round = "", batch_size = "", shuffle = ""): + + def __init__( + self, model, optimizer, loss, epochs_per_round="", batch_size="", shuffle="" + ): """ Constructor Parameters @@ -24,10 +34,15 @@ class Training: self.model = model self.optimizer = optimizer self.loss = loss - self.epochs_per_round = utils.conditional_value(epochs_per_round, "", 1) - self.batch_size = utils.conditional_value(batch_size, "", 1) + self.epochs_per_round = utils.conditional_value(epochs_per_round, "", int(1)) + self.batch_size = utils.conditional_value(batch_size, "", int(1)) self.shuffle = utils.conditional_value(shuffle, "", False) + def imshow(self, img): + npimg = img.numpy() + plt.imshow(np.transpose(npimg, (1, 2, 0))) + plt.show() + def train(self, dataset): """ One training iteration @@ -37,8 +52,16 @@ class Training: The training dataset. Should implement get_trainset(batch_size, shuffle) """ trainset = dataset.get_trainset(self.batch_size, self.shuffle) + + # dataiter = iter(trainset) + # images, labels = dataiter.next() + # self.imshow(torchvision.utils.make_grid(images[:16])) + # plt.savefig(' '.join('%5s' % j for j in labels) + ".png") + # print(' '.join('%5s' % j for j in labels[:16])) + for epoch in range(self.epochs_per_round): epoch_loss = 0.0 + count = 0 for data, target in trainset: self.model.zero_grad() output = self.model(data) @@ -46,4 +69,5 @@ class Training: epoch_loss += loss_val.item() loss_val.backward() self.optimizer.step() - logging.info("Epoch_loss: %d", epoch_loss) + count += 1 + logging.info("Epoch: {} loss: {}".format(epoch, epoch_loss / count)) diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py index eee6db02506b5938de09174f6c75d66c97e437f1..ff562697648fbb7a8bd968125dabe2ac346d538f 100644 --- a/src/decentralizepy/utils.py +++ b/src/decentralizepy/utils.py @@ -4,5 +4,6 @@ def conditional_value(var, nul, default): else: return default + def remove_keys(d, keys_to_remove): - return {key: d[key] for key in d if key not in keys_to_remove} \ No newline at end of file + return {key: d[key] for key in d if key not in keys_to_remove} diff --git a/testing.py b/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..fafca298f2a279192dcf62218201e9cf39e44ee0 --- /dev/null +++ b/testing.py @@ -0,0 +1,19 @@ +from torch import multiprocessing as mp + +from decentralizepy.communication.Communication import Communication +from decentralizepy.mappings.Linear import Linear + + +def f(rank, m_id, total_procs, filePath, mapping): + c = Communication(rank, m_id, total_procs, filePath, mapping) + c.connect_neighbours([i for i in range(total_procs) if i != rank]) + send = {} + send["message"] = "Hi I am rank {}".format(rank) + c.send((rank + 1) % total_procs, send) + print(rank, c.receive()) + + +if __name__ == "__main__": + l = Linear(2, 2) + m_id = int(input()) + mp.spawn(fn=f, nprocs=2, args=[m_id, 4, "ip_addr.json", l])