Skip to content

Commit 28cc903

Browse files
authored
Allow using hl.specialize to specialize on tensor strides (#1215)
1 parent 7c5406b commit 28cc903

File tree

8 files changed

+271
-9
lines changed

8 files changed

+271
-9
lines changed

helion/_compiler/compile_environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __init__(
121121
collections.Counter()
122122
)
123123
self.specialized_vars: set[sympy.Symbol] = set()
124+
self.specialized_strides: set[tuple[str, int]] = set()
124125
self.loop_dependency_checker = LoopDependencyChecker()
125126
self._symint_cache: dict[object, torch.SymInt] = {}
126127
self.device_load_count = (

helion/_compiler/device_function.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import sympy
1616
import torch
17+
from torch._dynamo.source import LocalSource
1718
from torch._inductor.codegen.triton import TritonPrinter
1819
from torch.fx.graph import _Namespace
1920

@@ -602,11 +603,23 @@ def tensor_size(self, fake_value: torch.Tensor, dim: int) -> Argument:
602603
return self._tensor_property(TensorSizeArg, fake_value, dim, "size")
603604

604605
def tensor_stride(self, fake_value: torch.Tensor, dim: int) -> Argument:
606+
v = fake_value.stride(dim)
607+
env = CompileEnvironment.current()
608+
# Check if this stride was explicitly specialized
609+
source = env.input_sources.get(fake_value)
605610
if (
606-
isinstance(v := fake_value.stride(dim), int)
607-
and CompileEnvironment.current().settings.static_shapes
611+
isinstance(source, LocalSource)
612+
and (source.local_name, dim) in env.specialized_strides
608613
):
609-
return StaticShape(v)
614+
return StaticShape(int(v))
615+
if isinstance(v, int):
616+
if env.settings.static_shapes:
617+
return StaticShape(v)
618+
else:
619+
# Check if all free symbols are specialized
620+
syms = v._sympy_().free_symbols
621+
if syms and syms <= env.specialized_vars:
622+
return StaticShape(int(v))
610623
return self._tensor_property(TensorStrideArg, fake_value, dim, "stride")
611624

612625
def sorted_args(self) -> list[Argument]:

helion/_compiler/type_propagation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ def propagate_call(
675675
attr = self.attr()
676676
if attr in {"dim", "ndimension"} and not (args or kwargs):
677677
return TypeInfo.from_example(self.tensor.fake_value.ndim, origin)
678-
if attr in {"shape", "size"} and not kwargs:
678+
if attr in {"shape", "size", "stride"} and not kwargs:
679679
fn = getattr(self.tensor.fake_value, attr)
680680
try:
681681
return TypeInfo.from_example(

helion/exc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ class SpecializeOnDevice(BaseError):
192192

193193

194194
class SpecializeArgType(BaseError):
195-
message = "hl.specialize() must be called on a size from an input tensor, got: {}"
195+
message = "hl.specialize() must be called on a size or stride from an input tensor, got: {}"
196196

197197

198198
class StackTensorcOnHost(BaseError):

helion/language/constexpr.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from typing_extensions import TypeVar
77

88
import torch
9+
from torch._dynamo.source import LocalSource
10+
from torch._dynamo.source import TensorProperty
11+
from torch._dynamo.source import TensorPropertySource
912

1013
from .. import exc
1114
from .._compiler.ast_extension import expr_from_string
@@ -87,7 +90,18 @@ def _(value: TypeInfo, *, origin: Origin) -> TypeInfo:
8790
env = CompileEnvironment.current()
8891

8992
def handle_symint(symint: torch.SymInt) -> int:
90-
env.specialized_vars.update(symint._sympy_().free_symbols)
93+
syms = symint._sympy_().free_symbols
94+
env.specialized_vars.update(syms)
95+
# Track stride specializations
96+
for sym in syms:
97+
for source in env.shape_env.var_to_sources.get(sym, []):
98+
if (
99+
isinstance(source, TensorPropertySource)
100+
and source.prop == TensorProperty.STRIDE
101+
and isinstance(source.base, LocalSource)
102+
and source.idx is not None
103+
):
104+
env.specialized_strides.add((source.base.local_name, source.idx))
91105
return symint.__int__()
92106

93107
specialized = _convert_specializable(proxy, on_symint=handle_symint)

helion/runtime/kernel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,12 +624,14 @@ def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]:
624624

625625
def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
626626
if isinstance(v, TensorPropertySource):
627-
assert v.prop == TensorProperty.SIZE
628627
index = v.idx
629628
assert index is not None
630629
inner = make_extractor(v.base)
631-
632-
return lambda args: cast("torch.Tensor", inner(args)).size(index)
630+
if v.prop == TensorProperty.SIZE:
631+
return lambda args: cast("torch.Tensor", inner(args)).size(index)
632+
if v.prop == TensorProperty.STRIDE:
633+
return lambda args: cast("torch.Tensor", inner(args)).stride(index)
634+
raise exc.SpecializeArgType(v)
633635
if isinstance(v, LocalSource):
634636
index = arg_name_to_index[v.local_name]
635637
return operator.itemgetter(index)

test/test_specialize.expected

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,119 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
335335
# src[test_specialize.py:N]: return out
336336
return out
337337

338+
--- assertExpectedJournal(TestSpecialize.test_specialize_size_becomes_static)
339+
from __future__ import annotations
340+
341+
import torch
342+
import triton
343+
import triton.language as tl
344+
from helion.runtime import default_launcher as _default_launcher
345+
346+
@triton.jit
347+
def _helion_fn(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
348+
# src[test_specialize.py:N]: for tile in hl.tile(n):
349+
pid_0 = tl.program_id(0)
350+
offset_0 = pid_0 * _BLOCK_SIZE_0
351+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
352+
mask_0 = indices_0 < 137
353+
# src[test_specialize.py:N]: out[tile] = x[tile] + 1
354+
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
355+
v_0 = 1.0
356+
v_1 = load + v_0
357+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
358+
359+
def fn(x: torch.Tensor, *, _launcher=_default_launcher):
360+
# src[test_specialize.py:N]: out = torch.empty_like(x)
361+
out = torch.empty_like(x)
362+
# src[test_specialize.py:N]: for tile in hl.tile(n):
363+
_BLOCK_SIZE_0 = 32
364+
# src[test_specialize.py:N]: for tile in hl.tile(n):
365+
# src[test_specialize.py:N]: out[tile] = x[tile] + 1
366+
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1)
367+
# src[test_specialize.py:N]: return out
368+
return out
369+
370+
--- assertExpectedJournal(TestSpecialize.test_specialize_stride_basic)
371+
from __future__ import annotations
372+
373+
import torch
374+
import triton
375+
import triton.language as tl
376+
from helion.runtime import default_launcher as _default_launcher
377+
378+
@triton.jit
379+
def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
380+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
381+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
382+
pid_0 = tl.program_id(0) % num_blocks_0
383+
pid_1 = tl.program_id(0) // num_blocks_0
384+
offset_0 = pid_0 * _BLOCK_SIZE_0
385+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
386+
mask_0 = indices_0 < x_size_0
387+
offset_1 = pid_1 * _BLOCK_SIZE_1
388+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
389+
mask_1 = indices_1 < x_size_1
390+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride
391+
load = tl.load(x + (indices_0[:, None] * 137 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
392+
v_0 = 137.0
393+
v_1 = load + v_0
394+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
395+
396+
def fn(x: torch.Tensor, *, _launcher=_default_launcher):
397+
# src[test_specialize.py:N]: out = torch.empty_like(x)
398+
out = torch.empty_like(x)
399+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
400+
_BLOCK_SIZE_0 = 32
401+
_BLOCK_SIZE_1 = 32
402+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
403+
# src[test_specialize.py:N]: # Use stride in computation to verify it's a constant
404+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride
405+
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
406+
# src[test_specialize.py:N]: return out
407+
return out
408+
409+
--- assertExpectedJournal(TestSpecialize.test_specialize_stride_tuple)
410+
from __future__ import annotations
411+
412+
import torch
413+
import triton
414+
import triton.language as tl
415+
from helion.runtime import default_launcher as _default_launcher
416+
417+
@triton.jit
418+
def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
419+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
420+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
421+
pid_0 = tl.program_id(0) % num_blocks_0
422+
pid_1 = tl.program_id(0) // num_blocks_0
423+
offset_0 = pid_0 * _BLOCK_SIZE_0
424+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
425+
mask_0 = indices_0 < x_size_0
426+
offset_1 = pid_1 * _BLOCK_SIZE_1
427+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
428+
mask_1 = indices_1 < x_size_1
429+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
430+
load = tl.load(x + (indices_0[:, None] * 311 + indices_1[None, :] * 131), mask_0[:, None] & mask_1[None, :], other=0)
431+
v_0 = 311.0
432+
v_1 = load + v_0
433+
v_2 = 131.0
434+
v_3 = v_1 + v_2
435+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
436+
437+
def fn(x: torch.Tensor, *, _launcher=_default_launcher):
438+
# src[test_specialize.py:N]: stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
439+
stride0, stride1 = (311, 131)
440+
# src[test_specialize.py:N]: out = torch.empty_like(x)
441+
out = torch.empty_like(x)
442+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
443+
_BLOCK_SIZE_0 = 32
444+
_BLOCK_SIZE_1 = 32
445+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
446+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
447+
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
448+
# src[test_specialize.py:N]: return out
449+
return out
450+
338451
--- assertExpectedJournal(TestSpecialize.test_specialize_tuple_element)
339452
from __future__ import annotations
340453

test/test_specialize.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,125 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
326326
self.assertIn("65536", code)
327327
self.assertExpectedJournal(code)
328328

329+
def test_specialize_size_becomes_static(self):
330+
"""Test that hl.specialize on a size makes it NOT passed to the triton kernel."""
331+
332+
@helion.kernel(static_shapes=False)
333+
def fn(x: torch.Tensor) -> torch.Tensor:
334+
n = hl.specialize(x.size(0))
335+
out = torch.empty_like(x)
336+
for tile in hl.tile(n):
337+
out[tile] = x[tile] + 1
338+
return out
339+
340+
x = torch.randn([137], device=DEVICE) # Use prime to avoid alignment
341+
code, result = code_and_output(fn, (x,))
342+
torch.testing.assert_close(result, x + 1)
343+
# Verify x_size_0 is NOT passed as an argument (it should be static)
344+
self.assertNotIn("x_size_0", code)
345+
self.assertExpectedJournal(code)
346+
347+
def test_specialize_stride_basic(self):
348+
"""Test that hl.specialize works with tensor strides."""
349+
350+
@helion.kernel(static_shapes=False, autotune_effort="none")
351+
def fn(x: torch.Tensor) -> torch.Tensor:
352+
stride = hl.specialize(x.stride(0))
353+
out = torch.empty_like(x)
354+
for tile in hl.tile(x.size()):
355+
# Use stride in computation to verify it's a constant
356+
out[tile] = x[tile] + stride
357+
return out
358+
359+
# Use empty_strided to create tensor with a unique stride value (137)
360+
# that won't be confused with shape values
361+
size = (64, 64)
362+
stride0 = 137 # Distinctive prime number for stride(0)
363+
stride1 = 1
364+
# Need storage size to fit: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1
365+
storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1
366+
storage = torch.randn(storage_size, device=DEVICE)
367+
x = torch.as_strided(storage, size, (stride0, stride1))
368+
369+
code, result = code_and_output(fn, (x,))
370+
torch.testing.assert_close(result, x + x.stride(0))
371+
# Verify the unique stride value 137 is inlined as a constant
372+
self.assertIn("137", code)
373+
# Verify x_stride_0 is NOT passed as an argument (it should be inlined)
374+
self.assertNotIn("x_stride_0", code)
375+
self.assertExpectedJournal(code)
376+
377+
def test_specialize_stride_creates_different_variants(self):
378+
"""Test that different stride patterns create different kernel variants."""
379+
380+
@helion.kernel(static_shapes=False, autotune_effort="none")
381+
def fn(x: torch.Tensor) -> torch.Tensor:
382+
stride = hl.specialize(x.stride(0))
383+
out = torch.empty_like(x)
384+
for tile in hl.tile(x.size()):
385+
out[tile] = x[tile] + stride
386+
return out
387+
388+
# Create two tensors with different unique stride values using empty_strided
389+
size = (64, 64)
390+
391+
# First tensor with stride(0) = 173 (distinctive prime)
392+
stride0_a = 173
393+
storage_size_a = (size[0] - 1) * stride0_a + (size[1] - 1) * 1 + 1
394+
storage_a = torch.randn(storage_size_a, device=DEVICE)
395+
x_a = torch.as_strided(storage_a, size, (stride0_a, 1))
396+
397+
# Second tensor with stride(0) = 257 (different distinctive prime)
398+
stride0_b = 257
399+
storage_size_b = (size[0] - 1) * stride0_b + (size[1] - 1) * 1 + 1
400+
storage_b = torch.randn(storage_size_b, device=DEVICE)
401+
x_b = torch.as_strided(storage_b, size, (stride0_b, 1))
402+
403+
# These should create different bound kernels due to different strides
404+
bound1 = fn.bind((x_a,))
405+
bound2 = fn.bind((x_b,))
406+
407+
# Verify different variants are used
408+
self.assertTrueIfInNormalMode(bound1 is not bound2)
409+
410+
# Verify correctness
411+
result1 = fn(x_a)
412+
result2 = fn(x_b)
413+
torch.testing.assert_close(result1, x_a + stride0_a)
414+
torch.testing.assert_close(result2, x_b + stride0_b)
415+
416+
def test_specialize_stride_tuple(self):
417+
"""Test that hl.specialize works with tuple of strides."""
418+
419+
@helion.kernel(static_shapes=False, autotune_effort="none")
420+
def fn(x: torch.Tensor) -> torch.Tensor:
421+
stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
422+
out = torch.empty_like(x)
423+
for tile in hl.tile(x.size()):
424+
out[tile] = x[tile] + stride0 + stride1
425+
return out
426+
427+
# Create tensor with unique stride values using empty_strided
428+
# stride0 = 311, stride1 = 131 (distinctive primes unlikely to appear elsewhere)
429+
size = (64, 64)
430+
stride0 = 311
431+
stride1 = 131
432+
# Storage must fit the largest offset: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1
433+
storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1
434+
storage = torch.randn(storage_size, device=DEVICE)
435+
x = torch.as_strided(storage, size, (stride0, stride1))
436+
437+
code, result = code_and_output(fn, (x,))
438+
expected = x + stride0 + stride1
439+
torch.testing.assert_close(result, expected)
440+
# Verify both unique stride values appear in the generated code
441+
self.assertIn("311", code)
442+
self.assertIn("131", code)
443+
# Verify both x_stride_0 and x_stride_1 are NOT passed as arguments (they should be inlined)
444+
self.assertNotIn("x_stride_0", code)
445+
self.assertNotIn("x_stride_1", code)
446+
self.assertExpectedJournal(code)
447+
329448

330449
if __name__ == "__main__":
331450
unittest.main()

0 commit comments

Comments
 (0)