From b0cfe04eea2afca02dccbe9af523672f6853c605 Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Thu, 27 Jan 2022 13:24:11 +0100
Subject: [PATCH] Improve Node

---
 README.rst                                    |    2 +-
 eval/main.ipynb                               | 2009 ++++++++++++++++-
 src/decentralizepy/communication/TCP.py       |    1 +
 src/decentralizepy/datasets/Celeba.py         |    4 +
 src/decentralizepy/datasets/Femnist.py        |    1 +
 src/decentralizepy/models/Model.py            |    1 +
 src/decentralizepy/node/Node.py               |  161 +-
 src/decentralizepy/sharing/GrowingAlpha.py    |    1 +
 src/decentralizepy/sharing/PartialModel.py    |    1 +
 .../training/GradientAccumulator.py           |    1 +
 src/decentralizepy/training/Training.py       |    9 +-
 src/decentralizepy/utils.py                   |    2 +-
 12 files changed, 2142 insertions(+), 51 deletions(-)

diff --git a/README.rst b/README.rst
index d1ef408..38cba02 100644
--- a/README.rst
+++ b/README.rst
@@ -15,7 +15,7 @@ Setting up decentralizepy
     pip3 install --upgrade pip
     pip install --upgrade pip
 
-* Install decentralizepy for development/ ::
+* Install decentralizepy for development. ::
 
     pip3 install --editable .\[dev\]
     
diff --git a/eval/main.ipynb b/eval/main.ipynb
index e25d47c..db66e89 100644
--- a/eval/main.ipynb
+++ b/eval/main.ipynb
@@ -268,33 +268,1971 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 24,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CNN(\n",
+      "  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
+      "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
+      "  (conv3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
+      "  (conv4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
+      "  (fc1): Linear(in_features=800, out_features=2, bias=True)\n",
+      ")\n"
+     ]
+    }
+   ],
    "source": [
     "from torch import multiprocessing as mp\n",
     "import torch\n",
-    "m1 = torch.nn.Linear(1,1)\n",
+    "from decentralizepy.datasets.Celeba import CNN\n",
+    "import numpy as np\n",
+    "m1 = CNN()\n",
     "o1 = torch.optim.SGD(m1.parameters(), 0.6)\n",
     "print(m1)\n",
     "\n",
     "\n",
-    "mp.spawn(fn = f, nprocs = 2, args=[m1, o1])\n",
+    "#mp.spawn(fn = f, nprocs = 2, args=[m1, o1])\n",
     "\n"
    ]
   },
   {
-   "cell_type": "markdown",
+   "cell_type": "code",
+   "execution_count": 25,
    "metadata": {},
-   "source": []
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "OrderedDict([('conv1.weight',\n",
+       "              tensor([[[[-0.1261,  0.1833, -0.1406],\n",
+       "                        [ 0.1324, -0.0685,  0.0938],\n",
+       "                        [ 0.0432,  0.1814, -0.0541]],\n",
+       "              \n",
+       "                       [[-0.1776, -0.1839, -0.0111],\n",
+       "                        [ 0.0888,  0.0888, -0.1344],\n",
+       "                        [-0.1838,  0.1737,  0.1584]],\n",
+       "              \n",
+       "                       [[ 0.0417,  0.1064, -0.0156],\n",
+       "                        [ 0.0667,  0.0856, -0.1746],\n",
+       "                        [ 0.0412,  0.1620,  0.0125]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0530, -0.1273, -0.0797],\n",
+       "                        [ 0.0422,  0.1135,  0.0475],\n",
+       "                        [-0.0244,  0.1691, -0.1383]],\n",
+       "              \n",
+       "                       [[ 0.0822, -0.1317, -0.1692],\n",
+       "                        [ 0.1373,  0.1388,  0.0103],\n",
+       "                        [-0.0481,  0.1105,  0.0631]],\n",
+       "              \n",
+       "                       [[-0.0352,  0.1259, -0.0530],\n",
+       "                        [-0.1394, -0.0281,  0.1844],\n",
+       "                        [ 0.0082,  0.1187,  0.0211]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0987,  0.0788, -0.1126],\n",
+       "                        [ 0.1769,  0.0763, -0.1767],\n",
+       "                        [-0.0570,  0.1156,  0.1770]],\n",
+       "              \n",
+       "                       [[ 0.0643, -0.0024, -0.0625],\n",
+       "                        [ 0.0819,  0.0140, -0.1882],\n",
+       "                        [ 0.1325, -0.0632, -0.0202]],\n",
+       "              \n",
+       "                       [[ 0.0053,  0.1042, -0.0058],\n",
+       "                        [-0.1082, -0.1753,  0.1762],\n",
+       "                        [-0.0501,  0.1166,  0.0561]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0358, -0.0685, -0.1278],\n",
+       "                        [ 0.0029, -0.1107,  0.1169],\n",
+       "                        [-0.1408,  0.1293,  0.1142]],\n",
+       "              \n",
+       "                       [[-0.0814,  0.0470,  0.0188],\n",
+       "                        [ 0.1538,  0.0137,  0.1128],\n",
+       "                        [-0.1597,  0.1432,  0.1370]],\n",
+       "              \n",
+       "                       [[ 0.1425,  0.1769, -0.0037],\n",
+       "                        [-0.1080, -0.0805, -0.0195],\n",
+       "                        [-0.1335, -0.1666,  0.1399]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.1117,  0.1918, -0.1666],\n",
+       "                        [-0.1392,  0.0086,  0.0172],\n",
+       "                        [-0.0721, -0.1711,  0.0344]],\n",
+       "              \n",
+       "                       [[ 0.1820, -0.0537, -0.0974],\n",
+       "                        [ 0.0366, -0.0710,  0.1273],\n",
+       "                        [ 0.1132, -0.1594,  0.0878]],\n",
+       "              \n",
+       "                       [[-0.0874, -0.0401,  0.1827],\n",
+       "                        [-0.0301,  0.1205, -0.0396],\n",
+       "                        [-0.1143, -0.1007,  0.1561]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.1522, -0.0012, -0.1785],\n",
+       "                        [-0.1833, -0.1828, -0.1643],\n",
+       "                        [-0.1765, -0.1757, -0.0608]],\n",
+       "              \n",
+       "                       [[-0.0684,  0.0521,  0.1137],\n",
+       "                        [-0.0028,  0.0616,  0.0758],\n",
+       "                        [-0.1736,  0.0667,  0.1229]],\n",
+       "              \n",
+       "                       [[ 0.1298, -0.1848, -0.1570],\n",
+       "                        [-0.1052, -0.1172, -0.1223],\n",
+       "                        [-0.1389, -0.0095, -0.0410]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0213, -0.0975,  0.0964],\n",
+       "                        [ 0.0535, -0.0775,  0.0790],\n",
+       "                        [-0.1796, -0.1468,  0.1036]],\n",
+       "              \n",
+       "                       [[-0.0403,  0.0646, -0.0932],\n",
+       "                        [ 0.1779, -0.1616,  0.0644],\n",
+       "                        [-0.0508, -0.1158, -0.0592]],\n",
+       "              \n",
+       "                       [[-0.1644, -0.1327,  0.0817],\n",
+       "                        [ 0.0320, -0.0213, -0.0946],\n",
+       "                        [-0.1106,  0.1463, -0.1642]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0985, -0.1160, -0.0915],\n",
+       "                        [ 0.1857,  0.0806,  0.1761],\n",
+       "                        [-0.0817,  0.1095,  0.0896]],\n",
+       "              \n",
+       "                       [[-0.0660, -0.1680,  0.1833],\n",
+       "                        [ 0.0611,  0.0077, -0.0848],\n",
+       "                        [-0.1516,  0.1737,  0.0484]],\n",
+       "              \n",
+       "                       [[ 0.1434, -0.0732, -0.0904],\n",
+       "                        [ 0.0962,  0.1783,  0.0192],\n",
+       "                        [ 0.0915,  0.0006,  0.0334]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0047,  0.1807, -0.1798],\n",
+       "                        [-0.0164,  0.1119, -0.0805],\n",
+       "                        [ 0.1855, -0.0681, -0.0187]],\n",
+       "              \n",
+       "                       [[-0.0069,  0.0491, -0.1868],\n",
+       "                        [-0.1609, -0.0316,  0.0150],\n",
+       "                        [-0.1605,  0.1506, -0.0074]],\n",
+       "              \n",
+       "                       [[ 0.0851, -0.1732, -0.1777],\n",
+       "                        [ 0.0539, -0.0500, -0.1231],\n",
+       "                        [ 0.1654,  0.0342, -0.1904]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0476,  0.0284,  0.1212],\n",
+       "                        [-0.1603, -0.1924,  0.0144],\n",
+       "                        [ 0.0076, -0.0928, -0.1645]],\n",
+       "              \n",
+       "                       [[ 0.0215,  0.1845, -0.1034],\n",
+       "                        [ 0.1574, -0.1577, -0.0438],\n",
+       "                        [-0.1360, -0.0601, -0.1693]],\n",
+       "              \n",
+       "                       [[-0.0720,  0.0619,  0.1405],\n",
+       "                        [ 0.0699, -0.1288,  0.0041],\n",
+       "                        [-0.0381, -0.1697, -0.1568]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1599,  0.1231, -0.1034],\n",
+       "                        [-0.0314,  0.0105, -0.1449],\n",
+       "                        [-0.0172, -0.0781,  0.0839]],\n",
+       "              \n",
+       "                       [[-0.0676,  0.1185, -0.1559],\n",
+       "                        [-0.1053, -0.1306,  0.1820],\n",
+       "                        [ 0.1584, -0.1370,  0.1828]],\n",
+       "              \n",
+       "                       [[ 0.0658,  0.1412, -0.0537],\n",
+       "                        [-0.1230, -0.1411, -0.0011],\n",
+       "                        [-0.1318, -0.0458,  0.1838]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0268,  0.1747, -0.1037],\n",
+       "                        [ 0.0515, -0.0228, -0.1024],\n",
+       "                        [-0.1543, -0.0643, -0.0100]],\n",
+       "              \n",
+       "                       [[-0.1572, -0.1530,  0.0026],\n",
+       "                        [ 0.1463, -0.1233,  0.0470],\n",
+       "                        [-0.1595, -0.1108, -0.0654]],\n",
+       "              \n",
+       "                       [[-0.0521, -0.0094,  0.1544],\n",
+       "                        [-0.0505, -0.0332,  0.0048],\n",
+       "                        [ 0.0735,  0.1350,  0.0690]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0025,  0.0724,  0.0930],\n",
+       "                        [-0.1885,  0.0475,  0.1100],\n",
+       "                        [-0.1622,  0.0087, -0.0030]],\n",
+       "              \n",
+       "                       [[ 0.1032, -0.1425, -0.0620],\n",
+       "                        [ 0.1515, -0.0736, -0.1888],\n",
+       "                        [-0.1246,  0.1424, -0.0491]],\n",
+       "              \n",
+       "                       [[ 0.1759, -0.1616,  0.1198],\n",
+       "                        [-0.1103,  0.1032,  0.1727],\n",
+       "                        [-0.0601,  0.1635,  0.0034]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0301,  0.1517,  0.0657],\n",
+       "                        [-0.1368, -0.1165,  0.1193],\n",
+       "                        [-0.0962,  0.1451,  0.1099]],\n",
+       "              \n",
+       "                       [[ 0.1646, -0.1860, -0.1187],\n",
+       "                        [-0.1367, -0.0911,  0.1337],\n",
+       "                        [-0.0926, -0.0524, -0.0672]],\n",
+       "              \n",
+       "                       [[-0.1509, -0.1231, -0.0855],\n",
+       "                        [ 0.1808, -0.0713,  0.0410],\n",
+       "                        [-0.0621, -0.0506,  0.1871]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0888, -0.0874, -0.0826],\n",
+       "                        [ 0.0416, -0.0961,  0.0603],\n",
+       "                        [ 0.1455,  0.0050,  0.0318]],\n",
+       "              \n",
+       "                       [[-0.1633,  0.0070, -0.1537],\n",
+       "                        [-0.0109,  0.1602, -0.0463],\n",
+       "                        [-0.0423, -0.0147, -0.1045]],\n",
+       "              \n",
+       "                       [[ 0.1640, -0.0997, -0.1662],\n",
+       "                        [-0.1074,  0.1549, -0.1905],\n",
+       "                        [-0.1708,  0.1624,  0.0219]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0824, -0.1376,  0.1086],\n",
+       "                        [ 0.0836,  0.0135,  0.0351],\n",
+       "                        [-0.1518,  0.0784, -0.1708]],\n",
+       "              \n",
+       "                       [[-0.1636, -0.1571,  0.1032],\n",
+       "                        [-0.1152,  0.0274, -0.1022],\n",
+       "                        [-0.0956, -0.1606, -0.1615]],\n",
+       "              \n",
+       "                       [[ 0.1307,  0.0419,  0.1924],\n",
+       "                        [-0.0599, -0.1296, -0.0448],\n",
+       "                        [ 0.0363,  0.0377, -0.0460]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1685, -0.1277, -0.0465],\n",
+       "                        [ 0.0922, -0.1011,  0.0742],\n",
+       "                        [ 0.0053, -0.1456,  0.0135]],\n",
+       "              \n",
+       "                       [[ 0.1341,  0.0131,  0.1281],\n",
+       "                        [-0.1020,  0.1069, -0.0631],\n",
+       "                        [-0.0439, -0.1189, -0.1822]],\n",
+       "              \n",
+       "                       [[ 0.1624, -0.1253,  0.0302],\n",
+       "                        [ 0.0709,  0.0767,  0.1453],\n",
+       "                        [ 0.0203,  0.1603, -0.1720]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1550,  0.1513, -0.1003],\n",
+       "                        [ 0.0370,  0.0367, -0.0233],\n",
+       "                        [ 0.0916, -0.0871,  0.1579]],\n",
+       "              \n",
+       "                       [[-0.1900,  0.0314,  0.0865],\n",
+       "                        [-0.0197,  0.0296, -0.0048],\n",
+       "                        [ 0.0846,  0.1543, -0.0770]],\n",
+       "              \n",
+       "                       [[-0.0016, -0.0978,  0.1826],\n",
+       "                        [-0.0477,  0.0689,  0.1079],\n",
+       "                        [ 0.0400,  0.0880,  0.1674]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0145, -0.0447, -0.1742],\n",
+       "                        [ 0.0394,  0.0127, -0.1172],\n",
+       "                        [ 0.1330, -0.1207,  0.0326]],\n",
+       "              \n",
+       "                       [[-0.0155, -0.1602,  0.0023],\n",
+       "                        [ 0.0789,  0.1648,  0.1781],\n",
+       "                        [-0.1468, -0.0481, -0.1260]],\n",
+       "              \n",
+       "                       [[-0.0139,  0.0848, -0.0536],\n",
+       "                        [-0.1581,  0.1130,  0.0717],\n",
+       "                        [ 0.0275, -0.0006, -0.0049]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0199,  0.0032, -0.1246],\n",
+       "                        [ 0.0479,  0.1418, -0.1295],\n",
+       "                        [-0.1646, -0.1139, -0.1018]],\n",
+       "              \n",
+       "                       [[ 0.1475,  0.1413, -0.0354],\n",
+       "                        [ 0.0612, -0.1652,  0.0801],\n",
+       "                        [-0.1306, -0.0165,  0.1733]],\n",
+       "              \n",
+       "                       [[ 0.1527,  0.0911, -0.1906],\n",
+       "                        [-0.1152,  0.1737,  0.0436],\n",
+       "                        [-0.0213, -0.0314, -0.0319]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0003, -0.0546, -0.1255],\n",
+       "                        [ 0.0914, -0.1414,  0.0542],\n",
+       "                        [ 0.1139,  0.0132,  0.0815]],\n",
+       "              \n",
+       "                       [[-0.0042,  0.0541,  0.1456],\n",
+       "                        [ 0.0509, -0.0790,  0.0272],\n",
+       "                        [ 0.1419,  0.0992, -0.1448]],\n",
+       "              \n",
+       "                       [[ 0.0496,  0.0013,  0.0838],\n",
+       "                        [-0.0662,  0.0315, -0.1168],\n",
+       "                        [-0.0069, -0.1503,  0.0729]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.1866,  0.1329, -0.0560],\n",
+       "                        [ 0.0026,  0.1533,  0.0326],\n",
+       "                        [-0.1161, -0.0323,  0.0053]],\n",
+       "              \n",
+       "                       [[-0.0243, -0.1823, -0.1657],\n",
+       "                        [-0.0107, -0.0832,  0.0029],\n",
+       "                        [ 0.0981,  0.1241, -0.1788]],\n",
+       "              \n",
+       "                       [[-0.0400, -0.0577, -0.0757],\n",
+       "                        [-0.0584,  0.0176, -0.1019],\n",
+       "                        [-0.1828,  0.1589, -0.0312]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1083, -0.1236, -0.0904],\n",
+       "                        [-0.1575,  0.0157,  0.0552],\n",
+       "                        [-0.0839,  0.1704, -0.1457]],\n",
+       "              \n",
+       "                       [[-0.1648, -0.0270, -0.0489],\n",
+       "                        [-0.1122, -0.0288, -0.0073],\n",
+       "                        [-0.1443, -0.1712,  0.0100]],\n",
+       "              \n",
+       "                       [[-0.1142, -0.1552,  0.1568],\n",
+       "                        [ 0.0743, -0.1108, -0.0643],\n",
+       "                        [-0.0394, -0.1345,  0.0992]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1591,  0.0942, -0.1035],\n",
+       "                        [-0.0781,  0.0725, -0.0888],\n",
+       "                        [ 0.0959,  0.0213,  0.1222]],\n",
+       "              \n",
+       "                       [[ 0.1202, -0.0217, -0.0955],\n",
+       "                        [-0.1748, -0.1133, -0.0704],\n",
+       "                        [-0.0670, -0.1401,  0.1553]],\n",
+       "              \n",
+       "                       [[ 0.0053, -0.0871, -0.0239],\n",
+       "                        [ 0.0961, -0.0547,  0.1741],\n",
+       "                        [-0.0570,  0.0477,  0.1853]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1115, -0.0183, -0.1302],\n",
+       "                        [ 0.1435, -0.0238, -0.0048],\n",
+       "                        [ 0.1862, -0.1837,  0.1711]],\n",
+       "              \n",
+       "                       [[ 0.1375, -0.1798,  0.0818],\n",
+       "                        [-0.0792,  0.0820,  0.1373],\n",
+       "                        [ 0.1849,  0.0672, -0.1822]],\n",
+       "              \n",
+       "                       [[ 0.1868, -0.0356,  0.0726],\n",
+       "                        [-0.1523, -0.1130,  0.1506],\n",
+       "                        [-0.1046,  0.0178,  0.0990]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.1321, -0.1641,  0.0411],\n",
+       "                        [ 0.0526,  0.0393,  0.0918],\n",
+       "                        [-0.1345,  0.0750,  0.0859]],\n",
+       "              \n",
+       "                       [[-0.0985,  0.1466,  0.1349],\n",
+       "                        [-0.1461, -0.1742,  0.0941],\n",
+       "                        [-0.1502, -0.1813,  0.0864]],\n",
+       "              \n",
+       "                       [[-0.1039,  0.1179,  0.1499],\n",
+       "                        [-0.0366, -0.0120,  0.0951],\n",
+       "                        [ 0.0087,  0.1212, -0.0183]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1375,  0.0765, -0.0072],\n",
+       "                        [-0.0041,  0.0379, -0.0243],\n",
+       "                        [-0.1495,  0.1601,  0.1575]],\n",
+       "              \n",
+       "                       [[-0.0454,  0.1642,  0.0720],\n",
+       "                        [-0.0533,  0.0150,  0.0039],\n",
+       "                        [ 0.0194,  0.0113, -0.1194]],\n",
+       "              \n",
+       "                       [[ 0.0527, -0.0886,  0.0359],\n",
+       "                        [ 0.1595,  0.0526, -0.0048],\n",
+       "                        [-0.1790, -0.0458, -0.0324]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1638,  0.0942,  0.0686],\n",
+       "                        [-0.1082, -0.0675,  0.1892],\n",
+       "                        [-0.1347, -0.1247,  0.0739]],\n",
+       "              \n",
+       "                       [[ 0.0595,  0.1504, -0.1657],\n",
+       "                        [ 0.0733,  0.0529, -0.1599],\n",
+       "                        [ 0.0171, -0.1127, -0.0259]],\n",
+       "              \n",
+       "                       [[-0.0092,  0.0193,  0.1176],\n",
+       "                        [-0.1183,  0.0101,  0.1011],\n",
+       "                        [ 0.0648, -0.1897,  0.0782]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0098, -0.1161, -0.0802],\n",
+       "                        [-0.1821,  0.0221, -0.1754],\n",
+       "                        [-0.1218,  0.0525, -0.0480]],\n",
+       "              \n",
+       "                       [[ 0.0770,  0.0477,  0.1514],\n",
+       "                        [ 0.0374, -0.1075, -0.1026],\n",
+       "                        [-0.0581, -0.1011,  0.1241]],\n",
+       "              \n",
+       "                       [[-0.0567, -0.0163,  0.0374],\n",
+       "                        [-0.1739, -0.0579,  0.0704],\n",
+       "                        [ 0.1817,  0.1561,  0.1677]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0569, -0.0763,  0.0044],\n",
+       "                        [-0.1133,  0.0813,  0.1477],\n",
+       "                        [ 0.0836,  0.0483, -0.1800]],\n",
+       "              \n",
+       "                       [[ 0.1343, -0.1590,  0.1177],\n",
+       "                        [ 0.1071, -0.1647, -0.0646],\n",
+       "                        [ 0.1578, -0.1261,  0.0243]],\n",
+       "              \n",
+       "                       [[-0.0424, -0.0241, -0.0988],\n",
+       "                        [ 0.0023,  0.0029, -0.0291],\n",
+       "                        [ 0.0415, -0.0557,  0.1427]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1028,  0.1054,  0.1658],\n",
+       "                        [-0.0357,  0.1579,  0.1237],\n",
+       "                        [ 0.0368,  0.0532, -0.1043]],\n",
+       "              \n",
+       "                       [[-0.0369,  0.0575, -0.1023],\n",
+       "                        [ 0.0635,  0.1015,  0.1112],\n",
+       "                        [-0.1235,  0.0467,  0.0908]],\n",
+       "              \n",
+       "                       [[ 0.1380,  0.0633,  0.1087],\n",
+       "                        [-0.1360,  0.0422, -0.1524],\n",
+       "                        [ 0.0819,  0.0918, -0.1624]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.1584, -0.0218, -0.0236],\n",
+       "                        [ 0.1878, -0.1289,  0.1343],\n",
+       "                        [ 0.0351,  0.1225, -0.1460]],\n",
+       "              \n",
+       "                       [[ 0.0690, -0.1439,  0.0056],\n",
+       "                        [ 0.0272, -0.0058,  0.0125],\n",
+       "                        [ 0.0868, -0.0684, -0.0884]],\n",
+       "              \n",
+       "                       [[ 0.1045,  0.0583, -0.0870],\n",
+       "                        [ 0.0600, -0.0732, -0.1695],\n",
+       "                        [ 0.0953,  0.0246,  0.1245]]]])),\n",
+       "             ('conv1.bias',\n",
+       "              tensor([-1.8698e-01, -7.9379e-06, -1.9277e-02,  5.2182e-02,  7.5716e-02,\n",
+       "                      -3.3830e-03, -9.6565e-02,  1.0241e-01, -8.2457e-02, -1.6224e-01,\n",
+       "                       1.2980e-01, -8.2256e-02, -7.4655e-02, -3.7980e-02,  8.3407e-02,\n",
+       "                      -1.4880e-01,  4.8939e-02,  2.7506e-02,  5.8676e-03, -1.5813e-01,\n",
+       "                      -6.2464e-04,  1.0359e-02, -1.5525e-01,  7.9100e-02,  1.6850e-02,\n",
+       "                      -1.3809e-01, -6.3393e-02, -5.3843e-02, -1.5219e-02, -1.7365e-01,\n",
+       "                       1.7249e-01, -1.1165e-01])),\n",
+       "             ('conv2.weight',\n",
+       "              tensor([[[[ 4.3750e-02,  4.5533e-02, -2.9410e-02],\n",
+       "                        [-4.1395e-02,  5.0397e-04, -1.3265e-02],\n",
+       "                        [-4.9851e-02, -1.0518e-02,  5.7710e-02]],\n",
+       "              \n",
+       "                       [[-5.6332e-02, -4.7168e-03, -4.4627e-02],\n",
+       "                        [ 5.3513e-03, -4.0824e-02,  1.8281e-02],\n",
+       "                        [ 5.0677e-02, -1.5295e-02, -6.1751e-03]],\n",
+       "              \n",
+       "                       [[-2.4984e-02,  1.2784e-02, -4.7123e-02],\n",
+       "                        [-4.3238e-02,  4.7349e-02, -1.5219e-02],\n",
+       "                        [-3.6073e-02,  4.1506e-02, -3.5337e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 5.6048e-02,  1.9129e-03, -1.7200e-02],\n",
+       "                        [ 5.8869e-02, -5.1520e-02, -5.3205e-02],\n",
+       "                        [-1.3903e-02,  5.1790e-02,  2.2585e-02]],\n",
+       "              \n",
+       "                       [[ 1.1835e-02, -4.9313e-02, -3.1838e-02],\n",
+       "                        [ 7.6813e-03,  4.2715e-02, -5.7404e-02],\n",
+       "                        [-4.1474e-02, -2.3128e-02, -4.7935e-02]],\n",
+       "              \n",
+       "                       [[-2.1860e-02, -2.1817e-02, -3.2578e-02],\n",
+       "                        [ 3.1317e-02,  3.3435e-02,  3.1837e-02],\n",
+       "                        [-2.2399e-03,  3.1600e-02,  4.0183e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 1.9610e-04,  5.3780e-02, -4.5810e-02],\n",
+       "                        [ 4.0340e-02,  1.4904e-02, -1.5597e-02],\n",
+       "                        [-4.6080e-02,  5.0714e-02, -5.7445e-03]],\n",
+       "              \n",
+       "                       [[-3.5281e-02,  3.3011e-02,  4.3343e-02],\n",
+       "                        [-4.6263e-02, -5.6184e-02,  5.1245e-03],\n",
+       "                        [ 3.6015e-02, -3.3152e-02,  4.6629e-03]],\n",
+       "              \n",
+       "                       [[ 1.7650e-03, -4.2336e-02,  4.3744e-02],\n",
+       "                        [ 2.1655e-02,  5.3759e-02,  1.3719e-03],\n",
+       "                        [ 4.2005e-02,  5.3998e-02,  1.9009e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 8.9786e-03, -1.8645e-02, -1.3587e-02],\n",
+       "                        [-5.4167e-02,  1.2335e-02, -3.0384e-02],\n",
+       "                        [-4.8722e-03, -3.7296e-02, -2.6446e-02]],\n",
+       "              \n",
+       "                       [[ 1.7580e-02,  3.8462e-02, -5.0269e-02],\n",
+       "                        [ 2.6601e-03, -1.1462e-02,  4.7459e-02],\n",
+       "                        [-2.8888e-02,  3.4436e-02, -4.9943e-02]],\n",
+       "              \n",
+       "                       [[-5.0206e-02, -5.6025e-02, -3.6346e-02],\n",
+       "                        [-2.4407e-02,  5.3721e-02, -5.4920e-02],\n",
+       "                        [ 5.1835e-02, -3.2396e-02,  3.2373e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-2.7759e-02, -5.4737e-02, -1.1689e-02],\n",
+       "                        [ 3.9462e-02,  2.8649e-02,  5.1776e-02],\n",
+       "                        [ 2.4253e-02, -2.8318e-02,  2.7402e-02]],\n",
+       "              \n",
+       "                       [[ 1.3045e-02, -1.0456e-02,  2.0426e-02],\n",
+       "                        [ 2.1949e-02,  4.6817e-02, -5.6093e-02],\n",
+       "                        [ 2.7145e-02, -5.5441e-02, -2.0719e-02]],\n",
+       "              \n",
+       "                       [[ 4.4704e-02, -2.4099e-02, -4.7185e-02],\n",
+       "                        [-4.3257e-02, -3.3058e-02, -8.6451e-03],\n",
+       "                        [-3.7283e-02, -3.4569e-02, -7.1049e-03]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-4.1559e-02, -2.9240e-02,  2.7197e-03],\n",
+       "                        [ 2.0770e-02,  5.4479e-02, -4.4845e-02],\n",
+       "                        [-1.1641e-02, -2.9814e-02, -2.4419e-02]],\n",
+       "              \n",
+       "                       [[-1.5743e-02,  1.0854e-02,  3.0878e-02],\n",
+       "                        [ 2.2739e-02,  3.2999e-02, -1.1902e-02],\n",
+       "                        [-3.4837e-02,  1.5305e-02, -8.7552e-03]],\n",
+       "              \n",
+       "                       [[-2.2882e-02,  9.4639e-03,  5.1878e-03],\n",
+       "                        [-2.6344e-02,  2.9063e-02, -1.9337e-02],\n",
+       "                        [-3.4314e-02,  1.5313e-02,  4.1524e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      ...,\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 9.7454e-03, -3.2903e-03,  1.0696e-02],\n",
+       "                        [-4.0918e-02,  1.6352e-02,  1.4646e-02],\n",
+       "                        [ 1.2516e-02, -2.1804e-02, -2.5489e-02]],\n",
+       "              \n",
+       "                       [[-1.6083e-02,  2.5374e-02,  3.1458e-02],\n",
+       "                        [-3.1497e-02, -1.9513e-02, -2.1223e-02],\n",
+       "                        [ 6.6286e-03,  1.6538e-02, -4.8944e-02]],\n",
+       "              \n",
+       "                       [[ 2.4808e-02, -2.9520e-02, -4.8227e-02],\n",
+       "                        [ 1.7325e-03, -4.7443e-02,  2.3087e-03],\n",
+       "                        [-1.0008e-02, -2.0313e-02,  2.9944e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 5.5781e-02, -2.0898e-02, -3.1487e-03],\n",
+       "                        [-1.6931e-02,  4.5279e-04, -1.5024e-02],\n",
+       "                        [-5.5885e-02,  2.7140e-02, -8.5434e-03]],\n",
+       "              \n",
+       "                       [[ 1.3970e-02, -3.3131e-02,  4.3112e-02],\n",
+       "                        [-3.4956e-02, -5.0144e-02, -1.6391e-02],\n",
+       "                        [-9.1003e-03, -2.0204e-02, -1.0226e-03]],\n",
+       "              \n",
+       "                       [[-4.0053e-02, -5.0194e-02,  5.0405e-02],\n",
+       "                        [ 5.4107e-02,  4.2185e-02,  3.4359e-02],\n",
+       "                        [ 1.6749e-02, -1.4102e-02,  5.0171e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 2.8229e-02,  5.6586e-02, -3.9617e-03],\n",
+       "                        [ 2.9538e-02, -1.2507e-02, -2.5516e-02],\n",
+       "                        [-1.5193e-02, -2.9232e-02, -2.0701e-02]],\n",
+       "              \n",
+       "                       [[-5.8773e-02,  3.3015e-02, -9.4146e-03],\n",
+       "                        [ 2.8957e-02,  5.8666e-02,  2.8679e-02],\n",
+       "                        [ 1.5249e-02, -1.2246e-03,  1.2230e-03]],\n",
+       "              \n",
+       "                       [[ 2.6050e-02, -4.6042e-02, -3.4895e-03],\n",
+       "                        [ 4.9529e-02,  6.6835e-03, -4.1808e-02],\n",
+       "                        [-8.6450e-03, -4.8510e-02, -2.4011e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-5.1427e-02,  2.4528e-02, -5.4878e-02],\n",
+       "                        [-1.8610e-02,  5.4365e-02,  3.5053e-03],\n",
+       "                        [-3.9922e-02,  4.2510e-02, -5.7261e-02]],\n",
+       "              \n",
+       "                       [[ 4.1938e-02, -4.2039e-02, -1.2487e-02],\n",
+       "                        [-1.4090e-02, -3.7895e-02,  1.4394e-02],\n",
+       "                        [ 2.2555e-02, -2.7264e-02,  5.6102e-02]],\n",
+       "              \n",
+       "                       [[ 1.5770e-02,  5.4672e-02, -2.4056e-02],\n",
+       "                        [ 5.2089e-02, -2.8859e-02, -2.6499e-03],\n",
+       "                        [-5.2122e-02, -3.7436e-02,  3.9897e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 2.7888e-02,  2.9241e-02, -1.9488e-02],\n",
+       "                        [ 2.8928e-02,  5.3312e-02, -2.9810e-02],\n",
+       "                        [-8.5104e-03,  5.7751e-02, -8.1857e-03]],\n",
+       "              \n",
+       "                       [[ 4.2649e-02,  3.3158e-03,  4.2879e-02],\n",
+       "                        [ 7.7893e-03, -3.2879e-02,  2.7630e-02],\n",
+       "                        [ 5.4706e-03,  4.8019e-02,  1.2420e-02]],\n",
+       "              \n",
+       "                       [[-4.2004e-02, -4.2790e-02,  2.4634e-02],\n",
+       "                        [-5.4641e-02,  3.4600e-02,  2.9071e-03],\n",
+       "                        [ 2.6470e-02,  4.6701e-02,  3.7158e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-1.7641e-02, -2.1205e-02, -5.1504e-02],\n",
+       "                        [-7.4737e-03,  5.5061e-02, -2.6397e-03],\n",
+       "                        [-4.4653e-02, -3.6719e-02,  3.4420e-06]],\n",
+       "              \n",
+       "                       [[ 1.6525e-02,  1.7280e-02,  5.4554e-03],\n",
+       "                        [ 4.0098e-02,  2.7571e-02, -4.4965e-02],\n",
+       "                        [ 6.1493e-03, -5.7754e-02,  1.0513e-02]],\n",
+       "              \n",
+       "                       [[-5.7615e-02,  3.2921e-02, -1.5900e-02],\n",
+       "                        [ 2.0081e-02,  5.4590e-02,  1.1296e-02],\n",
+       "                        [-4.5015e-02,  1.1341e-03,  2.6447e-02]]]])),\n",
+       "             ('conv2.bias',\n",
+       "              tensor([-0.0538, -0.0320, -0.0153,  0.0558,  0.0254,  0.0281, -0.0148,  0.0060,\n",
+       "                      -0.0283, -0.0062,  0.0437, -0.0064,  0.0341,  0.0233, -0.0201,  0.0391,\n",
+       "                       0.0243,  0.0071,  0.0125, -0.0138, -0.0377, -0.0169, -0.0475, -0.0004,\n",
+       "                      -0.0105, -0.0502,  0.0241,  0.0090,  0.0069, -0.0315, -0.0192,  0.0204])),\n",
+       "             ('conv3.weight',\n",
+       "              tensor([[[[-0.0215, -0.0208, -0.0272],\n",
+       "                        [-0.0493, -0.0117, -0.0285],\n",
+       "                        [ 0.0515, -0.0041,  0.0126]],\n",
+       "              \n",
+       "                       [[ 0.0299, -0.0301,  0.0552],\n",
+       "                        [ 0.0450,  0.0449, -0.0583],\n",
+       "                        [-0.0452, -0.0480, -0.0275]],\n",
+       "              \n",
+       "                       [[-0.0262, -0.0338,  0.0505],\n",
+       "                        [ 0.0146, -0.0364, -0.0044],\n",
+       "                        [-0.0102, -0.0051,  0.0017]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-0.0367, -0.0468, -0.0586],\n",
+       "                        [ 0.0126,  0.0037,  0.0191],\n",
+       "                        [-0.0153,  0.0048, -0.0160]],\n",
+       "              \n",
+       "                       [[-0.0050,  0.0364,  0.0582],\n",
+       "                        [ 0.0093, -0.0268, -0.0355],\n",
+       "                        [-0.0125,  0.0500,  0.0009]],\n",
+       "              \n",
+       "                       [[ 0.0237, -0.0211, -0.0130],\n",
+       "                        [-0.0489,  0.0118,  0.0387],\n",
+       "                        [-0.0006,  0.0301,  0.0283]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0391, -0.0464, -0.0158],\n",
+       "                        [ 0.0201, -0.0054,  0.0422],\n",
+       "                        [ 0.0085, -0.0474, -0.0251]],\n",
+       "              \n",
+       "                       [[-0.0346,  0.0536, -0.0391],\n",
+       "                        [ 0.0244, -0.0263, -0.0073],\n",
+       "                        [ 0.0076,  0.0160,  0.0044]],\n",
+       "              \n",
+       "                       [[-0.0128,  0.0146, -0.0381],\n",
+       "                        [-0.0277, -0.0142,  0.0226],\n",
+       "                        [ 0.0190,  0.0326, -0.0219]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-0.0217,  0.0129,  0.0558],\n",
+       "                        [ 0.0164, -0.0292, -0.0467],\n",
+       "                        [-0.0296,  0.0205, -0.0300]],\n",
+       "              \n",
+       "                       [[ 0.0254, -0.0151, -0.0583],\n",
+       "                        [ 0.0111, -0.0469, -0.0300],\n",
+       "                        [-0.0462,  0.0293, -0.0351]],\n",
+       "              \n",
+       "                       [[ 0.0401, -0.0251,  0.0160],\n",
+       "                        [-0.0160, -0.0195, -0.0065],\n",
+       "                        [-0.0519,  0.0351,  0.0357]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0544, -0.0209, -0.0454],\n",
+       "                        [ 0.0287, -0.0205, -0.0294],\n",
+       "                        [-0.0195, -0.0235, -0.0378]],\n",
+       "              \n",
+       "                       [[-0.0294, -0.0380,  0.0301],\n",
+       "                        [ 0.0360,  0.0367,  0.0458],\n",
+       "                        [-0.0189, -0.0017,  0.0145]],\n",
+       "              \n",
+       "                       [[-0.0297,  0.0567,  0.0276],\n",
+       "                        [ 0.0298,  0.0383,  0.0227],\n",
+       "                        [ 0.0262,  0.0063,  0.0131]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 0.0002,  0.0432, -0.0247],\n",
+       "                        [ 0.0068, -0.0298, -0.0484],\n",
+       "                        [-0.0361, -0.0014,  0.0444]],\n",
+       "              \n",
+       "                       [[-0.0184, -0.0201, -0.0163],\n",
+       "                        [-0.0466,  0.0255, -0.0244],\n",
+       "                        [ 0.0283,  0.0149, -0.0588]],\n",
+       "              \n",
+       "                       [[ 0.0323,  0.0392, -0.0254],\n",
+       "                        [ 0.0560,  0.0137, -0.0401],\n",
+       "                        [-0.0236,  0.0589,  0.0448]]],\n",
+       "              \n",
+       "              \n",
+       "                      ...,\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0213,  0.0204,  0.0574],\n",
+       "                        [-0.0276, -0.0196,  0.0117],\n",
+       "                        [ 0.0569, -0.0158, -0.0502]],\n",
+       "              \n",
+       "                       [[ 0.0452, -0.0038,  0.0502],\n",
+       "                        [ 0.0428, -0.0398, -0.0486],\n",
+       "                        [ 0.0130,  0.0563,  0.0576]],\n",
+       "              \n",
+       "                       [[ 0.0484, -0.0535,  0.0048],\n",
+       "                        [ 0.0268, -0.0290, -0.0390],\n",
+       "                        [ 0.0189, -0.0194, -0.0588]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 0.0163, -0.0113, -0.0520],\n",
+       "                        [ 0.0288, -0.0547, -0.0544],\n",
+       "                        [ 0.0442,  0.0376,  0.0566]],\n",
+       "              \n",
+       "                       [[-0.0343,  0.0569,  0.0438],\n",
+       "                        [-0.0403, -0.0372, -0.0532],\n",
+       "                        [ 0.0322,  0.0126,  0.0423]],\n",
+       "              \n",
+       "                       [[ 0.0577,  0.0136, -0.0480],\n",
+       "                        [-0.0293, -0.0348,  0.0342],\n",
+       "                        [-0.0510, -0.0078, -0.0042]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0243,  0.0406,  0.0537],\n",
+       "                        [ 0.0209, -0.0059, -0.0487],\n",
+       "                        [-0.0425,  0.0339,  0.0444]],\n",
+       "              \n",
+       "                       [[ 0.0465, -0.0467,  0.0461],\n",
+       "                        [-0.0389,  0.0144, -0.0502],\n",
+       "                        [ 0.0274,  0.0552,  0.0356]],\n",
+       "              \n",
+       "                       [[-0.0289,  0.0474, -0.0217],\n",
+       "                        [ 0.0472, -0.0135,  0.0164],\n",
+       "                        [-0.0165, -0.0049, -0.0475]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-0.0377,  0.0267,  0.0367],\n",
+       "                        [ 0.0111,  0.0114, -0.0329],\n",
+       "                        [ 0.0031, -0.0223, -0.0280]],\n",
+       "              \n",
+       "                       [[-0.0500, -0.0529,  0.0116],\n",
+       "                        [ 0.0483,  0.0121, -0.0149],\n",
+       "                        [ 0.0328,  0.0201,  0.0402]],\n",
+       "              \n",
+       "                       [[ 0.0463,  0.0157, -0.0332],\n",
+       "                        [ 0.0150,  0.0479,  0.0461],\n",
+       "                        [ 0.0275,  0.0506, -0.0466]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0478,  0.0274,  0.0500],\n",
+       "                        [-0.0394,  0.0032, -0.0496],\n",
+       "                        [ 0.0381,  0.0391,  0.0330]],\n",
+       "              \n",
+       "                       [[ 0.0184, -0.0560, -0.0345],\n",
+       "                        [-0.0459, -0.0215,  0.0452],\n",
+       "                        [ 0.0049,  0.0537,  0.0544]],\n",
+       "              \n",
+       "                       [[-0.0413, -0.0084,  0.0585],\n",
+       "                        [ 0.0338, -0.0067, -0.0113],\n",
+       "                        [-0.0187, -0.0234, -0.0525]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-0.0389,  0.0325, -0.0538],\n",
+       "                        [ 0.0118,  0.0509,  0.0352],\n",
+       "                        [-0.0351, -0.0341, -0.0506]],\n",
+       "              \n",
+       "                       [[ 0.0136, -0.0349,  0.0082],\n",
+       "                        [ 0.0358, -0.0211,  0.0537],\n",
+       "                        [-0.0183,  0.0390, -0.0267]],\n",
+       "              \n",
+       "                       [[-0.0219, -0.0145, -0.0351],\n",
+       "                        [ 0.0556,  0.0033, -0.0030],\n",
+       "                        [ 0.0075, -0.0425, -0.0365]]]])),\n",
+       "             ('conv3.bias',\n",
+       "              tensor([ 0.0304,  0.0099, -0.0004,  0.0334,  0.0301,  0.0491,  0.0530, -0.0432,\n",
+       "                      -0.0127, -0.0549, -0.0419,  0.0159, -0.0284,  0.0295, -0.0148,  0.0275,\n",
+       "                       0.0554, -0.0056,  0.0389, -0.0264, -0.0383,  0.0126,  0.0320,  0.0312,\n",
+       "                       0.0018,  0.0560, -0.0329, -0.0155, -0.0391, -0.0539, -0.0571, -0.0254])),\n",
+       "             ('conv4.weight',\n",
+       "              tensor([[[[-3.8911e-02, -3.4220e-02,  4.2567e-03],\n",
+       "                        [-4.5321e-02, -5.2531e-02, -8.1722e-03],\n",
+       "                        [-2.2638e-02,  4.4213e-02,  5.6989e-02]],\n",
+       "              \n",
+       "                       [[ 1.8417e-03, -1.4453e-02,  4.9892e-02],\n",
+       "                        [ 5.7762e-02,  9.6610e-03, -3.9509e-02],\n",
+       "                        [ 3.3795e-02,  5.0409e-02, -5.8834e-02]],\n",
+       "              \n",
+       "                       [[ 4.6645e-03, -1.6286e-02,  4.3410e-02],\n",
+       "                        [-3.4043e-02, -2.2207e-02,  4.0967e-02],\n",
+       "                        [ 5.3004e-02, -2.2756e-02, -6.7993e-03]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 1.1741e-02, -5.5062e-02, -3.3625e-02],\n",
+       "                        [-9.2320e-03, -3.3036e-02,  3.3196e-02],\n",
+       "                        [ 2.3940e-02,  2.0442e-02,  1.4183e-02]],\n",
+       "              \n",
+       "                       [[-2.7139e-02, -3.4129e-03, -1.0090e-02],\n",
+       "                        [ 1.3073e-02, -1.6998e-02, -4.7540e-02],\n",
+       "                        [-2.5758e-02, -1.9363e-02,  2.1905e-02]],\n",
+       "              \n",
+       "                       [[ 5.7593e-02,  1.5013e-02, -5.7894e-02],\n",
+       "                        [ 5.7964e-02, -2.3412e-02,  2.6955e-02],\n",
+       "                        [-3.9814e-02, -4.6015e-02, -5.3240e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 5.8211e-02, -4.1118e-02,  2.7704e-02],\n",
+       "                        [ 5.7198e-02,  8.4165e-03, -5.1708e-02],\n",
+       "                        [ 3.1423e-02,  1.5026e-02,  3.5922e-02]],\n",
+       "              \n",
+       "                       [[ 8.8858e-03,  3.2818e-02,  5.4486e-02],\n",
+       "                        [-2.6636e-02,  2.2604e-02,  2.9531e-02],\n",
+       "                        [-1.0327e-03,  2.2348e-03,  2.4103e-02]],\n",
+       "              \n",
+       "                       [[ 3.8683e-02, -5.0057e-03,  5.0224e-02],\n",
+       "                        [ 3.5756e-02, -2.7295e-02, -2.2854e-02],\n",
+       "                        [-3.2043e-02, -3.2415e-02,  4.1034e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-1.9791e-02,  4.3243e-02, -3.5177e-02],\n",
+       "                        [ 2.4554e-02,  4.2845e-03,  4.8009e-02],\n",
+       "                        [ 2.4897e-03,  3.9550e-02, -3.0833e-02]],\n",
+       "              \n",
+       "                       [[ 4.5807e-02,  6.5845e-03,  9.3362e-05],\n",
+       "                        [-1.9411e-02, -2.9161e-02,  5.0828e-02],\n",
+       "                        [ 1.2028e-03,  2.1260e-02, -4.3710e-03]],\n",
+       "              \n",
+       "                       [[-4.8702e-02, -2.0571e-02, -3.5162e-02],\n",
+       "                        [-2.5856e-02, -2.6619e-02,  8.1867e-03],\n",
+       "                        [-2.7671e-02, -9.6651e-03, -5.3279e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-5.6432e-02,  2.3722e-02, -2.1750e-02],\n",
+       "                        [-4.8247e-03, -2.1226e-02, -1.0829e-02],\n",
+       "                        [-1.9523e-02, -1.8187e-02,  2.2772e-03]],\n",
+       "              \n",
+       "                       [[-1.0907e-02,  3.3984e-02, -1.2088e-02],\n",
+       "                        [-1.8657e-02, -4.8297e-02,  2.3614e-02],\n",
+       "                        [-3.9670e-02,  6.1733e-03, -2.9168e-02]],\n",
+       "              \n",
+       "                       [[-4.2112e-02,  2.8203e-02, -1.7385e-03],\n",
+       "                        [-2.5282e-02, -9.4592e-05,  6.5093e-03],\n",
+       "                        [-4.1745e-02,  4.3988e-03, -1.1622e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 4.2991e-02,  1.6912e-02, -4.3689e-02],\n",
+       "                        [ 5.1871e-02,  4.8566e-02,  3.6205e-02],\n",
+       "                        [-3.2016e-02, -1.3596e-02, -2.7950e-02]],\n",
+       "              \n",
+       "                       [[-2.8307e-02, -4.0278e-02,  1.5087e-02],\n",
+       "                        [ 4.0443e-02, -3.5727e-02,  3.7196e-02],\n",
+       "                        [-1.4194e-02, -2.7319e-02, -5.1305e-02]],\n",
+       "              \n",
+       "                       [[-2.9962e-02,  2.4693e-02, -4.4912e-02],\n",
+       "                        [ 5.5890e-03,  4.6671e-02,  3.3599e-02],\n",
+       "                        [-3.9949e-02, -4.4716e-02,  2.2345e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      ...,\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 3.1920e-02, -4.9932e-02, -1.0871e-02],\n",
+       "                        [-3.7500e-02,  4.1638e-02, -1.3246e-02],\n",
+       "                        [ 1.6447e-02, -5.6741e-02, -3.7524e-02]],\n",
+       "              \n",
+       "                       [[ 3.3903e-02, -3.1321e-02, -4.4877e-02],\n",
+       "                        [-2.2473e-02, -2.4225e-02,  4.5838e-02],\n",
+       "                        [-2.0069e-02,  3.8338e-02,  5.8010e-02]],\n",
+       "              \n",
+       "                       [[ 1.7602e-02, -5.2530e-02,  4.9331e-02],\n",
+       "                        [ 2.4509e-02,  2.3943e-02, -2.1774e-02],\n",
+       "                        [-5.7154e-02,  5.7090e-02,  3.7531e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 2.8630e-02, -4.8644e-04, -5.3822e-02],\n",
+       "                        [-1.1102e-02,  4.8524e-02, -2.7142e-02],\n",
+       "                        [-5.3463e-02,  3.5607e-02, -1.1110e-02]],\n",
+       "              \n",
+       "                       [[ 4.7891e-02, -3.4098e-02,  3.6984e-02],\n",
+       "                        [ 3.9062e-02, -5.1119e-03, -3.3252e-02],\n",
+       "                        [ 5.5029e-02,  3.1092e-03, -5.9391e-03]],\n",
+       "              \n",
+       "                       [[ 9.6775e-03,  2.2903e-02, -1.5971e-02],\n",
+       "                        [-2.6969e-02,  8.5069e-04, -2.5744e-02],\n",
+       "                        [ 2.7311e-03, -2.2119e-02,  2.4367e-03]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 1.5206e-02, -5.3504e-02, -1.8013e-02],\n",
+       "                        [-3.0064e-02,  3.2887e-02,  1.7612e-02],\n",
+       "                        [-4.2775e-02, -2.7335e-02,  5.1532e-02]],\n",
+       "              \n",
+       "                       [[ 4.3325e-02,  1.2909e-03,  5.6831e-02],\n",
+       "                        [ 4.8283e-03, -4.4274e-02,  1.7624e-02],\n",
+       "                        [-7.2574e-03,  1.3743e-02, -5.1502e-02]],\n",
+       "              \n",
+       "                       [[-1.4299e-02, -3.0024e-02,  4.9578e-02],\n",
+       "                        [-4.6143e-02, -3.0686e-02, -5.0727e-02],\n",
+       "                        [ 5.4766e-02, -1.4012e-02,  3.3267e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 5.2773e-02,  4.0294e-02, -2.6113e-02],\n",
+       "                        [-2.5069e-02,  4.3956e-02, -4.7841e-02],\n",
+       "                        [ 3.0924e-02,  1.5174e-02, -4.8323e-02]],\n",
+       "              \n",
+       "                       [[-2.0325e-02,  3.2666e-02,  2.5174e-02],\n",
+       "                        [ 5.3775e-03, -3.2712e-02, -5.2251e-02],\n",
+       "                        [-2.7426e-02, -4.5502e-04, -3.2174e-02]],\n",
+       "              \n",
+       "                       [[ 8.5780e-03, -5.2099e-02,  5.7285e-02],\n",
+       "                        [-5.0897e-02, -5.3995e-03,  4.2719e-02],\n",
+       "                        [-5.2257e-02, -7.9682e-03,  2.1848e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-2.8660e-02,  1.2392e-02,  2.9940e-02],\n",
+       "                        [-1.1170e-02, -1.7499e-02, -5.3951e-02],\n",
+       "                        [ 4.9738e-02,  2.7240e-02,  2.9588e-03]],\n",
+       "              \n",
+       "                       [[ 3.7728e-02, -3.1084e-02,  4.6459e-02],\n",
+       "                        [-1.4961e-02,  1.6951e-02, -1.3976e-02],\n",
+       "                        [ 2.5768e-02, -3.8991e-02, -2.8738e-02]],\n",
+       "              \n",
+       "                       [[ 1.6084e-02, -7.3174e-03, -4.5839e-02],\n",
+       "                        [-3.6029e-02,  1.5303e-02, -5.4380e-02],\n",
+       "                        [ 2.1913e-02, -4.4792e-02,  4.7973e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 5.4945e-02, -4.5541e-02, -1.5806e-02],\n",
+       "                        [-4.5216e-02, -2.0338e-02,  3.3373e-02],\n",
+       "                        [-1.8431e-02,  4.3953e-02, -4.8196e-02]],\n",
+       "              \n",
+       "                       [[-3.6129e-02,  5.7705e-02, -1.2229e-02],\n",
+       "                        [-5.2801e-02,  7.0930e-03, -2.6721e-02],\n",
+       "                        [ 3.6720e-02,  7.5540e-04,  5.5401e-02]],\n",
+       "              \n",
+       "                       [[-4.0386e-02,  1.6714e-02,  2.9246e-02],\n",
+       "                        [ 4.7033e-02, -8.7923e-03,  5.1267e-02],\n",
+       "                        [-3.1211e-02, -2.7036e-02,  4.0346e-02]]]])),\n",
+       "             ('conv4.bias',\n",
+       "              tensor([-0.0368, -0.0535, -0.0568, -0.0364,  0.0167,  0.0255,  0.0031,  0.0562,\n",
+       "                      -0.0575,  0.0470, -0.0128,  0.0159, -0.0387,  0.0517, -0.0508, -0.0104,\n",
+       "                       0.0133, -0.0265,  0.0290,  0.0147, -0.0405,  0.0074,  0.0434, -0.0436,\n",
+       "                      -0.0295,  0.0417,  0.0324,  0.0316,  0.0143, -0.0312,  0.0463,  0.0441])),\n",
+       "             ('fc1.weight',\n",
+       "              tensor([[-0.0206, -0.0161, -0.0053,  ..., -0.0257, -0.0314, -0.0105],\n",
+       "                      [-0.0229,  0.0030,  0.0239,  ...,  0.0199, -0.0279,  0.0155]])),\n",
+       "             ('fc1.bias', tensor([-0.0333, -0.0126]))])"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "m1.state_dict()\n"
+   ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 28,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "OrderedDict([('conv1.weight',\n",
+       "              tensor([[[[-0.1261,  0.1833, -0.1406],\n",
+       "                        [ 0.1324, -0.0685,  0.0938],\n",
+       "                        [ 0.0432,  0.1814, -0.0541]],\n",
+       "              \n",
+       "                       [[-0.1776, -0.1839, -0.0111],\n",
+       "                        [ 0.0888,  0.0888, -0.1344],\n",
+       "                        [-0.1838,  0.1737,  0.1584]],\n",
+       "              \n",
+       "                       [[ 0.0417,  0.1064, -0.0156],\n",
+       "                        [ 0.0667,  0.0856, -0.1746],\n",
+       "                        [ 0.0412,  0.1620,  0.0125]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0530, -0.1273, -0.0797],\n",
+       "                        [ 0.0422,  0.1135,  0.0475],\n",
+       "                        [-0.0244,  0.1691, -0.1383]],\n",
+       "              \n",
+       "                       [[ 0.0822, -0.1317, -0.1692],\n",
+       "                        [ 0.1373,  0.1388,  0.0103],\n",
+       "                        [-0.0481,  0.1105,  0.0631]],\n",
+       "              \n",
+       "                       [[-0.0352,  0.1259, -0.0530],\n",
+       "                        [-0.1394, -0.0281,  0.1844],\n",
+       "                        [ 0.0082,  0.1187,  0.0211]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0987,  0.0788, -0.1126],\n",
+       "                        [ 0.1769,  0.0763, -0.1767],\n",
+       "                        [-0.0570,  0.1156,  0.1770]],\n",
+       "              \n",
+       "                       [[ 0.0643, -0.0024, -0.0625],\n",
+       "                        [ 0.0819,  0.0140, -0.1882],\n",
+       "                        [ 0.1325, -0.0632, -0.0202]],\n",
+       "              \n",
+       "                       [[ 0.0053,  0.1042, -0.0058],\n",
+       "                        [-0.1082, -0.1753,  0.1762],\n",
+       "                        [-0.0501,  0.1166,  0.0561]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0358, -0.0685, -0.1278],\n",
+       "                        [ 0.0029, -0.1107,  0.1169],\n",
+       "                        [-0.1408,  0.1293,  0.1142]],\n",
+       "              \n",
+       "                       [[-0.0814,  0.0470,  0.0188],\n",
+       "                        [ 0.1538,  0.0137,  0.1128],\n",
+       "                        [-0.1597,  0.1432,  0.1370]],\n",
+       "              \n",
+       "                       [[ 0.1425,  0.1769, -0.0037],\n",
+       "                        [-0.1080, -0.0805, -0.0195],\n",
+       "                        [-0.1335, -0.1666,  0.1399]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.1117,  0.1918, -0.1666],\n",
+       "                        [-0.1392,  0.0086,  0.0172],\n",
+       "                        [-0.0721, -0.1711,  0.0344]],\n",
+       "              \n",
+       "                       [[ 0.1820, -0.0537, -0.0974],\n",
+       "                        [ 0.0366, -0.0710,  0.1273],\n",
+       "                        [ 0.1132, -0.1594,  0.0878]],\n",
+       "              \n",
+       "                       [[-0.0874, -0.0401,  0.1827],\n",
+       "                        [-0.0301,  0.1205, -0.0396],\n",
+       "                        [-0.1143, -0.1007,  0.1561]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.1522, -0.0012, -0.1785],\n",
+       "                        [-0.1833, -0.1828, -0.1643],\n",
+       "                        [-0.1765, -0.1757, -0.0608]],\n",
+       "              \n",
+       "                       [[-0.0684,  0.0521,  0.1137],\n",
+       "                        [-0.0028,  0.0616,  0.0758],\n",
+       "                        [-0.1736,  0.0667,  0.1229]],\n",
+       "              \n",
+       "                       [[ 0.1298, -0.1848, -0.1570],\n",
+       "                        [-0.1052, -0.1172, -0.1223],\n",
+       "                        [-0.1389, -0.0095, -0.0410]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0213, -0.0975,  0.0964],\n",
+       "                        [ 0.0535, -0.0775,  0.0790],\n",
+       "                        [-0.1796, -0.1468,  0.1036]],\n",
+       "              \n",
+       "                       [[-0.0403,  0.0646, -0.0932],\n",
+       "                        [ 0.1779, -0.1616,  0.0644],\n",
+       "                        [-0.0508, -0.1158, -0.0592]],\n",
+       "              \n",
+       "                       [[-0.1644, -0.1327,  0.0817],\n",
+       "                        [ 0.0320, -0.0213, -0.0946],\n",
+       "                        [-0.1106,  0.1463, -0.1642]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0985, -0.1160, -0.0915],\n",
+       "                        [ 0.1857,  0.0806,  0.1761],\n",
+       "                        [-0.0817,  0.1095,  0.0896]],\n",
+       "              \n",
+       "                       [[-0.0660, -0.1680,  0.1833],\n",
+       "                        [ 0.0611,  0.0077, -0.0848],\n",
+       "                        [-0.1516,  0.1737,  0.0484]],\n",
+       "              \n",
+       "                       [[ 0.1434, -0.0732, -0.0904],\n",
+       "                        [ 0.0962,  0.1783,  0.0192],\n",
+       "                        [ 0.0915,  0.0006,  0.0334]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0047,  0.1807, -0.1798],\n",
+       "                        [-0.0164,  0.1119, -0.0805],\n",
+       "                        [ 0.1855, -0.0681, -0.0187]],\n",
+       "              \n",
+       "                       [[-0.0069,  0.0491, -0.1868],\n",
+       "                        [-0.1609, -0.0316,  0.0150],\n",
+       "                        [-0.1605,  0.1506, -0.0074]],\n",
+       "              \n",
+       "                       [[ 0.0851, -0.1732, -0.1777],\n",
+       "                        [ 0.0539, -0.0500, -0.1231],\n",
+       "                        [ 0.1654,  0.0342, -0.1904]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0476,  0.0284,  0.1212],\n",
+       "                        [-0.1603, -0.1924,  0.0144],\n",
+       "                        [ 0.0076, -0.0928, -0.1645]],\n",
+       "              \n",
+       "                       [[ 0.0215,  0.1845, -0.1034],\n",
+       "                        [ 0.1574, -0.1577, -0.0438],\n",
+       "                        [-0.1360, -0.0601, -0.1693]],\n",
+       "              \n",
+       "                       [[-0.0720,  0.0619,  0.1405],\n",
+       "                        [ 0.0699, -0.1288,  0.0041],\n",
+       "                        [-0.0381, -0.1697, -0.1568]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1599,  0.1231, -0.1034],\n",
+       "                        [-0.0314,  0.0105, -0.1449],\n",
+       "                        [-0.0172, -0.0781,  0.0839]],\n",
+       "              \n",
+       "                       [[-0.0676,  0.1185, -0.1559],\n",
+       "                        [-0.1053, -0.1306,  0.1820],\n",
+       "                        [ 0.1584, -0.1370,  0.1828]],\n",
+       "              \n",
+       "                       [[ 0.0658,  0.1412, -0.0537],\n",
+       "                        [-0.1230, -0.1411, -0.0011],\n",
+       "                        [-0.1318, -0.0458,  0.1838]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0268,  0.1747, -0.1037],\n",
+       "                        [ 0.0515, -0.0228, -0.1024],\n",
+       "                        [-0.1543, -0.0643, -0.0100]],\n",
+       "              \n",
+       "                       [[-0.1572, -0.1530,  0.0026],\n",
+       "                        [ 0.1463, -0.1233,  0.0470],\n",
+       "                        [-0.1595, -0.1108, -0.0654]],\n",
+       "              \n",
+       "                       [[-0.0521, -0.0094,  0.1544],\n",
+       "                        [-0.0505, -0.0332,  0.0048],\n",
+       "                        [ 0.0735,  0.1350,  0.0690]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0025,  0.0724,  0.0930],\n",
+       "                        [-0.1885,  0.0475,  0.1100],\n",
+       "                        [-0.1622,  0.0087, -0.0030]],\n",
+       "              \n",
+       "                       [[ 0.1032, -0.1425, -0.0620],\n",
+       "                        [ 0.1515, -0.0736, -0.1888],\n",
+       "                        [-0.1246,  0.1424, -0.0491]],\n",
+       "              \n",
+       "                       [[ 0.1759, -0.1616,  0.1198],\n",
+       "                        [-0.1103,  0.1032,  0.1727],\n",
+       "                        [-0.0601,  0.1635,  0.0034]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0301,  0.1517,  0.0657],\n",
+       "                        [-0.1368, -0.1165,  0.1193],\n",
+       "                        [-0.0962,  0.1451,  0.1099]],\n",
+       "              \n",
+       "                       [[ 0.1646, -0.1860, -0.1187],\n",
+       "                        [-0.1367, -0.0911,  0.1337],\n",
+       "                        [-0.0926, -0.0524, -0.0672]],\n",
+       "              \n",
+       "                       [[-0.1509, -0.1231, -0.0855],\n",
+       "                        [ 0.1808, -0.0713,  0.0410],\n",
+       "                        [-0.0621, -0.0506,  0.1871]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0888, -0.0874, -0.0826],\n",
+       "                        [ 0.0416, -0.0961,  0.0603],\n",
+       "                        [ 0.1455,  0.0050,  0.0318]],\n",
+       "              \n",
+       "                       [[-0.1633,  0.0070, -0.1537],\n",
+       "                        [-0.0109,  0.1602, -0.0463],\n",
+       "                        [-0.0423, -0.0147, -0.1045]],\n",
+       "              \n",
+       "                       [[ 0.1640, -0.0997, -0.1662],\n",
+       "                        [-0.1074,  0.1549, -0.1905],\n",
+       "                        [-0.1708,  0.1624,  0.0219]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0824, -0.1376,  0.1086],\n",
+       "                        [ 0.0836,  0.0135,  0.0351],\n",
+       "                        [-0.1518,  0.0784, -0.1708]],\n",
+       "              \n",
+       "                       [[-0.1636, -0.1571,  0.1032],\n",
+       "                        [-0.1152,  0.0274, -0.1022],\n",
+       "                        [-0.0956, -0.1606, -0.1615]],\n",
+       "              \n",
+       "                       [[ 0.1307,  0.0419,  0.1924],\n",
+       "                        [-0.0599, -0.1296, -0.0448],\n",
+       "                        [ 0.0363,  0.0377, -0.0460]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1685, -0.1277, -0.0465],\n",
+       "                        [ 0.0922, -0.1011,  0.0742],\n",
+       "                        [ 0.0053, -0.1456,  0.0135]],\n",
+       "              \n",
+       "                       [[ 0.1341,  0.0131,  0.1281],\n",
+       "                        [-0.1020,  0.1069, -0.0631],\n",
+       "                        [-0.0439, -0.1189, -0.1822]],\n",
+       "              \n",
+       "                       [[ 0.1624, -0.1253,  0.0302],\n",
+       "                        [ 0.0709,  0.0767,  0.1453],\n",
+       "                        [ 0.0203,  0.1603, -0.1720]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1550,  0.1513, -0.1003],\n",
+       "                        [ 0.0370,  0.0367, -0.0233],\n",
+       "                        [ 0.0916, -0.0871,  0.1579]],\n",
+       "              \n",
+       "                       [[-0.1900,  0.0314,  0.0865],\n",
+       "                        [-0.0197,  0.0296, -0.0048],\n",
+       "                        [ 0.0846,  0.1543, -0.0770]],\n",
+       "              \n",
+       "                       [[-0.0016, -0.0978,  0.1826],\n",
+       "                        [-0.0477,  0.0689,  0.1079],\n",
+       "                        [ 0.0400,  0.0880,  0.1674]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0145, -0.0447, -0.1742],\n",
+       "                        [ 0.0394,  0.0127, -0.1172],\n",
+       "                        [ 0.1330, -0.1207,  0.0326]],\n",
+       "              \n",
+       "                       [[-0.0155, -0.1602,  0.0023],\n",
+       "                        [ 0.0789,  0.1648,  0.1781],\n",
+       "                        [-0.1468, -0.0481, -0.1260]],\n",
+       "              \n",
+       "                       [[-0.0139,  0.0848, -0.0536],\n",
+       "                        [-0.1581,  0.1130,  0.0717],\n",
+       "                        [ 0.0275, -0.0006, -0.0049]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0199,  0.0032, -0.1246],\n",
+       "                        [ 0.0479,  0.1418, -0.1295],\n",
+       "                        [-0.1646, -0.1139, -0.1018]],\n",
+       "              \n",
+       "                       [[ 0.1475,  0.1413, -0.0354],\n",
+       "                        [ 0.0612, -0.1652,  0.0801],\n",
+       "                        [-0.1306, -0.0165,  0.1733]],\n",
+       "              \n",
+       "                       [[ 0.1527,  0.0911, -0.1906],\n",
+       "                        [-0.1152,  0.1737,  0.0436],\n",
+       "                        [-0.0213, -0.0314, -0.0319]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0003, -0.0546, -0.1255],\n",
+       "                        [ 0.0914, -0.1414,  0.0542],\n",
+       "                        [ 0.1139,  0.0132,  0.0815]],\n",
+       "              \n",
+       "                       [[-0.0042,  0.0541,  0.1456],\n",
+       "                        [ 0.0509, -0.0790,  0.0272],\n",
+       "                        [ 0.1419,  0.0992, -0.1448]],\n",
+       "              \n",
+       "                       [[ 0.0496,  0.0013,  0.0838],\n",
+       "                        [-0.0662,  0.0315, -0.1168],\n",
+       "                        [-0.0069, -0.1503,  0.0729]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.1866,  0.1329, -0.0560],\n",
+       "                        [ 0.0026,  0.1533,  0.0326],\n",
+       "                        [-0.1161, -0.0323,  0.0053]],\n",
+       "              \n",
+       "                       [[-0.0243, -0.1823, -0.1657],\n",
+       "                        [-0.0107, -0.0832,  0.0029],\n",
+       "                        [ 0.0981,  0.1241, -0.1788]],\n",
+       "              \n",
+       "                       [[-0.0400, -0.0577, -0.0757],\n",
+       "                        [-0.0584,  0.0176, -0.1019],\n",
+       "                        [-0.1828,  0.1589, -0.0312]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1083, -0.1236, -0.0904],\n",
+       "                        [-0.1575,  0.0157,  0.0552],\n",
+       "                        [-0.0839,  0.1704, -0.1457]],\n",
+       "              \n",
+       "                       [[-0.1648, -0.0270, -0.0489],\n",
+       "                        [-0.1122, -0.0288, -0.0073],\n",
+       "                        [-0.1443, -0.1712,  0.0100]],\n",
+       "              \n",
+       "                       [[-0.1142, -0.1552,  0.1568],\n",
+       "                        [ 0.0743, -0.1108, -0.0643],\n",
+       "                        [-0.0394, -0.1345,  0.0992]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1591,  0.0942, -0.1035],\n",
+       "                        [-0.0781,  0.0725, -0.0888],\n",
+       "                        [ 0.0959,  0.0213,  0.1222]],\n",
+       "              \n",
+       "                       [[ 0.1202, -0.0217, -0.0955],\n",
+       "                        [-0.1748, -0.1133, -0.0704],\n",
+       "                        [-0.0670, -0.1401,  0.1553]],\n",
+       "              \n",
+       "                       [[ 0.0053, -0.0871, -0.0239],\n",
+       "                        [ 0.0961, -0.0547,  0.1741],\n",
+       "                        [-0.0570,  0.0477,  0.1853]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1115, -0.0183, -0.1302],\n",
+       "                        [ 0.1435, -0.0238, -0.0048],\n",
+       "                        [ 0.1862, -0.1837,  0.1711]],\n",
+       "              \n",
+       "                       [[ 0.1375, -0.1798,  0.0818],\n",
+       "                        [-0.0792,  0.0820,  0.1373],\n",
+       "                        [ 0.1849,  0.0672, -0.1822]],\n",
+       "              \n",
+       "                       [[ 0.1868, -0.0356,  0.0726],\n",
+       "                        [-0.1523, -0.1130,  0.1506],\n",
+       "                        [-0.1046,  0.0178,  0.0990]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.1321, -0.1641,  0.0411],\n",
+       "                        [ 0.0526,  0.0393,  0.0918],\n",
+       "                        [-0.1345,  0.0750,  0.0859]],\n",
+       "              \n",
+       "                       [[-0.0985,  0.1466,  0.1349],\n",
+       "                        [-0.1461, -0.1742,  0.0941],\n",
+       "                        [-0.1502, -0.1813,  0.0864]],\n",
+       "              \n",
+       "                       [[-0.1039,  0.1179,  0.1499],\n",
+       "                        [-0.0366, -0.0120,  0.0951],\n",
+       "                        [ 0.0087,  0.1212, -0.0183]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1375,  0.0765, -0.0072],\n",
+       "                        [-0.0041,  0.0379, -0.0243],\n",
+       "                        [-0.1495,  0.1601,  0.1575]],\n",
+       "              \n",
+       "                       [[-0.0454,  0.1642,  0.0720],\n",
+       "                        [-0.0533,  0.0150,  0.0039],\n",
+       "                        [ 0.0194,  0.0113, -0.1194]],\n",
+       "              \n",
+       "                       [[ 0.0527, -0.0886,  0.0359],\n",
+       "                        [ 0.1595,  0.0526, -0.0048],\n",
+       "                        [-0.1790, -0.0458, -0.0324]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1638,  0.0942,  0.0686],\n",
+       "                        [-0.1082, -0.0675,  0.1892],\n",
+       "                        [-0.1347, -0.1247,  0.0739]],\n",
+       "              \n",
+       "                       [[ 0.0595,  0.1504, -0.1657],\n",
+       "                        [ 0.0733,  0.0529, -0.1599],\n",
+       "                        [ 0.0171, -0.1127, -0.0259]],\n",
+       "              \n",
+       "                       [[-0.0092,  0.0193,  0.1176],\n",
+       "                        [-0.1183,  0.0101,  0.1011],\n",
+       "                        [ 0.0648, -0.1897,  0.0782]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0098, -0.1161, -0.0802],\n",
+       "                        [-0.1821,  0.0221, -0.1754],\n",
+       "                        [-0.1218,  0.0525, -0.0480]],\n",
+       "              \n",
+       "                       [[ 0.0770,  0.0477,  0.1514],\n",
+       "                        [ 0.0374, -0.1075, -0.1026],\n",
+       "                        [-0.0581, -0.1011,  0.1241]],\n",
+       "              \n",
+       "                       [[-0.0567, -0.0163,  0.0374],\n",
+       "                        [-0.1739, -0.0579,  0.0704],\n",
+       "                        [ 0.1817,  0.1561,  0.1677]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0569, -0.0763,  0.0044],\n",
+       "                        [-0.1133,  0.0813,  0.1477],\n",
+       "                        [ 0.0836,  0.0483, -0.1800]],\n",
+       "              \n",
+       "                       [[ 0.1343, -0.1590,  0.1177],\n",
+       "                        [ 0.1071, -0.1647, -0.0646],\n",
+       "                        [ 0.1578, -0.1261,  0.0243]],\n",
+       "              \n",
+       "                       [[-0.0424, -0.0241, -0.0988],\n",
+       "                        [ 0.0023,  0.0029, -0.0291],\n",
+       "                        [ 0.0415, -0.0557,  0.1427]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.1028,  0.1054,  0.1658],\n",
+       "                        [-0.0357,  0.1579,  0.1237],\n",
+       "                        [ 0.0368,  0.0532, -0.1043]],\n",
+       "              \n",
+       "                       [[-0.0369,  0.0575, -0.1023],\n",
+       "                        [ 0.0635,  0.1015,  0.1112],\n",
+       "                        [-0.1235,  0.0467,  0.0908]],\n",
+       "              \n",
+       "                       [[ 0.1380,  0.0633,  0.1087],\n",
+       "                        [-0.1360,  0.0422, -0.1524],\n",
+       "                        [ 0.0819,  0.0918, -0.1624]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.1584, -0.0218, -0.0236],\n",
+       "                        [ 0.1878, -0.1289,  0.1343],\n",
+       "                        [ 0.0351,  0.1225, -0.1460]],\n",
+       "              \n",
+       "                       [[ 0.0690, -0.1439,  0.0056],\n",
+       "                        [ 0.0272, -0.0058,  0.0125],\n",
+       "                        [ 0.0868, -0.0684, -0.0884]],\n",
+       "              \n",
+       "                       [[ 0.1045,  0.0583, -0.0870],\n",
+       "                        [ 0.0600, -0.0732, -0.1695],\n",
+       "                        [ 0.0953,  0.0246,  0.1245]]]])),\n",
+       "             ('conv1.bias',\n",
+       "              tensor([-1.8698e-01, -7.9379e-06, -1.9277e-02,  5.2182e-02,  7.5716e-02,\n",
+       "                      -3.3830e-03, -9.6565e-02,  1.0241e-01, -8.2457e-02, -1.6224e-01,\n",
+       "                       1.2980e-01, -8.2256e-02, -7.4655e-02, -3.7980e-02,  8.3407e-02,\n",
+       "                      -1.4880e-01,  4.8939e-02,  2.7506e-02,  5.8676e-03, -1.5813e-01,\n",
+       "                      -6.2464e-04,  1.0359e-02, -1.5525e-01,  7.9100e-02,  1.6850e-02,\n",
+       "                      -1.3809e-01, -6.3393e-02, -5.3843e-02, -1.5219e-02, -1.7365e-01,\n",
+       "                       1.7249e-01, -1.1165e-01])),\n",
+       "             ('conv2.weight',\n",
+       "              tensor([[[[ 4.3750e-02,  4.5533e-02, -2.9410e-02],\n",
+       "                        [-4.1395e-02,  5.0397e-04, -1.3265e-02],\n",
+       "                        [-4.9851e-02, -1.0518e-02,  5.7710e-02]],\n",
+       "              \n",
+       "                       [[-5.6332e-02, -4.7168e-03, -4.4627e-02],\n",
+       "                        [ 5.3513e-03, -4.0824e-02,  1.8281e-02],\n",
+       "                        [ 5.0677e-02, -1.5295e-02, -6.1751e-03]],\n",
+       "              \n",
+       "                       [[-2.4984e-02,  1.2784e-02, -4.7123e-02],\n",
+       "                        [-4.3238e-02,  4.7349e-02, -1.5219e-02],\n",
+       "                        [-3.6073e-02,  4.1506e-02, -3.5337e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 5.6048e-02,  1.9129e-03, -1.7200e-02],\n",
+       "                        [ 5.8869e-02, -5.1520e-02, -5.3205e-02],\n",
+       "                        [-1.3903e-02,  5.1790e-02,  2.2585e-02]],\n",
+       "              \n",
+       "                       [[ 1.1835e-02, -4.9313e-02, -3.1838e-02],\n",
+       "                        [ 7.6813e-03,  4.2715e-02, -5.7404e-02],\n",
+       "                        [-4.1474e-02, -2.3128e-02, -4.7935e-02]],\n",
+       "              \n",
+       "                       [[-2.1860e-02, -2.1817e-02, -3.2578e-02],\n",
+       "                        [ 3.1317e-02,  3.3435e-02,  3.1837e-02],\n",
+       "                        [-2.2399e-03,  3.1600e-02,  4.0183e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 1.9610e-04,  5.3780e-02, -4.5810e-02],\n",
+       "                        [ 4.0340e-02,  1.4904e-02, -1.5597e-02],\n",
+       "                        [-4.6080e-02,  5.0714e-02, -5.7445e-03]],\n",
+       "              \n",
+       "                       [[-3.5281e-02,  3.3011e-02,  4.3343e-02],\n",
+       "                        [-4.6263e-02, -5.6184e-02,  5.1245e-03],\n",
+       "                        [ 3.6015e-02, -3.3152e-02,  4.6629e-03]],\n",
+       "              \n",
+       "                       [[ 1.7650e-03, -4.2336e-02,  4.3744e-02],\n",
+       "                        [ 2.1655e-02,  5.3759e-02,  1.3719e-03],\n",
+       "                        [ 4.2005e-02,  5.3998e-02,  1.9009e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 8.9786e-03, -1.8645e-02, -1.3587e-02],\n",
+       "                        [-5.4167e-02,  1.2335e-02, -3.0384e-02],\n",
+       "                        [-4.8722e-03, -3.7296e-02, -2.6446e-02]],\n",
+       "              \n",
+       "                       [[ 1.7580e-02,  3.8462e-02, -5.0269e-02],\n",
+       "                        [ 2.6601e-03, -1.1462e-02,  4.7459e-02],\n",
+       "                        [-2.8888e-02,  3.4436e-02, -4.9943e-02]],\n",
+       "              \n",
+       "                       [[-5.0206e-02, -5.6025e-02, -3.6346e-02],\n",
+       "                        [-2.4407e-02,  5.3721e-02, -5.4920e-02],\n",
+       "                        [ 5.1835e-02, -3.2396e-02,  3.2373e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-2.7759e-02, -5.4737e-02, -1.1689e-02],\n",
+       "                        [ 3.9462e-02,  2.8649e-02,  5.1776e-02],\n",
+       "                        [ 2.4253e-02, -2.8318e-02,  2.7402e-02]],\n",
+       "              \n",
+       "                       [[ 1.3045e-02, -1.0456e-02,  2.0426e-02],\n",
+       "                        [ 2.1949e-02,  4.6817e-02, -5.6093e-02],\n",
+       "                        [ 2.7145e-02, -5.5441e-02, -2.0719e-02]],\n",
+       "              \n",
+       "                       [[ 4.4704e-02, -2.4099e-02, -4.7185e-02],\n",
+       "                        [-4.3257e-02, -3.3058e-02, -8.6451e-03],\n",
+       "                        [-3.7283e-02, -3.4569e-02, -7.1049e-03]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-4.1559e-02, -2.9240e-02,  2.7197e-03],\n",
+       "                        [ 2.0770e-02,  5.4479e-02, -4.4845e-02],\n",
+       "                        [-1.1641e-02, -2.9814e-02, -2.4419e-02]],\n",
+       "              \n",
+       "                       [[-1.5743e-02,  1.0854e-02,  3.0878e-02],\n",
+       "                        [ 2.2739e-02,  3.2999e-02, -1.1902e-02],\n",
+       "                        [-3.4837e-02,  1.5305e-02, -8.7552e-03]],\n",
+       "              \n",
+       "                       [[-2.2882e-02,  9.4639e-03,  5.1878e-03],\n",
+       "                        [-2.6344e-02,  2.9063e-02, -1.9337e-02],\n",
+       "                        [-3.4314e-02,  1.5313e-02,  4.1524e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      ...,\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 9.7454e-03, -3.2903e-03,  1.0696e-02],\n",
+       "                        [-4.0918e-02,  1.6352e-02,  1.4646e-02],\n",
+       "                        [ 1.2516e-02, -2.1804e-02, -2.5489e-02]],\n",
+       "              \n",
+       "                       [[-1.6083e-02,  2.5374e-02,  3.1458e-02],\n",
+       "                        [-3.1497e-02, -1.9513e-02, -2.1223e-02],\n",
+       "                        [ 6.6286e-03,  1.6538e-02, -4.8944e-02]],\n",
+       "              \n",
+       "                       [[ 2.4808e-02, -2.9520e-02, -4.8227e-02],\n",
+       "                        [ 1.7325e-03, -4.7443e-02,  2.3087e-03],\n",
+       "                        [-1.0008e-02, -2.0313e-02,  2.9944e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 5.5781e-02, -2.0898e-02, -3.1487e-03],\n",
+       "                        [-1.6931e-02,  4.5279e-04, -1.5024e-02],\n",
+       "                        [-5.5885e-02,  2.7140e-02, -8.5434e-03]],\n",
+       "              \n",
+       "                       [[ 1.3970e-02, -3.3131e-02,  4.3112e-02],\n",
+       "                        [-3.4956e-02, -5.0144e-02, -1.6391e-02],\n",
+       "                        [-9.1003e-03, -2.0204e-02, -1.0226e-03]],\n",
+       "              \n",
+       "                       [[-4.0053e-02, -5.0194e-02,  5.0405e-02],\n",
+       "                        [ 5.4107e-02,  4.2185e-02,  3.4359e-02],\n",
+       "                        [ 1.6749e-02, -1.4102e-02,  5.0171e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 2.8229e-02,  5.6586e-02, -3.9617e-03],\n",
+       "                        [ 2.9538e-02, -1.2507e-02, -2.5516e-02],\n",
+       "                        [-1.5193e-02, -2.9232e-02, -2.0701e-02]],\n",
+       "              \n",
+       "                       [[-5.8773e-02,  3.3015e-02, -9.4146e-03],\n",
+       "                        [ 2.8957e-02,  5.8666e-02,  2.8679e-02],\n",
+       "                        [ 1.5249e-02, -1.2246e-03,  1.2230e-03]],\n",
+       "              \n",
+       "                       [[ 2.6050e-02, -4.6042e-02, -3.4895e-03],\n",
+       "                        [ 4.9529e-02,  6.6835e-03, -4.1808e-02],\n",
+       "                        [-8.6450e-03, -4.8510e-02, -2.4011e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-5.1427e-02,  2.4528e-02, -5.4878e-02],\n",
+       "                        [-1.8610e-02,  5.4365e-02,  3.5053e-03],\n",
+       "                        [-3.9922e-02,  4.2510e-02, -5.7261e-02]],\n",
+       "              \n",
+       "                       [[ 4.1938e-02, -4.2039e-02, -1.2487e-02],\n",
+       "                        [-1.4090e-02, -3.7895e-02,  1.4394e-02],\n",
+       "                        [ 2.2555e-02, -2.7264e-02,  5.6102e-02]],\n",
+       "              \n",
+       "                       [[ 1.5770e-02,  5.4672e-02, -2.4056e-02],\n",
+       "                        [ 5.2089e-02, -2.8859e-02, -2.6499e-03],\n",
+       "                        [-5.2122e-02, -3.7436e-02,  3.9897e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 2.7888e-02,  2.9241e-02, -1.9488e-02],\n",
+       "                        [ 2.8928e-02,  5.3312e-02, -2.9810e-02],\n",
+       "                        [-8.5104e-03,  5.7751e-02, -8.1857e-03]],\n",
+       "              \n",
+       "                       [[ 4.2649e-02,  3.3158e-03,  4.2879e-02],\n",
+       "                        [ 7.7893e-03, -3.2879e-02,  2.7630e-02],\n",
+       "                        [ 5.4706e-03,  4.8019e-02,  1.2420e-02]],\n",
+       "              \n",
+       "                       [[-4.2004e-02, -4.2790e-02,  2.4634e-02],\n",
+       "                        [-5.4641e-02,  3.4600e-02,  2.9071e-03],\n",
+       "                        [ 2.6470e-02,  4.6701e-02,  3.7158e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-1.7641e-02, -2.1205e-02, -5.1504e-02],\n",
+       "                        [-7.4737e-03,  5.5061e-02, -2.6397e-03],\n",
+       "                        [-4.4653e-02, -3.6719e-02,  3.4420e-06]],\n",
+       "              \n",
+       "                       [[ 1.6525e-02,  1.7280e-02,  5.4554e-03],\n",
+       "                        [ 4.0098e-02,  2.7571e-02, -4.4965e-02],\n",
+       "                        [ 6.1493e-03, -5.7754e-02,  1.0513e-02]],\n",
+       "              \n",
+       "                       [[-5.7615e-02,  3.2921e-02, -1.5900e-02],\n",
+       "                        [ 2.0081e-02,  5.4590e-02,  1.1296e-02],\n",
+       "                        [-4.5015e-02,  1.1341e-03,  2.6447e-02]]]])),\n",
+       "             ('conv2.bias',\n",
+       "              tensor([-0.0538, -0.0320, -0.0153,  0.0558,  0.0254,  0.0281, -0.0148,  0.0060,\n",
+       "                      -0.0283, -0.0062,  0.0437, -0.0064,  0.0341,  0.0233, -0.0201,  0.0391,\n",
+       "                       0.0243,  0.0071,  0.0125, -0.0138, -0.0377, -0.0169, -0.0475, -0.0004,\n",
+       "                      -0.0105, -0.0502,  0.0241,  0.0090,  0.0069, -0.0315, -0.0192,  0.0204])),\n",
+       "             ('conv3.weight',\n",
+       "              tensor([[[[-0.0215, -0.0208, -0.0272],\n",
+       "                        [-0.0493, -0.0117, -0.0285],\n",
+       "                        [ 0.0515, -0.0041,  0.0126]],\n",
+       "              \n",
+       "                       [[ 0.0299, -0.0301,  0.0552],\n",
+       "                        [ 0.0450,  0.0449, -0.0583],\n",
+       "                        [-0.0452, -0.0480, -0.0275]],\n",
+       "              \n",
+       "                       [[-0.0262, -0.0338,  0.0505],\n",
+       "                        [ 0.0146, -0.0364, -0.0044],\n",
+       "                        [-0.0102, -0.0051,  0.0017]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-0.0367, -0.0468, -0.0586],\n",
+       "                        [ 0.0126,  0.0037,  0.0191],\n",
+       "                        [-0.0153,  0.0048, -0.0160]],\n",
+       "              \n",
+       "                       [[-0.0050,  0.0364,  0.0582],\n",
+       "                        [ 0.0093, -0.0268, -0.0355],\n",
+       "                        [-0.0125,  0.0500,  0.0009]],\n",
+       "              \n",
+       "                       [[ 0.0237, -0.0211, -0.0130],\n",
+       "                        [-0.0489,  0.0118,  0.0387],\n",
+       "                        [-0.0006,  0.0301,  0.0283]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0391, -0.0464, -0.0158],\n",
+       "                        [ 0.0201, -0.0054,  0.0422],\n",
+       "                        [ 0.0085, -0.0474, -0.0251]],\n",
+       "              \n",
+       "                       [[-0.0346,  0.0536, -0.0391],\n",
+       "                        [ 0.0244, -0.0263, -0.0073],\n",
+       "                        [ 0.0076,  0.0160,  0.0044]],\n",
+       "              \n",
+       "                       [[-0.0128,  0.0146, -0.0381],\n",
+       "                        [-0.0277, -0.0142,  0.0226],\n",
+       "                        [ 0.0190,  0.0326, -0.0219]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-0.0217,  0.0129,  0.0558],\n",
+       "                        [ 0.0164, -0.0292, -0.0467],\n",
+       "                        [-0.0296,  0.0205, -0.0300]],\n",
+       "              \n",
+       "                       [[ 0.0254, -0.0151, -0.0583],\n",
+       "                        [ 0.0111, -0.0469, -0.0300],\n",
+       "                        [-0.0462,  0.0293, -0.0351]],\n",
+       "              \n",
+       "                       [[ 0.0401, -0.0251,  0.0160],\n",
+       "                        [-0.0160, -0.0195, -0.0065],\n",
+       "                        [-0.0519,  0.0351,  0.0357]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0544, -0.0209, -0.0454],\n",
+       "                        [ 0.0287, -0.0205, -0.0294],\n",
+       "                        [-0.0195, -0.0235, -0.0378]],\n",
+       "              \n",
+       "                       [[-0.0294, -0.0380,  0.0301],\n",
+       "                        [ 0.0360,  0.0367,  0.0458],\n",
+       "                        [-0.0189, -0.0017,  0.0145]],\n",
+       "              \n",
+       "                       [[-0.0297,  0.0567,  0.0276],\n",
+       "                        [ 0.0298,  0.0383,  0.0227],\n",
+       "                        [ 0.0262,  0.0063,  0.0131]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 0.0002,  0.0432, -0.0247],\n",
+       "                        [ 0.0068, -0.0298, -0.0484],\n",
+       "                        [-0.0361, -0.0014,  0.0444]],\n",
+       "              \n",
+       "                       [[-0.0184, -0.0201, -0.0163],\n",
+       "                        [-0.0466,  0.0255, -0.0244],\n",
+       "                        [ 0.0283,  0.0149, -0.0588]],\n",
+       "              \n",
+       "                       [[ 0.0323,  0.0392, -0.0254],\n",
+       "                        [ 0.0560,  0.0137, -0.0401],\n",
+       "                        [-0.0236,  0.0589,  0.0448]]],\n",
+       "              \n",
+       "              \n",
+       "                      ...,\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 0.0213,  0.0204,  0.0574],\n",
+       "                        [-0.0276, -0.0196,  0.0117],\n",
+       "                        [ 0.0569, -0.0158, -0.0502]],\n",
+       "              \n",
+       "                       [[ 0.0452, -0.0038,  0.0502],\n",
+       "                        [ 0.0428, -0.0398, -0.0486],\n",
+       "                        [ 0.0130,  0.0563,  0.0576]],\n",
+       "              \n",
+       "                       [[ 0.0484, -0.0535,  0.0048],\n",
+       "                        [ 0.0268, -0.0290, -0.0390],\n",
+       "                        [ 0.0189, -0.0194, -0.0588]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 0.0163, -0.0113, -0.0520],\n",
+       "                        [ 0.0288, -0.0547, -0.0544],\n",
+       "                        [ 0.0442,  0.0376,  0.0566]],\n",
+       "              \n",
+       "                       [[-0.0343,  0.0569,  0.0438],\n",
+       "                        [-0.0403, -0.0372, -0.0532],\n",
+       "                        [ 0.0322,  0.0126,  0.0423]],\n",
+       "              \n",
+       "                       [[ 0.0577,  0.0136, -0.0480],\n",
+       "                        [-0.0293, -0.0348,  0.0342],\n",
+       "                        [-0.0510, -0.0078, -0.0042]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0243,  0.0406,  0.0537],\n",
+       "                        [ 0.0209, -0.0059, -0.0487],\n",
+       "                        [-0.0425,  0.0339,  0.0444]],\n",
+       "              \n",
+       "                       [[ 0.0465, -0.0467,  0.0461],\n",
+       "                        [-0.0389,  0.0144, -0.0502],\n",
+       "                        [ 0.0274,  0.0552,  0.0356]],\n",
+       "              \n",
+       "                       [[-0.0289,  0.0474, -0.0217],\n",
+       "                        [ 0.0472, -0.0135,  0.0164],\n",
+       "                        [-0.0165, -0.0049, -0.0475]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-0.0377,  0.0267,  0.0367],\n",
+       "                        [ 0.0111,  0.0114, -0.0329],\n",
+       "                        [ 0.0031, -0.0223, -0.0280]],\n",
+       "              \n",
+       "                       [[-0.0500, -0.0529,  0.0116],\n",
+       "                        [ 0.0483,  0.0121, -0.0149],\n",
+       "                        [ 0.0328,  0.0201,  0.0402]],\n",
+       "              \n",
+       "                       [[ 0.0463,  0.0157, -0.0332],\n",
+       "                        [ 0.0150,  0.0479,  0.0461],\n",
+       "                        [ 0.0275,  0.0506, -0.0466]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-0.0478,  0.0274,  0.0500],\n",
+       "                        [-0.0394,  0.0032, -0.0496],\n",
+       "                        [ 0.0381,  0.0391,  0.0330]],\n",
+       "              \n",
+       "                       [[ 0.0184, -0.0560, -0.0345],\n",
+       "                        [-0.0459, -0.0215,  0.0452],\n",
+       "                        [ 0.0049,  0.0537,  0.0544]],\n",
+       "              \n",
+       "                       [[-0.0413, -0.0084,  0.0585],\n",
+       "                        [ 0.0338, -0.0067, -0.0113],\n",
+       "                        [-0.0187, -0.0234, -0.0525]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-0.0389,  0.0325, -0.0538],\n",
+       "                        [ 0.0118,  0.0509,  0.0352],\n",
+       "                        [-0.0351, -0.0341, -0.0506]],\n",
+       "              \n",
+       "                       [[ 0.0136, -0.0349,  0.0082],\n",
+       "                        [ 0.0358, -0.0211,  0.0537],\n",
+       "                        [-0.0183,  0.0390, -0.0267]],\n",
+       "              \n",
+       "                       [[-0.0219, -0.0145, -0.0351],\n",
+       "                        [ 0.0556,  0.0033, -0.0030],\n",
+       "                        [ 0.0075, -0.0425, -0.0365]]]])),\n",
+       "             ('conv3.bias',\n",
+       "              tensor([ 0.0304,  0.0099, -0.0004,  0.0334,  0.0301,  0.0491,  0.0530, -0.0432,\n",
+       "                      -0.0127, -0.0549, -0.0419,  0.0159, -0.0284,  0.0295, -0.0148,  0.0275,\n",
+       "                       0.0554, -0.0056,  0.0389, -0.0264, -0.0383,  0.0126,  0.0320,  0.0312,\n",
+       "                       0.0018,  0.0560, -0.0329, -0.0155, -0.0391, -0.0539, -0.0571, -0.0254])),\n",
+       "             ('conv4.weight',\n",
+       "              tensor([[[[-3.8911e-02, -3.4220e-02,  4.2567e-03],\n",
+       "                        [-4.5321e-02, -5.2531e-02, -8.1722e-03],\n",
+       "                        [-2.2638e-02,  4.4213e-02,  5.6989e-02]],\n",
+       "              \n",
+       "                       [[ 1.8417e-03, -1.4453e-02,  4.9892e-02],\n",
+       "                        [ 5.7762e-02,  9.6610e-03, -3.9509e-02],\n",
+       "                        [ 3.3795e-02,  5.0409e-02, -5.8834e-02]],\n",
+       "              \n",
+       "                       [[ 4.6645e-03, -1.6286e-02,  4.3410e-02],\n",
+       "                        [-3.4043e-02, -2.2207e-02,  4.0967e-02],\n",
+       "                        [ 5.3004e-02, -2.2756e-02, -6.7993e-03]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 1.1741e-02, -5.5062e-02, -3.3625e-02],\n",
+       "                        [-9.2320e-03, -3.3036e-02,  3.3196e-02],\n",
+       "                        [ 2.3940e-02,  2.0442e-02,  1.4183e-02]],\n",
+       "              \n",
+       "                       [[-2.7139e-02, -3.4129e-03, -1.0090e-02],\n",
+       "                        [ 1.3073e-02, -1.6998e-02, -4.7540e-02],\n",
+       "                        [-2.5758e-02, -1.9363e-02,  2.1905e-02]],\n",
+       "              \n",
+       "                       [[ 5.7593e-02,  1.5013e-02, -5.7894e-02],\n",
+       "                        [ 5.7964e-02, -2.3412e-02,  2.6955e-02],\n",
+       "                        [-3.9814e-02, -4.6015e-02, -5.3240e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 5.8211e-02, -4.1118e-02,  2.7704e-02],\n",
+       "                        [ 5.7198e-02,  8.4165e-03, -5.1708e-02],\n",
+       "                        [ 3.1423e-02,  1.5026e-02,  3.5922e-02]],\n",
+       "              \n",
+       "                       [[ 8.8858e-03,  3.2818e-02,  5.4486e-02],\n",
+       "                        [-2.6636e-02,  2.2604e-02,  2.9531e-02],\n",
+       "                        [-1.0327e-03,  2.2348e-03,  2.4103e-02]],\n",
+       "              \n",
+       "                       [[ 3.8683e-02, -5.0057e-03,  5.0224e-02],\n",
+       "                        [ 3.5756e-02, -2.7295e-02, -2.2854e-02],\n",
+       "                        [-3.2043e-02, -3.2415e-02,  4.1034e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[-1.9791e-02,  4.3243e-02, -3.5177e-02],\n",
+       "                        [ 2.4554e-02,  4.2845e-03,  4.8009e-02],\n",
+       "                        [ 2.4897e-03,  3.9550e-02, -3.0833e-02]],\n",
+       "              \n",
+       "                       [[ 4.5807e-02,  6.5845e-03,  9.3362e-05],\n",
+       "                        [-1.9411e-02, -2.9161e-02,  5.0828e-02],\n",
+       "                        [ 1.2028e-03,  2.1260e-02, -4.3710e-03]],\n",
+       "              \n",
+       "                       [[-4.8702e-02, -2.0571e-02, -3.5162e-02],\n",
+       "                        [-2.5856e-02, -2.6619e-02,  8.1867e-03],\n",
+       "                        [-2.7671e-02, -9.6651e-03, -5.3279e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-5.6432e-02,  2.3722e-02, -2.1750e-02],\n",
+       "                        [-4.8247e-03, -2.1226e-02, -1.0829e-02],\n",
+       "                        [-1.9523e-02, -1.8187e-02,  2.2772e-03]],\n",
+       "              \n",
+       "                       [[-1.0907e-02,  3.3984e-02, -1.2088e-02],\n",
+       "                        [-1.8657e-02, -4.8297e-02,  2.3614e-02],\n",
+       "                        [-3.9670e-02,  6.1733e-03, -2.9168e-02]],\n",
+       "              \n",
+       "                       [[-4.2112e-02,  2.8203e-02, -1.7385e-03],\n",
+       "                        [-2.5282e-02, -9.4592e-05,  6.5093e-03],\n",
+       "                        [-4.1745e-02,  4.3988e-03, -1.1622e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 4.2991e-02,  1.6912e-02, -4.3689e-02],\n",
+       "                        [ 5.1871e-02,  4.8566e-02,  3.6205e-02],\n",
+       "                        [-3.2016e-02, -1.3596e-02, -2.7950e-02]],\n",
+       "              \n",
+       "                       [[-2.8307e-02, -4.0278e-02,  1.5087e-02],\n",
+       "                        [ 4.0443e-02, -3.5727e-02,  3.7196e-02],\n",
+       "                        [-1.4194e-02, -2.7319e-02, -5.1305e-02]],\n",
+       "              \n",
+       "                       [[-2.9962e-02,  2.4693e-02, -4.4912e-02],\n",
+       "                        [ 5.5890e-03,  4.6671e-02,  3.3599e-02],\n",
+       "                        [-3.9949e-02, -4.4716e-02,  2.2345e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      ...,\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 3.1920e-02, -4.9932e-02, -1.0871e-02],\n",
+       "                        [-3.7500e-02,  4.1638e-02, -1.3246e-02],\n",
+       "                        [ 1.6447e-02, -5.6741e-02, -3.7524e-02]],\n",
+       "              \n",
+       "                       [[ 3.3903e-02, -3.1321e-02, -4.4877e-02],\n",
+       "                        [-2.2473e-02, -2.4225e-02,  4.5838e-02],\n",
+       "                        [-2.0069e-02,  3.8338e-02,  5.8010e-02]],\n",
+       "              \n",
+       "                       [[ 1.7602e-02, -5.2530e-02,  4.9331e-02],\n",
+       "                        [ 2.4509e-02,  2.3943e-02, -2.1774e-02],\n",
+       "                        [-5.7154e-02,  5.7090e-02,  3.7531e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 2.8630e-02, -4.8644e-04, -5.3822e-02],\n",
+       "                        [-1.1102e-02,  4.8524e-02, -2.7142e-02],\n",
+       "                        [-5.3463e-02,  3.5607e-02, -1.1110e-02]],\n",
+       "              \n",
+       "                       [[ 4.7891e-02, -3.4098e-02,  3.6984e-02],\n",
+       "                        [ 3.9062e-02, -5.1119e-03, -3.3252e-02],\n",
+       "                        [ 5.5029e-02,  3.1092e-03, -5.9391e-03]],\n",
+       "              \n",
+       "                       [[ 9.6775e-03,  2.2903e-02, -1.5971e-02],\n",
+       "                        [-2.6969e-02,  8.5069e-04, -2.5744e-02],\n",
+       "                        [ 2.7311e-03, -2.2119e-02,  2.4367e-03]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[ 1.5206e-02, -5.3504e-02, -1.8013e-02],\n",
+       "                        [-3.0064e-02,  3.2887e-02,  1.7612e-02],\n",
+       "                        [-4.2775e-02, -2.7335e-02,  5.1532e-02]],\n",
+       "              \n",
+       "                       [[ 4.3325e-02,  1.2909e-03,  5.6831e-02],\n",
+       "                        [ 4.8283e-03, -4.4274e-02,  1.7624e-02],\n",
+       "                        [-7.2574e-03,  1.3743e-02, -5.1502e-02]],\n",
+       "              \n",
+       "                       [[-1.4299e-02, -3.0024e-02,  4.9578e-02],\n",
+       "                        [-4.6143e-02, -3.0686e-02, -5.0727e-02],\n",
+       "                        [ 5.4766e-02, -1.4012e-02,  3.3267e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 5.2773e-02,  4.0294e-02, -2.6113e-02],\n",
+       "                        [-2.5069e-02,  4.3956e-02, -4.7841e-02],\n",
+       "                        [ 3.0924e-02,  1.5174e-02, -4.8323e-02]],\n",
+       "              \n",
+       "                       [[-2.0325e-02,  3.2666e-02,  2.5174e-02],\n",
+       "                        [ 5.3775e-03, -3.2712e-02, -5.2251e-02],\n",
+       "                        [-2.7426e-02, -4.5502e-04, -3.2174e-02]],\n",
+       "              \n",
+       "                       [[ 8.5780e-03, -5.2099e-02,  5.7285e-02],\n",
+       "                        [-5.0897e-02, -5.3995e-03,  4.2719e-02],\n",
+       "                        [-5.2257e-02, -7.9682e-03,  2.1848e-02]]],\n",
+       "              \n",
+       "              \n",
+       "                      [[[-2.8660e-02,  1.2392e-02,  2.9940e-02],\n",
+       "                        [-1.1170e-02, -1.7499e-02, -5.3951e-02],\n",
+       "                        [ 4.9738e-02,  2.7240e-02,  2.9588e-03]],\n",
+       "              \n",
+       "                       [[ 3.7728e-02, -3.1084e-02,  4.6459e-02],\n",
+       "                        [-1.4961e-02,  1.6951e-02, -1.3976e-02],\n",
+       "                        [ 2.5768e-02, -3.8991e-02, -2.8738e-02]],\n",
+       "              \n",
+       "                       [[ 1.6084e-02, -7.3174e-03, -4.5839e-02],\n",
+       "                        [-3.6029e-02,  1.5303e-02, -5.4380e-02],\n",
+       "                        [ 2.1913e-02, -4.4792e-02,  4.7973e-02]],\n",
+       "              \n",
+       "                       ...,\n",
+       "              \n",
+       "                       [[ 5.4945e-02, -4.5541e-02, -1.5806e-02],\n",
+       "                        [-4.5216e-02, -2.0338e-02,  3.3373e-02],\n",
+       "                        [-1.8431e-02,  4.3953e-02, -4.8196e-02]],\n",
+       "              \n",
+       "                       [[-3.6129e-02,  5.7705e-02, -1.2229e-02],\n",
+       "                        [-5.2801e-02,  7.0930e-03, -2.6721e-02],\n",
+       "                        [ 3.6720e-02,  7.5540e-04,  5.5401e-02]],\n",
+       "              \n",
+       "                       [[-4.0386e-02,  1.6714e-02,  2.9246e-02],\n",
+       "                        [ 4.7033e-02, -8.7923e-03,  5.1267e-02],\n",
+       "                        [-3.1211e-02, -2.7036e-02,  4.0346e-02]]]])),\n",
+       "             ('conv4.bias',\n",
+       "              tensor([-0.0368, -0.0535, -0.0568, -0.0364,  0.0167,  0.0255,  0.0031,  0.0562,\n",
+       "                      -0.0575,  0.0470, -0.0128,  0.0159, -0.0387,  0.0517, -0.0508, -0.0104,\n",
+       "                       0.0133, -0.0265,  0.0290,  0.0147, -0.0405,  0.0074,  0.0434, -0.0436,\n",
+       "                      -0.0295,  0.0417,  0.0324,  0.0316,  0.0143, -0.0312,  0.0463,  0.0441])),\n",
+       "             ('fc1.weight',\n",
+       "              tensor([[-0.0206, -0.0161, -0.0053,  ..., -0.0257, -0.0314, -0.0105],\n",
+       "                      [-0.0229,  0.0030,  0.0239,  ...,  0.0199, -0.0279,  0.0155]])),\n",
+       "             ('fc1.bias', tensor([-0.5000, -1.2000]))])"
+      ]
+     },
+     "execution_count": 28,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
-    "o1.param_groups"
+    "m2 = m1.state_dict()\n",
+    "m2['fc1.bias'] = torch.Tensor([-0.5, -1.2])\n",
+    "m1.load_state_dict(m2)\n",
+    "m1.state_dict()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "{'state': {},\n",
+       " 'param_groups': [{'lr': 0.6,\n",
+       "   'momentum': 0,\n",
+       "   'dampening': 0,\n",
+       "   'weight_decay': 0,\n",
+       "   'nesterov': False,\n",
+       "   'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]}"
+      ]
+     },
+     "execution_count": 29,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "o1.state_dict()"
    ]
   },
   {
@@ -3698,6 +5636,59 @@
     "torch.tensor([1,2,3]).shape[0]"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "sizes:  [0.2857142857142857, 0.2857142857142857, 0.4285714285714286]\n",
+      "2 3 \n",
+      "6 5 \n",
+      "1 4 7 \n"
+     ]
+    }
+   ],
+   "source": [
+    "from decentralizepy.datasets.Partitioner import DataPartitioner\n",
+    "l = [1, 2, 3, 4, 5, 6, 7]\n",
+    "e = len(l) // 3\n",
+    "frac = e / len(l)\n",
+    "sizes = [frac] * 3\n",
+    "sizes[-1] += 1.0 - frac * 3\n",
+    "print(\"sizes: \", sizes)\n",
+    "\n",
+    "for i in range(3):\n",
+    "    myPar = DataPartitioner(l, sizes).use(i)\n",
+    "    for j in range(len(myPar)):\n",
+    "        print(myPar.__getitem__(j), end=' ')\n",
+    "    print()\n",
+    "    \n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.5\n"
+     ]
+    }
+   ],
+   "source": [
+    "w = 1\n",
+    "p = 2\n",
+    "i = w/p\n",
+    "print(i)"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py
index 54a1af8..e78a3b1 100644
--- a/src/decentralizepy/communication/TCP.py
+++ b/src/decentralizepy/communication/TCP.py
@@ -15,6 +15,7 @@ class TCP(Communication):
     TCP Communication API
 
     """
+
     def addr(self, rank, machine_id):
         """
         Returns TCP address of the process.
diff --git a/src/decentralizepy/datasets/Celeba.py b/src/decentralizepy/datasets/Celeba.py
index 0b43268..baf3434 100644
--- a/src/decentralizepy/datasets/Celeba.py
+++ b/src/decentralizepy/datasets/Celeba.py
@@ -138,10 +138,13 @@ class Celeba(Dataset):
                 os.path.join(self.train_dir, cur_file)
             )
             for cur_client in clients:
+                logging.debug("Got data of client: {}".format(cur_client))
                 self.clients.append(cur_client)
                 my_train_data["x"].extend(self.process_x(train_data[cur_client]["x"]))
                 my_train_data["y"].extend(train_data[cur_client]["y"])
                 self.num_samples.append(len(train_data[cur_client]["y"]))
+
+        logging.debug("Initial shape of x: {}".format(np.array(my_train_data["x"], dtype=np.dtype("float32")).shape))
         self.train_x = (
             np.array(my_train_data["x"], dtype=np.dtype("float32"))
             .reshape(-1, IMAGE_DIM, IMAGE_DIM, CHANNELS)
@@ -409,6 +412,7 @@ class CNN(Model):
     Class for a CNN Model for Celeba
 
     """
+
     def __init__(self):
         """
         Constructor. Instantiates the CNN Model
diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py
index 5d0dc27..c0360c5 100644
--- a/src/decentralizepy/datasets/Femnist.py
+++ b/src/decentralizepy/datasets/Femnist.py
@@ -413,6 +413,7 @@ class CNN(Model):
     Class for a CNN Model for FEMNIST
 
     """
+
     def __init__(self):
         """
         Constructor. Instantiates the CNN Model
diff --git a/src/decentralizepy/models/Model.py b/src/decentralizepy/models/Model.py
index da83402..cf073bc 100644
--- a/src/decentralizepy/models/Model.py
+++ b/src/decentralizepy/models/Model.py
@@ -7,6 +7,7 @@ class Model(nn.Module):
     More fields can be added here
 
     """
+
     def __init__(self):
         """
         Constructor
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 183fbfe..eb7200a 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -42,42 +42,20 @@ class Node:
         plt.title(title)
         plt.savefig(filename)
 
-    def instantiate(
-        self,
-        rank: int,
-        machine_id: int,
-        mapping: Mapping,
-        graph: Graph,
-        config,
-        iterations=1,
-        log_dir=".",
-        log_level=logging.INFO,
-        test_after=5,
-        *args
-    ):
+    def init_log(self, log_dir, rank, log_level, force=True):
         """
-        Construct objects.
+        Instantiate Logging.
 
         Parameters
         ----------
-        rank : int
-            Rank of process local to the machine
-        machine_id : int
-            Machine ID on which the process in running
-        n_procs_local : int
-            Number of processes on current machine
-        mapping : decentralizepy.mappings
-            The object containing the mapping rank <--> uid
-        graph : decentralizepy.graphs
-            The object containing the global graph
-        config : dict
-            A dictionary of configurations.
         log_dir : str
             Logging directory
+        rank : rank : int
+            Rank of process local to the machine
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
-        args : optional
-            Other arguments
+        force : bool
+            Argument to logging.basicConfig()
 
         """
         log_file = os.path.join(log_dir, str(rank) + ".log")
@@ -88,20 +66,49 @@ class Node:
             force=True,
         )
 
-        logging.info("Started process.")
+    def cache_fields(
+        self, rank, machine_id, mapping, graph, iterations, log_dir, test_after
+    ):
+        """
+        Instantiate object field with arguments.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        log_dir : str
+            Logging directory
 
+        """
         self.rank = rank
         self.machine_id = machine_id
         self.graph = graph
         self.mapping = mapping
         self.uid = self.mapping.get_uid(rank, machine_id)
         self.log_dir = log_dir
+        self.iterations = iterations
+        self.test_after = test_after
 
         logging.debug("Rank: %d", self.rank)
         logging.debug("type(graph): %s", str(type(self.rank)))
         logging.debug("type(mapping): %s", str(type(self.mapping)))
 
-        dataset_configs = config["DATASET"]
+    def init_dataset_model(self, dataset_configs):
+        """
+        Instantiate dataset and model from config.
+
+        Parameters
+        ----------
+        dataset_configs : dict
+            Python dict containing dataset config params
+
+        """
         dataset_module = importlib.import_module(dataset_configs["dataset_package"])
         self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
         self.dataset_params = utils.remove_keys(
@@ -116,7 +123,16 @@ class Node:
         self.model_class = getattr(dataset_module, dataset_configs["model_class"])
         self.model = self.model_class()
 
-        optimizer_configs = config["OPTIMIZER_PARAMS"]
+    def init_optimizer(self, optimizer_configs):
+        """
+        Instantiate optimizer from config.
+
+        Parameters
+        ----------
+        optimizer_configs : dict
+            Python dict containing optimizer config params
+
+        """
         optimizer_module = importlib.import_module(
             optimizer_configs["optimizer_package"]
         )
@@ -130,7 +146,16 @@ class Node:
             self.model.parameters(), **self.optimizer_params
         )
 
-        train_configs = config["TRAIN_PARAMS"]
+    def init_trainer(self, train_configs):
+        """
+        Instantiate training module and loss from config.
+
+        Parameters
+        ----------
+        train_configs : dict
+            Python dict containing training config params
+
+        """
         train_module = importlib.import_module(train_configs["training_package"])
         train_class = getattr(train_module, train_configs["training_class"])
 
@@ -155,7 +180,16 @@ class Node:
             self.model, self.optimizer, self.loss, **train_params
         )
 
-        comm_configs = config["COMMUNICATION"]
+    def init_comm(self, comm_configs):
+        """
+        Instantiate communication module from config.
+
+        Parameters
+        ----------
+        comm_configs : dict
+            Python dict containing communication config params
+
+        """
         comm_module = importlib.import_module(comm_configs["comm_package"])
         comm_class = getattr(comm_module, comm_configs["comm_class"])
         comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
@@ -163,7 +197,16 @@ class Node:
             self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
         )
 
-        sharing_configs = config["SHARING"]
+    def init_sharing(self, sharing_configs):
+        """
+        Instantiate sharing module from config.
+
+        Parameters
+        ----------
+        sharing_configs : dict
+            Python dict containing sharing config params
+
+        """
         sharing_package = importlib.import_module(sharing_configs["sharing_package"])
         sharing_class = getattr(sharing_package, sharing_configs["sharing_class"])
         sharing_params = utils.remove_keys(
@@ -181,9 +224,53 @@ class Node:
             **sharing_params
         )
 
-        self.iterations = iterations
-        self.test_after = test_after
-        self.log_dir = log_dir
+    def instantiate(
+        self,
+        rank: int,
+        machine_id: int,
+        mapping: Mapping,
+        graph: Graph,
+        config,
+        iterations=1,
+        log_dir=".",
+        log_level=logging.INFO,
+        test_after=5,
+        *args
+    ):
+        """
+        Construct objects.
+
+        Parameters
+        ----------
+        rank : int
+            Rank of process local to the machine
+        machine_id : int
+            Machine ID on which the process in running
+        mapping : decentralizepy.mappings
+            The object containing the mapping rank <--> uid
+        graph : decentralizepy.graphs
+            The object containing the global graph
+        config : dict
+            A dictionary of configurations.
+        log_dir : str
+            Logging directory
+        log_level : logging.Level
+            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        args : optional
+            Other arguments
+
+        """
+        logging.info("Started process.")
+
+        self.cache_fields(
+            rank, machine_id, mapping, graph, iterations, log_dir, test_after
+        )
+        self.init_log(log_dir, rank, log_level)
+        self.init_dataset_model(config["DATASET"])
+        self.init_optimizer(config["OPTIMIZER_PARAMS"])
+        self.init_trainer(config["TRAIN_PARAMS"])
+        self.init_comm(config["COMMUNICATION"])
+        self.init_sharing(config["SHARING"])
 
     def run(self):
         """
diff --git a/src/decentralizepy/sharing/GrowingAlpha.py b/src/decentralizepy/sharing/GrowingAlpha.py
index 4979828..7fe7bf5 100644
--- a/src/decentralizepy/sharing/GrowingAlpha.py
+++ b/src/decentralizepy/sharing/GrowingAlpha.py
@@ -8,6 +8,7 @@ class GrowingAlpha(PartialModel):
     This class implements the basic growing partial model sharing using a linear function.
 
     """
+
     def __init__(
         self,
         rank,
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 7addc0e..7eef218 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -13,6 +13,7 @@ class PartialModel(Sharing):
     This class implements the vanilla version of partial model sharing.
 
     """
+
     def __init__(
         self,
         rank,
diff --git a/src/decentralizepy/training/GradientAccumulator.py b/src/decentralizepy/training/GradientAccumulator.py
index e4feff2..718e793 100644
--- a/src/decentralizepy/training/GradientAccumulator.py
+++ b/src/decentralizepy/training/GradientAccumulator.py
@@ -8,6 +8,7 @@ class GradientAccumulator(Training):
     This class implements the training module which also accumulates gradients of steps in a list.
 
     """
+
     def __init__(
         self,
         model,
diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py
index 52a8e9d..bd80773 100644
--- a/src/decentralizepy/training/Training.py
+++ b/src/decentralizepy/training/Training.py
@@ -109,7 +109,7 @@ class Training:
         self.optimizer.step()
         return loss_val.item()
 
-    def train_full(self, trainset):
+    def train_full(self, dataset):
         """
         One training iteration, goes through the entire dataset
 
@@ -120,9 +120,12 @@ class Training:
 
         """
         for epoch in range(self.rounds):
+            trainset = dataset.get_trainset(self.batch_size, self.shuffle)
             epoch_loss = 0.0
             count = 0
             for data, target in trainset:
+                logging.info("Starting minibatch {} with num_samples: {}".format(count, len(data)))
+                logging.info("Classes: {}".format(target))
                 epoch_loss += self.trainstep(data, target)
                 count += 1
             logging.info("Epoch: {} loss: {}".format(epoch, epoch_loss / count))
@@ -137,13 +140,13 @@ class Training:
             The training dataset. Should implement get_trainset(batch_size, shuffle)
 
         """
-        trainset = dataset.get_trainset(self.batch_size, self.shuffle)
 
         if self.full_epochs:
-            self.train_full(trainset)
+            self.train_full(dataset)
         else:
             iter_loss = 0.0
             count = 0
+            trainset = dataset.get_trainset(self.batch_size, self.shuffle)
             while count < self.rounds:
                 for data, target in trainset:
                     iter_loss += self.trainstep(data, target)
diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py
index eac1e17..c6bf149 100644
--- a/src/decentralizepy/utils.py
+++ b/src/decentralizepy/utils.py
@@ -16,7 +16,7 @@ def conditional_value(var, nul, default):
         The null value. Assigns default if var == nul
     default : any
         The default value
-    
+
     Returns
     -------
     type(var)
-- 
GitLab