Skip to content

Commit a2f5ed1

Browse files
authored
[Interpret Mode] Support custom block size (#1194)
1 parent 2e53e64 commit a2f5ed1

File tree

4 files changed

+110
-55
lines changed

4 files changed

+110
-55
lines changed

helion/language/loops.py

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from .._compiler.ast_extension import LoopType
2424
from .._compiler.ast_extension import expr_from_string
2525
from .._compiler.compile_environment import CompileEnvironment
26-
from .._compiler.compile_environment import warning
2726
from .._compiler.type_propagation import GridIndexType
2827
from .._compiler.type_propagation import IterType
2928
from .._compiler.type_propagation import LiteralType
@@ -519,47 +518,41 @@ def _(
519518
end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None,
520519
block_size: int | torch.Tensor | list[int | torch.Tensor] | None = None,
521520
) -> Iterator[RefTile | tuple[RefTile, ...]]:
522-
# Issue warning if block_size is specified in interpret mode
523-
if block_size is not None:
524-
warning(exc.BlockSizeIgnoredInInterpretMode(block_size))
525-
526-
# Step 1: Normalize begin and end values
527521
begin, end = _normalize_begin_end_ref(begin_or_end, end_or_none)
528-
529-
# Step 2: Convert to lists and then to ints
522+
scalar_input = not isinstance(begin, list) and not isinstance(end, list)
530523
begin_list = _normalize_to_list(begin)
531524
end_list = _normalize_to_list(end)
532-
begin_ints = [_to_int(b) for b in begin_list]
533-
end_ints = [_to_int(e) for e in end_list]
534-
535-
# Step 3: Determine block sizes - always return full dimension size, ignoring block_size parameter
536-
block_size_list = []
537-
for b, e in zip(begin_ints, end_ints, strict=True):
538-
assert b is not None and e is not None
539-
block_size_list.append(e - b)
540-
541-
# Step 4: Determine return type
542-
# Return single tiles if input was not a list
543-
return_single = not isinstance(begin, list) and not isinstance(end, list)
544-
545-
# Step 5: Generate tiles
546-
# Build tiles for each dimension
547-
tiles = []
548-
for b, e in zip(begin_ints, end_ints, strict=True):
549-
assert b is not None and e is not None
550-
if b != e:
551-
# Only create tile if range is non-empty
552-
tiles.append(RefTile(b, e, e - b))
553-
554-
# Yield result based on return type
555-
if tiles: # Only yield if we have at least one non-empty dimension
556-
if return_single:
557-
# Single dimension case - yield the tile directly
558-
assert len(tiles) == 1
559-
yield tiles[0]
560-
else:
561-
# Multi-dimensional case - yield as tuple
562-
yield tuple(tiles)
525+
526+
# Normalize block_size to list matching dimensions
527+
bs_list: list[int | torch.Tensor | None]
528+
if block_size is None:
529+
bs_list = [None] * len(begin_list)
530+
else:
531+
bs_list = cast(
532+
"list[int | torch.Tensor | None]", _normalize_to_list(block_size)
533+
)
534+
if len(bs_list) == 1 and len(begin_list) > 1:
535+
bs_list = bs_list * len(begin_list)
536+
537+
# Build tile ranges for each dimension
538+
dim_ranges: list[list[tuple[int, int, int]]] = []
539+
for b, e, bs in zip(begin_list, end_list, bs_list, strict=True):
540+
b_int, e_int = _to_int(b), _to_int(e)
541+
assert b_int is not None and e_int is not None
542+
if b_int == e_int:
543+
continue
544+
bs_int = _to_int(bs) if bs is not None else (e_int - b_int)
545+
assert bs_int is not None
546+
dim_ranges.append(
547+
[(s, min(s + bs_int, e_int), bs_int) for s in range(b_int, e_int, bs_int)]
548+
)
549+
550+
if not dim_ranges:
551+
return
552+
553+
for combo in itertools.product(*dim_ranges):
554+
tiles = list(starmap(RefTile, combo))
555+
yield tiles[0] if scalar_input else tuple(tiles)
563556

564557

565558
def _codegen_loop_helper(

helion/language/tile_ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,5 +259,4 @@ def _(state: CodegenState) -> ast.AST:
259259

260260
@_decorators.ref(tile_id)
261261
def _(tile: RefTile) -> int:
262-
# ID is always 0 since we always have one tile per dim in ref mode
263-
return 0
262+
return tile._slice.start // tile._block_size

helion/runtime/ref_mode.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import threading
55
import typing
66
from typing import TYPE_CHECKING
7+
from typing import Any
78
from typing import Callable
89
from typing import Protocol
910
from typing import cast
1011

1112
import torch
13+
from torch._prims_common import is_integer_dtype
1214
from torch.overrides import BaseTorchFunctionMode
1315

1416
from .._compiler.compile_environment import CompileEnvironment
@@ -180,6 +182,10 @@ def __torch_function__(
180182
return self._method_handlers[func_name](args, kwargs)
181183
if func_name in self._binary_op_names:
182184
return self._handle_binary_op(func, args, kwargs)
185+
if func_name == "__getitem__":
186+
return self._handle_getitem(args, kwargs)
187+
if func_name == "__setitem__":
188+
return self._handle_setitem(args, kwargs)
183189

184190
if func in self._binary_ops:
185191
return self._handle_binary_op(func, args, kwargs)
@@ -334,6 +340,59 @@ def _should_handle_binary_op(self, lhs: object, rhs: object) -> bool:
334340
# Only handle shape-based masking for non-broadcasting cases
335341
return True
336342

343+
@staticmethod
344+
def _is_int_tensor(x: object) -> bool:
345+
return type(x) is torch.Tensor and is_integer_dtype(x.dtype)
346+
347+
def _handle_getitem(
348+
self,
349+
args: tuple[object, ...],
350+
kwargs: dict[str, object],
351+
) -> torch.Tensor:
352+
"""Handle tensor indexing with out-of-bounds index clamping."""
353+
tensor = cast("torch.Tensor", args[0])
354+
indices: Any = args[1]
355+
is_tuple = isinstance(indices, tuple)
356+
indices_list = list(indices) if is_tuple else [indices]
357+
358+
for dim, idx in enumerate(indices_list):
359+
if self._is_int_tensor(idx):
360+
indices_list[dim] = torch.clamp(idx, min=0, max=tensor.size(dim) - 1)
361+
362+
return tensor[tuple(indices_list) if is_tuple else indices_list[0]]
363+
364+
def _handle_setitem(
365+
self,
366+
args: tuple[object, ...],
367+
kwargs: dict[str, object],
368+
) -> None:
369+
"""Handle tensor indexed assignment with out-of-bounds index clamping."""
370+
tensor = cast("torch.Tensor", args[0])
371+
indices: Any = args[1]
372+
value: Any = args[2]
373+
is_tuple = isinstance(indices, tuple)
374+
indices_list = list(indices) if is_tuple else [indices]
375+
376+
valid_mask: torch.Tensor | None = None
377+
for dim, idx in enumerate(indices_list):
378+
if self._is_int_tensor(idx):
379+
max_idx = tensor.size(dim) - 1
380+
dim_valid = (idx >= 0) & (idx <= max_idx)
381+
valid_mask = (
382+
dim_valid if valid_mask is None else (valid_mask & dim_valid)
383+
)
384+
indices_list[dim] = torch.clamp(idx, min=0, max=max_idx)
385+
386+
final_indices = tuple(indices_list) if is_tuple else indices_list[0]
387+
if valid_mask is not None and type(value) is torch.Tensor:
388+
current = tensor[final_indices]
389+
mask: torch.Tensor = valid_mask
390+
while mask.dim() < value.dim():
391+
mask = mask.unsqueeze(-1)
392+
value = torch.where(mask, value, current)
393+
394+
tensor[final_indices] = value
395+
337396
def _setup_binary_ops_handling(self) -> None:
338397
"""Initialize binary operation tracking sets and mappings."""
339398
# Define binary operations and their variants

test/test_ref_eager.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99

1010
import helion
11-
from helion import exc
1211
from helion._testing import DEVICE
1312
from helion._testing import TestCase
1413
from helion._testing import assert_ref_eager_mode
@@ -95,7 +94,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
9594
expected = x * 2.0
9695
torch.testing.assert_close(result, expected)
9796

98-
def test_block_size_warning(self):
97+
def test_block_size_support(self):
9998
@helion.kernel(ref_mode=helion.RefMode.EAGER)
10099
def kernel(x: torch.Tensor) -> torch.Tensor:
101100
m, n = x.shape
@@ -105,20 +104,25 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
105104
return out
106105

107106
with assert_ref_eager_mode():
108-
# Run the kernel to capture the warning message
109-
captured_stderr = io.StringIO()
110-
with contextlib.redirect_stderr(captured_stderr):
111-
x = torch.randn(128, 128, device=DEVICE)
112-
kernel(x)
113-
114-
stderr_output = captured_stderr.getvalue()
107+
x = torch.randn(128, 128, device=DEVICE)
108+
result = kernel(x)
109+
expected = x * 2.0
110+
torch.testing.assert_close(result, expected)
115111

116-
# Create expected warning message using the actual class
117-
expected_warning = exc.BlockSizeIgnoredInInterpretMode(2)
118-
expected_warning_text = expected_warning.report()
112+
def test_tile_begin_with_block_size_1(self):
113+
@helion.kernel(ref_mode=helion.RefMode.EAGER)
114+
def kernel(x: torch.Tensor) -> torch.Tensor:
115+
n = x.size(0)
116+
out = torch.empty_like(x)
117+
for tile in hl.tile(n, block_size=1):
118+
out[tile] = x[tile] + tile.begin
119+
return out
119120

120-
# Check that the expected warning appears in stderr
121-
self.assertIn(expected_warning_text, stderr_output)
121+
with assert_ref_eager_mode():
122+
x = torch.zeros(8, device=DEVICE)
123+
result = kernel(x)
124+
expected = torch.arange(8, device=DEVICE, dtype=torch.float32)
125+
torch.testing.assert_close(result, expected)
122126

123127

124128
if __name__ == "__main__":

0 commit comments

Comments
 (0)