Skip to content

Commit e9df949

Browse files
author
Jeff Yang
authored
fix: download datasets on local rank 0 in multi node (#65)
* fix: download datasets on local rank 0 in multi node * fix: idist.get_local_rank()
1 parent 280bebd commit e9df949

File tree

6 files changed

+51
-29
lines changed

6 files changed

+51
-29
lines changed

templates/gan/datasets.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from torchvision import transforms as T
22
from torchvision import datasets as dset
3+
import ignite.distributed as idist
34

45

56
def get_datasets(dataset, dataroot):
@@ -12,6 +13,12 @@ def get_datasets(dataset, dataroot):
1213
Returns:
1314
dataset, num_channels
1415
"""
16+
local_rank = idist.get_local_rank()
17+
18+
if local_rank > 0:
19+
# Ensure that only rank 0 download the dataset
20+
idist.barrier()
21+
1522
resize = T.Resize(64)
1623
crop = T.CenterCrop(64)
1724
to_tensor = T.ToTensor()
@@ -42,4 +49,8 @@ def get_datasets(dataset, dataroot):
4249
else:
4350
raise RuntimeError(f"Invalid dataset name: {dataset}")
4451

52+
if local_rank == 0:
53+
# Ensure that only rank 0 download the dataset
54+
idist.barrier()
55+
4556
return dataset, nc

templates/gan/main.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,8 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
4040
# datasets and dataloaders
4141
# -----------------------------
4242

43-
if rank > 0:
44-
# Ensure that only rank 0 download the dataset
45-
idist.barrier()
46-
4743
train_dataset, num_channels = get_datasets(config.dataset, config.data_path)
4844

49-
if rank == 0:
50-
# Ensure that only rank 0 download the dataset
51-
idist.barrier()
52-
5345
train_dataloader = idist.auto_dataloader(
5446
train_dataset,
5547
batch_size=config.batch_size,

templates/image_classification/datasets.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from torchvision import datasets
22
from torchvision.transforms import Compose, Normalize, Pad, RandomCrop, RandomHorizontalFlip, ToTensor
3+
import ignite.distributed as idist
34

45
train_transform = Compose(
56
[
@@ -20,7 +21,17 @@
2021

2122

2223
def get_datasets(path):
24+
local_rank = idist.get_local_rank()
25+
26+
if local_rank > 0:
27+
# Ensure that only rank 0 download the dataset
28+
idist.barrier()
29+
2330
train_ds = datasets.CIFAR10(root=path, train=True, download=True, transform=train_transform)
2431
eval_ds = datasets.CIFAR10(root=path, train=False, download=True, transform=eval_transform)
2532

33+
if local_rank == 0:
34+
# Ensure that only rank 0 download the dataset
35+
idist.barrier()
36+
2637
return train_ds, eval_ds

templates/image_classification/main.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,8 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
3636
# TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
3737
# See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader
3838

39-
if rank > 0:
40-
# Ensure that only rank 0 download the dataset
41-
idist.barrier()
42-
4339
train_dataset, eval_dataset = get_datasets(path=config.data_path)
4440

45-
if rank == 0:
46-
# Ensure that only rank 0 download the dataset
47-
idist.barrier()
48-
4941
train_dataloader = idist.auto_dataloader(
5042
train_dataset,
5143
batch_size=config.train_batch_size,
@@ -128,7 +120,9 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
128120

129121
# setup ignite logger only on rank 0
130122
if rank == 0:
131-
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)
123+
logger_handler = get_logger(
124+
config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer
125+
)
132126

133127
# -----------------------------------
134128
# resume from the saved checkpoints

templates/single/datasets.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,21 @@
1-
# CUSTOM DATASETS AND DATALOADERS GO HERE
1+
# MAKE SURE YOUR DATASETS ARE DOWNLOADING ON LOCAL_RANK 0.
2+
3+
import ignite.distributed as idist
4+
5+
6+
def get_datasets(*args, **kwargs):
7+
local_rank = idist.get_local_rank()
8+
9+
if local_rank > 0:
10+
# Ensure that only rank 0 download the dataset
11+
idist.barrier()
12+
13+
# CUSTOM DATASETS GO HERE
14+
train_dataset = ...
15+
eval_dataset = ...
16+
17+
if local_rank == 0:
18+
# Ensure that only rank 0 download the dataset
19+
idist.barrier()
20+
21+
return train_dataset, eval_dataset

templates/single/main.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ignite.engine.events import Events
1212
from ignite.utils import manual_seed
1313

14+
from datasets import get_datasets
1415
from trainers import create_trainers, TrainEvents
1516
from handlers import get_handlers, get_logger
1617
from utils import setup_logging, log_metrics, log_basic_info, initialize, resume_from
@@ -34,16 +35,7 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
3435
# TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
3536
# See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader
3637

37-
if rank > 0:
38-
# Ensure that only rank 0 download the dataset
39-
idist.barrier()
40-
41-
train_dataset = ...
42-
eval_dataset = ...
43-
44-
if rank == 0:
45-
# Ensure that only rank 0 download the dataset
46-
idist.barrier()
38+
train_dataset, eval_dataset = get_datasets()
4739

4840
train_dataloader = idist.auto_dataloader(train_dataset, **kwargs)
4941
eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs)
@@ -104,7 +96,9 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
10496

10597
# setup ignite logger only on rank 0
10698
if rank == 0:
107-
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)
99+
logger_handler = get_logger(
100+
config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer
101+
)
108102

109103
# -----------------------------------
110104
# resume from the saved checkpoints

0 commit comments

Comments
 (0)