Skip to content
Snippets Groups Projects
Commit 2367e854 authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

cte only starts when the flag is set

parent bf5f03d1
No related branches found
No related tags found
1 merge request!13Only start star topology when needed
......@@ -243,6 +243,17 @@ class Node:
comm_class = getattr(comm_module, comm_configs["comm_class"])
comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
self.addresses_filepath = comm_params.get("addresses_filepath", None)
if self.centralized_test_eval:
self.testing_comm = TCP(
self.rank,
self.machine_id,
self.mapping,
self.star.n_procs,
self.addresses_filepath,
offset=self.star.n_procs,
)
self.testing_comm.connect_neighbors(self.star.neighbors(self.uid))
self.communication = comm_class(
self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
)
......@@ -360,16 +371,6 @@ class Node:
self.testset = self.dataset.get_testset()
self.communication.connect_neighbors(self.graph.neighbors(self.uid))
rounds_to_test = self.test_after
testing_comm = TCP(
self.rank,
self.machine_id,
self.mapping,
self.star.n_procs,
self.addresses_filepath,
offset=self.star.n_procs,
)
testing_comm.connect_neighbors(self.star.neighbors(self.uid))
rounds_to_train_evaluate = self.train_evaluate_after
global_epoch = 1
change = 1
......@@ -387,19 +388,20 @@ class Node:
**dataset_params_copy
)
dataset = self.whole_dataset
tthelper = TrainTestHelper(
dataset, # self.whole_dataset,
# self.model_test, # todo: this only works if eval_train is set to false
self.model,
self.loss,
self.weights_store_dir,
self.mapping.get_n_procs(),
self.trainer,
testing_comm,
self.star,
self.threads_per_proc,
eval_train=self.centralized_train_eval,
)
if self.centralized_test_eval:
tthelper = TrainTestHelper(
dataset, # self.whole_dataset,
# self.model_test, # todo: this only works if eval_train is set to false
self.model,
self.loss,
self.weights_store_dir,
self.mapping.get_n_procs(),
self.trainer,
self.testing_comm,
self.star,
self.threads_per_proc,
eval_train=self.centralized_train_eval,
)
for iteration in range(self.iterations):
logging.info("Starting training iteration: %d", iteration)
......@@ -475,8 +477,8 @@ class Node:
if trl is not None:
results_dict["train_loss"][iteration + 1] = trl
else:
testing_comm.send(0, self.model.get_weights())
sender, data = testing_comm.receive()
self.testing_comm.send(0, self.model.get_weights())
sender, data = self.testing_comm.receive()
assert sender == 0 and data == "finished"
else:
logging.info("Evaluating on test set.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment