Skip to content

Commit f4dddb7

Browse files
authored
added configuration for local rank (#197)
1 parent 69586a0 commit f4dddb7

File tree

3 files changed

+14
-3
lines changed
  • src/templates

3 files changed

+14
-3
lines changed

src/templates/template-text-classification/data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ def __len__(self):
4242

4343

4444
def setup_data(config):
45+
#::: if (it.use_dist) { :::#
4546
local_rank = idist.get_local_rank()
4647

4748
if local_rank > 0:
4849
idist.barrier()
50+
#::: } :::#
4951

5052
dataset_train, dataset_eval = load_dataset(
5153
"imdb", split=["train", "test"], cache_dir=config.data_path
@@ -61,9 +63,10 @@ def setup_data(config):
6163
dataset_eval = TransformerDataset(
6264
test_texts, test_labels, tokenizer, config.max_length
6365
)
64-
66+
#::: if (it.use_dist) { :::#
6567
if local_rank == 0:
6668
idist.barrier()
69+
#::: } :::#
6770

6871
dataloader_train = idist.auto_dataloader(
6972
dataset_train,

src/templates/template-vision-dcgan/data.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,21 @@ def setup_data(config: Any):
1212
----------
1313
config: needs to contain `data_path`, `train_batch_size`, `eval_batch_size`, and `num_workers`
1414
"""
15+
#::: if (it.use_dist) { :::#
1516
local_rank = idist.get_local_rank()
17+
#::: } :::#
1618
transform = T.Compose(
1719
[
1820
T.Resize(64),
1921
T.ToTensor(),
2022
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
2123
]
2224
)
23-
25+
#::: if (it.use_dist) { :::#
2426
if local_rank > 0:
2527
# Ensure that only rank 0 download the dataset
2628
idist.barrier()
29+
#::: } :::#
2730

2831
dataset_train = torchvision.datasets.CIFAR10(
2932
root=config.data_path,
@@ -38,9 +41,11 @@ def setup_data(config: Any):
3841
transform=transform,
3942
)
4043
nc = 3
44+
#::: if (it.use_dist) { :::#
4145
if local_rank == 0:
4246
# Ensure that only rank 0 download the dataset
4347
idist.barrier()
48+
#::: } :::#
4449

4550
dataloader_train = idist.auto_dataloader(
4651
dataset_train,

src/templates/template-vision-segmentation/data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,17 @@ def prepare_image_mask(batch, device, non_blocking):
177177

178178

179179
def download_datasets(data_path):
180+
#::: if (it.use_dist) { :::#
180181
local_rank = idist.get_local_rank()
181182
if local_rank > 0:
182183
# Ensure that only rank 0 download the dataset
183184
idist.barrier()
185+
#::: } :::#
184186

185187
VOCSegmentation(data_path, image_set="train", download=True)
186188
VOCSegmentation(data_path, image_set="val", download=True)
187-
189+
#::: if (it.use_dist) { :::#
188190
if local_rank == 0:
189191
# Ensure that only rank 0 download the dataset
190192
idist.barrier()
193+
#::: } :::#

0 commit comments

Comments
 (0)