Skip to content

Commit 69586a0

Browse files
authored
Removed idist.barrier() where needed (#194)
* removed idist.barrier() where needed * added configuration for idist.barrier() * after bash formatting * made some more required changes
1 parent c478264 commit 69586a0

File tree

1 file changed

+7
-0
lines changed
  • src/templates/template-vision-classification

1 file changed

+7
-0
lines changed

src/templates/template-vision-classification/data.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,21 @@ def setup_data(config: Any):
1212
----------
1313
config: needs to contain `data_path`, `train_batch_size`, `eval_batch_size`, and `num_workers`
1414
"""
15+
#::: if (it.use_dist) { :::#
1516
local_rank = idist.get_local_rank()
17+
#::: } :::#
1618
transform = T.Compose(
1719
[
1820
T.ToTensor(),
1921
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
2022
]
2123
)
2224

25+
#::: if (it.use_dist) { :::#
2326
if local_rank > 0:
2427
# Ensure that only rank 0 download the dataset
2528
idist.barrier()
29+
#::: } :::#
2630

2731
dataset_train = torchvision.datasets.CIFAR10(
2832
root=config.data_path,
@@ -36,9 +40,12 @@ def setup_data(config: Any):
3640
download=True,
3741
transform=transform,
3842
)
43+
44+
#::: if (it.use_dist) { :::#
3945
if local_rank == 0:
4046
# Ensure that only rank 0 download the dataset
4147
idist.barrier()
48+
#::: } :::#
4249

4350
dataloader_train = idist.auto_dataloader(
4451
dataset_train,

0 commit comments

Comments
 (0)