Skip to content

Commit b0de996

Browse files
ydcjeffvfdev-5
andauthored
ci: speed up test runs (#85)
* ci: speed up test runs * fix: rm max_epochs in evaluators, run log_metrics * fix: replace jinja value with 5 in single * Apply suggestions from code review [skip ci] Co-authored-by: vfdev <vfdev.5@gmail.com> * remove useless code, fix #87 * remove useless code, fix #87 * readd custom events in single * This commit add the somehow missed evaluation * fix: reuse eval_epoch_length Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 10ef370 commit b0de996

File tree

6 files changed

+42
-42
lines changed

6 files changed

+42
-42
lines changed

.github/run_test.sh

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ elif [ $1 == "unittest" ]; then
2020
elif [ $1 == "default" ]; then
2121
for file in $(find ./tests/dist -iname "main.py" -not -path "./tests/dist/launch/*" -not -path "./tests/dist/spawn/*" -not -path "./tests/dist/single/*")
2222
do
23-
python $file --verbose --log_every_iters 2 --num_workers 1 --epoch_length 10
23+
python $file \
24+
--verbose \
25+
--log_every_iters 2 \
26+
--num_workers 1 \
27+
--train_epoch_length 10 \
28+
--eval_epoch_length 10
2429
done
2530
elif [ $1 == "launch" ]; then
2631
for file in $(find ./tests/dist/launch -iname "main.py" -not -path "./tests/dist/launch/single/*")
@@ -31,7 +36,8 @@ elif [ $1 == "launch" ]; then
3136
--verbose \
3237
--backend gloo \
3338
--num_workers 1 \
34-
--epoch_length 10 \
39+
--eval_epoch_length 10 \
40+
--train_epoch_length 10 \
3541
--log_every_iters 2
3642
done
3743
elif [ $1 == "spawn" ]; then
@@ -41,7 +47,8 @@ elif [ $1 == "spawn" ]; then
4147
--verbose \
4248
--backend gloo \
4349
--num_workers 1 \
44-
--epoch_length 10 \
50+
--eval_epoch_length 10 \
51+
--train_epoch_length 10 \
4552
--nproc_per_node 2 \
4653
--log_every_iters 2
4754
done

templates/_base/_argparse.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,15 @@
5656
"type": int,
5757
"help": "master node port for torch native backends. Default: %(default)s",
5858
},
59-
"epoch_length": {
59+
"train_epoch_length": {
6060
"default": None,
6161
"type": int,
62-
"help": "epoch_length of Engine.run(). Default: %(default)s"
62+
"help": "epoch_length of Engine.run() for training. Default: %(default)s"
63+
},
64+
"eval_epoch_length": {
65+
"default": None,
66+
"type": int,
67+
"help": "epoch_length of Engine.run() for evaluation. Default: %(default)s"
6368
},
6469
# ignite handlers options
6570
"save_every_iters": {

templates/gan/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def create_plots(engine):
207207
# setup if done. let's run the training
208208
# ------------------------------------------
209209

210-
trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.epoch_length)
210+
trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)
211211

212212
# ------------------------------------------------------------
213213
# close the logger after the training completed / terminated

templates/image_classification/main.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ignite.metrics import Accuracy, Loss
1313

1414
from datasets import get_datasets
15-
from trainers import create_trainers, TrainEvents
15+
from trainers import create_trainers
1616
from utils import setup_logging, log_metrics, log_basic_info, initialize, resume_from, get_handlers, get_logger
1717
from config import get_default_parser
1818

@@ -140,30 +140,6 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
140140
if config.resume_from:
141141
resume_from(to_load=to_save, checkpoint_fp=config.resume_from)
142142

143-
# --------------------------------------------
144-
# let's trigger custom events we registered
145-
# we will use a `event_filter` to trigger that
146-
# `event_filter` has to return boolean
147-
# whether this event should be executed
148-
# here will log the gradients on the 1st iteration
149-
# and every 100 iterations
150-
# --------------------------------------------
151-
152-
@trainer.on(TrainEvents.BACKWARD_COMPLETED(lambda _, ev: (ev % 100 == 0) or (ev == 1)))
153-
def _():
154-
# do something interesting
155-
pass
156-
157-
# ----------------------------------------
158-
# here we will use `every` to trigger
159-
# every 100 iterations
160-
# ----------------------------------------
161-
162-
@trainer.on(TrainEvents.OPTIM_STEP_COMPLETED(every=100))
163-
def _():
164-
# do something interesting
165-
pass
166-
167143
# --------------------------------
168144
# print metrics to the stderr
169145
# with `add_event_handler` API
@@ -182,23 +158,22 @@ def _():
182158

183159
@trainer.on(Events.EPOCH_COMPLETED(every=1))
184160
def _():
185-
evaluator.run(eval_dataloader, max_epochs=1)
186-
evaluator.add_event_handler(Events.EPOCH_COMPLETED(every=1), log_metrics, tag="eval")
161+
evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)
162+
log_metrics(evaluator, "eval")
187163

188164
# --------------------------------------------------
189165
# let's try run evaluation first as a sanity check
190166
# --------------------------------------------------
191167

192168
@trainer.on(Events.STARTED)
193169
def _():
194-
evaluator.run(eval_dataloader, max_epochs=1, epoch_length=2)
195-
evaluator.state.max_epochs = None
170+
evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)
196171

197172
# ------------------------------------------
198173
# setup if done. let's run the training
199174
# ------------------------------------------
200175

201-
trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.epoch_length)
176+
trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)
202177

203178
# ------------------------------------------------------------
204179
# close the logger after the training completed / terminated

templates/single/config.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,15 @@
1-
{% include "_argparse.py" %}
1+
{% extends "_argparse.py" %}
2+
{% block get_default_parser %}
3+
UPDATES = {
4+
# training options
5+
"max_epochs": {
6+
"default": 5,
7+
"type": int,
8+
"help": "max_epochs of ignite.Engine.run() for training. Default: %(default)s",
9+
}
10+
}
11+
12+
DEFAULTS.update(UPDATES)
13+
14+
{{ super() }}
15+
{% endblock %}

templates/single/main.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,24 +160,23 @@ def _():
160160

161161
@trainer.on(Events.EPOCH_COMPLETED(every=1))
162162
def _():
163-
evaluator.run(eval_dataloader, max_epochs=1)
164-
evaluator.add_event_handler(Events.EPOCH_COMPLETED(every=1), log_metrics, tag="eval")
163+
evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)
164+
log_metrics(evaluator, "eval")
165165

166166
# --------------------------------------------------
167167
# let's try run evaluation first as a sanity check
168168
# --------------------------------------------------
169169

170170
@trainer.on(Events.STARTED)
171171
def _():
172-
evaluator.run(eval_dataloader, max_epochs=1, epoch_length=2)
173-
evaluator.state.max_epochs = None
172+
evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)
174173

175174
# ------------------------------------------
176175
# setup if done. let's run the training
177176
# ------------------------------------------
178177
# TODO : PLEASE provide `max_epochs` parameters
179178

180-
trainer.run(train_dataloader, epoch_length=config.epoch_length)
179+
trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)
181180

182181
# ------------------------------------------------------------
183182
# close the logger after the training completed / terminated

0 commit comments

Comments
 (0)