-
Notifications
You must be signed in to change notification settings - Fork 561
Description
🐛 Bug
all_to_all operation generates invalid HLO that fails verification with error: RET_CHECK failure hlo->operand_count() == split_count. The generated HLO all-to-all instruction is missing required attributes (split_dimension, concat_dimension, split_count) and has mismatched operand count vs split count.
To Reproduce
Steps to reproduce the behavior:
- Create a simple tensor on XLA device
- Call
xm.all_to_all()with split parameters - Try to execute the tensor (e.g.,
.cpu())
Minimal reproduction code:
import os
os.environ["PJRT_DEVICE"] = "CPU"
import torch
import torch_xla.core.xla_model as xm
# Create tensor on XLA device
device = xm.xla_device()
value = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.int32, device=device)
# Call all_to_all - this generates invalid HLO
result = xm.all_to_all(
value,
split_dimension=0,
concat_dimension=0,
split_count=2)
# Force execution triggers the error
print(result.cpu())Error message:
RET_CHECK failure (external/xla/xla/service/hlo_verifier.cc:566) hlo->operand_count() == split_count
Expected behavior
The all_to_all operation should generate valid HLO that passes verification and executes successfully. The HLO instruction should include proper split_dimension, concat_dimension, and split_count attributes that match the operand structure.
Environment
- Reproducible on XLA backend [CPU/TPU]: CPU/NEURON
- torch_xla version: 2.5.0+
Additional context
This affects both CPU and Neuron backends. The bug seems in the HLO generation layer where TokenHandler::GetInput() modifies the input tensor, causing PyTorch XLA to create multiple operands without properly setting the corresponding HLO attributes.