Skip to content
Snippets Groups Projects
Commit b6096c2e authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

option for when to evaluate on the trainset

parent fc6ee11c
No related branches found
No related tags found
1 merge request!3FFT Wavelets and more
......@@ -10,6 +10,7 @@ config_file=~/tmp/config.ini
procs_per_machine=16
machines=6
iterations=5
train_evaluate_after=5
test_after=21 # we do not test
eval_file=testing.py
log_level=INFO
......@@ -32,7 +33,7 @@ do
mkdir -p $log_dir
cp $i $config_file
$python_bin/crudini --set $config_file COMMUNICATION addresses_filepath $ip_machines
$env_python $eval_file -ro 0 -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
$env_python $eval_file -ro 0 -tea $train_evaluate_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
echo $i is done
sleep 3
echo end of sleep
......
......@@ -62,6 +62,7 @@ if __name__ == "__main__":
args.log_dir,
log_level[args.log_level],
args.test_after,
args.train_evaluate_after,
args.reset_optimizer,
],
)
......@@ -77,6 +77,7 @@ class Node:
iterations,
log_dir,
test_after,
train_evaluate_after,
reset_optimizer,
):
"""
......@@ -96,6 +97,10 @@ class Node:
Number of iterations (communication steps) for which the model should be trained
log_dir : str
Logging directory
test_after : int
Number of iterations after which the test loss and accuracy arecalculated
train_evaluate_after : int
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
......@@ -108,6 +113,7 @@ class Node:
self.log_dir = log_dir
self.iterations = iterations
self.test_after = test_after
self.train_evaluate_after = train_evaluate_after
self.reset_optimizer = reset_optimizer
logging.debug("Rank: %d", self.rank)
......@@ -262,6 +268,7 @@ class Node:
log_dir=".",
log_level=logging.INFO,
test_after=5,
train_evaluate_after = 1,
reset_optimizer=1,
*args
):
......@@ -286,6 +293,10 @@ class Node:
Logging directory
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
test_after : int
Number of iterations after which the test loss and accuracy arecalculated
train_evaluate_after : int
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
args : optional
......@@ -302,6 +313,7 @@ class Node:
iterations,
log_dir,
test_after,
train_evaluate_after,
reset_optimizer,
)
self.init_log(log_dir, rank, log_level)
......@@ -319,6 +331,7 @@ class Node:
self.testset = self.dataset.get_testset()
self.communication.connect_neighbors(self.graph.neighbors(self.uid))
rounds_to_test = self.test_after
rounds_to_train_evaluate = self.train_evaluate_after
for iteration in range(self.iterations):
logging.info("Starting training iteration: %d", iteration)
......@@ -332,7 +345,6 @@ class Node:
) # Reset optimizer state
self.trainer.reset_optimizer(self.optimizer)
loss_after_sharing = self.trainer.eval_loss(self.dataset)
if iteration:
with open(
......@@ -352,7 +364,6 @@ class Node:
"grad_std": {},
}
results_dict["train_loss"][iteration + 1] = loss_after_sharing
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
if hasattr(self.sharing, "total_meta"):
......@@ -365,14 +376,21 @@ class Node:
results_dict["grad_mean"][iteration + 1] = self.sharing.mean
if hasattr(self.sharing, "std"):
results_dict["grad_std"][iteration + 1] = self.sharing.std
self.save_plot(
results_dict["train_loss"],
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
)
rounds_to_train_evaluate -= 1
if rounds_to_test == 0:
logging.info("Evaluating on train set.")
rounds_to_train_evaluate = self.train_evaluate_after
loss_after_sharing = self.trainer.eval_loss(self.dataset)
results_dict["train_loss"][iteration + 1] = loss_after_sharing
self.save_plot(
results_dict["train_loss"],
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
)
rounds_to_test -= 1
......@@ -417,6 +435,7 @@ class Node:
log_dir=".",
log_level=logging.INFO,
test_after=5,
train_evaluate_after=1,
reset_optimizer=1,
*args
):
......@@ -453,6 +472,10 @@ class Node:
Logging directory
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
test_after : int
Number of iterations after which the test loss and accuracy arecalculated
train_evaluate_after : int
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
args : optional
......@@ -473,6 +496,7 @@ class Node:
log_dir,
log_level,
test_after,
train_evaluate_after,
reset_optimizer,
*args
)
......
......@@ -75,6 +75,7 @@ def get_args():
parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges")
parser.add_argument("-gt", "--graph_type", type=str, default="edges")
parser.add_argument("-ta", "--test_after", type=int, default=5)
parser.add_argument("-tea", "--train_evaluate_after", type=int, default=1)
parser.add_argument("-ro", "--reset_optimizer", type=int, default=1)
args = parser.parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment