Skip to content

Commit 462d6d8

Browse files
authored
refactor: train_engine to trainer, eval_engine to evaluator (#76)
1 parent 979cc78 commit 462d6d8

File tree

10 files changed

+130
-130
lines changed

10 files changed

+130
-130
lines changed

templates/_base/_handlers.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
def get_handlers(
1717
config: Any,
1818
model: Module,
19-
train_engine: Engine,
20-
eval_engine: Engine,
19+
trainer: Engine,
20+
evaluator: Engine,
2121
metric_name: str,
2222
es_metric_name: str,
2323
train_sampler: Optional[DistributedSampler] = None,
@@ -48,9 +48,9 @@ def get_handlers(
4848
4949
model
5050
best model to save
51-
train_engine
51+
trainer
5252
the engine used for training
53-
eval_engine
53+
evaluator
5454
the engine used for evaluation
5555
metric_name
5656
evaluation metric to save the best model
@@ -63,7 +63,7 @@ def get_handlers(
6363
lr_scheduler
6464
learning rate scheduler as native torch LRScheduler or ignite’s parameter scheduler
6565
output_names
66-
list of names associated with `train_engine`'s process_function output dictionary
66+
list of names associated with `trainer`'s process_function output dictionary
6767
kwargs
6868
keyword arguments passed to Checkpoint handler
6969
@@ -78,7 +78,7 @@ def get_handlers(
7878
# kwargs can be passed to save the model based on training stats
7979
# like score_name, score_function
8080
common.setup_common_training_handlers(
81-
trainer=train_engine,
81+
trainer=trainer,
8282
train_sampler=train_sampler,
8383
to_save=to_save,
8484
lr_scheduler=lr_scheduler,
@@ -99,11 +99,11 @@ def get_handlers(
9999
# https://pytorch.org/ignite/contrib/engines.html#ignite.contrib.engines.common.save_best_model_by_val_score
100100
best_model_handler = common.save_best_model_by_val_score(
101101
output_path=config.output_dir / 'checkpoints',
102-
evaluator=eval_engine,
102+
evaluator=evaluator,
103103
model=model,
104104
metric_name=metric_name,
105105
n_saved=config.n_saved,
106-
trainer=train_engine,
106+
trainer=trainer,
107107
tag='eval',
108108
)
109109
{% endif %}
@@ -112,8 +112,8 @@ def get_handlers(
112112
# https://pytorch.org/ignite/contrib/engines.html#ignite.contrib.engines.common.add_early_stopping_by_val_score
113113
es_handler = common.add_early_stopping_by_val_score(
114114
patience=config.patience,
115-
evaluator=eval_engine,
116-
trainer=train_engine,
115+
evaluator=evaluator,
116+
trainer=trainer,
117117
metric_name=es_metric_name,
118118
)
119119
{% endif %}
@@ -125,7 +125,7 @@ def get_handlers(
125125
# you can replace with the events you want to measure
126126
timer_handler = Timer(average=True)
127127
timer_handler.attach(
128-
engine=train_engine,
128+
engine=trainer,
129129
start=Events.EPOCH_STARTED,
130130
resume=Events.ITERATION_STARTED,
131131
pause=Events.ITERATION_COMPLETED,
@@ -135,7 +135,7 @@ def get_handlers(
135135
{% if setup_timelimit %}
136136

137137
# training will terminate if training time exceed `limit_sec`.
138-
train_engine.add_event_handler(
138+
trainer.add_event_handler(
139139
Events.ITERATION_COMPLETED, TimeLimit(limit_sec=config.limit_sec)
140140
)
141141
{% endif %}
@@ -144,8 +144,8 @@ def get_handlers(
144144

145145
def get_logger(
146146
config: Any,
147-
train_engine: Engine,
148-
eval_engine: Optional[Union[Engine, Dict[str, Engine]]] = None,
147+
trainer: Engine,
148+
evaluator: Optional[Union[Engine, Dict[str, Engine]]] = None,
149149
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
150150
**kwargs: Any,
151151
) -> Optional[BaseLogger]:
@@ -160,9 +160,9 @@ def get_logger(
160160
- `filepath`: logging path to output file
161161
- `logger_log_every_iters`: logging iteration interval for loggers
162162
163-
train_engine
163+
trainer
164164
trainer engine
165-
eval_engine
165+
evaluator
166166
evaluator engine
167167
optimizers
168168
optimizers to log optimizer parameters
@@ -177,58 +177,58 @@ def get_logger(
177177

178178
{% if logger_deps == 'clearml' %}
179179
logger_handler = common.setup_clearml_logging(
180-
trainer=train_engine,
180+
trainer=trainer,
181181
optimizers=optimizers,
182-
evaluators=eval_engine,
182+
evaluators=evaluator,
183183
log_every_iters=config.logger_log_every_iters,
184184
**kwargs,
185185
)
186186
{% elif logger_deps == 'mlflow' %}
187187
logger_handler = common.setup_mlflow_logging(
188-
trainer=train_engine,
188+
trainer=trainer,
189189
optimizers=optimizers,
190-
evaluators=eval_engine,
190+
evaluators=evaluator,
191191
log_every_iters=config.logger_log_every_iters,
192192
**kwargs,
193193
)
194194
{% elif logger_deps == 'neptune-client' %}
195195
logger_handler = common.setup_neptune_logging(
196-
trainer=train_engine,
196+
trainer=trainer,
197197
optimizers=optimizers,
198-
evaluators=eval_engine,
198+
evaluators=evaluator,
199199
log_every_iters=config.logger_log_every_iters,
200200
**kwargs,
201201
)
202202
{% elif logger_deps == 'polyaxon-client' %}
203203
logger_handler = common.setup_plx_logging(
204-
trainer=train_engine,
204+
trainer=trainer,
205205
optimizers=optimizers,
206-
evaluators=eval_engine,
206+
evaluators=evaluator,
207207
log_every_iters=config.logger_log_every_iters,
208208
**kwargs,
209209
)
210210
{% elif logger_deps == 'tensorboard' %}
211211
logger_handler = common.setup_tb_logging(
212212
output_path=config.output_dir,
213-
trainer=train_engine,
213+
trainer=trainer,
214214
optimizers=optimizers,
215-
evaluators=eval_engine,
215+
evaluators=evaluator,
216216
log_every_iters=config.logger_log_every_iters,
217217
**kwargs,
218218
)
219219
{% elif logger_deps == 'visdom' %}
220220
logger_handler = common.setup_visdom_logging(
221-
trainer=train_engine,
221+
trainer=trainer,
222222
optimizers=optimizers,
223-
evaluators=eval_engine,
223+
evaluators=evaluator,
224224
log_every_iters=config.logger_log_every_iters,
225225
**kwargs,
226226
)
227227
{% elif logger_deps == 'wandb' %}
228228
logger_handler = common.setup_wandb_logging(
229-
trainer=train_engine,
229+
trainer=trainer,
230230
optimizers=optimizers,
231-
evaluators=eval_engine,
231+
evaluators=evaluator,
232232
log_every_iters=config.logger_log_every_iters,
233233
**kwargs,
234234
)

templates/gan/_test_internal.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
def test_get_handlers(tmp_path):
30-
train_engine = Engine(lambda e, b: b)
30+
trainer = Engine(lambda e, b: b)
3131
config = Namespace(
3232
output_dir=tmp_path,
3333
save_every_iters=1,
@@ -44,8 +44,8 @@ def test_get_handlers(tmp_path):
4444
bm_handler, es_handler, timer_handler = get_handlers(
4545
config=config,
4646
model=nn.Linear(1, 1),
47-
train_engine=train_engine,
48-
eval_engine=train_engine,
47+
trainer=trainer,
48+
evaluator=trainer,
4949
metric_name="eval_loss",
5050
es_metric_name="eval_loss",
5151
)
@@ -56,12 +56,12 @@ def test_get_handlers(tmp_path):
5656

5757
def test_get_logger(tmp_path):
5858
config = Namespace(output_dir=tmp_path, logger_log_every_iters=1)
59-
train_engine = Engine(lambda e, b: b)
59+
trainer = Engine(lambda e, b: b)
6060
optimizer = optim.Adam(nn.Linear(1, 1).parameters())
6161
logger_handler = get_logger(
6262
config=config,
63-
train_engine=train_engine,
64-
eval_engine=train_engine,
63+
trainer=trainer,
64+
evaluator=trainer,
6565
optimizers=optimizer,
6666
)
6767
types = (
@@ -82,7 +82,7 @@ def test_create_trainers():
8282
model, optimizer, device, loss_fn, batch = set_up()
8383
real_labels = torch.ones(2, device=device)
8484
fake_labels = torch.zeros(2, device=device)
85-
train_engine = create_trainers(
85+
trainer = create_trainers(
8686
config=Namespace(use_amp=True),
8787
netD=model,
8888
netG=model,
@@ -93,7 +93,7 @@ def test_create_trainers():
9393
real_labels=real_labels,
9494
fake_labels=fake_labels,
9595
)
96-
assert isinstance(train_engine, Engine)
96+
assert isinstance(trainer, Engine)
9797

9898

9999
def test_get_default_parser():

templates/gan/main.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,14 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
7171
netD, netG, optimizerD, optimizerG, loss_fn, lr_scheduler = initialize(config, num_channels)
7272

7373
# -----------------------------
74-
# train_engine and eval_engine
74+
# trainer and evaluator
7575
# -----------------------------
7676
ws = idist.get_world_size()
7777
real_labels = torch.ones(config.batch_size // ws, device=device)
7878
fake_labels = torch.zeros(config.batch_size // ws, device=device)
7979
fixed_noise = torch.randn(config.batch_size // ws, config.z_dim, 1, 1, device=device)
8080

81-
train_engine = create_trainers(
81+
trainer = create_trainers(
8282
config=config,
8383
netD=netD,
8484
netG=netG,
@@ -97,19 +97,19 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
9797

9898
logger = setup_logging(config)
9999
log_basic_info(logger, config)
100-
train_engine.logger = logger
100+
trainer.logger = logger
101101

102102
# -------------------------------------
103103
# ignite handlers and ignite loggers
104104
# -------------------------------------
105105

106-
to_save = {'netD': netD, 'netG': netG, 'optimizerD': optimizerD, 'optimizerG': optimizerG, 'trainer': train_engine}
106+
to_save = {'netD': netD, 'netG': netG, 'optimizerD': optimizerD, 'optimizerG': optimizerG, 'trainer': trainer}
107107
optimizers = {'optimizerD': optimizerD, 'optimizerG': optimizerG}
108108
best_model_handler, es_handler, timer_handler = get_handlers(
109109
config=config,
110110
model={'netD', netD, 'netG', netG},
111-
train_engine=train_engine,
112-
eval_engine=train_engine,
111+
trainer=trainer,
112+
evaluator=trainer,
113113
metric_name='errD',
114114
es_metric_name='errD',
115115
to_save=to_save,
@@ -119,7 +119,7 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
119119

120120
# setup ignite logger only on rank 0
121121
if rank == 0:
122-
logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers)
122+
logger_handler = get_logger(config=config, trainer=trainer, optimizers=optimizers)
123123

124124
# -----------------------------------
125125
# resume from the saved checkpoints
@@ -132,7 +132,7 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
132132
# adding handlers using `trainer.on` decorator API
133133
# --------------------------------------------------
134134

135-
@train_engine.on(Events.EPOCH_COMPLETED)
135+
@trainer.on(Events.EPOCH_COMPLETED)
136136
def save_fake_example(engine):
137137
fake = netG(fixed_noise)
138138
path = config.output_dir / (FAKE_IMG_FNAME.format(engine.state.epoch))
@@ -141,7 +141,7 @@ def save_fake_example(engine):
141141
# --------------------------------------------------
142142
# adding handlers using `trainer.on` decorator API
143143
# --------------------------------------------------
144-
@train_engine.on(Events.EPOCH_COMPLETED)
144+
@trainer.on(Events.EPOCH_COMPLETED)
145145
def save_real_example(engine):
146146
img, y = engine.state.batch
147147
path = config.output_dir / (REAL_IMG_FNAME.format(engine.state.epoch))
@@ -150,13 +150,13 @@ def save_real_example(engine):
150150
# -------------------------------------------------------------
151151
# adding handlers using `trainer.on` decorator API
152152
# -------------------------------------------------------------
153-
@train_engine.on(Events.EPOCH_COMPLETED)
153+
@trainer.on(Events.EPOCH_COMPLETED)
154154
def print_times(engine):
155155
if not timer_handler:
156156
logger.info(f"Epoch {engine.state.epoch} done. Time per batch: {timer_handler.value():.3f}[s]")
157157
timer_handler.reset()
158158

159-
@train_engine.on(Events.ITERATION_COMPLETED(every=config.log_every_iters))
159+
@trainer.on(Events.ITERATION_COMPLETED(every=config.log_every_iters))
160160
@idist.one_rank_only()
161161
def print_logs(engine):
162162
fname = config.output_dir / LOGS_FNAME
@@ -174,7 +174,7 @@ def print_logs(engine):
174174
# -------------------------------------------------------------
175175
# adding handlers using `trainer.on` decorator API
176176
# -------------------------------------------------------------
177-
@train_engine.on(Events.EPOCH_COMPLETED)
177+
@trainer.on(Events.EPOCH_COMPLETED)
178178
def create_plots(engine):
179179
try:
180180
import matplotlib as mpl
@@ -202,13 +202,13 @@ def create_plots(engine):
202202
# for training stats
203203
# --------------------------------
204204

205-
train_engine.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")
205+
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")
206206

207207
# ------------------------------------------
208208
# setup if done. let's run the training
209209
# ------------------------------------------
210210

211-
train_engine.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.epoch_length)
211+
trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.epoch_length)
212212

213213
# ------------------------------------------------------------
214214
# close the logger after the training completed / terminated

templates/gan/trainers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
`train_engine` and `eval_engine` like trainer and evaluator
2+
`trainer` and `evaluator` like trainer and evaluator
33
"""
44
from typing import Any
55

@@ -131,13 +131,13 @@ def create_trainers(**kwargs) -> Engine:
131131
132132
Returns
133133
-------
134-
train_engine
134+
trainer
135135
"""
136-
train_engine = Engine(
136+
trainer = Engine(
137137
lambda e, b: train_function(
138138
engine=e,
139139
batch=b,
140140
**kwargs,
141141
)
142142
)
143-
return train_engine
143+
return trainer

0 commit comments

Comments
 (0)