From 09a3a71a86d4a975517a65616cc93d0433f107a1 Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Sun, 26 Dec 2021 11:21:10 +0100
Subject: [PATCH] Add more config, dataset shapes logging

---
 eval/config_celeba_100.ini             | 33 ++++++++++++++++++++++++++
 src/decentralizepy/datasets/Celeba.py  |  8 +++----
 src/decentralizepy/datasets/Femnist.py |  8 +++----
 3 files changed, 41 insertions(+), 8 deletions(-)
 create mode 100644 eval/config_celeba_100.ini

diff --git a/eval/config_celeba_100.ini b/eval/config_celeba_100.ini
new file mode 100644
index 0000000..dcaff4f
--- /dev/null
+++ b/eval/config_celeba_100.ini
@@ -0,0 +1,33 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Celeba
+dataset_class = Celeba
+model_class = CNN
+n_procs = 96
+images_dir = /home/risharma/leaf/data/celeba/data/raw/img_align_celeba
+train_dir = /home/risharma/leaf/data/celeba/per_user_data/train
+test_dir = /home/risharma/leaf/data/celeba/data/test
+; python list of fractions below
+sizes = 
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = Adam
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+epochs_per_round = 5
+batch_size = 512
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCP
+comm_class = TCP
+addresses_filepath = ip_addr_6Machines.json
+
+[SHARING]
+sharing_package = decentralizepy.sharing.Sharing
+sharing_class = Sharing
diff --git a/src/decentralizepy/datasets/Celeba.py b/src/decentralizepy/datasets/Celeba.py
index 687d58f..748f596 100644
--- a/src/decentralizepy/datasets/Celeba.py
+++ b/src/decentralizepy/datasets/Celeba.py
@@ -115,8 +115,8 @@ class Celeba(Dataset):
             .transpose(0, 3, 1, 2)  # Channel first: torch
         )
         self.train_y = np.array(my_train_data["y"], dtype=np.dtype("int64")).reshape(-1)
-        logging.debug("train_x.shape: %s", str(self.train_x.shape))
-        logging.debug("train_y.shape: %s", str(self.train_y.shape))
+        logging.info("train_x.shape: %s", str(self.train_x.shape))
+        logging.info("train_y.shape: %s", str(self.train_y.shape))
         assert self.train_x.shape[0] == self.train_y.shape[0]
         assert self.train_x.shape[0] > 0
 
@@ -134,8 +134,8 @@ class Celeba(Dataset):
             .transpose(0, 3, 1, 2)
         )
         self.test_y = np.array(test_y, dtype=np.dtype("int64")).reshape(-1)
-        logging.debug("test_x.shape: %s", str(self.test_x.shape))
-        logging.debug("test_y.shape: %s", str(self.test_y.shape))
+        logging.info("test_x.shape: %s", str(self.test_x.shape))
+        logging.info("test_y.shape: %s", str(self.test_y.shape))
         assert self.test_x.shape[0] == self.test_y.shape[0]
         assert self.test_x.shape[0] > 0
 
diff --git a/src/decentralizepy/datasets/Femnist.py b/src/decentralizepy/datasets/Femnist.py
index 3bdf6cd..7370a0d 100644
--- a/src/decentralizepy/datasets/Femnist.py
+++ b/src/decentralizepy/datasets/Femnist.py
@@ -113,8 +113,8 @@ class Femnist(Dataset):
             .transpose(0, 3, 1, 2)
         )
         self.train_y = np.array(my_train_data["y"], dtype=np.dtype("int64")).reshape(-1)
-        logging.debug("train_x.shape: %s", str(self.train_x.shape))
-        logging.debug("train_y.shape: %s", str(self.train_y.shape))
+        logging.info("train_x.shape: %s", str(self.train_x.shape))
+        logging.info("train_y.shape: %s", str(self.train_y.shape))
         assert self.train_x.shape[0] == self.train_y.shape[0]
         assert self.train_x.shape[0] > 0
 
@@ -134,8 +134,8 @@ class Femnist(Dataset):
             .transpose(0, 3, 1, 2)
         )
         self.test_y = np.array(test_y, dtype=np.dtype("int64")).reshape(-1)
-        logging.debug("test_x.shape: %s", str(self.test_x.shape))
-        logging.debug("test_y.shape: %s", str(self.test_y.shape))
+        logging.info("test_x.shape: %s", str(self.test_x.shape))
+        logging.info("test_y.shape: %s", str(self.test_y.shape))
         assert self.test_x.shape[0] == self.test_y.shape[0]
         assert self.test_x.shape[0] > 0
 
-- 
GitLab