diff --git a/eval/testing.py b/eval/testing.py index f24e65b6b2b05ecca4e641b0509b68d8c9ce5976..f60e2919f997bac33464be9fd64b71ed30f1ab94 100644 --- a/eval/testing.py +++ b/eval/testing.py @@ -62,5 +62,6 @@ if __name__ == "__main__": args.log_dir, log_level[args.log_level], args.test_after, + args.reset_optimzer, ], ) diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index 9f1b7dae9357a988edc1d0e526a461e7b3abcb42..48bb1acfce05a09c27a93cc8d17b638f3f329191 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -69,7 +69,15 @@ class Node: ) def cache_fields( - self, rank, machine_id, mapping, graph, iterations, log_dir, test_after + self, + rank, + machine_id, + mapping, + graph, + iterations, + log_dir, + test_after, + reset_optimizer, ): """ Instantiate object field with arguments. @@ -86,6 +94,8 @@ class Node: The object containing the global graph log_dir : str Logging directory + reset_optimizer : int + 1 if optimizer should be reset every communication round, else 0 """ self.rank = rank @@ -96,6 +106,7 @@ class Node: self.log_dir = log_dir self.iterations = iterations self.test_after = test_after + self.reset_optimizer = reset_optimizer logging.debug("Rank: %d", self.rank) logging.debug("type(graph): %s", str(type(self.rank))) @@ -113,8 +124,10 @@ class Node: """ dataset_module = importlib.import_module(dataset_configs["dataset_package"]) self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"]) + torch.manual_seed(dataset_configs["random_seed"]) self.dataset_params = utils.remove_keys( - dataset_configs, ["dataset_package", "dataset_class", "model_class"] + dataset_configs, + ["dataset_package", "dataset_class", "model_class", "random_seed"], ) self.dataset = self.dataset_class( self.rank, self.machine_id, self.mapping, **self.dataset_params @@ -244,6 +257,7 @@ class Node: log_dir=".", log_level=logging.INFO, test_after=5, + reset_optimizer=1, *args ): """ @@ -265,6 +279,8 @@ class Node: Logging directory log_level : logging.Level One of DEBUG, INFO, WARNING, ERROR, CRITICAL + reset_optimizer : int + 1 if optimizer should be reset every communication round, else 0 args : optional Other arguments @@ -272,7 +288,14 @@ class Node: logging.info("Started process.") self.cache_fields( - rank, machine_id, mapping, graph, iterations, log_dir, test_after + rank, + machine_id, + mapping, + graph, + iterations, + log_dir, + test_after, + reset_optimizer, ) self.init_log(log_dir, rank, log_level) self.init_dataset_model(config["DATASET"]) @@ -295,10 +318,12 @@ class Node: self.trainer.train(self.dataset) self.sharing.step() - self.optimizer = self.optimizer_class( - self.model.parameters(), **self.optimizer_params - ) # Reset optimizer state - self.trainer.reset_optimizer(self.optimizer) + + if self.reset_optimizer: + self.optimizer = self.optimizer_class( + self.model.parameters(), **self.optimizer_params + ) # Reset optimizer state + self.trainer.reset_optimizer(self.optimizer) loss_after_sharing = self.trainer.eval_loss(self.dataset) @@ -385,6 +410,7 @@ class Node: log_dir=".", log_level=logging.INFO, test_after=5, + reset_optimizer=1, *args ): """ @@ -418,6 +444,8 @@ class Node: Logging directory log_level : logging.Level One of DEBUG, INFO, WARNING, ERROR, CRITICAL + reset_optimizer : int + 1 if optimizer should be reset every communication round, else 0 args : optional Other arguments @@ -436,6 +464,7 @@ class Node: log_dir, log_level, test_after, + reset_optimizer, *args ) logging.info( diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py index c6bf14904561a428bc24a9b252f522313fd327b0..996e4bc4316f3590da1dcfcbc287df81d52fa9b7 100644 --- a/src/decentralizepy/utils.py +++ b/src/decentralizepy/utils.py @@ -75,6 +75,7 @@ def get_args(): parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges") parser.add_argument("-gt", "--graph_type", type=str, default="edges") parser.add_argument("-ta", "--test_after", type=int, default=5) + parser.add_argument("-ro", "--reset_optimizer", type=int, default=1) args = parser.parse_args() return args @@ -103,6 +104,7 @@ def write_args(args, path): "graph_file": args.graph_file, "graph_type": args.graph_type, "test_after": args.test_after, + "reset_optimizer": args.reset_optimizer, } with open(os.path.join(path, "args.json"), "w") as of: json.dump(data, of)