1212from ignite .metrics import Accuracy , Loss
1313
1414from datasets import get_datasets
15- from trainers import create_trainers , TrainEvents
15+ from trainers import create_trainers
1616from utils import setup_logging , log_metrics , log_basic_info , initialize , resume_from , get_handlers , get_logger
1717from config import get_default_parser
1818
@@ -140,30 +140,6 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
140140 if config .resume_from :
141141 resume_from (to_load = to_save , checkpoint_fp = config .resume_from )
142142
143- # --------------------------------------------
144- # let's trigger custom events we registered
145- # we will use a `event_filter` to trigger that
146- # `event_filter` has to return boolean
147- # whether this event should be executed
148- # here will log the gradients on the 1st iteration
149- # and every 100 iterations
150- # --------------------------------------------
151-
152- @trainer .on (TrainEvents .BACKWARD_COMPLETED (lambda _ , ev : (ev % 100 == 0 ) or (ev == 1 )))
153- def _ ():
154- # do something interesting
155- pass
156-
157- # ----------------------------------------
158- # here we will use `every` to trigger
159- # every 100 iterations
160- # ----------------------------------------
161-
162- @trainer .on (TrainEvents .OPTIM_STEP_COMPLETED (every = 100 ))
163- def _ ():
164- # do something interesting
165- pass
166-
167143 # --------------------------------
168144 # print metrics to the stderr
169145 # with `add_event_handler` API
@@ -182,23 +158,22 @@ def _():
182158
183159 @trainer .on (Events .EPOCH_COMPLETED (every = 1 ))
184160 def _ ():
185- evaluator .run (eval_dataloader , max_epochs = 1 )
186- evaluator . add_event_handler ( Events . EPOCH_COMPLETED ( every = 1 ), log_metrics , tag = "eval" )
161+ evaluator .run (eval_dataloader , epoch_length = config . eval_epoch_length )
162+ log_metrics ( evaluator , "eval" )
187163
188164 # --------------------------------------------------
189165 # let's try run evaluation first as a sanity check
190166 # --------------------------------------------------
191167
192168 @trainer .on (Events .STARTED )
193169 def _ ():
194- evaluator .run (eval_dataloader , max_epochs = 1 , epoch_length = 2 )
195- evaluator .state .max_epochs = None
170+ evaluator .run (eval_dataloader , epoch_length = config .eval_epoch_length )
196171
197172 # ------------------------------------------
198173 # setup if done. let's run the training
199174 # ------------------------------------------
200175
201- trainer .run (train_dataloader , max_epochs = config .max_epochs , epoch_length = config .epoch_length )
176+ trainer .run (train_dataloader , max_epochs = config .max_epochs , epoch_length = config .train_epoch_length )
202177
203178 # ------------------------------------------------------------
204179 # close the logger after the training completed / terminated
0 commit comments