Skip to content

Commit 11060f8

Browse files
committed
make train.py compatible with torchrun
1 parent beef62e commit 11060f8

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ def main():
355355
args.world_size = 1
356356
args.rank = 0 # global rank
357357
if args.distributed:
358+
if 'LOCAL_RANK' in os.environ:
359+
args.local_rank = int(os.getenv('LOCAL_RANK'))
358360
args.device = 'cuda:%d' % args.local_rank
359361
torch.cuda.set_device(args.local_rank)
360362
torch.distributed.init_process_group(backend='nccl', init_method='env://')

0 commit comments

Comments
 (0)