Skip to content

Commit a2cb804

Browse files
authored
[Interpret Mode] Fix hl.load with multiple 1D tensor indices (#1227)
1 parent 96f169f commit a2cb804

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

helion/language/memory_ops.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,19 @@ def _(
353353
from .ref_tile import RefTile
354354

355355
if extra_mask is None:
356+
# Convert RefTiles to indices
357+
indices = [idx.index if isinstance(idx, RefTile) else idx for idx in index]
358+
# Use meshgrid for Cartesian product when we have multiple tensor indices
359+
tensor_idxs = [
360+
i for i, idx in enumerate(indices) if isinstance(idx, torch.Tensor)
361+
]
362+
if len(tensor_idxs) > 1:
363+
# pyrefly: ignore [bad-argument-type]
364+
grids = torch.meshgrid(*(indices[i] for i in tensor_idxs), indexing="ij")
365+
for i, grid in zip(tensor_idxs, grids, strict=False):
366+
indices[i] = grid
356367
# pyrefly: ignore [bad-argument-type]
357-
return tensor[tuple(index)]
368+
return tensor[tuple(indices)]
358369

359370
# Create zero result matching mask shape
360371
result = torch.zeros(extra_mask.shape, dtype=tensor.dtype, device=tensor.device)

test/test_ref_eager.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,42 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
174174
result.to(torch.float32), x.to(torch.float32), atol=1e-2, rtol=1e-2
175175
)
176176

177+
def test_load_2d_indexing_without_extra_mask(self):
178+
"""Test that hl.load with two 1D tensor indices produces 2D output in ref eager mode."""
179+
180+
@helion.kernel(ref_mode=helion.RefMode.EAGER)
181+
def kernel(mask: torch.Tensor) -> torch.Tensor:
182+
n = mask.size(0)
183+
out = torch.zeros_like(mask)
184+
for tile_i, tile_j in hl.tile([n, n]):
185+
# Load with two 1D tensor indices - should produce [tile_I, tile_J] output
186+
vals = hl.load(mask, [tile_i.index, tile_j.index])
187+
out[tile_i, tile_j] = vals
188+
return out
189+
190+
with assert_ref_eager_mode():
191+
mask = torch.tril(torch.ones(4, 4, device=DEVICE, dtype=torch.float32))
192+
result = kernel(mask)
193+
torch.testing.assert_close(result, mask)
194+
195+
def test_load_3d_indexing_without_extra_mask(self):
196+
"""Test that hl.load with three 1D tensor indices produces 3D output in ref eager mode."""
197+
198+
@helion.kernel(ref_mode=helion.RefMode.EAGER)
199+
def kernel(x: torch.Tensor) -> torch.Tensor:
200+
d0, d1, d2 = x.shape
201+
out = torch.zeros_like(x)
202+
for tile_i, tile_j, tile_k in hl.tile([d0, d1, d2]):
203+
# Load with three 1D tensor indices - should produce [tile_I, tile_J, tile_K] output
204+
vals = hl.load(x, [tile_i.index, tile_j.index, tile_k.index])
205+
out[tile_i, tile_j, tile_k] = vals
206+
return out
207+
208+
with assert_ref_eager_mode():
209+
x = torch.arange(24, device=DEVICE, dtype=torch.float32).reshape(2, 3, 4)
210+
result = kernel(x)
211+
torch.testing.assert_close(result, x)
212+
177213

178214
if __name__ == "__main__":
179215
unittest.main()

0 commit comments

Comments
 (0)