From b6096c2e808ab732b5824af510f5133d7923d921 Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Sat, 19 Mar 2022 19:39:39 +0100
Subject: [PATCH] option for when to evaluate on the trainset

---
 eval/run_all.sh                 |  3 ++-
 eval/testing.py                 |  1 +
 src/decentralizepy/node/Node.py | 44 +++++++++++++++++++++++++--------
 src/decentralizepy/utils.py     |  1 +
 4 files changed, 38 insertions(+), 11 deletions(-)

diff --git a/eval/run_all.sh b/eval/run_all.sh
index c9e5714..d5d0c04 100755
--- a/eval/run_all.sh
+++ b/eval/run_all.sh
@@ -10,6 +10,7 @@ config_file=~/tmp/config.ini
 procs_per_machine=16
 machines=6
 iterations=5
+train_evaluate_after=5
 test_after=21 # we do not test
 eval_file=testing.py
 log_level=INFO
@@ -32,7 +33,7 @@ do
   mkdir -p $log_dir
   cp $i $config_file
   $python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
-  $env_python $eval_file -ro 0 -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+  $env_python $eval_file -ro 0 -tea $train_evaluate_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
   echo $i is done
   sleep 3
   echo end of sleep
diff --git a/eval/testing.py b/eval/testing.py
index abd6333..efb80df 100644
--- a/eval/testing.py
+++ b/eval/testing.py
@@ -62,6 +62,7 @@ if __name__ == "__main__":
             args.log_dir,
             log_level[args.log_level],
             args.test_after,
+            args.train_evaluate_after,
             args.reset_optimizer,
         ],
     )
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 7854c38..a543261 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -77,6 +77,7 @@ class Node:
         iterations,
         log_dir,
         test_after,
+        train_evaluate_after,
         reset_optimizer,
     ):
         """
@@ -96,6 +97,10 @@ class Node:
             Number of iterations (communication steps) for which the model should be trained
         log_dir : str
             Logging directory
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
         reset_optimizer : int
             1 if optimizer should be reset every communication round, else 0
 
@@ -108,6 +113,7 @@ class Node:
         self.log_dir = log_dir
         self.iterations = iterations
         self.test_after = test_after
+        self.train_evaluate_after = train_evaluate_after
         self.reset_optimizer = reset_optimizer
 
         logging.debug("Rank: %d", self.rank)
@@ -262,6 +268,7 @@ class Node:
         log_dir=".",
         log_level=logging.INFO,
         test_after=5,
+        train_evaluate_after = 1,
         reset_optimizer=1,
         *args
     ):
@@ -286,6 +293,10 @@ class Node:
             Logging directory
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
         reset_optimizer : int
             1 if optimizer should be reset every communication round, else 0
         args : optional
@@ -302,6 +313,7 @@ class Node:
             iterations,
             log_dir,
             test_after,
+            train_evaluate_after,
             reset_optimizer,
         )
         self.init_log(log_dir, rank, log_level)
@@ -319,6 +331,7 @@ class Node:
         self.testset = self.dataset.get_testset()
         self.communication.connect_neighbors(self.graph.neighbors(self.uid))
         rounds_to_test = self.test_after
+        rounds_to_train_evaluate = self.train_evaluate_after
 
         for iteration in range(self.iterations):
             logging.info("Starting training iteration: %d", iteration)
@@ -332,7 +345,6 @@ class Node:
                 )  # Reset optimizer state
                 self.trainer.reset_optimizer(self.optimizer)
 
-            loss_after_sharing = self.trainer.eval_loss(self.dataset)
 
             if iteration:
                 with open(
@@ -352,7 +364,6 @@ class Node:
                     "grad_std": {},
                 }
 
-            results_dict["train_loss"][iteration + 1] = loss_after_sharing
             results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
 
             if hasattr(self.sharing, "total_meta"):
@@ -365,14 +376,21 @@ class Node:
                 results_dict["grad_mean"][iteration + 1] = self.sharing.mean
             if hasattr(self.sharing, "std"):
                 results_dict["grad_std"][iteration + 1] = self.sharing.std
-
-            self.save_plot(
-                results_dict["train_loss"],
-                "train_loss",
-                "Training Loss",
-                "Communication Rounds",
-                os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
-            )
+            
+            rounds_to_train_evaluate -= 1
+
+            if rounds_to_test == 0:
+                logging.info("Evaluating on train set.")
+                rounds_to_train_evaluate = self.train_evaluate_after
+                loss_after_sharing = self.trainer.eval_loss(self.dataset)
+                results_dict["train_loss"][iteration + 1] = loss_after_sharing
+                self.save_plot(
+                    results_dict["train_loss"],
+                    "train_loss",
+                    "Training Loss",
+                    "Communication Rounds",
+                    os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
+                )
 
             rounds_to_test -= 1
 
@@ -417,6 +435,7 @@ class Node:
         log_dir=".",
         log_level=logging.INFO,
         test_after=5,
+        train_evaluate_after=1,
         reset_optimizer=1,
         *args
     ):
@@ -453,6 +472,10 @@ class Node:
             Logging directory
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        test_after : int
+            Number of iterations after which the test loss and accuracy arecalculated
+        train_evaluate_after : int
+            Number of iterations after which the train loss is calculated
         reset_optimizer : int
             1 if optimizer should be reset every communication round, else 0
         args : optional
@@ -473,6 +496,7 @@ class Node:
             log_dir,
             log_level,
             test_after,
+            train_evaluate_after,
             reset_optimizer,
             *args
         )
diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py
index 82f2068..3ca85f5 100644
--- a/src/decentralizepy/utils.py
+++ b/src/decentralizepy/utils.py
@@ -75,6 +75,7 @@ def get_args():
     parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges")
     parser.add_argument("-gt", "--graph_type", type=str, default="edges")
     parser.add_argument("-ta", "--test_after", type=int, default=5)
+    parser.add_argument("-tea", "--train_evaluate_after", type=int, default=1)
     parser.add_argument("-ro", "--reset_optimizer", type=int, default=1)
 
     args = parser.parse_args()
-- 
GitLab