Skip to content
Snippets Groups Projects
Commit 4627acd6 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Models start at the same point; optional optimizer reset

parent 98b1561d
No related branches found
No related tags found
No related merge requests found
......@@ -62,5 +62,6 @@ if __name__ == "__main__":
args.log_dir,
log_level[args.log_level],
args.test_after,
args.reset_optimzer,
],
)
......@@ -69,7 +69,15 @@ class Node:
)
def cache_fields(
self, rank, machine_id, mapping, graph, iterations, log_dir, test_after
self,
rank,
machine_id,
mapping,
graph,
iterations,
log_dir,
test_after,
reset_optimizer,
):
"""
Instantiate object field with arguments.
......@@ -86,6 +94,8 @@ class Node:
The object containing the global graph
log_dir : str
Logging directory
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
"""
self.rank = rank
......@@ -96,6 +106,7 @@ class Node:
self.log_dir = log_dir
self.iterations = iterations
self.test_after = test_after
self.reset_optimizer = reset_optimizer
logging.debug("Rank: %d", self.rank)
logging.debug("type(graph): %s", str(type(self.rank)))
......@@ -113,8 +124,10 @@ class Node:
"""
dataset_module = importlib.import_module(dataset_configs["dataset_package"])
self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
torch.manual_seed(dataset_configs["random_seed"])
self.dataset_params = utils.remove_keys(
dataset_configs, ["dataset_package", "dataset_class", "model_class"]
dataset_configs,
["dataset_package", "dataset_class", "model_class", "random_seed"],
)
self.dataset = self.dataset_class(
self.rank, self.machine_id, self.mapping, **self.dataset_params
......@@ -244,6 +257,7 @@ class Node:
log_dir=".",
log_level=logging.INFO,
test_after=5,
reset_optimizer=1,
*args
):
"""
......@@ -265,6 +279,8 @@ class Node:
Logging directory
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
args : optional
Other arguments
......@@ -272,7 +288,14 @@ class Node:
logging.info("Started process.")
self.cache_fields(
rank, machine_id, mapping, graph, iterations, log_dir, test_after
rank,
machine_id,
mapping,
graph,
iterations,
log_dir,
test_after,
reset_optimizer,
)
self.init_log(log_dir, rank, log_level)
self.init_dataset_model(config["DATASET"])
......@@ -295,10 +318,12 @@ class Node:
self.trainer.train(self.dataset)
self.sharing.step()
self.optimizer = self.optimizer_class(
self.model.parameters(), **self.optimizer_params
) # Reset optimizer state
self.trainer.reset_optimizer(self.optimizer)
if self.reset_optimizer:
self.optimizer = self.optimizer_class(
self.model.parameters(), **self.optimizer_params
) # Reset optimizer state
self.trainer.reset_optimizer(self.optimizer)
loss_after_sharing = self.trainer.eval_loss(self.dataset)
......@@ -385,6 +410,7 @@ class Node:
log_dir=".",
log_level=logging.INFO,
test_after=5,
reset_optimizer=1,
*args
):
"""
......@@ -418,6 +444,8 @@ class Node:
Logging directory
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
args : optional
Other arguments
......@@ -436,6 +464,7 @@ class Node:
log_dir,
log_level,
test_after,
reset_optimizer,
*args
)
logging.info(
......
......@@ -75,6 +75,7 @@ def get_args():
parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges")
parser.add_argument("-gt", "--graph_type", type=str, default="edges")
parser.add_argument("-ta", "--test_after", type=int, default=5)
parser.add_argument("-ro", "--reset_optimizer", type=int, default=1)
args = parser.parse_args()
return args
......@@ -103,6 +104,7 @@ def write_args(args, path):
"graph_file": args.graph_file,
"graph_type": args.graph_type,
"test_after": args.test_after,
"reset_optimizer": args.reset_optimizer,
}
with open(os.path.join(path, "args.json"), "w") as of:
json.dump(data, of)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment