Skip to content

Conversation

@khasinski
Copy link

@khasinski khasinski commented Dec 14, 2025

Summary

This adds Torch::NN::DataParallel for multi-GPU training, allowing automatic batch splitting across GPUs.

What about torch-ddl?

I've missed the other PR adding multi-gpu (and even distributed) workloads :) I still think I should submit this, since it's much smaller changeset and has some value. Using multiple GPUs locally is simpler than setting up a cluster load for distributed learning.

Usage

model = MyModel.new.to("cuda:0")
dp_model = Torch::NN::DataParallel.new(model, device_ids: [0, 1])

optimizer.zero_grad
output = dp_model.call(input)
loss = criterion.call(output, target)
loss.backward
optimizer.step

Models that return loss

If the model returns a scalar loss (e.g., [logits, loss]), use dp_model.backward instead of loss.backward:

optimizer.zero_grad
logits, loss = dp_model.call(input, targets: targets)
dp_model.backward(scale: 1.0)
optimizer.step

This is necessary because gathering scalar tensors across devices breaks the autograd graph. The backward method calls backward on each replica's loss separately, then reduces gradients to the original module.

What's included

CUDA device management:

  • Torch::CUDA.current_device - get current CUDA device index
  • Torch::CUDA.set_device(id) - set current CUDA device (useful for testing devices)
  • Torch::CUDA.synchronize - wait for all CUDA operations to complete
  • Torch::CUDA.nccl_available? - check if NCCL is available (useful for checking if we can run DataParallel)

DataParallel:

  • Torch::NN::DataParallel - wraps a module for multi-GPU training
  • Torch::NN::Parallel.replicate - copies modules to multiple devices
  • Torch::NN::Parallel.parallel_apply - runs forward pass on replicas in parallel
  • Torch::NN._scatter / Torch::NN._gather - splits and combines tensors across devices (internal methods, underscore notation mimics what's in pytorch)

Testing

Verified with nanogpt-rb on 2 GPUs. Both GPUs were utilized, the speedup was...well, negative, since they were mismatched (RTX 4090 and GTX 1050 Ti) 😅 A better pairing should result in some improvement.

- Torch::CUDA.current_device / set_device
- Torch::CUDA.synchronize
- Torch::CUDA.nccl_available?
Internal helpers exposed as Torch::NN._scatter and Torch::NN._gather.
Underscore prefix indicates these are not public API.
Implements Torch::NN::DataParallel which parallelizes module execution
by splitting input batches across multiple GPUs.

Features:
- Automatic batch scattering across devices
- Module replication with proper gradient flow
- Thread-based parallel forward pass execution
- Replica caching with in-place weight sync (~35% speedup)
- CUDA synchronization before gathering results
- Output gathering back to primary device

Includes special handling for models that return scalar losses,
requiring use of dp_model.backward() instead of loss.backward().
@orlando-labs
Copy link
Contributor

Hi, @khasinski. There was similar fully functional PR that fully mimics pytorch interface. You may find torch-ddp gem from this repo

@khasinski
Copy link
Author

Hey @orlando-labs. I saw that PR after I've pushed mine 😅 However I still think that this is valuable to keep, since I don't really need distributed training (too much setup), I just need multiple CUDA devices in one computer with as little change to the underlying codebase as possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants