We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent beef62e commit 11060f8Copy full SHA for 11060f8
train.py
@@ -355,6 +355,8 @@ def main():
355
args.world_size = 1
356
args.rank = 0 # global rank
357
if args.distributed:
358
+ if 'LOCAL_RANK' in os.environ:
359
+ args.local_rank = int(os.getenv('LOCAL_RANK'))
360
args.device = 'cuda:%d' % args.local_rank
361
torch.cuda.set_device(args.local_rank)
362
torch.distributed.init_process_group(backend='nccl', init_method='env://')
0 commit comments