From 4627acd6f2e0c7ce3bd637e114587bd9065c68d1 Mon Sep 17 00:00:00 2001
From: Rishi Sharma <rishi.sharma@epfl.ch>
Date: Tue, 8 Mar 2022 21:24:52 +0100
Subject: [PATCH] Models start at the same point; optional optimizer reset

---
 eval/testing.py                 |  1 +
 src/decentralizepy/node/Node.py | 43 +++++++++++++++++++++++++++------
 src/decentralizepy/utils.py     |  2 ++
 3 files changed, 39 insertions(+), 7 deletions(-)

diff --git a/eval/testing.py b/eval/testing.py
index f24e65b..f60e291 100644
--- a/eval/testing.py
+++ b/eval/testing.py
@@ -62,5 +62,6 @@ if __name__ == "__main__":
             args.log_dir,
             log_level[args.log_level],
             args.test_after,
+            args.reset_optimzer,
         ],
     )
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 9f1b7da..48bb1ac 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -69,7 +69,15 @@ class Node:
         )
 
     def cache_fields(
-        self, rank, machine_id, mapping, graph, iterations, log_dir, test_after
+        self,
+        rank,
+        machine_id,
+        mapping,
+        graph,
+        iterations,
+        log_dir,
+        test_after,
+        reset_optimizer,
     ):
         """
         Instantiate object field with arguments.
@@ -86,6 +94,8 @@ class Node:
             The object containing the global graph
         log_dir : str
             Logging directory
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
 
         """
         self.rank = rank
@@ -96,6 +106,7 @@ class Node:
         self.log_dir = log_dir
         self.iterations = iterations
         self.test_after = test_after
+        self.reset_optimizer = reset_optimizer
 
         logging.debug("Rank: %d", self.rank)
         logging.debug("type(graph): %s", str(type(self.rank)))
@@ -113,8 +124,10 @@ class Node:
         """
         dataset_module = importlib.import_module(dataset_configs["dataset_package"])
         self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
+        torch.manual_seed(dataset_configs["random_seed"])
         self.dataset_params = utils.remove_keys(
-            dataset_configs, ["dataset_package", "dataset_class", "model_class"]
+            dataset_configs,
+            ["dataset_package", "dataset_class", "model_class", "random_seed"],
         )
         self.dataset = self.dataset_class(
             self.rank, self.machine_id, self.mapping, **self.dataset_params
@@ -244,6 +257,7 @@ class Node:
         log_dir=".",
         log_level=logging.INFO,
         test_after=5,
+        reset_optimizer=1,
         *args
     ):
         """
@@ -265,6 +279,8 @@ class Node:
             Logging directory
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
         args : optional
             Other arguments
 
@@ -272,7 +288,14 @@ class Node:
         logging.info("Started process.")
 
         self.cache_fields(
-            rank, machine_id, mapping, graph, iterations, log_dir, test_after
+            rank,
+            machine_id,
+            mapping,
+            graph,
+            iterations,
+            log_dir,
+            test_after,
+            reset_optimizer,
         )
         self.init_log(log_dir, rank, log_level)
         self.init_dataset_model(config["DATASET"])
@@ -295,10 +318,12 @@ class Node:
             self.trainer.train(self.dataset)
 
             self.sharing.step()
-            self.optimizer = self.optimizer_class(
-                self.model.parameters(), **self.optimizer_params
-            )  # Reset optimizer state
-            self.trainer.reset_optimizer(self.optimizer)
+
+            if self.reset_optimizer:
+                self.optimizer = self.optimizer_class(
+                    self.model.parameters(), **self.optimizer_params
+                )  # Reset optimizer state
+                self.trainer.reset_optimizer(self.optimizer)
 
             loss_after_sharing = self.trainer.eval_loss(self.dataset)
 
@@ -385,6 +410,7 @@ class Node:
         log_dir=".",
         log_level=logging.INFO,
         test_after=5,
+        reset_optimizer=1,
         *args
     ):
         """
@@ -418,6 +444,8 @@ class Node:
             Logging directory
         log_level : logging.Level
             One of DEBUG, INFO, WARNING, ERROR, CRITICAL
+        reset_optimizer : int
+            1 if optimizer should be reset every communication round, else 0
         args : optional
             Other arguments
 
@@ -436,6 +464,7 @@ class Node:
             log_dir,
             log_level,
             test_after,
+            reset_optimizer,
             *args
         )
         logging.info(
diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py
index c6bf149..996e4bc 100644
--- a/src/decentralizepy/utils.py
+++ b/src/decentralizepy/utils.py
@@ -75,6 +75,7 @@ def get_args():
     parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges")
     parser.add_argument("-gt", "--graph_type", type=str, default="edges")
     parser.add_argument("-ta", "--test_after", type=int, default=5)
+    parser.add_argument("-ro", "--reset_optimizer", type=int, default=1)
 
     args = parser.parse_args()
     return args
@@ -103,6 +104,7 @@ def write_args(args, path):
         "graph_file": args.graph_file,
         "graph_type": args.graph_type,
         "test_after": args.test_after,
+        "reset_optimizer": args.reset_optimizer,
     }
     with open(os.path.join(path, "args.json"), "w") as of:
         json.dump(data, of)
-- 
GitLab