From b67a85128abd1ce51c8838e93a5d54d0d907424f Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Wed, 10 Nov 2021 18:18:37 +0100
Subject: [PATCH] Add logging, fix Node and test Femnist

---
 config.ini                              |   2 +
 main.ipynb                              | 334 +++++++++++-------------
 src/decentralizepy/datasets/Data.py     |   2 +-
 src/decentralizepy/datasets/Dataset.py  |  19 +-
 src/decentralizepy/datasets/Femnist.py  | 108 ++++----
 src/decentralizepy/node/Node.py         |  91 ++++---
 src/decentralizepy/training/Training.py |  29 ++
 src/decentralizepy/training/__init__.py |   0
 src/decentralizepy/utils.py             |   8 +
 9 files changed, 322 insertions(+), 271 deletions(-)
 create mode 100644 src/decentralizepy/training/Training.py
 create mode 100644 src/decentralizepy/training/__init__.py
 create mode 100644 src/decentralizepy/utils.py

diff --git a/config.ini b/config.ini
index 3420bba..2b8d4ff 100644
--- a/config.ini
+++ b/config.ini
@@ -18,5 +18,7 @@ optimizer_class = SGD
 lr = 0.1
 
 [TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
 epochs_per_round = 25
 batch_size = 64
diff --git a/main.ipynb b/main.ipynb
index f679c22..36d80f4 100644
--- a/main.ipynb
+++ b/main.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -16,20 +16,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "datasets.FEMNIST.FEMNIST"
-      ]
-     },
-     "execution_count": 7,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
     "a = FEMNIST\n",
     "a"
@@ -37,7 +26,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -46,59 +35,18 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "[{1, 2, 5, 9, 10},\n",
-       " {0, 2, 3, 4, 5},\n",
-       " {0, 1, 3, 8, 10},\n",
-       " {1, 2, 4, 6, 7, 8, 10},\n",
-       " {1, 3, 5, 8, 10},\n",
-       " {0, 1, 4, 6, 9},\n",
-       " {1, 3, 5, 7, 8, 10},\n",
-       " {0, 2, 3, 6, 8, 9, 11},\n",
-       " {1, 2, 3, 4, 6, 7, 11},\n",
-       " {0, 2, 4, 5, 7, 10, 11},\n",
-       " {0, 2, 3, 4, 5, 6, 9},\n",
-       " {0, 4, 7, 8, 9}]"
-      ]
-     },
-     "execution_count": 9,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
     "b.adj_list"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "{1, 2, 5, 9, 10}\n",
-      "{0, 2, 3, 4, 5}\n",
-      "{0, 1, 3, 8, 10}\n",
-      "{1, 2, 4, 6, 7, 8, 10}\n",
-      "{1, 3, 5, 8, 10}\n",
-      "{0, 1, 4, 6, 9}\n",
-      "{1, 3, 5, 7, 8, 10}\n",
-      "{0, 2, 3, 6, 8, 9, 11}\n",
-      "{1, 2, 3, 4, 6, 7, 11}\n",
-      "{0, 2, 4, 5, 7, 10, 11}\n",
-      "{0, 2, 3, 4, 5, 6, 9}\n",
-      "{0, 4, 7, 8, 9}\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "for i in range(12):\n",
     "    print(b.neighbors(i))"
@@ -106,7 +54,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -115,7 +63,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -125,7 +73,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -137,7 +85,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -146,28 +94,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "./leaf/data/femnist/data/train/all_data_6_niid_0_keep_0_train_9.json\n",
-      "Current_Users:  100\n",
-      "total_users:  200\n",
-      "total_users:  100\n",
-      "dict_keys(['f3408_47', 'f3327_11', 'f3417_01', 'f3339_15', 'f3580_22', 'f3414_29', 'f3328_45', 'f3592_19', 'f3516_45', 'f3130_44', 'f3321_36', 'f3284_38', 'f3232_11', 'f3547_04', 'f3265_08', 'f3500_08', 'f3243_44', 'f3349_22', 'f3118_09', 'f3179_39', 'f3381_42', 'f3198_32', 'f3299_12', 'f3237_27', 'f3593_26', 'f3133_33', 'f3591_14', 'f3231_19', 'f3478_49', 'f3447_20', 'f3442_00', 'f3464_12', 'f3293_30', 'f3111_05', 'f3227_14', 'f3146_14', 'f3165_11', 'f3440_33', 'f3379_03', 'f3529_11', 'f3441_24', 'f3253_11', 'f3238_40', 'f3583_09', 'f3256_38', 'f3325_08', 'f3512_31', 'f3214_03', 'f3572_03', 'f3457_40', 'f3419_33', 'f3496_38', 'f3582_25', 'f3205_40', 'f3353_33', 'f3115_25', 'f3517_27', 'f3567_49', 'f3230_21', 'f3336_15', 'f3415_33', 'f3280_34', 'f3294_06', 'f3171_30', 'f3363_42', 'f3105_03', 'f3545_06', 'f3426_23', 'f3102_36', 'f3164_09', 'f3202_01', 'f3365_46', 'f3450_19', 'f3573_02', 'f3290_01', 'f3443_42', 'f3471_02', 'f3136_07', 'f3553_12', 'f3434_00', 'f3537_23', 'f3479_08', 'f3578_27', 'f3286_40', 'f3155_15', 'f3494_34', 'f3460_47', 'f3595_18', 'f3518_46', 'f3433_10', 'f3538_29', 'f3266_12', 'f3375_30', 'f3390_07', 'f3261_00', 'f3221_05', 'f3139_09', 'f3234_23', 'f3341_29', 'f3485_45'])\n",
-      "(155, 784)\n",
-      "(155,)\n",
-      "(164, 784)\n",
-      "(164,)\n",
-      "[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 16 18 19 20 21 22 23 24 25\n",
-      " 26 27 29 30 31 32 33 34 35 36 37 38 39 40 43 44 45 46 47 48 49 50 51 52\n",
-      " 53 54 55 56 57 58 60 61]\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "for f in files:\n",
     "    file_path = os.path.join(datadir, f)\n",
@@ -191,18 +120,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "import torch\n",
-      "['import torch.multiprocessing as mp\\n', '\\n', '\\n', 'x = [1, 2]\\n', '\\n', 'def f(id, a):\\n', '    print(id, x)\\n', '    print(id, a)\\n', '\\n', \"if __name__ == '__main__':\\n\", '    x.append(3)\\n', '    mp.spawn(f, nprocs=2, args=(x, ))']\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "file = 'run.py'\n",
     "with open(file, 'r') as inf:\n",
@@ -212,18 +132,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "['a', 'a', 'a']\n",
-      "['a', 'a', 'c']\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "def f(l):\n",
     "    l[2] = 'c'\n",
@@ -236,17 +147,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "['a', 'b']\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "l = ['a', 'b', 'c']\n",
     "print(l[:-1])"
@@ -254,31 +157,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Section:  GRAPH\n",
-      "('package', 'decentralizepy.graphs.SmallWorld')\n",
-      "('graph_class', 'SmallWorld')\n",
-      "Section:  DATASET\n",
-      "('package', 'decentralizepy.datasets.Femnist')\n",
-      "('dataset_class', 'Femnist')\n",
-      "('model_class', 'LogisticRegression')\n",
-      "('n_procs', 1.0)\n",
-      "('train_dir', '')\n",
-      "('test_dir', '')\n",
-      "('sizes', '[0.4, 0.2, 0.3, 0.1]')\n",
-      "Section:  MODEL_PARAMS\n",
-      "('optimizer_package', 'torch.optim')\n",
-      "('optimizer_class', 'SGD')\n",
-      "('lr', 0.1)\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "from localconfig import LocalConfig\n",
     "\n",
@@ -288,23 +169,16 @@
     "        print(\"Section: \", section)\n",
     "        for key, value in config.items(section):\n",
     "            print((key, value))\n",
+    "    print(dict(config.items('DATASET')))\n",
     " \n",
     "read_ini(\"config.ini\")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "15\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "def func(a = 1, b = 2, c = 3):\n",
     "    print(a + b + c)\n",
@@ -314,48 +188,158 @@
     "func(*l)"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from torch import multiprocessing as mp\n",
+    "\n",
+    "mp.spawn(fn = func, nprocs = 2, args = [], kwargs = {'a': 4, 'b': 5, 'c': 6})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "l = '[0.4, 0.2, 0.3, 0.1]'\n",
+    "type(eval(l))\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "f1 = Femnist(1, 'leaf/data/femnist/data/train')\n",
+    "f1.instantiate_dataset()\n",
+    "f1.train_x.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from decentralizepy.datasets.Femnist import Femnist\n",
+    "from decentralizepy.graphs.SmallWorld import SmallWorld\n",
+    "from decentralizepy.mappings.Linear import Linear\n",
+    "\n",
+    "f = Femnist(2, 'leaf/data/femnist/data/train', sizes=[0.6, 0.4])\n",
+    "g = SmallWorld(4, 1, 0.5)\n",
+    "l = Linear(2, 2)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from decentralizepy.node.Node import Node\n",
+    "from torch import multiprocessing as mp\n",
+    "import logging\n",
+    "n1 = Node(0, l, g, f, \"./results\", logging.DEBUG)\n",
+    "n2 = Node(1, l, g, f, \"./results\", logging.DEBUG)\n",
+    "# mp.spawn(fn = Node, nprocs = 2, args=[l,g,f])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from testing import f"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 3,
    "metadata": {},
    "outputs": [
     {
-     "ename": "TypeError",
-     "evalue": "spawn() got an unexpected keyword argument 'kwargs'",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
-      "\u001b[0;32m/tmp/ipykernel_52405/4231740097.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmultiprocessing\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mmp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspawn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnprocs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'a'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'b'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'c'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;31mTypeError\u001b[0m: spawn() got an unexpected keyword argument 'kwargs'"
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Linear(in_features=1, out_features=1, bias=True)\n",
+      "1 OrderedDict([('weight', tensor([[0.9654]])), ('bias', tensor([-0.2141]))])\n",
+      "1 [{'params': [Parameter containing:\n",
+      "tensor([[0.9654]], requires_grad=True), Parameter containing:\n",
+      "tensor([-0.2141], requires_grad=True)], 'lr': 0.6, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]\n",
+      "1 OrderedDict([('weight', tensor([[0.]])), ('bias', tensor([-0.2141]))])\n",
+      "1 [{'params': [Parameter containing:\n",
+      "tensor([[0.]], requires_grad=True), Parameter containing:\n",
+      "tensor([-0.2141], requires_grad=True)], 'lr': 0.6, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]\n",
+      "0 OrderedDict([('weight', tensor([[0.]])), ('bias', tensor([-0.2141]))])\n",
+      "0 [{'params': [Parameter containing:\n",
+      "tensor([[0.]], requires_grad=True), Parameter containing:\n",
+      "tensor([-0.2141], requires_grad=True)], 'lr': 0.6, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]\n",
+      "0 OrderedDict([('weight', tensor([[0.]])), ('bias', tensor([-0.2141]))])\n",
+      "0 [{'params': [Parameter containing:\n",
+      "tensor([[0.]], requires_grad=True), Parameter containing:\n",
+      "tensor([-0.2141], requires_grad=True)], 'lr': 0.6, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]\n"
      ]
     }
    ],
    "source": [
     "from torch import multiprocessing as mp\n",
+    "import torch\n",
+    "m1 = torch.nn.Linear(1,1)\n",
+    "o1 = torch.optim.SGD(m1.parameters(), 0.6)\n",
+    "print(m1)\n",
     "\n",
-    "mp.spawn(fn = func, nprocs = 2, args = [], kwargs = {'a': 4, 'b': 5, 'c': 6})"
+    "\n",
+    "mp.spawn(fn = f, nprocs = 2, args=[m1, o1])\n",
+    "\n"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": []
+  },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "list"
-      ]
-     },
-     "execution_count": 19,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
-    "l = '[0.4, 0.2, 0.3, 0.1]'\n",
-    "type(eval(l))"
+    "o1.param_groups"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "with torch.no_grad():\n",
+    "    o1.param_groups[0][\"params\"][0].copy_(torch.zeros(1,))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "o1.param_groups"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "m1.state_dict()"
    ]
   },
   {
diff --git a/src/decentralizepy/datasets/Data.py b/src/decentralizepy/datasets/Data.py
index 2e8b2c2..4b3ac7b 100644
--- a/src/decentralizepy/datasets/Data.py
+++ b/src/decentralizepy/datasets/Data.py
@@ -16,7 +16,7 @@ class Data:
         self.x = x
         self.y = y
 
-    def __get_item__(self, i):
+    def __getitem__(self, i):
         """
         Function to get the item with index i.
         Parameters
diff --git a/src/decentralizepy/datasets/Dataset.py b/src/decentralizepy/datasets/Dataset.py
index b509ce0..4761450 100644
--- a/src/decentralizepy/datasets/Dataset.py
+++ b/src/decentralizepy/datasets/Dataset.py
@@ -1,9 +1,4 @@
-def __conditional_value__(var, nul, default):
-    if var != nul:
-        return var
-    else:
-        return default
-
+from decentralizepy import utils
 
 class Dataset:
     """
@@ -11,7 +6,7 @@ class Dataset:
     All datasets must follow this API.
     """
 
-    def __init__(self, rank='', n_procs='', train_dir='', test_dir='', sizes=''):
+    def __init__(self, rank="", n_procs="", train_dir="", test_dir="", sizes=""):
         """
         Constructor which reads the data files, instantiates and partitions the dataset
         Parameters
@@ -29,11 +24,11 @@ class Dataset:
             A list of fractions specifying how much data to alot each process. Sum of fractions should be 1.0
             By default, each process gets an equal amount.
         """
-        self.rank = __conditional_value__(rank, '', 0)
-        self.n_procs = __conditional_value__(n_procs, '', 1)
-        self.train_dir = __conditional_value__(train_dir, '', None)
-        self.test_dir = __conditional_value__(test_dir, '', None)
-        self.sizes = __conditional_value__(sizes, '', None)
+        self.rank = utils.conditional_value(rank, "", 0)
+        self.n_procs = utils.conditional_value(n_procs, "", 1)
+        self.train_dir = utils.conditional_value(train_dir, "", None)
+        self.test_dir = utils.conditional_value(test_dir, "", None)
+        self.sizes = utils.conditional_value(sizes, "", None)
         if self.sizes:
             if type(self.sizes) == str:
                 self.sizes = eval(self.sizes)
diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py
index a01680f..41f3e1b 100644
--- a/src/decentralizepy/datasets/Femnist.py
+++ b/src/decentralizepy/datasets/Femnist.py
@@ -1,4 +1,5 @@
 import json
+import logging
 import os
 from collections import defaultdict
 
@@ -16,42 +17,41 @@ IMAGE_SIZE = (28, 28)
 FLAT_SIZE = 28 * 28
 
 
-def __read_dir__(data_dir):
-    """
-    Function to read all the FEMNIST data files in the directory
-    Parameters
-    ----------
-    data_dir : str
-        Path to the folder containing the data files
-    Returns
-    -------
-    3-tuple
-        A tuple containing list of clients, number of samples per client,
-        and the data items per client
-    """
-    clients = []
-    num_samples = []
-    data = defaultdict(lambda: None)
-
-    files = os.listdir(data_dir)
-    files = [f for f in files if f.endswith(".json")]
-    for f in files:
-        file_path = os.path.join(data_dir, f)
-        with open(file_path, "r") as inf:
-            client_data = json.load(inf)
-        clients.extend(client_data["users"])
-        num_samples.extend(client_data["num_samples"])
-        data.update(client_data["user_data"])
-
-    return clients, num_samples, data
-
-
 class Femnist(Dataset):
     """
     Class for the FEMNIST dataset
     """
 
-    def __init__(self, rank='', n_procs='', train_dir='', test_dir='', sizes=''):
+    def __read_dir__(self, data_dir):
+        """
+        Function to read all the FEMNIST data files in the directory
+        Parameters
+        ----------
+        data_dir : str
+            Path to the folder containing the data files
+        Returns
+        -------
+        3-tuple
+            A tuple containing list of clients, number of samples per client,
+            and the data items per client
+        """
+        clients = []
+        num_samples = []
+        data = defaultdict(lambda: None)
+
+        files = os.listdir(data_dir)
+        files = [f for f in files if f.endswith(".json")]
+        for f in files:
+            file_path = os.path.join(data_dir, f)
+            with open(file_path, "r") as inf:
+                client_data = json.load(inf)
+            clients.extend(client_data["users"])
+            num_samples.extend(client_data["num_samples"])
+            data.update(client_data["user_data"])
+
+        return clients, num_samples, data
+
+    def __init__(self, rank, n_procs="", train_dir="", test_dir="", sizes=""):
         """
         Constructor which reads the data files, instantiates and partitions the dataset
         Parameters
@@ -70,42 +70,56 @@ class Femnist(Dataset):
             By default, each process gets an equal amount.
         """
         super().__init__(rank, n_procs, train_dir, test_dir, sizes)
+
         if self.__training__:
-            clients, num_samples, train_data = __read_dir__(train_dir)
-            c_len = len(self.clients)
+            logging.info("Loading training set.")
+            clients, num_samples, train_data = self.__read_dir__(self.train_dir)
+            c_len = len(clients)
 
-            if sizes == None:  # Equal distribution of data among processes
+            if self.sizes == None:  # Equal distribution of data among processes
                 e = c_len // self.n_procs
                 frac = e / c_len
-                sizes = [frac] * self.n_procs
-                sizes[-1] += 1.0 - frac * self.n_procs
+                self.sizes = [frac] * self.n_procs
+                self.sizes[-1] += 1.0 - frac * self.n_procs
+                print(self.sizes)
 
-            my_clients = DataPartitioner(clients, sizes).use(self.rank)
-            my_train_data = []
+            my_clients = DataPartitioner(clients, self.sizes).use(self.rank)
+            my_train_data = {"x": [], "y": []}
             self.clients = []
             self.num_samples = []
+            logging.debug("Clients Length: %d", c_len)
+            logging.debug("My_clients_len: %d", my_clients.__len__())
             for i in range(my_clients.__len__()):
-                cur_client = my_clients.__get_item__(i)
+                cur_client = my_clients.__getitem__(i)
                 self.clients.append(cur_client)
-                my_train_data.extend(train_data[cur_client])
+                my_train_data["x"].extend(train_data[cur_client]["x"])
+                my_train_data["y"].extend(train_data[cur_client]["y"])
                 self.num_samples.append(len(train_data[cur_client]["y"]))
-
-            self.train_x = np.array(
-                my_train_data["x"], dtype=np.dtype("float64")
-            ).reshape(-1, 28, 28, 1)
+            self.train_x = (
+                np.array(my_train_data["x"], dtype=np.dtype("float64"))
+                .reshape(-1, 28, 28, 1)
+                .transpose(0, 3, 1, 2)
+            )
             self.train_y = np.array(
                 my_train_data["y"], dtype=np.dtype("float64")
             ).reshape(-1, 1)
+            logging.debug("train_x.shape: %s", str(self.train_x.shape))
+            logging.debug("train_y.shape: %s", str(self.train_y.shape))
 
         if self.__testing__:
-            _, _, test_data = __read_dir__(test_dir)
+            logging.info("Loading training set.")
+            _, _, test_data = self.__read_dir__(self.test_dir)
             test_data = test_data.values()
-            self.test_x = np.array(test_data["x"], dtype=np.dtype("float64")).reshape(
-                -1, 28, 28, 1
+            self.test_x = (
+                np.array(test_data["x"], dtype=np.dtype("float64"))
+                .reshape(-1, 28, 28, 1)
+                .transpose(0, 3, 1, 2)
             )
             self.test_y = np.array(test_data["y"], dtype=np.dtype("float64")).reshape(
                 -1, 1
             )
+            logging.debug("test_x.shape: %s", str(self.test_x.shape))
+            logging.debug("test_y.shape: %s", str(self.test_y.shape))
 
         # TODO: Add Validation
 
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index ed02982..7f203c0 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -1,16 +1,30 @@
-from torch import utils, optim
-import importlib
+import logging
+import os
+
+from decentralizepy.datasets.Dataset import Dataset
+from decentralizepy.graphs.Graph import Graph
+from decentralizepy.mappings.Mapping import Mapping
+from decentralizepy import utils
 
+from torch import optim
+import importlib
 
-def __remove_keys__(d, keys_to_remove):
-    return {key: d[key] for key in d if key not in keys_to_remove}
 
 class Node:
     """
     This class defines the node (entity that performs learning, sharing and communication).
     """
-
-    def __init__(self, rank, mapping, graph, config, *args):
+    def __init__(
+        self,
+        rank: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations = 1,
+        log_dir=".",
+        log_level=logging.INFO,
+        *args
+    ):
         """
         Constructor
         Parameters
@@ -32,55 +46,60 @@ class Node:
             [OPTIMIZER_PARAMS]
                 optimizer_package
                 optimizer_class
+            [TRAIN_PARAMS]
+                training_package = decentralizepy.training.Training
+                training_class = Training
+                epochs_per_round = 25
+                batch_size = 64
+        log_dir : str
+            Logging directory
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
         args : optional
             Other arguments
         """
+        log_file = os.path.join(log_dir, str(rank) + ".log")
+        print(log_file)
+        logging.basicConfig(
+            filename=log_file,
+            format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s",
+            level=log_level,
+            force=True,
+        )
+
+        logging.info("Started process.")
+
         self.rank = rank
         self.graph = graph
         self.mapping = mapping
+
+        logging.debug("Rank: %d", self.rank)
+        logging.debug("type(graph): %s", str(type(self.rank)))
+        logging.debug("type(mapping): %s", str(type(self.mapping)))
         
         dataset_configs = dict(config.items("DATASET"))
         dataset_module = importlib.import_module(dataset_configs["dataset_package"])
         dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
-        
-        dataset_params = __remove_keys__(dataset_configs, ["dataset_package", "dataset_class", "model_class"])
+        dataset_params = utils.remove_keys(dataset_configs, ["dataset_package", "dataset_class", "model_class"])
         self.dataset =  dataset_class(rank, **dataset_params)
         self.trainset = self.dataset.get_trainset()
 
+        logging.info("Dataset instantiation complete.")
+
         model_class = getattr(dataset_module, dataset_configs["model_class"])
         self.model = model_class()
 
         optimizer_configs = dict(config.items("OPTIMIZER_PARAMS"))
         optimizer_module = importlib.import_module(optimizer_configs["optimizer_package"])
         optimizer_class = getattr(optimizer_module, optimizer_configs["optimizer_class"])
-        
-        optimizer_params = __remove_keys__(optimizer_configs, ["optimizer_package", "optimizer_class"])
+        optimizer_params = utils.remove_keys(optimizer_configs, ["optimizer_package", "optimizer_class"])
         self.optimizer = optimizer_class(self.model.parameters(), **optimizer_params)
 
-        self.run()
-
-    def train_step(self):
-        """
-        The training step
-        """
-        for epoch in self.epochs_per_round: # Instantiate this variable
-            for data, target in self.trainset: 
-                # Perform training step
-                raise NotImplementedError
-
-
-
-
-    def run(self):
-        """
-        The learning loop.
-        """
-        while True:
-            # Receive data
-
-            # Learn
-
-
-            # Send data
-            raise NotImplementedError
+        train_configs = dict(config.items("TRAIN_PARAMS"))
+        train_module = importlib.import_module(train_configs["training_package"])
+        train_class = getattr(train_module, train_configs["training_class"])
+        train_params = utils.remove_keys(train_configs, ["training_package", "training_class"])
+        self.trainer = train_class(self.model, self.optimizer, **train_params)
 
+        for iteration in range(iterations):
+            self.trainer.train(self.trainset)
\ No newline at end of file
diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py
new file mode 100644
index 0000000..3662366
--- /dev/null
+++ b/src/decentralizepy/training/Training.py
@@ -0,0 +1,29 @@
+from torch.optim import SGD
+
+from decentralizepy import utils
+class Training:
+    """
+    This class implements the training module for a single node.
+    """
+    def __init__(self, model, optimizer, epochs_per_round = "", batch_size = ""):
+        """
+        Constructor
+        Parameters
+        ----------
+        epochs_per_round : int
+            Number of epochs per training call
+        batch_size : int
+            Number of items to learn over, in one batch
+        """
+        self.epochs_per_round = utils.conditional_value(epochs_per_round, "", 1)
+        self.batch_size = utils.conditional_value(batch_size, "", 1)
+
+    def train(self, trainset):
+        """
+        One training iteration
+        Parameters
+        ----------
+        trainset : decentralizepy.datasets.Data
+            The training dataset. Should implement __getitem__(i)
+        """
+        raise NotImplementedError
diff --git a/src/decentralizepy/training/__init__.py b/src/decentralizepy/training/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py
new file mode 100644
index 0000000..eee6db0
--- /dev/null
+++ b/src/decentralizepy/utils.py
@@ -0,0 +1,8 @@
+def conditional_value(var, nul, default):
+    if var != nul:
+        return var
+    else:
+        return default
+
+def remove_keys(d, keys_to_remove):
+        return {key: d[key] for key in d if key not in keys_to_remove}
\ No newline at end of file
-- 
GitLab