From b0cfe04eea2afca02dccbe9af523672f6853c605 Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Thu, 27 Jan 2022 13:24:11 +0100 Subject: [PATCH] Improve Node --- README.rst | 2 +- eval/main.ipynb | 2009 ++++++++++++++++- src/decentralizepy/communication/TCP.py | 1 + src/decentralizepy/datasets/Celeba.py | 4 + src/decentralizepy/datasets/Femnist.py | 1 + src/decentralizepy/models/Model.py | 1 + src/decentralizepy/node/Node.py | 161 +- src/decentralizepy/sharing/GrowingAlpha.py | 1 + src/decentralizepy/sharing/PartialModel.py | 1 + .../training/GradientAccumulator.py | 1 + src/decentralizepy/training/Training.py | 9 +- src/decentralizepy/utils.py | 2 +- 12 files changed, 2142 insertions(+), 51 deletions(-) diff --git a/README.rst b/README.rst index d1ef408..38cba02 100644 --- a/README.rst +++ b/README.rst @@ -15,7 +15,7 @@ Setting up decentralizepy pip3 install --upgrade pip pip install --upgrade pip -* Install decentralizepy for development/ :: +* Install decentralizepy for development. :: pip3 install --editable .\[dev\] diff --git a/eval/main.ipynb b/eval/main.ipynb index e25d47c..db66e89 100644 --- a/eval/main.ipynb +++ b/eval/main.ipynb @@ -268,33 +268,1971 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CNN(\n", + " (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n", + " (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n", + " (conv3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n", + " (conv4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n", + " (fc1): Linear(in_features=800, out_features=2, bias=True)\n", + ")\n" + ] + } + ], "source": [ "from torch import multiprocessing as mp\n", "import torch\n", - "m1 = torch.nn.Linear(1,1)\n", + "from decentralizepy.datasets.Celeba import CNN\n", + "import numpy as np\n", + "m1 = CNN()\n", "o1 = torch.optim.SGD(m1.parameters(), 0.6)\n", "print(m1)\n", "\n", "\n", - "mp.spawn(fn = f, nprocs = 2, args=[m1, o1])\n", + "#mp.spawn(fn = f, nprocs = 2, args=[m1, o1])\n", "\n" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 25, "metadata": {}, - "source": [] + "outputs": [ + { + "data": { + "text/plain": [ + "OrderedDict([('conv1.weight',\n", + " tensor([[[[-0.1261, 0.1833, -0.1406],\n", + " [ 0.1324, -0.0685, 0.0938],\n", + " [ 0.0432, 0.1814, -0.0541]],\n", + " \n", + " [[-0.1776, -0.1839, -0.0111],\n", + " [ 0.0888, 0.0888, -0.1344],\n", + " [-0.1838, 0.1737, 0.1584]],\n", + " \n", + " [[ 0.0417, 0.1064, -0.0156],\n", + " [ 0.0667, 0.0856, -0.1746],\n", + " [ 0.0412, 0.1620, 0.0125]]],\n", + " \n", + " \n", + " [[[-0.0530, -0.1273, -0.0797],\n", + " [ 0.0422, 0.1135, 0.0475],\n", + " [-0.0244, 0.1691, -0.1383]],\n", + " \n", + " [[ 0.0822, -0.1317, -0.1692],\n", + " [ 0.1373, 0.1388, 0.0103],\n", + " [-0.0481, 0.1105, 0.0631]],\n", + " \n", + " [[-0.0352, 0.1259, -0.0530],\n", + " [-0.1394, -0.0281, 0.1844],\n", + " [ 0.0082, 0.1187, 0.0211]]],\n", + " \n", + " \n", + " [[[ 0.0987, 0.0788, -0.1126],\n", + " [ 0.1769, 0.0763, -0.1767],\n", + " [-0.0570, 0.1156, 0.1770]],\n", + " \n", + " [[ 0.0643, -0.0024, -0.0625],\n", + " [ 0.0819, 0.0140, -0.1882],\n", + " [ 0.1325, -0.0632, -0.0202]],\n", + " \n", + " [[ 0.0053, 0.1042, -0.0058],\n", + " [-0.1082, -0.1753, 0.1762],\n", + " [-0.0501, 0.1166, 0.0561]]],\n", + " \n", + " \n", + " [[[ 0.0358, -0.0685, -0.1278],\n", + " [ 0.0029, -0.1107, 0.1169],\n", + " [-0.1408, 0.1293, 0.1142]],\n", + " \n", + " [[-0.0814, 0.0470, 0.0188],\n", + " [ 0.1538, 0.0137, 0.1128],\n", + " [-0.1597, 0.1432, 0.1370]],\n", + " \n", + " [[ 0.1425, 0.1769, -0.0037],\n", + " [-0.1080, -0.0805, -0.0195],\n", + " [-0.1335, -0.1666, 0.1399]]],\n", + " \n", + " \n", + " [[[ 0.1117, 0.1918, -0.1666],\n", + " [-0.1392, 0.0086, 0.0172],\n", + " [-0.0721, -0.1711, 0.0344]],\n", + " \n", + " [[ 0.1820, -0.0537, -0.0974],\n", + " [ 0.0366, -0.0710, 0.1273],\n", + " [ 0.1132, -0.1594, 0.0878]],\n", + " \n", + " [[-0.0874, -0.0401, 0.1827],\n", + " [-0.0301, 0.1205, -0.0396],\n", + " [-0.1143, -0.1007, 0.1561]]],\n", + " \n", + " \n", + " [[[ 0.1522, -0.0012, -0.1785],\n", + " [-0.1833, -0.1828, -0.1643],\n", + " [-0.1765, -0.1757, -0.0608]],\n", + " \n", + " [[-0.0684, 0.0521, 0.1137],\n", + " [-0.0028, 0.0616, 0.0758],\n", + " [-0.1736, 0.0667, 0.1229]],\n", + " \n", + " [[ 0.1298, -0.1848, -0.1570],\n", + " [-0.1052, -0.1172, -0.1223],\n", + " [-0.1389, -0.0095, -0.0410]]],\n", + " \n", + " \n", + " [[[ 0.0213, -0.0975, 0.0964],\n", + " [ 0.0535, -0.0775, 0.0790],\n", + " [-0.1796, -0.1468, 0.1036]],\n", + " \n", + " [[-0.0403, 0.0646, -0.0932],\n", + " [ 0.1779, -0.1616, 0.0644],\n", + " [-0.0508, -0.1158, -0.0592]],\n", + " \n", + " [[-0.1644, -0.1327, 0.0817],\n", + " [ 0.0320, -0.0213, -0.0946],\n", + " [-0.1106, 0.1463, -0.1642]]],\n", + " \n", + " \n", + " [[[-0.0985, -0.1160, -0.0915],\n", + " [ 0.1857, 0.0806, 0.1761],\n", + " [-0.0817, 0.1095, 0.0896]],\n", + " \n", + " [[-0.0660, -0.1680, 0.1833],\n", + " [ 0.0611, 0.0077, -0.0848],\n", + " [-0.1516, 0.1737, 0.0484]],\n", + " \n", + " [[ 0.1434, -0.0732, -0.0904],\n", + " [ 0.0962, 0.1783, 0.0192],\n", + " [ 0.0915, 0.0006, 0.0334]]],\n", + " \n", + " \n", + " [[[-0.0047, 0.1807, -0.1798],\n", + " [-0.0164, 0.1119, -0.0805],\n", + " [ 0.1855, -0.0681, -0.0187]],\n", + " \n", + " [[-0.0069, 0.0491, -0.1868],\n", + " [-0.1609, -0.0316, 0.0150],\n", + " [-0.1605, 0.1506, -0.0074]],\n", + " \n", + " [[ 0.0851, -0.1732, -0.1777],\n", + " [ 0.0539, -0.0500, -0.1231],\n", + " [ 0.1654, 0.0342, -0.1904]]],\n", + " \n", + " \n", + " [[[ 0.0476, 0.0284, 0.1212],\n", + " [-0.1603, -0.1924, 0.0144],\n", + " [ 0.0076, -0.0928, -0.1645]],\n", + " \n", + " [[ 0.0215, 0.1845, -0.1034],\n", + " [ 0.1574, -0.1577, -0.0438],\n", + " [-0.1360, -0.0601, -0.1693]],\n", + " \n", + " [[-0.0720, 0.0619, 0.1405],\n", + " [ 0.0699, -0.1288, 0.0041],\n", + " [-0.0381, -0.1697, -0.1568]]],\n", + " \n", + " \n", + " [[[-0.1599, 0.1231, -0.1034],\n", + " [-0.0314, 0.0105, -0.1449],\n", + " [-0.0172, -0.0781, 0.0839]],\n", + " \n", + " [[-0.0676, 0.1185, -0.1559],\n", + " [-0.1053, -0.1306, 0.1820],\n", + " [ 0.1584, -0.1370, 0.1828]],\n", + " \n", + " [[ 0.0658, 0.1412, -0.0537],\n", + " [-0.1230, -0.1411, -0.0011],\n", + " [-0.1318, -0.0458, 0.1838]]],\n", + " \n", + " \n", + " [[[-0.0268, 0.1747, -0.1037],\n", + " [ 0.0515, -0.0228, -0.1024],\n", + " [-0.1543, -0.0643, -0.0100]],\n", + " \n", + " [[-0.1572, -0.1530, 0.0026],\n", + " [ 0.1463, -0.1233, 0.0470],\n", + " [-0.1595, -0.1108, -0.0654]],\n", + " \n", + " [[-0.0521, -0.0094, 0.1544],\n", + " [-0.0505, -0.0332, 0.0048],\n", + " [ 0.0735, 0.1350, 0.0690]]],\n", + " \n", + " \n", + " [[[ 0.0025, 0.0724, 0.0930],\n", + " [-0.1885, 0.0475, 0.1100],\n", + " [-0.1622, 0.0087, -0.0030]],\n", + " \n", + " [[ 0.1032, -0.1425, -0.0620],\n", + " [ 0.1515, -0.0736, -0.1888],\n", + " [-0.1246, 0.1424, -0.0491]],\n", + " \n", + " [[ 0.1759, -0.1616, 0.1198],\n", + " [-0.1103, 0.1032, 0.1727],\n", + " [-0.0601, 0.1635, 0.0034]]],\n", + " \n", + " \n", + " [[[ 0.0301, 0.1517, 0.0657],\n", + " [-0.1368, -0.1165, 0.1193],\n", + " [-0.0962, 0.1451, 0.1099]],\n", + " \n", + " [[ 0.1646, -0.1860, -0.1187],\n", + " [-0.1367, -0.0911, 0.1337],\n", + " [-0.0926, -0.0524, -0.0672]],\n", + " \n", + " [[-0.1509, -0.1231, -0.0855],\n", + " [ 0.1808, -0.0713, 0.0410],\n", + " [-0.0621, -0.0506, 0.1871]]],\n", + " \n", + " \n", + " [[[ 0.0888, -0.0874, -0.0826],\n", + " [ 0.0416, -0.0961, 0.0603],\n", + " [ 0.1455, 0.0050, 0.0318]],\n", + " \n", + " [[-0.1633, 0.0070, -0.1537],\n", + " [-0.0109, 0.1602, -0.0463],\n", + " [-0.0423, -0.0147, -0.1045]],\n", + " \n", + " [[ 0.1640, -0.0997, -0.1662],\n", + " [-0.1074, 0.1549, -0.1905],\n", + " [-0.1708, 0.1624, 0.0219]]],\n", + " \n", + " \n", + " [[[ 0.0824, -0.1376, 0.1086],\n", + " [ 0.0836, 0.0135, 0.0351],\n", + " [-0.1518, 0.0784, -0.1708]],\n", + " \n", + " [[-0.1636, -0.1571, 0.1032],\n", + " [-0.1152, 0.0274, -0.1022],\n", + " [-0.0956, -0.1606, -0.1615]],\n", + " \n", + " [[ 0.1307, 0.0419, 0.1924],\n", + " [-0.0599, -0.1296, -0.0448],\n", + " [ 0.0363, 0.0377, -0.0460]]],\n", + " \n", + " \n", + " [[[-0.1685, -0.1277, -0.0465],\n", + " [ 0.0922, -0.1011, 0.0742],\n", + " [ 0.0053, -0.1456, 0.0135]],\n", + " \n", + " [[ 0.1341, 0.0131, 0.1281],\n", + " [-0.1020, 0.1069, -0.0631],\n", + " [-0.0439, -0.1189, -0.1822]],\n", + " \n", + " [[ 0.1624, -0.1253, 0.0302],\n", + " [ 0.0709, 0.0767, 0.1453],\n", + " [ 0.0203, 0.1603, -0.1720]]],\n", + " \n", + " \n", + " [[[-0.1550, 0.1513, -0.1003],\n", + " [ 0.0370, 0.0367, -0.0233],\n", + " [ 0.0916, -0.0871, 0.1579]],\n", + " \n", + " [[-0.1900, 0.0314, 0.0865],\n", + " [-0.0197, 0.0296, -0.0048],\n", + " [ 0.0846, 0.1543, -0.0770]],\n", + " \n", + " [[-0.0016, -0.0978, 0.1826],\n", + " [-0.0477, 0.0689, 0.1079],\n", + " [ 0.0400, 0.0880, 0.1674]]],\n", + " \n", + " \n", + " [[[ 0.0145, -0.0447, -0.1742],\n", + " [ 0.0394, 0.0127, -0.1172],\n", + " [ 0.1330, -0.1207, 0.0326]],\n", + " \n", + " [[-0.0155, -0.1602, 0.0023],\n", + " [ 0.0789, 0.1648, 0.1781],\n", + " [-0.1468, -0.0481, -0.1260]],\n", + " \n", + " [[-0.0139, 0.0848, -0.0536],\n", + " [-0.1581, 0.1130, 0.0717],\n", + " [ 0.0275, -0.0006, -0.0049]]],\n", + " \n", + " \n", + " [[[-0.0199, 0.0032, -0.1246],\n", + " [ 0.0479, 0.1418, -0.1295],\n", + " [-0.1646, -0.1139, -0.1018]],\n", + " \n", + " [[ 0.1475, 0.1413, -0.0354],\n", + " [ 0.0612, -0.1652, 0.0801],\n", + " [-0.1306, -0.0165, 0.1733]],\n", + " \n", + " [[ 0.1527, 0.0911, -0.1906],\n", + " [-0.1152, 0.1737, 0.0436],\n", + " [-0.0213, -0.0314, -0.0319]]],\n", + " \n", + " \n", + " [[[-0.0003, -0.0546, -0.1255],\n", + " [ 0.0914, -0.1414, 0.0542],\n", + " [ 0.1139, 0.0132, 0.0815]],\n", + " \n", + " [[-0.0042, 0.0541, 0.1456],\n", + " [ 0.0509, -0.0790, 0.0272],\n", + " [ 0.1419, 0.0992, -0.1448]],\n", + " \n", + " [[ 0.0496, 0.0013, 0.0838],\n", + " [-0.0662, 0.0315, -0.1168],\n", + " [-0.0069, -0.1503, 0.0729]]],\n", + " \n", + " \n", + " [[[ 0.1866, 0.1329, -0.0560],\n", + " [ 0.0026, 0.1533, 0.0326],\n", + " [-0.1161, -0.0323, 0.0053]],\n", + " \n", + " [[-0.0243, -0.1823, -0.1657],\n", + " [-0.0107, -0.0832, 0.0029],\n", + " [ 0.0981, 0.1241, -0.1788]],\n", + " \n", + " [[-0.0400, -0.0577, -0.0757],\n", + " [-0.0584, 0.0176, -0.1019],\n", + " [-0.1828, 0.1589, -0.0312]]],\n", + " \n", + " \n", + " [[[-0.1083, -0.1236, -0.0904],\n", + " [-0.1575, 0.0157, 0.0552],\n", + " [-0.0839, 0.1704, -0.1457]],\n", + " \n", + " [[-0.1648, -0.0270, -0.0489],\n", + " [-0.1122, -0.0288, -0.0073],\n", + " [-0.1443, -0.1712, 0.0100]],\n", + " \n", + " [[-0.1142, -0.1552, 0.1568],\n", + " [ 0.0743, -0.1108, -0.0643],\n", + " [-0.0394, -0.1345, 0.0992]]],\n", + " \n", + " \n", + " [[[-0.1591, 0.0942, -0.1035],\n", + " [-0.0781, 0.0725, -0.0888],\n", + " [ 0.0959, 0.0213, 0.1222]],\n", + " \n", + " [[ 0.1202, -0.0217, -0.0955],\n", + " [-0.1748, -0.1133, -0.0704],\n", + " [-0.0670, -0.1401, 0.1553]],\n", + " \n", + " [[ 0.0053, -0.0871, -0.0239],\n", + " [ 0.0961, -0.0547, 0.1741],\n", + " [-0.0570, 0.0477, 0.1853]]],\n", + " \n", + " \n", + " [[[-0.1115, -0.0183, -0.1302],\n", + " [ 0.1435, -0.0238, -0.0048],\n", + " [ 0.1862, -0.1837, 0.1711]],\n", + " \n", + " [[ 0.1375, -0.1798, 0.0818],\n", + " [-0.0792, 0.0820, 0.1373],\n", + " [ 0.1849, 0.0672, -0.1822]],\n", + " \n", + " [[ 0.1868, -0.0356, 0.0726],\n", + " [-0.1523, -0.1130, 0.1506],\n", + " [-0.1046, 0.0178, 0.0990]]],\n", + " \n", + " \n", + " [[[ 0.1321, -0.1641, 0.0411],\n", + " [ 0.0526, 0.0393, 0.0918],\n", + " [-0.1345, 0.0750, 0.0859]],\n", + " \n", + " [[-0.0985, 0.1466, 0.1349],\n", + " [-0.1461, -0.1742, 0.0941],\n", + " [-0.1502, -0.1813, 0.0864]],\n", + " \n", + " [[-0.1039, 0.1179, 0.1499],\n", + " [-0.0366, -0.0120, 0.0951],\n", + " [ 0.0087, 0.1212, -0.0183]]],\n", + " \n", + " \n", + " [[[-0.1375, 0.0765, -0.0072],\n", + " [-0.0041, 0.0379, -0.0243],\n", + " [-0.1495, 0.1601, 0.1575]],\n", + " \n", + " [[-0.0454, 0.1642, 0.0720],\n", + " [-0.0533, 0.0150, 0.0039],\n", + " [ 0.0194, 0.0113, -0.1194]],\n", + " \n", + " [[ 0.0527, -0.0886, 0.0359],\n", + " [ 0.1595, 0.0526, -0.0048],\n", + " [-0.1790, -0.0458, -0.0324]]],\n", + " \n", + " \n", + " [[[-0.1638, 0.0942, 0.0686],\n", + " [-0.1082, -0.0675, 0.1892],\n", + " [-0.1347, -0.1247, 0.0739]],\n", + " \n", + " [[ 0.0595, 0.1504, -0.1657],\n", + " [ 0.0733, 0.0529, -0.1599],\n", + " [ 0.0171, -0.1127, -0.0259]],\n", + " \n", + " [[-0.0092, 0.0193, 0.1176],\n", + " [-0.1183, 0.0101, 0.1011],\n", + " [ 0.0648, -0.1897, 0.0782]]],\n", + " \n", + " \n", + " [[[ 0.0098, -0.1161, -0.0802],\n", + " [-0.1821, 0.0221, -0.1754],\n", + " [-0.1218, 0.0525, -0.0480]],\n", + " \n", + " [[ 0.0770, 0.0477, 0.1514],\n", + " [ 0.0374, -0.1075, -0.1026],\n", + " [-0.0581, -0.1011, 0.1241]],\n", + " \n", + " [[-0.0567, -0.0163, 0.0374],\n", + " [-0.1739, -0.0579, 0.0704],\n", + " [ 0.1817, 0.1561, 0.1677]]],\n", + " \n", + " \n", + " [[[-0.0569, -0.0763, 0.0044],\n", + " [-0.1133, 0.0813, 0.1477],\n", + " [ 0.0836, 0.0483, -0.1800]],\n", + " \n", + " [[ 0.1343, -0.1590, 0.1177],\n", + " [ 0.1071, -0.1647, -0.0646],\n", + " [ 0.1578, -0.1261, 0.0243]],\n", + " \n", + " [[-0.0424, -0.0241, -0.0988],\n", + " [ 0.0023, 0.0029, -0.0291],\n", + " [ 0.0415, -0.0557, 0.1427]]],\n", + " \n", + " \n", + " [[[-0.1028, 0.1054, 0.1658],\n", + " [-0.0357, 0.1579, 0.1237],\n", + " [ 0.0368, 0.0532, -0.1043]],\n", + " \n", + " [[-0.0369, 0.0575, -0.1023],\n", + " [ 0.0635, 0.1015, 0.1112],\n", + " [-0.1235, 0.0467, 0.0908]],\n", + " \n", + " [[ 0.1380, 0.0633, 0.1087],\n", + " [-0.1360, 0.0422, -0.1524],\n", + " [ 0.0819, 0.0918, -0.1624]]],\n", + " \n", + " \n", + " [[[ 0.1584, -0.0218, -0.0236],\n", + " [ 0.1878, -0.1289, 0.1343],\n", + " [ 0.0351, 0.1225, -0.1460]],\n", + " \n", + " [[ 0.0690, -0.1439, 0.0056],\n", + " [ 0.0272, -0.0058, 0.0125],\n", + " [ 0.0868, -0.0684, -0.0884]],\n", + " \n", + " [[ 0.1045, 0.0583, -0.0870],\n", + " [ 0.0600, -0.0732, -0.1695],\n", + " [ 0.0953, 0.0246, 0.1245]]]])),\n", + " ('conv1.bias',\n", + " tensor([-1.8698e-01, -7.9379e-06, -1.9277e-02, 5.2182e-02, 7.5716e-02,\n", + " -3.3830e-03, -9.6565e-02, 1.0241e-01, -8.2457e-02, -1.6224e-01,\n", + " 1.2980e-01, -8.2256e-02, -7.4655e-02, -3.7980e-02, 8.3407e-02,\n", + " -1.4880e-01, 4.8939e-02, 2.7506e-02, 5.8676e-03, -1.5813e-01,\n", + " -6.2464e-04, 1.0359e-02, -1.5525e-01, 7.9100e-02, 1.6850e-02,\n", + " -1.3809e-01, -6.3393e-02, -5.3843e-02, -1.5219e-02, -1.7365e-01,\n", + " 1.7249e-01, -1.1165e-01])),\n", + " ('conv2.weight',\n", + " tensor([[[[ 4.3750e-02, 4.5533e-02, -2.9410e-02],\n", + " [-4.1395e-02, 5.0397e-04, -1.3265e-02],\n", + " [-4.9851e-02, -1.0518e-02, 5.7710e-02]],\n", + " \n", + " [[-5.6332e-02, -4.7168e-03, -4.4627e-02],\n", + " [ 5.3513e-03, -4.0824e-02, 1.8281e-02],\n", + " [ 5.0677e-02, -1.5295e-02, -6.1751e-03]],\n", + " \n", + " [[-2.4984e-02, 1.2784e-02, -4.7123e-02],\n", + " [-4.3238e-02, 4.7349e-02, -1.5219e-02],\n", + " [-3.6073e-02, 4.1506e-02, -3.5337e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 5.6048e-02, 1.9129e-03, -1.7200e-02],\n", + " [ 5.8869e-02, -5.1520e-02, -5.3205e-02],\n", + " [-1.3903e-02, 5.1790e-02, 2.2585e-02]],\n", + " \n", + " [[ 1.1835e-02, -4.9313e-02, -3.1838e-02],\n", + " [ 7.6813e-03, 4.2715e-02, -5.7404e-02],\n", + " [-4.1474e-02, -2.3128e-02, -4.7935e-02]],\n", + " \n", + " [[-2.1860e-02, -2.1817e-02, -3.2578e-02],\n", + " [ 3.1317e-02, 3.3435e-02, 3.1837e-02],\n", + " [-2.2399e-03, 3.1600e-02, 4.0183e-02]]],\n", + " \n", + " \n", + " [[[ 1.9610e-04, 5.3780e-02, -4.5810e-02],\n", + " [ 4.0340e-02, 1.4904e-02, -1.5597e-02],\n", + " [-4.6080e-02, 5.0714e-02, -5.7445e-03]],\n", + " \n", + " [[-3.5281e-02, 3.3011e-02, 4.3343e-02],\n", + " [-4.6263e-02, -5.6184e-02, 5.1245e-03],\n", + " [ 3.6015e-02, -3.3152e-02, 4.6629e-03]],\n", + " \n", + " [[ 1.7650e-03, -4.2336e-02, 4.3744e-02],\n", + " [ 2.1655e-02, 5.3759e-02, 1.3719e-03],\n", + " [ 4.2005e-02, 5.3998e-02, 1.9009e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 8.9786e-03, -1.8645e-02, -1.3587e-02],\n", + " [-5.4167e-02, 1.2335e-02, -3.0384e-02],\n", + " [-4.8722e-03, -3.7296e-02, -2.6446e-02]],\n", + " \n", + " [[ 1.7580e-02, 3.8462e-02, -5.0269e-02],\n", + " [ 2.6601e-03, -1.1462e-02, 4.7459e-02],\n", + " [-2.8888e-02, 3.4436e-02, -4.9943e-02]],\n", + " \n", + " [[-5.0206e-02, -5.6025e-02, -3.6346e-02],\n", + " [-2.4407e-02, 5.3721e-02, -5.4920e-02],\n", + " [ 5.1835e-02, -3.2396e-02, 3.2373e-02]]],\n", + " \n", + " \n", + " [[[-2.7759e-02, -5.4737e-02, -1.1689e-02],\n", + " [ 3.9462e-02, 2.8649e-02, 5.1776e-02],\n", + " [ 2.4253e-02, -2.8318e-02, 2.7402e-02]],\n", + " \n", + " [[ 1.3045e-02, -1.0456e-02, 2.0426e-02],\n", + " [ 2.1949e-02, 4.6817e-02, -5.6093e-02],\n", + " [ 2.7145e-02, -5.5441e-02, -2.0719e-02]],\n", + " \n", + " [[ 4.4704e-02, -2.4099e-02, -4.7185e-02],\n", + " [-4.3257e-02, -3.3058e-02, -8.6451e-03],\n", + " [-3.7283e-02, -3.4569e-02, -7.1049e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-4.1559e-02, -2.9240e-02, 2.7197e-03],\n", + " [ 2.0770e-02, 5.4479e-02, -4.4845e-02],\n", + " [-1.1641e-02, -2.9814e-02, -2.4419e-02]],\n", + " \n", + " [[-1.5743e-02, 1.0854e-02, 3.0878e-02],\n", + " [ 2.2739e-02, 3.2999e-02, -1.1902e-02],\n", + " [-3.4837e-02, 1.5305e-02, -8.7552e-03]],\n", + " \n", + " [[-2.2882e-02, 9.4639e-03, 5.1878e-03],\n", + " [-2.6344e-02, 2.9063e-02, -1.9337e-02],\n", + " [-3.4314e-02, 1.5313e-02, 4.1524e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 9.7454e-03, -3.2903e-03, 1.0696e-02],\n", + " [-4.0918e-02, 1.6352e-02, 1.4646e-02],\n", + " [ 1.2516e-02, -2.1804e-02, -2.5489e-02]],\n", + " \n", + " [[-1.6083e-02, 2.5374e-02, 3.1458e-02],\n", + " [-3.1497e-02, -1.9513e-02, -2.1223e-02],\n", + " [ 6.6286e-03, 1.6538e-02, -4.8944e-02]],\n", + " \n", + " [[ 2.4808e-02, -2.9520e-02, -4.8227e-02],\n", + " [ 1.7325e-03, -4.7443e-02, 2.3087e-03],\n", + " [-1.0008e-02, -2.0313e-02, 2.9944e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 5.5781e-02, -2.0898e-02, -3.1487e-03],\n", + " [-1.6931e-02, 4.5279e-04, -1.5024e-02],\n", + " [-5.5885e-02, 2.7140e-02, -8.5434e-03]],\n", + " \n", + " [[ 1.3970e-02, -3.3131e-02, 4.3112e-02],\n", + " [-3.4956e-02, -5.0144e-02, -1.6391e-02],\n", + " [-9.1003e-03, -2.0204e-02, -1.0226e-03]],\n", + " \n", + " [[-4.0053e-02, -5.0194e-02, 5.0405e-02],\n", + " [ 5.4107e-02, 4.2185e-02, 3.4359e-02],\n", + " [ 1.6749e-02, -1.4102e-02, 5.0171e-02]]],\n", + " \n", + " \n", + " [[[ 2.8229e-02, 5.6586e-02, -3.9617e-03],\n", + " [ 2.9538e-02, -1.2507e-02, -2.5516e-02],\n", + " [-1.5193e-02, -2.9232e-02, -2.0701e-02]],\n", + " \n", + " [[-5.8773e-02, 3.3015e-02, -9.4146e-03],\n", + " [ 2.8957e-02, 5.8666e-02, 2.8679e-02],\n", + " [ 1.5249e-02, -1.2246e-03, 1.2230e-03]],\n", + " \n", + " [[ 2.6050e-02, -4.6042e-02, -3.4895e-03],\n", + " [ 4.9529e-02, 6.6835e-03, -4.1808e-02],\n", + " [-8.6450e-03, -4.8510e-02, -2.4011e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-5.1427e-02, 2.4528e-02, -5.4878e-02],\n", + " [-1.8610e-02, 5.4365e-02, 3.5053e-03],\n", + " [-3.9922e-02, 4.2510e-02, -5.7261e-02]],\n", + " \n", + " [[ 4.1938e-02, -4.2039e-02, -1.2487e-02],\n", + " [-1.4090e-02, -3.7895e-02, 1.4394e-02],\n", + " [ 2.2555e-02, -2.7264e-02, 5.6102e-02]],\n", + " \n", + " [[ 1.5770e-02, 5.4672e-02, -2.4056e-02],\n", + " [ 5.2089e-02, -2.8859e-02, -2.6499e-03],\n", + " [-5.2122e-02, -3.7436e-02, 3.9897e-02]]],\n", + " \n", + " \n", + " [[[ 2.7888e-02, 2.9241e-02, -1.9488e-02],\n", + " [ 2.8928e-02, 5.3312e-02, -2.9810e-02],\n", + " [-8.5104e-03, 5.7751e-02, -8.1857e-03]],\n", + " \n", + " [[ 4.2649e-02, 3.3158e-03, 4.2879e-02],\n", + " [ 7.7893e-03, -3.2879e-02, 2.7630e-02],\n", + " [ 5.4706e-03, 4.8019e-02, 1.2420e-02]],\n", + " \n", + " [[-4.2004e-02, -4.2790e-02, 2.4634e-02],\n", + " [-5.4641e-02, 3.4600e-02, 2.9071e-03],\n", + " [ 2.6470e-02, 4.6701e-02, 3.7158e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.7641e-02, -2.1205e-02, -5.1504e-02],\n", + " [-7.4737e-03, 5.5061e-02, -2.6397e-03],\n", + " [-4.4653e-02, -3.6719e-02, 3.4420e-06]],\n", + " \n", + " [[ 1.6525e-02, 1.7280e-02, 5.4554e-03],\n", + " [ 4.0098e-02, 2.7571e-02, -4.4965e-02],\n", + " [ 6.1493e-03, -5.7754e-02, 1.0513e-02]],\n", + " \n", + " [[-5.7615e-02, 3.2921e-02, -1.5900e-02],\n", + " [ 2.0081e-02, 5.4590e-02, 1.1296e-02],\n", + " [-4.5015e-02, 1.1341e-03, 2.6447e-02]]]])),\n", + " ('conv2.bias',\n", + " tensor([-0.0538, -0.0320, -0.0153, 0.0558, 0.0254, 0.0281, -0.0148, 0.0060,\n", + " -0.0283, -0.0062, 0.0437, -0.0064, 0.0341, 0.0233, -0.0201, 0.0391,\n", + " 0.0243, 0.0071, 0.0125, -0.0138, -0.0377, -0.0169, -0.0475, -0.0004,\n", + " -0.0105, -0.0502, 0.0241, 0.0090, 0.0069, -0.0315, -0.0192, 0.0204])),\n", + " ('conv3.weight',\n", + " tensor([[[[-0.0215, -0.0208, -0.0272],\n", + " [-0.0493, -0.0117, -0.0285],\n", + " [ 0.0515, -0.0041, 0.0126]],\n", + " \n", + " [[ 0.0299, -0.0301, 0.0552],\n", + " [ 0.0450, 0.0449, -0.0583],\n", + " [-0.0452, -0.0480, -0.0275]],\n", + " \n", + " [[-0.0262, -0.0338, 0.0505],\n", + " [ 0.0146, -0.0364, -0.0044],\n", + " [-0.0102, -0.0051, 0.0017]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0367, -0.0468, -0.0586],\n", + " [ 0.0126, 0.0037, 0.0191],\n", + " [-0.0153, 0.0048, -0.0160]],\n", + " \n", + " [[-0.0050, 0.0364, 0.0582],\n", + " [ 0.0093, -0.0268, -0.0355],\n", + " [-0.0125, 0.0500, 0.0009]],\n", + " \n", + " [[ 0.0237, -0.0211, -0.0130],\n", + " [-0.0489, 0.0118, 0.0387],\n", + " [-0.0006, 0.0301, 0.0283]]],\n", + " \n", + " \n", + " [[[-0.0391, -0.0464, -0.0158],\n", + " [ 0.0201, -0.0054, 0.0422],\n", + " [ 0.0085, -0.0474, -0.0251]],\n", + " \n", + " [[-0.0346, 0.0536, -0.0391],\n", + " [ 0.0244, -0.0263, -0.0073],\n", + " [ 0.0076, 0.0160, 0.0044]],\n", + " \n", + " [[-0.0128, 0.0146, -0.0381],\n", + " [-0.0277, -0.0142, 0.0226],\n", + " [ 0.0190, 0.0326, -0.0219]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0217, 0.0129, 0.0558],\n", + " [ 0.0164, -0.0292, -0.0467],\n", + " [-0.0296, 0.0205, -0.0300]],\n", + " \n", + " [[ 0.0254, -0.0151, -0.0583],\n", + " [ 0.0111, -0.0469, -0.0300],\n", + " [-0.0462, 0.0293, -0.0351]],\n", + " \n", + " [[ 0.0401, -0.0251, 0.0160],\n", + " [-0.0160, -0.0195, -0.0065],\n", + " [-0.0519, 0.0351, 0.0357]]],\n", + " \n", + " \n", + " [[[ 0.0544, -0.0209, -0.0454],\n", + " [ 0.0287, -0.0205, -0.0294],\n", + " [-0.0195, -0.0235, -0.0378]],\n", + " \n", + " [[-0.0294, -0.0380, 0.0301],\n", + " [ 0.0360, 0.0367, 0.0458],\n", + " [-0.0189, -0.0017, 0.0145]],\n", + " \n", + " [[-0.0297, 0.0567, 0.0276],\n", + " [ 0.0298, 0.0383, 0.0227],\n", + " [ 0.0262, 0.0063, 0.0131]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0002, 0.0432, -0.0247],\n", + " [ 0.0068, -0.0298, -0.0484],\n", + " [-0.0361, -0.0014, 0.0444]],\n", + " \n", + " [[-0.0184, -0.0201, -0.0163],\n", + " [-0.0466, 0.0255, -0.0244],\n", + " [ 0.0283, 0.0149, -0.0588]],\n", + " \n", + " [[ 0.0323, 0.0392, -0.0254],\n", + " [ 0.0560, 0.0137, -0.0401],\n", + " [-0.0236, 0.0589, 0.0448]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 0.0213, 0.0204, 0.0574],\n", + " [-0.0276, -0.0196, 0.0117],\n", + " [ 0.0569, -0.0158, -0.0502]],\n", + " \n", + " [[ 0.0452, -0.0038, 0.0502],\n", + " [ 0.0428, -0.0398, -0.0486],\n", + " [ 0.0130, 0.0563, 0.0576]],\n", + " \n", + " [[ 0.0484, -0.0535, 0.0048],\n", + " [ 0.0268, -0.0290, -0.0390],\n", + " [ 0.0189, -0.0194, -0.0588]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0163, -0.0113, -0.0520],\n", + " [ 0.0288, -0.0547, -0.0544],\n", + " [ 0.0442, 0.0376, 0.0566]],\n", + " \n", + " [[-0.0343, 0.0569, 0.0438],\n", + " [-0.0403, -0.0372, -0.0532],\n", + " [ 0.0322, 0.0126, 0.0423]],\n", + " \n", + " [[ 0.0577, 0.0136, -0.0480],\n", + " [-0.0293, -0.0348, 0.0342],\n", + " [-0.0510, -0.0078, -0.0042]]],\n", + " \n", + " \n", + " [[[-0.0243, 0.0406, 0.0537],\n", + " [ 0.0209, -0.0059, -0.0487],\n", + " [-0.0425, 0.0339, 0.0444]],\n", + " \n", + " [[ 0.0465, -0.0467, 0.0461],\n", + " [-0.0389, 0.0144, -0.0502],\n", + " [ 0.0274, 0.0552, 0.0356]],\n", + " \n", + " [[-0.0289, 0.0474, -0.0217],\n", + " [ 0.0472, -0.0135, 0.0164],\n", + " [-0.0165, -0.0049, -0.0475]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0377, 0.0267, 0.0367],\n", + " [ 0.0111, 0.0114, -0.0329],\n", + " [ 0.0031, -0.0223, -0.0280]],\n", + " \n", + " [[-0.0500, -0.0529, 0.0116],\n", + " [ 0.0483, 0.0121, -0.0149],\n", + " [ 0.0328, 0.0201, 0.0402]],\n", + " \n", + " [[ 0.0463, 0.0157, -0.0332],\n", + " [ 0.0150, 0.0479, 0.0461],\n", + " [ 0.0275, 0.0506, -0.0466]]],\n", + " \n", + " \n", + " [[[-0.0478, 0.0274, 0.0500],\n", + " [-0.0394, 0.0032, -0.0496],\n", + " [ 0.0381, 0.0391, 0.0330]],\n", + " \n", + " [[ 0.0184, -0.0560, -0.0345],\n", + " [-0.0459, -0.0215, 0.0452],\n", + " [ 0.0049, 0.0537, 0.0544]],\n", + " \n", + " [[-0.0413, -0.0084, 0.0585],\n", + " [ 0.0338, -0.0067, -0.0113],\n", + " [-0.0187, -0.0234, -0.0525]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0389, 0.0325, -0.0538],\n", + " [ 0.0118, 0.0509, 0.0352],\n", + " [-0.0351, -0.0341, -0.0506]],\n", + " \n", + " [[ 0.0136, -0.0349, 0.0082],\n", + " [ 0.0358, -0.0211, 0.0537],\n", + " [-0.0183, 0.0390, -0.0267]],\n", + " \n", + " [[-0.0219, -0.0145, -0.0351],\n", + " [ 0.0556, 0.0033, -0.0030],\n", + " [ 0.0075, -0.0425, -0.0365]]]])),\n", + " ('conv3.bias',\n", + " tensor([ 0.0304, 0.0099, -0.0004, 0.0334, 0.0301, 0.0491, 0.0530, -0.0432,\n", + " -0.0127, -0.0549, -0.0419, 0.0159, -0.0284, 0.0295, -0.0148, 0.0275,\n", + " 0.0554, -0.0056, 0.0389, -0.0264, -0.0383, 0.0126, 0.0320, 0.0312,\n", + " 0.0018, 0.0560, -0.0329, -0.0155, -0.0391, -0.0539, -0.0571, -0.0254])),\n", + " ('conv4.weight',\n", + " tensor([[[[-3.8911e-02, -3.4220e-02, 4.2567e-03],\n", + " [-4.5321e-02, -5.2531e-02, -8.1722e-03],\n", + " [-2.2638e-02, 4.4213e-02, 5.6989e-02]],\n", + " \n", + " [[ 1.8417e-03, -1.4453e-02, 4.9892e-02],\n", + " [ 5.7762e-02, 9.6610e-03, -3.9509e-02],\n", + " [ 3.3795e-02, 5.0409e-02, -5.8834e-02]],\n", + " \n", + " [[ 4.6645e-03, -1.6286e-02, 4.3410e-02],\n", + " [-3.4043e-02, -2.2207e-02, 4.0967e-02],\n", + " [ 5.3004e-02, -2.2756e-02, -6.7993e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.1741e-02, -5.5062e-02, -3.3625e-02],\n", + " [-9.2320e-03, -3.3036e-02, 3.3196e-02],\n", + " [ 2.3940e-02, 2.0442e-02, 1.4183e-02]],\n", + " \n", + " [[-2.7139e-02, -3.4129e-03, -1.0090e-02],\n", + " [ 1.3073e-02, -1.6998e-02, -4.7540e-02],\n", + " [-2.5758e-02, -1.9363e-02, 2.1905e-02]],\n", + " \n", + " [[ 5.7593e-02, 1.5013e-02, -5.7894e-02],\n", + " [ 5.7964e-02, -2.3412e-02, 2.6955e-02],\n", + " [-3.9814e-02, -4.6015e-02, -5.3240e-02]]],\n", + " \n", + " \n", + " [[[ 5.8211e-02, -4.1118e-02, 2.7704e-02],\n", + " [ 5.7198e-02, 8.4165e-03, -5.1708e-02],\n", + " [ 3.1423e-02, 1.5026e-02, 3.5922e-02]],\n", + " \n", + " [[ 8.8858e-03, 3.2818e-02, 5.4486e-02],\n", + " [-2.6636e-02, 2.2604e-02, 2.9531e-02],\n", + " [-1.0327e-03, 2.2348e-03, 2.4103e-02]],\n", + " \n", + " [[ 3.8683e-02, -5.0057e-03, 5.0224e-02],\n", + " [ 3.5756e-02, -2.7295e-02, -2.2854e-02],\n", + " [-3.2043e-02, -3.2415e-02, 4.1034e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.9791e-02, 4.3243e-02, -3.5177e-02],\n", + " [ 2.4554e-02, 4.2845e-03, 4.8009e-02],\n", + " [ 2.4897e-03, 3.9550e-02, -3.0833e-02]],\n", + " \n", + " [[ 4.5807e-02, 6.5845e-03, 9.3362e-05],\n", + " [-1.9411e-02, -2.9161e-02, 5.0828e-02],\n", + " [ 1.2028e-03, 2.1260e-02, -4.3710e-03]],\n", + " \n", + " [[-4.8702e-02, -2.0571e-02, -3.5162e-02],\n", + " [-2.5856e-02, -2.6619e-02, 8.1867e-03],\n", + " [-2.7671e-02, -9.6651e-03, -5.3279e-02]]],\n", + " \n", + " \n", + " [[[-5.6432e-02, 2.3722e-02, -2.1750e-02],\n", + " [-4.8247e-03, -2.1226e-02, -1.0829e-02],\n", + " [-1.9523e-02, -1.8187e-02, 2.2772e-03]],\n", + " \n", + " [[-1.0907e-02, 3.3984e-02, -1.2088e-02],\n", + " [-1.8657e-02, -4.8297e-02, 2.3614e-02],\n", + " [-3.9670e-02, 6.1733e-03, -2.9168e-02]],\n", + " \n", + " [[-4.2112e-02, 2.8203e-02, -1.7385e-03],\n", + " [-2.5282e-02, -9.4592e-05, 6.5093e-03],\n", + " [-4.1745e-02, 4.3988e-03, -1.1622e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 4.2991e-02, 1.6912e-02, -4.3689e-02],\n", + " [ 5.1871e-02, 4.8566e-02, 3.6205e-02],\n", + " [-3.2016e-02, -1.3596e-02, -2.7950e-02]],\n", + " \n", + " [[-2.8307e-02, -4.0278e-02, 1.5087e-02],\n", + " [ 4.0443e-02, -3.5727e-02, 3.7196e-02],\n", + " [-1.4194e-02, -2.7319e-02, -5.1305e-02]],\n", + " \n", + " [[-2.9962e-02, 2.4693e-02, -4.4912e-02],\n", + " [ 5.5890e-03, 4.6671e-02, 3.3599e-02],\n", + " [-3.9949e-02, -4.4716e-02, 2.2345e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 3.1920e-02, -4.9932e-02, -1.0871e-02],\n", + " [-3.7500e-02, 4.1638e-02, -1.3246e-02],\n", + " [ 1.6447e-02, -5.6741e-02, -3.7524e-02]],\n", + " \n", + " [[ 3.3903e-02, -3.1321e-02, -4.4877e-02],\n", + " [-2.2473e-02, -2.4225e-02, 4.5838e-02],\n", + " [-2.0069e-02, 3.8338e-02, 5.8010e-02]],\n", + " \n", + " [[ 1.7602e-02, -5.2530e-02, 4.9331e-02],\n", + " [ 2.4509e-02, 2.3943e-02, -2.1774e-02],\n", + " [-5.7154e-02, 5.7090e-02, 3.7531e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 2.8630e-02, -4.8644e-04, -5.3822e-02],\n", + " [-1.1102e-02, 4.8524e-02, -2.7142e-02],\n", + " [-5.3463e-02, 3.5607e-02, -1.1110e-02]],\n", + " \n", + " [[ 4.7891e-02, -3.4098e-02, 3.6984e-02],\n", + " [ 3.9062e-02, -5.1119e-03, -3.3252e-02],\n", + " [ 5.5029e-02, 3.1092e-03, -5.9391e-03]],\n", + " \n", + " [[ 9.6775e-03, 2.2903e-02, -1.5971e-02],\n", + " [-2.6969e-02, 8.5069e-04, -2.5744e-02],\n", + " [ 2.7311e-03, -2.2119e-02, 2.4367e-03]]],\n", + " \n", + " \n", + " [[[ 1.5206e-02, -5.3504e-02, -1.8013e-02],\n", + " [-3.0064e-02, 3.2887e-02, 1.7612e-02],\n", + " [-4.2775e-02, -2.7335e-02, 5.1532e-02]],\n", + " \n", + " [[ 4.3325e-02, 1.2909e-03, 5.6831e-02],\n", + " [ 4.8283e-03, -4.4274e-02, 1.7624e-02],\n", + " [-7.2574e-03, 1.3743e-02, -5.1502e-02]],\n", + " \n", + " [[-1.4299e-02, -3.0024e-02, 4.9578e-02],\n", + " [-4.6143e-02, -3.0686e-02, -5.0727e-02],\n", + " [ 5.4766e-02, -1.4012e-02, 3.3267e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 5.2773e-02, 4.0294e-02, -2.6113e-02],\n", + " [-2.5069e-02, 4.3956e-02, -4.7841e-02],\n", + " [ 3.0924e-02, 1.5174e-02, -4.8323e-02]],\n", + " \n", + " [[-2.0325e-02, 3.2666e-02, 2.5174e-02],\n", + " [ 5.3775e-03, -3.2712e-02, -5.2251e-02],\n", + " [-2.7426e-02, -4.5502e-04, -3.2174e-02]],\n", + " \n", + " [[ 8.5780e-03, -5.2099e-02, 5.7285e-02],\n", + " [-5.0897e-02, -5.3995e-03, 4.2719e-02],\n", + " [-5.2257e-02, -7.9682e-03, 2.1848e-02]]],\n", + " \n", + " \n", + " [[[-2.8660e-02, 1.2392e-02, 2.9940e-02],\n", + " [-1.1170e-02, -1.7499e-02, -5.3951e-02],\n", + " [ 4.9738e-02, 2.7240e-02, 2.9588e-03]],\n", + " \n", + " [[ 3.7728e-02, -3.1084e-02, 4.6459e-02],\n", + " [-1.4961e-02, 1.6951e-02, -1.3976e-02],\n", + " [ 2.5768e-02, -3.8991e-02, -2.8738e-02]],\n", + " \n", + " [[ 1.6084e-02, -7.3174e-03, -4.5839e-02],\n", + " [-3.6029e-02, 1.5303e-02, -5.4380e-02],\n", + " [ 2.1913e-02, -4.4792e-02, 4.7973e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 5.4945e-02, -4.5541e-02, -1.5806e-02],\n", + " [-4.5216e-02, -2.0338e-02, 3.3373e-02],\n", + " [-1.8431e-02, 4.3953e-02, -4.8196e-02]],\n", + " \n", + " [[-3.6129e-02, 5.7705e-02, -1.2229e-02],\n", + " [-5.2801e-02, 7.0930e-03, -2.6721e-02],\n", + " [ 3.6720e-02, 7.5540e-04, 5.5401e-02]],\n", + " \n", + " [[-4.0386e-02, 1.6714e-02, 2.9246e-02],\n", + " [ 4.7033e-02, -8.7923e-03, 5.1267e-02],\n", + " [-3.1211e-02, -2.7036e-02, 4.0346e-02]]]])),\n", + " ('conv4.bias',\n", + " tensor([-0.0368, -0.0535, -0.0568, -0.0364, 0.0167, 0.0255, 0.0031, 0.0562,\n", + " -0.0575, 0.0470, -0.0128, 0.0159, -0.0387, 0.0517, -0.0508, -0.0104,\n", + " 0.0133, -0.0265, 0.0290, 0.0147, -0.0405, 0.0074, 0.0434, -0.0436,\n", + " -0.0295, 0.0417, 0.0324, 0.0316, 0.0143, -0.0312, 0.0463, 0.0441])),\n", + " ('fc1.weight',\n", + " tensor([[-0.0206, -0.0161, -0.0053, ..., -0.0257, -0.0314, -0.0105],\n", + " [-0.0229, 0.0030, 0.0239, ..., 0.0199, -0.0279, 0.0155]])),\n", + " ('fc1.bias', tensor([-0.0333, -0.0126]))])" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m1.state_dict()\n" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "OrderedDict([('conv1.weight',\n", + " tensor([[[[-0.1261, 0.1833, -0.1406],\n", + " [ 0.1324, -0.0685, 0.0938],\n", + " [ 0.0432, 0.1814, -0.0541]],\n", + " \n", + " [[-0.1776, -0.1839, -0.0111],\n", + " [ 0.0888, 0.0888, -0.1344],\n", + " [-0.1838, 0.1737, 0.1584]],\n", + " \n", + " [[ 0.0417, 0.1064, -0.0156],\n", + " [ 0.0667, 0.0856, -0.1746],\n", + " [ 0.0412, 0.1620, 0.0125]]],\n", + " \n", + " \n", + " [[[-0.0530, -0.1273, -0.0797],\n", + " [ 0.0422, 0.1135, 0.0475],\n", + " [-0.0244, 0.1691, -0.1383]],\n", + " \n", + " [[ 0.0822, -0.1317, -0.1692],\n", + " [ 0.1373, 0.1388, 0.0103],\n", + " [-0.0481, 0.1105, 0.0631]],\n", + " \n", + " [[-0.0352, 0.1259, -0.0530],\n", + " [-0.1394, -0.0281, 0.1844],\n", + " [ 0.0082, 0.1187, 0.0211]]],\n", + " \n", + " \n", + " [[[ 0.0987, 0.0788, -0.1126],\n", + " [ 0.1769, 0.0763, -0.1767],\n", + " [-0.0570, 0.1156, 0.1770]],\n", + " \n", + " [[ 0.0643, -0.0024, -0.0625],\n", + " [ 0.0819, 0.0140, -0.1882],\n", + " [ 0.1325, -0.0632, -0.0202]],\n", + " \n", + " [[ 0.0053, 0.1042, -0.0058],\n", + " [-0.1082, -0.1753, 0.1762],\n", + " [-0.0501, 0.1166, 0.0561]]],\n", + " \n", + " \n", + " [[[ 0.0358, -0.0685, -0.1278],\n", + " [ 0.0029, -0.1107, 0.1169],\n", + " [-0.1408, 0.1293, 0.1142]],\n", + " \n", + " [[-0.0814, 0.0470, 0.0188],\n", + " [ 0.1538, 0.0137, 0.1128],\n", + " [-0.1597, 0.1432, 0.1370]],\n", + " \n", + " [[ 0.1425, 0.1769, -0.0037],\n", + " [-0.1080, -0.0805, -0.0195],\n", + " [-0.1335, -0.1666, 0.1399]]],\n", + " \n", + " \n", + " [[[ 0.1117, 0.1918, -0.1666],\n", + " [-0.1392, 0.0086, 0.0172],\n", + " [-0.0721, -0.1711, 0.0344]],\n", + " \n", + " [[ 0.1820, -0.0537, -0.0974],\n", + " [ 0.0366, -0.0710, 0.1273],\n", + " [ 0.1132, -0.1594, 0.0878]],\n", + " \n", + " [[-0.0874, -0.0401, 0.1827],\n", + " [-0.0301, 0.1205, -0.0396],\n", + " [-0.1143, -0.1007, 0.1561]]],\n", + " \n", + " \n", + " [[[ 0.1522, -0.0012, -0.1785],\n", + " [-0.1833, -0.1828, -0.1643],\n", + " [-0.1765, -0.1757, -0.0608]],\n", + " \n", + " [[-0.0684, 0.0521, 0.1137],\n", + " [-0.0028, 0.0616, 0.0758],\n", + " [-0.1736, 0.0667, 0.1229]],\n", + " \n", + " [[ 0.1298, -0.1848, -0.1570],\n", + " [-0.1052, -0.1172, -0.1223],\n", + " [-0.1389, -0.0095, -0.0410]]],\n", + " \n", + " \n", + " [[[ 0.0213, -0.0975, 0.0964],\n", + " [ 0.0535, -0.0775, 0.0790],\n", + " [-0.1796, -0.1468, 0.1036]],\n", + " \n", + " [[-0.0403, 0.0646, -0.0932],\n", + " [ 0.1779, -0.1616, 0.0644],\n", + " [-0.0508, -0.1158, -0.0592]],\n", + " \n", + " [[-0.1644, -0.1327, 0.0817],\n", + " [ 0.0320, -0.0213, -0.0946],\n", + " [-0.1106, 0.1463, -0.1642]]],\n", + " \n", + " \n", + " [[[-0.0985, -0.1160, -0.0915],\n", + " [ 0.1857, 0.0806, 0.1761],\n", + " [-0.0817, 0.1095, 0.0896]],\n", + " \n", + " [[-0.0660, -0.1680, 0.1833],\n", + " [ 0.0611, 0.0077, -0.0848],\n", + " [-0.1516, 0.1737, 0.0484]],\n", + " \n", + " [[ 0.1434, -0.0732, -0.0904],\n", + " [ 0.0962, 0.1783, 0.0192],\n", + " [ 0.0915, 0.0006, 0.0334]]],\n", + " \n", + " \n", + " [[[-0.0047, 0.1807, -0.1798],\n", + " [-0.0164, 0.1119, -0.0805],\n", + " [ 0.1855, -0.0681, -0.0187]],\n", + " \n", + " [[-0.0069, 0.0491, -0.1868],\n", + " [-0.1609, -0.0316, 0.0150],\n", + " [-0.1605, 0.1506, -0.0074]],\n", + " \n", + " [[ 0.0851, -0.1732, -0.1777],\n", + " [ 0.0539, -0.0500, -0.1231],\n", + " [ 0.1654, 0.0342, -0.1904]]],\n", + " \n", + " \n", + " [[[ 0.0476, 0.0284, 0.1212],\n", + " [-0.1603, -0.1924, 0.0144],\n", + " [ 0.0076, -0.0928, -0.1645]],\n", + " \n", + " [[ 0.0215, 0.1845, -0.1034],\n", + " [ 0.1574, -0.1577, -0.0438],\n", + " [-0.1360, -0.0601, -0.1693]],\n", + " \n", + " [[-0.0720, 0.0619, 0.1405],\n", + " [ 0.0699, -0.1288, 0.0041],\n", + " [-0.0381, -0.1697, -0.1568]]],\n", + " \n", + " \n", + " [[[-0.1599, 0.1231, -0.1034],\n", + " [-0.0314, 0.0105, -0.1449],\n", + " [-0.0172, -0.0781, 0.0839]],\n", + " \n", + " [[-0.0676, 0.1185, -0.1559],\n", + " [-0.1053, -0.1306, 0.1820],\n", + " [ 0.1584, -0.1370, 0.1828]],\n", + " \n", + " [[ 0.0658, 0.1412, -0.0537],\n", + " [-0.1230, -0.1411, -0.0011],\n", + " [-0.1318, -0.0458, 0.1838]]],\n", + " \n", + " \n", + " [[[-0.0268, 0.1747, -0.1037],\n", + " [ 0.0515, -0.0228, -0.1024],\n", + " [-0.1543, -0.0643, -0.0100]],\n", + " \n", + " [[-0.1572, -0.1530, 0.0026],\n", + " [ 0.1463, -0.1233, 0.0470],\n", + " [-0.1595, -0.1108, -0.0654]],\n", + " \n", + " [[-0.0521, -0.0094, 0.1544],\n", + " [-0.0505, -0.0332, 0.0048],\n", + " [ 0.0735, 0.1350, 0.0690]]],\n", + " \n", + " \n", + " [[[ 0.0025, 0.0724, 0.0930],\n", + " [-0.1885, 0.0475, 0.1100],\n", + " [-0.1622, 0.0087, -0.0030]],\n", + " \n", + " [[ 0.1032, -0.1425, -0.0620],\n", + " [ 0.1515, -0.0736, -0.1888],\n", + " [-0.1246, 0.1424, -0.0491]],\n", + " \n", + " [[ 0.1759, -0.1616, 0.1198],\n", + " [-0.1103, 0.1032, 0.1727],\n", + " [-0.0601, 0.1635, 0.0034]]],\n", + " \n", + " \n", + " [[[ 0.0301, 0.1517, 0.0657],\n", + " [-0.1368, -0.1165, 0.1193],\n", + " [-0.0962, 0.1451, 0.1099]],\n", + " \n", + " [[ 0.1646, -0.1860, -0.1187],\n", + " [-0.1367, -0.0911, 0.1337],\n", + " [-0.0926, -0.0524, -0.0672]],\n", + " \n", + " [[-0.1509, -0.1231, -0.0855],\n", + " [ 0.1808, -0.0713, 0.0410],\n", + " [-0.0621, -0.0506, 0.1871]]],\n", + " \n", + " \n", + " [[[ 0.0888, -0.0874, -0.0826],\n", + " [ 0.0416, -0.0961, 0.0603],\n", + " [ 0.1455, 0.0050, 0.0318]],\n", + " \n", + " [[-0.1633, 0.0070, -0.1537],\n", + " [-0.0109, 0.1602, -0.0463],\n", + " [-0.0423, -0.0147, -0.1045]],\n", + " \n", + " [[ 0.1640, -0.0997, -0.1662],\n", + " [-0.1074, 0.1549, -0.1905],\n", + " [-0.1708, 0.1624, 0.0219]]],\n", + " \n", + " \n", + " [[[ 0.0824, -0.1376, 0.1086],\n", + " [ 0.0836, 0.0135, 0.0351],\n", + " [-0.1518, 0.0784, -0.1708]],\n", + " \n", + " [[-0.1636, -0.1571, 0.1032],\n", + " [-0.1152, 0.0274, -0.1022],\n", + " [-0.0956, -0.1606, -0.1615]],\n", + " \n", + " [[ 0.1307, 0.0419, 0.1924],\n", + " [-0.0599, -0.1296, -0.0448],\n", + " [ 0.0363, 0.0377, -0.0460]]],\n", + " \n", + " \n", + " [[[-0.1685, -0.1277, -0.0465],\n", + " [ 0.0922, -0.1011, 0.0742],\n", + " [ 0.0053, -0.1456, 0.0135]],\n", + " \n", + " [[ 0.1341, 0.0131, 0.1281],\n", + " [-0.1020, 0.1069, -0.0631],\n", + " [-0.0439, -0.1189, -0.1822]],\n", + " \n", + " [[ 0.1624, -0.1253, 0.0302],\n", + " [ 0.0709, 0.0767, 0.1453],\n", + " [ 0.0203, 0.1603, -0.1720]]],\n", + " \n", + " \n", + " [[[-0.1550, 0.1513, -0.1003],\n", + " [ 0.0370, 0.0367, -0.0233],\n", + " [ 0.0916, -0.0871, 0.1579]],\n", + " \n", + " [[-0.1900, 0.0314, 0.0865],\n", + " [-0.0197, 0.0296, -0.0048],\n", + " [ 0.0846, 0.1543, -0.0770]],\n", + " \n", + " [[-0.0016, -0.0978, 0.1826],\n", + " [-0.0477, 0.0689, 0.1079],\n", + " [ 0.0400, 0.0880, 0.1674]]],\n", + " \n", + " \n", + " [[[ 0.0145, -0.0447, -0.1742],\n", + " [ 0.0394, 0.0127, -0.1172],\n", + " [ 0.1330, -0.1207, 0.0326]],\n", + " \n", + " [[-0.0155, -0.1602, 0.0023],\n", + " [ 0.0789, 0.1648, 0.1781],\n", + " [-0.1468, -0.0481, -0.1260]],\n", + " \n", + " [[-0.0139, 0.0848, -0.0536],\n", + " [-0.1581, 0.1130, 0.0717],\n", + " [ 0.0275, -0.0006, -0.0049]]],\n", + " \n", + " \n", + " [[[-0.0199, 0.0032, -0.1246],\n", + " [ 0.0479, 0.1418, -0.1295],\n", + " [-0.1646, -0.1139, -0.1018]],\n", + " \n", + " [[ 0.1475, 0.1413, -0.0354],\n", + " [ 0.0612, -0.1652, 0.0801],\n", + " [-0.1306, -0.0165, 0.1733]],\n", + " \n", + " [[ 0.1527, 0.0911, -0.1906],\n", + " [-0.1152, 0.1737, 0.0436],\n", + " [-0.0213, -0.0314, -0.0319]]],\n", + " \n", + " \n", + " [[[-0.0003, -0.0546, -0.1255],\n", + " [ 0.0914, -0.1414, 0.0542],\n", + " [ 0.1139, 0.0132, 0.0815]],\n", + " \n", + " [[-0.0042, 0.0541, 0.1456],\n", + " [ 0.0509, -0.0790, 0.0272],\n", + " [ 0.1419, 0.0992, -0.1448]],\n", + " \n", + " [[ 0.0496, 0.0013, 0.0838],\n", + " [-0.0662, 0.0315, -0.1168],\n", + " [-0.0069, -0.1503, 0.0729]]],\n", + " \n", + " \n", + " [[[ 0.1866, 0.1329, -0.0560],\n", + " [ 0.0026, 0.1533, 0.0326],\n", + " [-0.1161, -0.0323, 0.0053]],\n", + " \n", + " [[-0.0243, -0.1823, -0.1657],\n", + " [-0.0107, -0.0832, 0.0029],\n", + " [ 0.0981, 0.1241, -0.1788]],\n", + " \n", + " [[-0.0400, -0.0577, -0.0757],\n", + " [-0.0584, 0.0176, -0.1019],\n", + " [-0.1828, 0.1589, -0.0312]]],\n", + " \n", + " \n", + " [[[-0.1083, -0.1236, -0.0904],\n", + " [-0.1575, 0.0157, 0.0552],\n", + " [-0.0839, 0.1704, -0.1457]],\n", + " \n", + " [[-0.1648, -0.0270, -0.0489],\n", + " [-0.1122, -0.0288, -0.0073],\n", + " [-0.1443, -0.1712, 0.0100]],\n", + " \n", + " [[-0.1142, -0.1552, 0.1568],\n", + " [ 0.0743, -0.1108, -0.0643],\n", + " [-0.0394, -0.1345, 0.0992]]],\n", + " \n", + " \n", + " [[[-0.1591, 0.0942, -0.1035],\n", + " [-0.0781, 0.0725, -0.0888],\n", + " [ 0.0959, 0.0213, 0.1222]],\n", + " \n", + " [[ 0.1202, -0.0217, -0.0955],\n", + " [-0.1748, -0.1133, -0.0704],\n", + " [-0.0670, -0.1401, 0.1553]],\n", + " \n", + " [[ 0.0053, -0.0871, -0.0239],\n", + " [ 0.0961, -0.0547, 0.1741],\n", + " [-0.0570, 0.0477, 0.1853]]],\n", + " \n", + " \n", + " [[[-0.1115, -0.0183, -0.1302],\n", + " [ 0.1435, -0.0238, -0.0048],\n", + " [ 0.1862, -0.1837, 0.1711]],\n", + " \n", + " [[ 0.1375, -0.1798, 0.0818],\n", + " [-0.0792, 0.0820, 0.1373],\n", + " [ 0.1849, 0.0672, -0.1822]],\n", + " \n", + " [[ 0.1868, -0.0356, 0.0726],\n", + " [-0.1523, -0.1130, 0.1506],\n", + " [-0.1046, 0.0178, 0.0990]]],\n", + " \n", + " \n", + " [[[ 0.1321, -0.1641, 0.0411],\n", + " [ 0.0526, 0.0393, 0.0918],\n", + " [-0.1345, 0.0750, 0.0859]],\n", + " \n", + " [[-0.0985, 0.1466, 0.1349],\n", + " [-0.1461, -0.1742, 0.0941],\n", + " [-0.1502, -0.1813, 0.0864]],\n", + " \n", + " [[-0.1039, 0.1179, 0.1499],\n", + " [-0.0366, -0.0120, 0.0951],\n", + " [ 0.0087, 0.1212, -0.0183]]],\n", + " \n", + " \n", + " [[[-0.1375, 0.0765, -0.0072],\n", + " [-0.0041, 0.0379, -0.0243],\n", + " [-0.1495, 0.1601, 0.1575]],\n", + " \n", + " [[-0.0454, 0.1642, 0.0720],\n", + " [-0.0533, 0.0150, 0.0039],\n", + " [ 0.0194, 0.0113, -0.1194]],\n", + " \n", + " [[ 0.0527, -0.0886, 0.0359],\n", + " [ 0.1595, 0.0526, -0.0048],\n", + " [-0.1790, -0.0458, -0.0324]]],\n", + " \n", + " \n", + " [[[-0.1638, 0.0942, 0.0686],\n", + " [-0.1082, -0.0675, 0.1892],\n", + " [-0.1347, -0.1247, 0.0739]],\n", + " \n", + " [[ 0.0595, 0.1504, -0.1657],\n", + " [ 0.0733, 0.0529, -0.1599],\n", + " [ 0.0171, -0.1127, -0.0259]],\n", + " \n", + " [[-0.0092, 0.0193, 0.1176],\n", + " [-0.1183, 0.0101, 0.1011],\n", + " [ 0.0648, -0.1897, 0.0782]]],\n", + " \n", + " \n", + " [[[ 0.0098, -0.1161, -0.0802],\n", + " [-0.1821, 0.0221, -0.1754],\n", + " [-0.1218, 0.0525, -0.0480]],\n", + " \n", + " [[ 0.0770, 0.0477, 0.1514],\n", + " [ 0.0374, -0.1075, -0.1026],\n", + " [-0.0581, -0.1011, 0.1241]],\n", + " \n", + " [[-0.0567, -0.0163, 0.0374],\n", + " [-0.1739, -0.0579, 0.0704],\n", + " [ 0.1817, 0.1561, 0.1677]]],\n", + " \n", + " \n", + " [[[-0.0569, -0.0763, 0.0044],\n", + " [-0.1133, 0.0813, 0.1477],\n", + " [ 0.0836, 0.0483, -0.1800]],\n", + " \n", + " [[ 0.1343, -0.1590, 0.1177],\n", + " [ 0.1071, -0.1647, -0.0646],\n", + " [ 0.1578, -0.1261, 0.0243]],\n", + " \n", + " [[-0.0424, -0.0241, -0.0988],\n", + " [ 0.0023, 0.0029, -0.0291],\n", + " [ 0.0415, -0.0557, 0.1427]]],\n", + " \n", + " \n", + " [[[-0.1028, 0.1054, 0.1658],\n", + " [-0.0357, 0.1579, 0.1237],\n", + " [ 0.0368, 0.0532, -0.1043]],\n", + " \n", + " [[-0.0369, 0.0575, -0.1023],\n", + " [ 0.0635, 0.1015, 0.1112],\n", + " [-0.1235, 0.0467, 0.0908]],\n", + " \n", + " [[ 0.1380, 0.0633, 0.1087],\n", + " [-0.1360, 0.0422, -0.1524],\n", + " [ 0.0819, 0.0918, -0.1624]]],\n", + " \n", + " \n", + " [[[ 0.1584, -0.0218, -0.0236],\n", + " [ 0.1878, -0.1289, 0.1343],\n", + " [ 0.0351, 0.1225, -0.1460]],\n", + " \n", + " [[ 0.0690, -0.1439, 0.0056],\n", + " [ 0.0272, -0.0058, 0.0125],\n", + " [ 0.0868, -0.0684, -0.0884]],\n", + " \n", + " [[ 0.1045, 0.0583, -0.0870],\n", + " [ 0.0600, -0.0732, -0.1695],\n", + " [ 0.0953, 0.0246, 0.1245]]]])),\n", + " ('conv1.bias',\n", + " tensor([-1.8698e-01, -7.9379e-06, -1.9277e-02, 5.2182e-02, 7.5716e-02,\n", + " -3.3830e-03, -9.6565e-02, 1.0241e-01, -8.2457e-02, -1.6224e-01,\n", + " 1.2980e-01, -8.2256e-02, -7.4655e-02, -3.7980e-02, 8.3407e-02,\n", + " -1.4880e-01, 4.8939e-02, 2.7506e-02, 5.8676e-03, -1.5813e-01,\n", + " -6.2464e-04, 1.0359e-02, -1.5525e-01, 7.9100e-02, 1.6850e-02,\n", + " -1.3809e-01, -6.3393e-02, -5.3843e-02, -1.5219e-02, -1.7365e-01,\n", + " 1.7249e-01, -1.1165e-01])),\n", + " ('conv2.weight',\n", + " tensor([[[[ 4.3750e-02, 4.5533e-02, -2.9410e-02],\n", + " [-4.1395e-02, 5.0397e-04, -1.3265e-02],\n", + " [-4.9851e-02, -1.0518e-02, 5.7710e-02]],\n", + " \n", + " [[-5.6332e-02, -4.7168e-03, -4.4627e-02],\n", + " [ 5.3513e-03, -4.0824e-02, 1.8281e-02],\n", + " [ 5.0677e-02, -1.5295e-02, -6.1751e-03]],\n", + " \n", + " [[-2.4984e-02, 1.2784e-02, -4.7123e-02],\n", + " [-4.3238e-02, 4.7349e-02, -1.5219e-02],\n", + " [-3.6073e-02, 4.1506e-02, -3.5337e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 5.6048e-02, 1.9129e-03, -1.7200e-02],\n", + " [ 5.8869e-02, -5.1520e-02, -5.3205e-02],\n", + " [-1.3903e-02, 5.1790e-02, 2.2585e-02]],\n", + " \n", + " [[ 1.1835e-02, -4.9313e-02, -3.1838e-02],\n", + " [ 7.6813e-03, 4.2715e-02, -5.7404e-02],\n", + " [-4.1474e-02, -2.3128e-02, -4.7935e-02]],\n", + " \n", + " [[-2.1860e-02, -2.1817e-02, -3.2578e-02],\n", + " [ 3.1317e-02, 3.3435e-02, 3.1837e-02],\n", + " [-2.2399e-03, 3.1600e-02, 4.0183e-02]]],\n", + " \n", + " \n", + " [[[ 1.9610e-04, 5.3780e-02, -4.5810e-02],\n", + " [ 4.0340e-02, 1.4904e-02, -1.5597e-02],\n", + " [-4.6080e-02, 5.0714e-02, -5.7445e-03]],\n", + " \n", + " [[-3.5281e-02, 3.3011e-02, 4.3343e-02],\n", + " [-4.6263e-02, -5.6184e-02, 5.1245e-03],\n", + " [ 3.6015e-02, -3.3152e-02, 4.6629e-03]],\n", + " \n", + " [[ 1.7650e-03, -4.2336e-02, 4.3744e-02],\n", + " [ 2.1655e-02, 5.3759e-02, 1.3719e-03],\n", + " [ 4.2005e-02, 5.3998e-02, 1.9009e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 8.9786e-03, -1.8645e-02, -1.3587e-02],\n", + " [-5.4167e-02, 1.2335e-02, -3.0384e-02],\n", + " [-4.8722e-03, -3.7296e-02, -2.6446e-02]],\n", + " \n", + " [[ 1.7580e-02, 3.8462e-02, -5.0269e-02],\n", + " [ 2.6601e-03, -1.1462e-02, 4.7459e-02],\n", + " [-2.8888e-02, 3.4436e-02, -4.9943e-02]],\n", + " \n", + " [[-5.0206e-02, -5.6025e-02, -3.6346e-02],\n", + " [-2.4407e-02, 5.3721e-02, -5.4920e-02],\n", + " [ 5.1835e-02, -3.2396e-02, 3.2373e-02]]],\n", + " \n", + " \n", + " [[[-2.7759e-02, -5.4737e-02, -1.1689e-02],\n", + " [ 3.9462e-02, 2.8649e-02, 5.1776e-02],\n", + " [ 2.4253e-02, -2.8318e-02, 2.7402e-02]],\n", + " \n", + " [[ 1.3045e-02, -1.0456e-02, 2.0426e-02],\n", + " [ 2.1949e-02, 4.6817e-02, -5.6093e-02],\n", + " [ 2.7145e-02, -5.5441e-02, -2.0719e-02]],\n", + " \n", + " [[ 4.4704e-02, -2.4099e-02, -4.7185e-02],\n", + " [-4.3257e-02, -3.3058e-02, -8.6451e-03],\n", + " [-3.7283e-02, -3.4569e-02, -7.1049e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-4.1559e-02, -2.9240e-02, 2.7197e-03],\n", + " [ 2.0770e-02, 5.4479e-02, -4.4845e-02],\n", + " [-1.1641e-02, -2.9814e-02, -2.4419e-02]],\n", + " \n", + " [[-1.5743e-02, 1.0854e-02, 3.0878e-02],\n", + " [ 2.2739e-02, 3.2999e-02, -1.1902e-02],\n", + " [-3.4837e-02, 1.5305e-02, -8.7552e-03]],\n", + " \n", + " [[-2.2882e-02, 9.4639e-03, 5.1878e-03],\n", + " [-2.6344e-02, 2.9063e-02, -1.9337e-02],\n", + " [-3.4314e-02, 1.5313e-02, 4.1524e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 9.7454e-03, -3.2903e-03, 1.0696e-02],\n", + " [-4.0918e-02, 1.6352e-02, 1.4646e-02],\n", + " [ 1.2516e-02, -2.1804e-02, -2.5489e-02]],\n", + " \n", + " [[-1.6083e-02, 2.5374e-02, 3.1458e-02],\n", + " [-3.1497e-02, -1.9513e-02, -2.1223e-02],\n", + " [ 6.6286e-03, 1.6538e-02, -4.8944e-02]],\n", + " \n", + " [[ 2.4808e-02, -2.9520e-02, -4.8227e-02],\n", + " [ 1.7325e-03, -4.7443e-02, 2.3087e-03],\n", + " [-1.0008e-02, -2.0313e-02, 2.9944e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 5.5781e-02, -2.0898e-02, -3.1487e-03],\n", + " [-1.6931e-02, 4.5279e-04, -1.5024e-02],\n", + " [-5.5885e-02, 2.7140e-02, -8.5434e-03]],\n", + " \n", + " [[ 1.3970e-02, -3.3131e-02, 4.3112e-02],\n", + " [-3.4956e-02, -5.0144e-02, -1.6391e-02],\n", + " [-9.1003e-03, -2.0204e-02, -1.0226e-03]],\n", + " \n", + " [[-4.0053e-02, -5.0194e-02, 5.0405e-02],\n", + " [ 5.4107e-02, 4.2185e-02, 3.4359e-02],\n", + " [ 1.6749e-02, -1.4102e-02, 5.0171e-02]]],\n", + " \n", + " \n", + " [[[ 2.8229e-02, 5.6586e-02, -3.9617e-03],\n", + " [ 2.9538e-02, -1.2507e-02, -2.5516e-02],\n", + " [-1.5193e-02, -2.9232e-02, -2.0701e-02]],\n", + " \n", + " [[-5.8773e-02, 3.3015e-02, -9.4146e-03],\n", + " [ 2.8957e-02, 5.8666e-02, 2.8679e-02],\n", + " [ 1.5249e-02, -1.2246e-03, 1.2230e-03]],\n", + " \n", + " [[ 2.6050e-02, -4.6042e-02, -3.4895e-03],\n", + " [ 4.9529e-02, 6.6835e-03, -4.1808e-02],\n", + " [-8.6450e-03, -4.8510e-02, -2.4011e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-5.1427e-02, 2.4528e-02, -5.4878e-02],\n", + " [-1.8610e-02, 5.4365e-02, 3.5053e-03],\n", + " [-3.9922e-02, 4.2510e-02, -5.7261e-02]],\n", + " \n", + " [[ 4.1938e-02, -4.2039e-02, -1.2487e-02],\n", + " [-1.4090e-02, -3.7895e-02, 1.4394e-02],\n", + " [ 2.2555e-02, -2.7264e-02, 5.6102e-02]],\n", + " \n", + " [[ 1.5770e-02, 5.4672e-02, -2.4056e-02],\n", + " [ 5.2089e-02, -2.8859e-02, -2.6499e-03],\n", + " [-5.2122e-02, -3.7436e-02, 3.9897e-02]]],\n", + " \n", + " \n", + " [[[ 2.7888e-02, 2.9241e-02, -1.9488e-02],\n", + " [ 2.8928e-02, 5.3312e-02, -2.9810e-02],\n", + " [-8.5104e-03, 5.7751e-02, -8.1857e-03]],\n", + " \n", + " [[ 4.2649e-02, 3.3158e-03, 4.2879e-02],\n", + " [ 7.7893e-03, -3.2879e-02, 2.7630e-02],\n", + " [ 5.4706e-03, 4.8019e-02, 1.2420e-02]],\n", + " \n", + " [[-4.2004e-02, -4.2790e-02, 2.4634e-02],\n", + " [-5.4641e-02, 3.4600e-02, 2.9071e-03],\n", + " [ 2.6470e-02, 4.6701e-02, 3.7158e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.7641e-02, -2.1205e-02, -5.1504e-02],\n", + " [-7.4737e-03, 5.5061e-02, -2.6397e-03],\n", + " [-4.4653e-02, -3.6719e-02, 3.4420e-06]],\n", + " \n", + " [[ 1.6525e-02, 1.7280e-02, 5.4554e-03],\n", + " [ 4.0098e-02, 2.7571e-02, -4.4965e-02],\n", + " [ 6.1493e-03, -5.7754e-02, 1.0513e-02]],\n", + " \n", + " [[-5.7615e-02, 3.2921e-02, -1.5900e-02],\n", + " [ 2.0081e-02, 5.4590e-02, 1.1296e-02],\n", + " [-4.5015e-02, 1.1341e-03, 2.6447e-02]]]])),\n", + " ('conv2.bias',\n", + " tensor([-0.0538, -0.0320, -0.0153, 0.0558, 0.0254, 0.0281, -0.0148, 0.0060,\n", + " -0.0283, -0.0062, 0.0437, -0.0064, 0.0341, 0.0233, -0.0201, 0.0391,\n", + " 0.0243, 0.0071, 0.0125, -0.0138, -0.0377, -0.0169, -0.0475, -0.0004,\n", + " -0.0105, -0.0502, 0.0241, 0.0090, 0.0069, -0.0315, -0.0192, 0.0204])),\n", + " ('conv3.weight',\n", + " tensor([[[[-0.0215, -0.0208, -0.0272],\n", + " [-0.0493, -0.0117, -0.0285],\n", + " [ 0.0515, -0.0041, 0.0126]],\n", + " \n", + " [[ 0.0299, -0.0301, 0.0552],\n", + " [ 0.0450, 0.0449, -0.0583],\n", + " [-0.0452, -0.0480, -0.0275]],\n", + " \n", + " [[-0.0262, -0.0338, 0.0505],\n", + " [ 0.0146, -0.0364, -0.0044],\n", + " [-0.0102, -0.0051, 0.0017]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0367, -0.0468, -0.0586],\n", + " [ 0.0126, 0.0037, 0.0191],\n", + " [-0.0153, 0.0048, -0.0160]],\n", + " \n", + " [[-0.0050, 0.0364, 0.0582],\n", + " [ 0.0093, -0.0268, -0.0355],\n", + " [-0.0125, 0.0500, 0.0009]],\n", + " \n", + " [[ 0.0237, -0.0211, -0.0130],\n", + " [-0.0489, 0.0118, 0.0387],\n", + " [-0.0006, 0.0301, 0.0283]]],\n", + " \n", + " \n", + " [[[-0.0391, -0.0464, -0.0158],\n", + " [ 0.0201, -0.0054, 0.0422],\n", + " [ 0.0085, -0.0474, -0.0251]],\n", + " \n", + " [[-0.0346, 0.0536, -0.0391],\n", + " [ 0.0244, -0.0263, -0.0073],\n", + " [ 0.0076, 0.0160, 0.0044]],\n", + " \n", + " [[-0.0128, 0.0146, -0.0381],\n", + " [-0.0277, -0.0142, 0.0226],\n", + " [ 0.0190, 0.0326, -0.0219]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0217, 0.0129, 0.0558],\n", + " [ 0.0164, -0.0292, -0.0467],\n", + " [-0.0296, 0.0205, -0.0300]],\n", + " \n", + " [[ 0.0254, -0.0151, -0.0583],\n", + " [ 0.0111, -0.0469, -0.0300],\n", + " [-0.0462, 0.0293, -0.0351]],\n", + " \n", + " [[ 0.0401, -0.0251, 0.0160],\n", + " [-0.0160, -0.0195, -0.0065],\n", + " [-0.0519, 0.0351, 0.0357]]],\n", + " \n", + " \n", + " [[[ 0.0544, -0.0209, -0.0454],\n", + " [ 0.0287, -0.0205, -0.0294],\n", + " [-0.0195, -0.0235, -0.0378]],\n", + " \n", + " [[-0.0294, -0.0380, 0.0301],\n", + " [ 0.0360, 0.0367, 0.0458],\n", + " [-0.0189, -0.0017, 0.0145]],\n", + " \n", + " [[-0.0297, 0.0567, 0.0276],\n", + " [ 0.0298, 0.0383, 0.0227],\n", + " [ 0.0262, 0.0063, 0.0131]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0002, 0.0432, -0.0247],\n", + " [ 0.0068, -0.0298, -0.0484],\n", + " [-0.0361, -0.0014, 0.0444]],\n", + " \n", + " [[-0.0184, -0.0201, -0.0163],\n", + " [-0.0466, 0.0255, -0.0244],\n", + " [ 0.0283, 0.0149, -0.0588]],\n", + " \n", + " [[ 0.0323, 0.0392, -0.0254],\n", + " [ 0.0560, 0.0137, -0.0401],\n", + " [-0.0236, 0.0589, 0.0448]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 0.0213, 0.0204, 0.0574],\n", + " [-0.0276, -0.0196, 0.0117],\n", + " [ 0.0569, -0.0158, -0.0502]],\n", + " \n", + " [[ 0.0452, -0.0038, 0.0502],\n", + " [ 0.0428, -0.0398, -0.0486],\n", + " [ 0.0130, 0.0563, 0.0576]],\n", + " \n", + " [[ 0.0484, -0.0535, 0.0048],\n", + " [ 0.0268, -0.0290, -0.0390],\n", + " [ 0.0189, -0.0194, -0.0588]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0163, -0.0113, -0.0520],\n", + " [ 0.0288, -0.0547, -0.0544],\n", + " [ 0.0442, 0.0376, 0.0566]],\n", + " \n", + " [[-0.0343, 0.0569, 0.0438],\n", + " [-0.0403, -0.0372, -0.0532],\n", + " [ 0.0322, 0.0126, 0.0423]],\n", + " \n", + " [[ 0.0577, 0.0136, -0.0480],\n", + " [-0.0293, -0.0348, 0.0342],\n", + " [-0.0510, -0.0078, -0.0042]]],\n", + " \n", + " \n", + " [[[-0.0243, 0.0406, 0.0537],\n", + " [ 0.0209, -0.0059, -0.0487],\n", + " [-0.0425, 0.0339, 0.0444]],\n", + " \n", + " [[ 0.0465, -0.0467, 0.0461],\n", + " [-0.0389, 0.0144, -0.0502],\n", + " [ 0.0274, 0.0552, 0.0356]],\n", + " \n", + " [[-0.0289, 0.0474, -0.0217],\n", + " [ 0.0472, -0.0135, 0.0164],\n", + " [-0.0165, -0.0049, -0.0475]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0377, 0.0267, 0.0367],\n", + " [ 0.0111, 0.0114, -0.0329],\n", + " [ 0.0031, -0.0223, -0.0280]],\n", + " \n", + " [[-0.0500, -0.0529, 0.0116],\n", + " [ 0.0483, 0.0121, -0.0149],\n", + " [ 0.0328, 0.0201, 0.0402]],\n", + " \n", + " [[ 0.0463, 0.0157, -0.0332],\n", + " [ 0.0150, 0.0479, 0.0461],\n", + " [ 0.0275, 0.0506, -0.0466]]],\n", + " \n", + " \n", + " [[[-0.0478, 0.0274, 0.0500],\n", + " [-0.0394, 0.0032, -0.0496],\n", + " [ 0.0381, 0.0391, 0.0330]],\n", + " \n", + " [[ 0.0184, -0.0560, -0.0345],\n", + " [-0.0459, -0.0215, 0.0452],\n", + " [ 0.0049, 0.0537, 0.0544]],\n", + " \n", + " [[-0.0413, -0.0084, 0.0585],\n", + " [ 0.0338, -0.0067, -0.0113],\n", + " [-0.0187, -0.0234, -0.0525]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0389, 0.0325, -0.0538],\n", + " [ 0.0118, 0.0509, 0.0352],\n", + " [-0.0351, -0.0341, -0.0506]],\n", + " \n", + " [[ 0.0136, -0.0349, 0.0082],\n", + " [ 0.0358, -0.0211, 0.0537],\n", + " [-0.0183, 0.0390, -0.0267]],\n", + " \n", + " [[-0.0219, -0.0145, -0.0351],\n", + " [ 0.0556, 0.0033, -0.0030],\n", + " [ 0.0075, -0.0425, -0.0365]]]])),\n", + " ('conv3.bias',\n", + " tensor([ 0.0304, 0.0099, -0.0004, 0.0334, 0.0301, 0.0491, 0.0530, -0.0432,\n", + " -0.0127, -0.0549, -0.0419, 0.0159, -0.0284, 0.0295, -0.0148, 0.0275,\n", + " 0.0554, -0.0056, 0.0389, -0.0264, -0.0383, 0.0126, 0.0320, 0.0312,\n", + " 0.0018, 0.0560, -0.0329, -0.0155, -0.0391, -0.0539, -0.0571, -0.0254])),\n", + " ('conv4.weight',\n", + " tensor([[[[-3.8911e-02, -3.4220e-02, 4.2567e-03],\n", + " [-4.5321e-02, -5.2531e-02, -8.1722e-03],\n", + " [-2.2638e-02, 4.4213e-02, 5.6989e-02]],\n", + " \n", + " [[ 1.8417e-03, -1.4453e-02, 4.9892e-02],\n", + " [ 5.7762e-02, 9.6610e-03, -3.9509e-02],\n", + " [ 3.3795e-02, 5.0409e-02, -5.8834e-02]],\n", + " \n", + " [[ 4.6645e-03, -1.6286e-02, 4.3410e-02],\n", + " [-3.4043e-02, -2.2207e-02, 4.0967e-02],\n", + " [ 5.3004e-02, -2.2756e-02, -6.7993e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.1741e-02, -5.5062e-02, -3.3625e-02],\n", + " [-9.2320e-03, -3.3036e-02, 3.3196e-02],\n", + " [ 2.3940e-02, 2.0442e-02, 1.4183e-02]],\n", + " \n", + " [[-2.7139e-02, -3.4129e-03, -1.0090e-02],\n", + " [ 1.3073e-02, -1.6998e-02, -4.7540e-02],\n", + " [-2.5758e-02, -1.9363e-02, 2.1905e-02]],\n", + " \n", + " [[ 5.7593e-02, 1.5013e-02, -5.7894e-02],\n", + " [ 5.7964e-02, -2.3412e-02, 2.6955e-02],\n", + " [-3.9814e-02, -4.6015e-02, -5.3240e-02]]],\n", + " \n", + " \n", + " [[[ 5.8211e-02, -4.1118e-02, 2.7704e-02],\n", + " [ 5.7198e-02, 8.4165e-03, -5.1708e-02],\n", + " [ 3.1423e-02, 1.5026e-02, 3.5922e-02]],\n", + " \n", + " [[ 8.8858e-03, 3.2818e-02, 5.4486e-02],\n", + " [-2.6636e-02, 2.2604e-02, 2.9531e-02],\n", + " [-1.0327e-03, 2.2348e-03, 2.4103e-02]],\n", + " \n", + " [[ 3.8683e-02, -5.0057e-03, 5.0224e-02],\n", + " [ 3.5756e-02, -2.7295e-02, -2.2854e-02],\n", + " [-3.2043e-02, -3.2415e-02, 4.1034e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.9791e-02, 4.3243e-02, -3.5177e-02],\n", + " [ 2.4554e-02, 4.2845e-03, 4.8009e-02],\n", + " [ 2.4897e-03, 3.9550e-02, -3.0833e-02]],\n", + " \n", + " [[ 4.5807e-02, 6.5845e-03, 9.3362e-05],\n", + " [-1.9411e-02, -2.9161e-02, 5.0828e-02],\n", + " [ 1.2028e-03, 2.1260e-02, -4.3710e-03]],\n", + " \n", + " [[-4.8702e-02, -2.0571e-02, -3.5162e-02],\n", + " [-2.5856e-02, -2.6619e-02, 8.1867e-03],\n", + " [-2.7671e-02, -9.6651e-03, -5.3279e-02]]],\n", + " \n", + " \n", + " [[[-5.6432e-02, 2.3722e-02, -2.1750e-02],\n", + " [-4.8247e-03, -2.1226e-02, -1.0829e-02],\n", + " [-1.9523e-02, -1.8187e-02, 2.2772e-03]],\n", + " \n", + " [[-1.0907e-02, 3.3984e-02, -1.2088e-02],\n", + " [-1.8657e-02, -4.8297e-02, 2.3614e-02],\n", + " [-3.9670e-02, 6.1733e-03, -2.9168e-02]],\n", + " \n", + " [[-4.2112e-02, 2.8203e-02, -1.7385e-03],\n", + " [-2.5282e-02, -9.4592e-05, 6.5093e-03],\n", + " [-4.1745e-02, 4.3988e-03, -1.1622e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 4.2991e-02, 1.6912e-02, -4.3689e-02],\n", + " [ 5.1871e-02, 4.8566e-02, 3.6205e-02],\n", + " [-3.2016e-02, -1.3596e-02, -2.7950e-02]],\n", + " \n", + " [[-2.8307e-02, -4.0278e-02, 1.5087e-02],\n", + " [ 4.0443e-02, -3.5727e-02, 3.7196e-02],\n", + " [-1.4194e-02, -2.7319e-02, -5.1305e-02]],\n", + " \n", + " [[-2.9962e-02, 2.4693e-02, -4.4912e-02],\n", + " [ 5.5890e-03, 4.6671e-02, 3.3599e-02],\n", + " [-3.9949e-02, -4.4716e-02, 2.2345e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 3.1920e-02, -4.9932e-02, -1.0871e-02],\n", + " [-3.7500e-02, 4.1638e-02, -1.3246e-02],\n", + " [ 1.6447e-02, -5.6741e-02, -3.7524e-02]],\n", + " \n", + " [[ 3.3903e-02, -3.1321e-02, -4.4877e-02],\n", + " [-2.2473e-02, -2.4225e-02, 4.5838e-02],\n", + " [-2.0069e-02, 3.8338e-02, 5.8010e-02]],\n", + " \n", + " [[ 1.7602e-02, -5.2530e-02, 4.9331e-02],\n", + " [ 2.4509e-02, 2.3943e-02, -2.1774e-02],\n", + " [-5.7154e-02, 5.7090e-02, 3.7531e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 2.8630e-02, -4.8644e-04, -5.3822e-02],\n", + " [-1.1102e-02, 4.8524e-02, -2.7142e-02],\n", + " [-5.3463e-02, 3.5607e-02, -1.1110e-02]],\n", + " \n", + " [[ 4.7891e-02, -3.4098e-02, 3.6984e-02],\n", + " [ 3.9062e-02, -5.1119e-03, -3.3252e-02],\n", + " [ 5.5029e-02, 3.1092e-03, -5.9391e-03]],\n", + " \n", + " [[ 9.6775e-03, 2.2903e-02, -1.5971e-02],\n", + " [-2.6969e-02, 8.5069e-04, -2.5744e-02],\n", + " [ 2.7311e-03, -2.2119e-02, 2.4367e-03]]],\n", + " \n", + " \n", + " [[[ 1.5206e-02, -5.3504e-02, -1.8013e-02],\n", + " [-3.0064e-02, 3.2887e-02, 1.7612e-02],\n", + " [-4.2775e-02, -2.7335e-02, 5.1532e-02]],\n", + " \n", + " [[ 4.3325e-02, 1.2909e-03, 5.6831e-02],\n", + " [ 4.8283e-03, -4.4274e-02, 1.7624e-02],\n", + " [-7.2574e-03, 1.3743e-02, -5.1502e-02]],\n", + " \n", + " [[-1.4299e-02, -3.0024e-02, 4.9578e-02],\n", + " [-4.6143e-02, -3.0686e-02, -5.0727e-02],\n", + " [ 5.4766e-02, -1.4012e-02, 3.3267e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 5.2773e-02, 4.0294e-02, -2.6113e-02],\n", + " [-2.5069e-02, 4.3956e-02, -4.7841e-02],\n", + " [ 3.0924e-02, 1.5174e-02, -4.8323e-02]],\n", + " \n", + " [[-2.0325e-02, 3.2666e-02, 2.5174e-02],\n", + " [ 5.3775e-03, -3.2712e-02, -5.2251e-02],\n", + " [-2.7426e-02, -4.5502e-04, -3.2174e-02]],\n", + " \n", + " [[ 8.5780e-03, -5.2099e-02, 5.7285e-02],\n", + " [-5.0897e-02, -5.3995e-03, 4.2719e-02],\n", + " [-5.2257e-02, -7.9682e-03, 2.1848e-02]]],\n", + " \n", + " \n", + " [[[-2.8660e-02, 1.2392e-02, 2.9940e-02],\n", + " [-1.1170e-02, -1.7499e-02, -5.3951e-02],\n", + " [ 4.9738e-02, 2.7240e-02, 2.9588e-03]],\n", + " \n", + " [[ 3.7728e-02, -3.1084e-02, 4.6459e-02],\n", + " [-1.4961e-02, 1.6951e-02, -1.3976e-02],\n", + " [ 2.5768e-02, -3.8991e-02, -2.8738e-02]],\n", + " \n", + " [[ 1.6084e-02, -7.3174e-03, -4.5839e-02],\n", + " [-3.6029e-02, 1.5303e-02, -5.4380e-02],\n", + " [ 2.1913e-02, -4.4792e-02, 4.7973e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 5.4945e-02, -4.5541e-02, -1.5806e-02],\n", + " [-4.5216e-02, -2.0338e-02, 3.3373e-02],\n", + " [-1.8431e-02, 4.3953e-02, -4.8196e-02]],\n", + " \n", + " [[-3.6129e-02, 5.7705e-02, -1.2229e-02],\n", + " [-5.2801e-02, 7.0930e-03, -2.6721e-02],\n", + " [ 3.6720e-02, 7.5540e-04, 5.5401e-02]],\n", + " \n", + " [[-4.0386e-02, 1.6714e-02, 2.9246e-02],\n", + " [ 4.7033e-02, -8.7923e-03, 5.1267e-02],\n", + " [-3.1211e-02, -2.7036e-02, 4.0346e-02]]]])),\n", + " ('conv4.bias',\n", + " tensor([-0.0368, -0.0535, -0.0568, -0.0364, 0.0167, 0.0255, 0.0031, 0.0562,\n", + " -0.0575, 0.0470, -0.0128, 0.0159, -0.0387, 0.0517, -0.0508, -0.0104,\n", + " 0.0133, -0.0265, 0.0290, 0.0147, -0.0405, 0.0074, 0.0434, -0.0436,\n", + " -0.0295, 0.0417, 0.0324, 0.0316, 0.0143, -0.0312, 0.0463, 0.0441])),\n", + " ('fc1.weight',\n", + " tensor([[-0.0206, -0.0161, -0.0053, ..., -0.0257, -0.0314, -0.0105],\n", + " [-0.0229, 0.0030, 0.0239, ..., 0.0199, -0.0279, 0.0155]])),\n", + " ('fc1.bias', tensor([-0.5000, -1.2000]))])" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "o1.param_groups" + "m2 = m1.state_dict()\n", + "m2['fc1.bias'] = torch.Tensor([-0.5, -1.2])\n", + "m1.load_state_dict(m2)\n", + "m1.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'state': {},\n", + " 'param_groups': [{'lr': 0.6,\n", + " 'momentum': 0,\n", + " 'dampening': 0,\n", + " 'weight_decay': 0,\n", + " 'nesterov': False,\n", + " 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]}" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "o1.state_dict()" ] }, { @@ -3698,6 +5636,59 @@ "torch.tensor([1,2,3]).shape[0]" ] }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sizes: [0.2857142857142857, 0.2857142857142857, 0.4285714285714286]\n", + "2 3 \n", + "6 5 \n", + "1 4 7 \n" + ] + } + ], + "source": [ + "from decentralizepy.datasets.Partitioner import DataPartitioner\n", + "l = [1, 2, 3, 4, 5, 6, 7]\n", + "e = len(l) // 3\n", + "frac = e / len(l)\n", + "sizes = [frac] * 3\n", + "sizes[-1] += 1.0 - frac * 3\n", + "print(\"sizes: \", sizes)\n", + "\n", + "for i in range(3):\n", + " myPar = DataPartitioner(l, sizes).use(i)\n", + " for j in range(len(myPar)):\n", + " print(myPar.__getitem__(j), end=' ')\n", + " print()\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5\n" + ] + } + ], + "source": [ + "w = 1\n", + "p = 2\n", + "i = w/p\n", + "print(i)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/decentralizepy/communication/TCP.py b/src/decentralizepy/communication/TCP.py index 54a1af8..e78a3b1 100644 --- a/src/decentralizepy/communication/TCP.py +++ b/src/decentralizepy/communication/TCP.py @@ -15,6 +15,7 @@ class TCP(Communication): TCP Communication API """ + def addr(self, rank, machine_id): """ Returns TCP address of the process. diff --git a/src/decentralizepy/datasets/Celeba.py b/src/decentralizepy/datasets/Celeba.py index 0b43268..baf3434 100644 --- a/src/decentralizepy/datasets/Celeba.py +++ b/src/decentralizepy/datasets/Celeba.py @@ -138,10 +138,13 @@ class Celeba(Dataset): os.path.join(self.train_dir, cur_file) ) for cur_client in clients: + logging.debug("Got data of client: {}".format(cur_client)) self.clients.append(cur_client) my_train_data["x"].extend(self.process_x(train_data[cur_client]["x"])) my_train_data["y"].extend(train_data[cur_client]["y"]) self.num_samples.append(len(train_data[cur_client]["y"])) + + logging.debug("Initial shape of x: {}".format(np.array(my_train_data["x"], dtype=np.dtype("float32")).shape)) self.train_x = ( np.array(my_train_data["x"], dtype=np.dtype("float32")) .reshape(-1, IMAGE_DIM, IMAGE_DIM, CHANNELS) @@ -409,6 +412,7 @@ class CNN(Model): Class for a CNN Model for Celeba """ + def __init__(self): """ Constructor. Instantiates the CNN Model diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py index 5d0dc27..c0360c5 100644 --- a/src/decentralizepy/datasets/Femnist.py +++ b/src/decentralizepy/datasets/Femnist.py @@ -413,6 +413,7 @@ class CNN(Model): Class for a CNN Model for FEMNIST """ + def __init__(self): """ Constructor. Instantiates the CNN Model diff --git a/src/decentralizepy/models/Model.py b/src/decentralizepy/models/Model.py index da83402..cf073bc 100644 --- a/src/decentralizepy/models/Model.py +++ b/src/decentralizepy/models/Model.py @@ -7,6 +7,7 @@ class Model(nn.Module): More fields can be added here """ + def __init__(self): """ Constructor diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index 183fbfe..eb7200a 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -42,42 +42,20 @@ class Node: plt.title(title) plt.savefig(filename) - def instantiate( - self, - rank: int, - machine_id: int, - mapping: Mapping, - graph: Graph, - config, - iterations=1, - log_dir=".", - log_level=logging.INFO, - test_after=5, - *args - ): + def init_log(self, log_dir, rank, log_level, force=True): """ - Construct objects. + Instantiate Logging. Parameters ---------- - rank : int - Rank of process local to the machine - machine_id : int - Machine ID on which the process in running - n_procs_local : int - Number of processes on current machine - mapping : decentralizepy.mappings - The object containing the mapping rank <--> uid - graph : decentralizepy.graphs - The object containing the global graph - config : dict - A dictionary of configurations. log_dir : str Logging directory + rank : rank : int + Rank of process local to the machine log_level : logging.Level One of DEBUG, INFO, WARNING, ERROR, CRITICAL - args : optional - Other arguments + force : bool + Argument to logging.basicConfig() """ log_file = os.path.join(log_dir, str(rank) + ".log") @@ -88,20 +66,49 @@ class Node: force=True, ) - logging.info("Started process.") + def cache_fields( + self, rank, machine_id, mapping, graph, iterations, log_dir, test_after + ): + """ + Instantiate object field with arguments. + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + log_dir : str + Logging directory + """ self.rank = rank self.machine_id = machine_id self.graph = graph self.mapping = mapping self.uid = self.mapping.get_uid(rank, machine_id) self.log_dir = log_dir + self.iterations = iterations + self.test_after = test_after logging.debug("Rank: %d", self.rank) logging.debug("type(graph): %s", str(type(self.rank))) logging.debug("type(mapping): %s", str(type(self.mapping))) - dataset_configs = config["DATASET"] + def init_dataset_model(self, dataset_configs): + """ + Instantiate dataset and model from config. + + Parameters + ---------- + dataset_configs : dict + Python dict containing dataset config params + + """ dataset_module = importlib.import_module(dataset_configs["dataset_package"]) self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"]) self.dataset_params = utils.remove_keys( @@ -116,7 +123,16 @@ class Node: self.model_class = getattr(dataset_module, dataset_configs["model_class"]) self.model = self.model_class() - optimizer_configs = config["OPTIMIZER_PARAMS"] + def init_optimizer(self, optimizer_configs): + """ + Instantiate optimizer from config. + + Parameters + ---------- + optimizer_configs : dict + Python dict containing optimizer config params + + """ optimizer_module = importlib.import_module( optimizer_configs["optimizer_package"] ) @@ -130,7 +146,16 @@ class Node: self.model.parameters(), **self.optimizer_params ) - train_configs = config["TRAIN_PARAMS"] + def init_trainer(self, train_configs): + """ + Instantiate training module and loss from config. + + Parameters + ---------- + train_configs : dict + Python dict containing training config params + + """ train_module = importlib.import_module(train_configs["training_package"]) train_class = getattr(train_module, train_configs["training_class"]) @@ -155,7 +180,16 @@ class Node: self.model, self.optimizer, self.loss, **train_params ) - comm_configs = config["COMMUNICATION"] + def init_comm(self, comm_configs): + """ + Instantiate communication module from config. + + Parameters + ---------- + comm_configs : dict + Python dict containing communication config params + + """ comm_module = importlib.import_module(comm_configs["comm_package"]) comm_class = getattr(comm_module, comm_configs["comm_class"]) comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"]) @@ -163,7 +197,16 @@ class Node: self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params ) - sharing_configs = config["SHARING"] + def init_sharing(self, sharing_configs): + """ + Instantiate sharing module from config. + + Parameters + ---------- + sharing_configs : dict + Python dict containing sharing config params + + """ sharing_package = importlib.import_module(sharing_configs["sharing_package"]) sharing_class = getattr(sharing_package, sharing_configs["sharing_class"]) sharing_params = utils.remove_keys( @@ -181,9 +224,53 @@ class Node: **sharing_params ) - self.iterations = iterations - self.test_after = test_after - self.log_dir = log_dir + def instantiate( + self, + rank: int, + machine_id: int, + mapping: Mapping, + graph: Graph, + config, + iterations=1, + log_dir=".", + log_level=logging.INFO, + test_after=5, + *args + ): + """ + Construct objects. + + Parameters + ---------- + rank : int + Rank of process local to the machine + machine_id : int + Machine ID on which the process in running + mapping : decentralizepy.mappings + The object containing the mapping rank <--> uid + graph : decentralizepy.graphs + The object containing the global graph + config : dict + A dictionary of configurations. + log_dir : str + Logging directory + log_level : logging.Level + One of DEBUG, INFO, WARNING, ERROR, CRITICAL + args : optional + Other arguments + + """ + logging.info("Started process.") + + self.cache_fields( + rank, machine_id, mapping, graph, iterations, log_dir, test_after + ) + self.init_log(log_dir, rank, log_level) + self.init_dataset_model(config["DATASET"]) + self.init_optimizer(config["OPTIMIZER_PARAMS"]) + self.init_trainer(config["TRAIN_PARAMS"]) + self.init_comm(config["COMMUNICATION"]) + self.init_sharing(config["SHARING"]) def run(self): """ diff --git a/src/decentralizepy/sharing/GrowingAlpha.py b/src/decentralizepy/sharing/GrowingAlpha.py index 4979828..7fe7bf5 100644 --- a/src/decentralizepy/sharing/GrowingAlpha.py +++ b/src/decentralizepy/sharing/GrowingAlpha.py @@ -8,6 +8,7 @@ class GrowingAlpha(PartialModel): This class implements the basic growing partial model sharing using a linear function. """ + def __init__( self, rank, diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py index 7addc0e..7eef218 100644 --- a/src/decentralizepy/sharing/PartialModel.py +++ b/src/decentralizepy/sharing/PartialModel.py @@ -13,6 +13,7 @@ class PartialModel(Sharing): This class implements the vanilla version of partial model sharing. """ + def __init__( self, rank, diff --git a/src/decentralizepy/training/GradientAccumulator.py b/src/decentralizepy/training/GradientAccumulator.py index e4feff2..718e793 100644 --- a/src/decentralizepy/training/GradientAccumulator.py +++ b/src/decentralizepy/training/GradientAccumulator.py @@ -8,6 +8,7 @@ class GradientAccumulator(Training): This class implements the training module which also accumulates gradients of steps in a list. """ + def __init__( self, model, diff --git a/src/decentralizepy/training/Training.py b/src/decentralizepy/training/Training.py index 52a8e9d..bd80773 100644 --- a/src/decentralizepy/training/Training.py +++ b/src/decentralizepy/training/Training.py @@ -109,7 +109,7 @@ class Training: self.optimizer.step() return loss_val.item() - def train_full(self, trainset): + def train_full(self, dataset): """ One training iteration, goes through the entire dataset @@ -120,9 +120,12 @@ class Training: """ for epoch in range(self.rounds): + trainset = dataset.get_trainset(self.batch_size, self.shuffle) epoch_loss = 0.0 count = 0 for data, target in trainset: + logging.info("Starting minibatch {} with num_samples: {}".format(count, len(data))) + logging.info("Classes: {}".format(target)) epoch_loss += self.trainstep(data, target) count += 1 logging.info("Epoch: {} loss: {}".format(epoch, epoch_loss / count)) @@ -137,13 +140,13 @@ class Training: The training dataset. Should implement get_trainset(batch_size, shuffle) """ - trainset = dataset.get_trainset(self.batch_size, self.shuffle) if self.full_epochs: - self.train_full(trainset) + self.train_full(dataset) else: iter_loss = 0.0 count = 0 + trainset = dataset.get_trainset(self.batch_size, self.shuffle) while count < self.rounds: for data, target in trainset: iter_loss += self.trainstep(data, target) diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py index eac1e17..c6bf149 100644 --- a/src/decentralizepy/utils.py +++ b/src/decentralizepy/utils.py @@ -16,7 +16,7 @@ def conditional_value(var, nul, default): The null value. Assigns default if var == nul default : any The default value - + Returns ------- type(var) -- GitLab