1111from ignite .engine .events import Events
1212from ignite .utils import manual_seed
1313
14+ from datasets import get_datasets
1415from trainers import create_trainers , TrainEvents
1516from handlers import get_handlers , get_logger
1617from utils import setup_logging , log_metrics , log_basic_info , initialize , resume_from
@@ -34,16 +35,7 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
3435 # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
3536 # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader
3637
37- if rank > 0 :
38- # Ensure that only rank 0 download the dataset
39- idist .barrier ()
40-
41- train_dataset = ...
42- eval_dataset = ...
43-
44- if rank == 0 :
45- # Ensure that only rank 0 download the dataset
46- idist .barrier ()
38+ train_dataset , eval_dataset = get_datasets ()
4739
4840 train_dataloader = idist .auto_dataloader (train_dataset , ** kwargs )
4941 eval_dataloader = idist .auto_dataloader (eval_dataset , ** kwargs )
@@ -104,7 +96,9 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
10496
10597 # setup ignite logger only on rank 0
10698 if rank == 0 :
107- logger_handler = get_logger (config = config , train_engine = train_engine , eval_engine = eval_engine , optimizers = optimizer )
99+ logger_handler = get_logger (
100+ config = config , train_engine = train_engine , eval_engine = eval_engine , optimizers = optimizer
101+ )
108102
109103 # -----------------------------------
110104 # resume from the saved checkpoints
0 commit comments