From 5e26f71571e32c9b755b98bf5c9bf470288b242e Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 8 Dec 2025 22:40:49 -0800 Subject: [PATCH 1/5] test --- test/test_specialize.expected | 96 +++++++++++++++++++++++++++++++++++ test/test_specialize.py | 87 +++++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+) diff --git a/test/test_specialize.expected b/test/test_specialize.expected index 03a52f23c..7a113c6a4 100644 --- a/test/test_specialize.expected +++ b/test/test_specialize.expected @@ -1,6 +1,102 @@ This file is automatically generated by assertExpectedJournal calls in test_specialize.py. Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. +--- assertExpectedJournal(TestMarkStatic.test_mark_static) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_matmul(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 64 + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < 56 + # src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_specialize.py:N]: for tile_k in hl.tile(k): + # src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + symnode_0 = 128 + for offset_2 in tl.range(0, symnode_0.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < symnode_0 + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + # src[test_specialize.py:N]: out[tile_m, tile_n] = acc.to(x.dtype) + v_0 = tl.cast(acc, tl.float16) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_0, mask_0[:, None] & mask_1[None, :]) + +def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_specialize.py:N]: m, k = x.size() + m, k = x.size() + # src[test_specialize.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_specialize.py:N]: out = torch.empty([m, n], device=x.device, dtype=x.dtype) + out = torch.empty([m, n], device=x.device, dtype=x.dtype) + # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_specialize.py:N]: for tile_k in hl.tile(k): + # src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 32 + # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_specialize.py:N]: for tile_k in hl.tile(k): + # src[test_specialize.py:N-N]: ... + _launcher(_helion_matmul, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(56, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_specialize.py:N]: return out + return out + +--- assertExpectedJournal(TestMarkStatic.test_mark_static_and_hl_specialize) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_fn(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): + num_blocks_0 = tl.cdiv(320, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < 640 + # src[test_specialize.py:N]: out[tile] = x[tile] * 2 + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) + v_0 = 2.0 + v_1 = load * v_0 + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_1[None, :]) + +def fn(x: torch.Tensor, *, _launcher=_default_launcher): + # src[test_specialize.py:N]: out = torch.empty_like(x) + out = torch.empty_like(x) + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 16 + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): + # src[test_specialize.py:N]: out[tile] = x[tile] * 2 + _launcher(_helion_fn, (triton.cdiv(320, _BLOCK_SIZE_0) * triton.cdiv(640, _BLOCK_SIZE_1),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) + # src[test_specialize.py:N]: return out + return out + --- assertExpectedJournal(TestSpecialize.test_dynamic_size_block_non_power_of_two) from __future__ import annotations diff --git a/test/test_specialize.py b/test/test_specialize.py index b4224f4b2..197890517 100644 --- a/test/test_specialize.py +++ b/test/test_specialize.py @@ -446,5 +446,92 @@ def fn(x: torch.Tensor) -> torch.Tensor: self.assertExpectedJournal(code) +@skipIfCpu("needs to be debugged") +class TestMarkStatic(RefEagerTestBase, TestCase): + """Tests for torch._dynamo.mark_static() external specialization API.""" + + maxDiff = 163842 + + def test_mark_static(self): + """Test mark_static: multiple tensors, multiple dims, negative indexing.""" + + @helion.kernel(autotune_effort="none", static_shapes=False) + def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + k2, n = y.size() + out = torch.empty([m, n], device=x.device, dtype=x.dtype) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc.to(x.dtype) + return out + + m, k, n = 64, 128, 56 + + # First, run WITHOUT mark_static - dimensions should NOT be constants + x = torch.randn([m, k], device=DEVICE, dtype=torch.float16) + y = torch.randn([k, n], device=DEVICE, dtype=torch.float16) + code_no_spec, result_no_spec = code_and_output( + matmul, (x, y), block_sizes=[32, 32, 32] + ) + torch.testing.assert_close(result_no_spec, x @ y, rtol=1e-2, atol=1e-2) + self.assertNotIn("64", code_no_spec) + self.assertNotIn("128", code_no_spec) + self.assertNotIn("56", code_no_spec) + + # Now, run WITH mark_static - dimensions SHOULD be constants + x_static = torch.randn([m, k], device=DEVICE, dtype=torch.float16) + y_static = torch.randn([k, n], device=DEVICE, dtype=torch.float16) + torch._dynamo.mark_static(x_static, [0, -1]) # test list and negative index + torch._dynamo.mark_static(y_static, 1) + + code, result = code_and_output( + matmul, (x_static, y_static), block_sizes=[32, 32, 32] + ) + torch.testing.assert_close(result, x_static @ y_static, rtol=1e-2, atol=1e-2) + self.assertIn("64", code) + self.assertIn("128", code) + self.assertIn("56", code) + self.assertExpectedJournal(code) + + # Cache hit: same tensors + self.assertIs( + matmul.bind((x_static, y_static)), matmul.bind((x_static, y_static)) + ) + # Cache miss: different specialized values + x2 = torch.randn([48, 96], device=DEVICE, dtype=torch.float16) + y2 = torch.randn([96, 24], device=DEVICE, dtype=torch.float16) + torch._dynamo.mark_static(x2, [0, -1]) + torch._dynamo.mark_static(y2, 1) + self.assertIsNot(matmul.bind((x_static, y_static)), matmul.bind((x2, y2))) + + def test_mark_static_and_hl_specialize(self): + """Test that external mark_static and internal hl.specialize form a union.""" + + @helion.kernel(autotune_effort="none", static_shapes=False) + def fn(x: torch.Tensor) -> torch.Tensor: + hl.specialize(x.size(0)) # internal specialize on dim 0 + out = torch.empty_like(x) + for tile in hl.tile(x.size()): + out[tile] = x[tile] * 2 + return out + + # mark_static on dim 1 should combine with hl.specialize on dim 0 + x = torch.randn([320, 640], device=DEVICE) + torch._dynamo.mark_static(x, -1) + + code, result = code_and_output(fn, (x,), block_sizes=[16, 16]) + torch.testing.assert_close(result, x * 2) + self.assertIn("320", code) # dim 0 from hl.specialize + self.assertIn("640", code) # dim 1 from mark_static + self.assertExpectedJournal(code) + + # Cache miss: changing externally-specialized dim + x2 = torch.randn([320, 128], device=DEVICE) + torch._dynamo.mark_static(x2, -1) + self.assertIsNot(fn.bind((x,)), fn.bind((x2,))) + + if __name__ == "__main__": unittest.main() From a8369d50b0ffcbc2f39703bb31b50bac76108396 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 8 Dec 2025 22:40:55 -0800 Subject: [PATCH 2/5] up --- helion/_compiler/compile_environment.py | 10 ++++++++++ helion/_compiler/device_function.py | 4 +++- helion/_compiler/host_function.py | 6 +++++- helion/_testing.py | 8 ++++++++ helion/runtime/kernel.py | 21 +++++++++++++++++++++ test/test_examples.expected | 16 ---------------- test/test_tensor_descriptor.expected | 16 ---------------- 7 files changed, 47 insertions(+), 34 deletions(-) diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 7b1a48783..2ec32c23f 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -128,6 +128,16 @@ def __init__( 0 # Track number of loads in all device code for eviction policy tuning ) + def specialize_expr(self, expr: sympy.Expr) -> sympy.Expr: + """Substitute any specialized vars with their concrete values.""" + if subs := { + s: sympy.Integer(self.shape_env.size_hint(s)) + for s in expr.free_symbols & self.specialized_vars + }: + # pyrefly: ignore [bad-assignment] + expr = expr.xreplace(subs) + return expr + def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None: from .device_function import contains_only_block_size_symbols diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index a8203c0a6..d7f59dcf7 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -374,7 +374,8 @@ def set_pid(self, pid: ProgramIDs) -> None: self.pid = pid def sympy_expr(self, expr: sympy.Expr) -> str: - expr = CompileEnvironment.current().shape_env.simplify(expr) + env = CompileEnvironment.current() + expr = env.specialize_expr(env.shape_env.simplify(expr)) if not expr.free_symbols: return texpr(expr) if expr in self.expr_to_var_info: @@ -394,6 +395,7 @@ def sympy_expr(self, expr: sympy.Expr) -> str: replacements[sym] = sympy.Symbol( self._lift_sympy_arg(sym), integer=True ) + # pyrefly: ignore [bad-argument-type] return texpr(expr.xreplace(replacements)) def _lift_sympy_arg(self, expr: sympy.Expr) -> str: diff --git a/helion/_compiler/host_function.py b/helion/_compiler/host_function.py index e1cf5ff00..0f5ce286c 100644 --- a/helion/_compiler/host_function.py +++ b/helion/_compiler/host_function.py @@ -191,7 +191,10 @@ def set_local_types(self, local_types: dict[str, TypeInfo]) -> None: type_info.populate_symbol_origins(NameOrigin(name, fn)) def sympy_expr(self, expr: sympy.Expr) -> str: - expr = CompileEnvironment.current().shape_env.simplify(expr) + env = CompileEnvironment.current() + expr = env.specialize_expr(env.shape_env.simplify(expr)) + if not expr.free_symbols: + return pexpr(expr) if expr in self.expr_to_origin: return self.expr_to_origin[expr].origin.host_str() replacements = {} @@ -199,6 +202,7 @@ def sympy_expr(self, expr: sympy.Expr) -> str: assert isinstance(sym, sympy.Symbol) origin = self.expr_to_origin[sym].origin replacements[sym] = sympy.Symbol(origin.host_str(), integer=True) + # pyrefly: ignore [bad-argument-type] return pexpr(expr.xreplace(replacements)) def literal_expr(self, expr: object) -> str: diff --git a/helion/_testing.py b/helion/_testing.py index 9aaa05daa..f99201cad 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -511,6 +511,14 @@ def assertNotIn( if not self._in_ref_eager_mode: super().assertNotIn(member, container, msg) # type: ignore[misc] + def assertIs(self, expr1: object, expr2: object, msg: str | None = None) -> None: + if not self._in_ref_eager_mode: + super().assertIs(expr1, expr2, msg) # type: ignore[misc] + + def assertIsNot(self, expr1: object, expr2: object, msg: str | None = None) -> None: + if not self._in_ref_eager_mode: + super().assertIsNot(expr1, expr2, msg) # type: ignore[misc] + def assertTrueIfInNormalMode(self, condition: bool, msg: str | None = None) -> None: if not self._in_ref_eager_mode: self.assertTrue(condition, msg) # type: ignore[attr-defined] diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index f4527d8e6..b5b361f0b 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -403,6 +403,9 @@ def __init__( constexpr_args[name] = arg else: self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name))) + + self._apply_mark_static(args) + with ( _maybe_skip_dtype_check_in_meta_registrations(), patch_inductor_lowerings(), @@ -420,6 +423,20 @@ def __init__( self.maybe_log_repro(log.warning, args, config=config) raise + def _apply_mark_static(self, args: tuple[object, ...]) -> None: + """ + Apply torch._dynamo.mark_static() markings from input tensors. + + This reads _dynamo_static_indices from each tensor argument and marks + the corresponding dimensions as specialized (constant) in the kernel. + """ + for arg, fake_arg in zip(args, self.fake_args, strict=True): + if isinstance(arg, torch.Tensor) and isinstance(fake_arg, torch.Tensor): + for dim in getattr(arg, "_dynamo_static_indices", ()): + size = fake_arg.size(dim) + if isinstance(size, torch.SymInt): + self.env.specialized_vars.update(size._sympy_().free_symbols) + @property def settings(self) -> Settings: """ @@ -891,12 +908,14 @@ def kernel( def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable: # NOTE: If a machine has two different gpu types on the same machine, # obj.device.type will incorrectly hit + static_indices = frozenset(getattr(obj, "_dynamo_static_indices", ())) if fn.settings.static_shapes: return ( obj.dtype, obj.device.type, (*obj.size(),), (*obj.stride(),), + static_indices, ) bucketed = tuple([min(s, 2) for s in obj.size()]) if fn.settings.index_dtype is None: @@ -909,11 +928,13 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable: obj.device.type, bucketed, needs_int64, + static_indices, ) return ( obj.dtype, obj.device.type, bucketed, + static_indices, ) diff --git a/test/test_examples.expected b/test/test_examples.expected index 1931c3d6e..92f0edf38 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -460,27 +460,11 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la _RDIM_SIZE_2 = 64 # src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) _BLOCK_SIZE_0 = 1 - # src[attention.py:N]: q = q_view[tile_b, tile_m, :] - _SHAPE_DIM = q_in.size(3) - _SHAPE_DIM_1 = q_in.size(3) - _SHAPE_DIM_2 = q_in.size(3) # src[attention.py:N]: for tile_n in hl.tile(v_view.size(1)): # src[attention.py:N]: k = k_view[tile_b, :, tile_n] # src[attention.py:N]: qk = torch.bmm(q, k) # src[attention.py:N-N]: ... _BLOCK_SIZE_3 = 32 - # src[attention.py:N]: k = k_view[tile_b, :, tile_n] - _SHAPE_DIM_3 = q_in.size(3) - _SHAPE_DIM_4 = q_in.size(3) - _SHAPE_DIM_5 = q_in.size(3) - # src[attention.py:N]: v = v_view[tile_b, tile_n, :] - _SHAPE_DIM_6 = q_in.size(3) - _SHAPE_DIM_7 = q_in.size(3) - _SHAPE_DIM_8 = q_in.size(3) - # src[attention.py:N]: out[tile_b, tile_m, :] = acc.to(out.dtype) - _SHAPE_DIM_9 = q_in.size(3) - _SHAPE_DIM_10 = q_in.size(3) - _SHAPE_DIM_11 = q_in.size(3) # src[attention.py:N]: for tile_b, tile_m in hl.tile([q_view.size(0), m_dim]): # src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) # src[attention.py:N]: l_i = torch.full_like(m_i, 1.0) diff --git a/test/test_tensor_descriptor.expected b/test/test_tensor_descriptor.expected index d8ddcc09f..d1145a8fe 100644 --- a/test/test_tensor_descriptor.expected +++ b/test/test_tensor_descriptor.expected @@ -123,27 +123,11 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la _RDIM_SIZE_2 = 64 # src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) _BLOCK_SIZE_0 = 1 - # src[attention.py:N]: q = q_view[tile_b, tile_m, :] - _SHAPE_DIM = q_in.size(3) - _SHAPE_DIM_1 = q_in.size(3) - _SHAPE_DIM_2 = q_in.size(3) # src[attention.py:N]: for tile_n in hl.tile(v_view.size(1)): # src[attention.py:N]: k = k_view[tile_b, :, tile_n] # src[attention.py:N]: qk = torch.bmm(q, k) # src[attention.py:N-N]: ... _BLOCK_SIZE_3 = 16 - # src[attention.py:N]: k = k_view[tile_b, :, tile_n] - _SHAPE_DIM_3 = q_in.size(3) - _SHAPE_DIM_4 = q_in.size(3) - _SHAPE_DIM_5 = q_in.size(3) - # src[attention.py:N]: v = v_view[tile_b, tile_n, :] - _SHAPE_DIM_6 = q_in.size(3) - _SHAPE_DIM_7 = q_in.size(3) - _SHAPE_DIM_8 = q_in.size(3) - # src[attention.py:N]: out[tile_b, tile_m, :] = acc.to(out.dtype) - _SHAPE_DIM_9 = q_in.size(3) - _SHAPE_DIM_10 = q_in.size(3) - _SHAPE_DIM_11 = q_in.size(3) # src[attention.py:N]: for tile_b, tile_m in hl.tile([q_view.size(0), m_dim]): # src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) # src[attention.py:N]: l_i = torch.full_like(m_i, 1.0) From d91f8dec646fbff9ca504ebd5996c5d1a4414aa5 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 12 Dec 2025 10:36:41 -0800 Subject: [PATCH 3/5] update docs --- docs/deployment_autotuning.md | 116 ++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/docs/deployment_autotuning.md b/docs/deployment_autotuning.md index b79ee046d..c2a07cf13 100644 --- a/docs/deployment_autotuning.md +++ b/docs/deployment_autotuning.md @@ -178,6 +178,122 @@ def my_kernel(x, y): See {doc}`api/kernel` for the full decorator reference. +## Selective Shape Specialization + +The `static_shapes` setting is all-or-nothing: either every dimension is +specialized (`static_shapes=True`) or dimensions are bucketed dynamically +(`static_shapes=False`). Sometimes you want finer control - specializing +only specific dimensions while keeping others dynamic. + +Helion provides two APIs for selective shape specialization: + +| API | Location | Effect | +|-----|----------|--------| +| `hl.specialize()` | Inside kernel | Dimension always specialized for all calls | +| `torch._dynamo.mark_static()` | Outside kernel | Dimension specialized only for marked tensors | + +### `hl.specialize()` - Internal Specialization + +Use {func}`~helion.language.specialize` inside the kernel to make specific +dimensions compile-time constants. This applies to **every call** to the kernel: + +```python +import torch +import helion +import helion.language as hl + +@helion.kernel(static_shapes=False) +def rms_norm_fwd( + x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5 +) -> torch.Tensor: + m, n = x.size() + hl.specialize(n) # hidden dimension becomes a compile-time constant + out = torch.empty_like(x) + for tile_m in hl.tile(m): + x_tile = x[tile_m, :].to(torch.float32) + x_squared = x_tile * x_tile + mean_x_squared = torch.mean(x_squared, dim=-1) + inv_rms = torch.rsqrt(mean_x_squared + eps) + normalized = x_tile * inv_rms[:, None] + out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype) + return out + +# Every call specializes on n - different hidden sizes = different cache entries +weight_4096 = torch.randn([4096], device="cuda") +weight_2048 = torch.randn([2048], device="cuda") +result1 = rms_norm_fwd(torch.randn([2048, 4096], device="cuda"), weight_4096) # compiles for n=4096 +result2 = rms_norm_fwd(torch.randn([1024, 4096], device="cuda"), weight_4096) # reuses n=4096 +result3 = rms_norm_fwd(torch.randn([2048, 2048], device="cuda"), weight_2048) # compiles for n=2048 +``` + +Use `hl.specialize()` when a dimension is performance-critical and you want +it specialized regardless of how the kernel is called. + +### `torch._dynamo.mark_static()` - External Specialization + +Use `torch._dynamo.mark_static()` **before** calling the kernel to specialize +dimensions on specific tensors. This is useful when you want the **same kernel** +to serve both dynamic and specialized code paths: + +```python +@helion.kernel(static_shapes=False) +def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + k2, n = y.size() + out = torch.empty([m, n], device=x.device, dtype=x.dtype) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc.to(x.dtype) + return out + +# Dynamic call - all dimensions remain symbolic +x_dyn = torch.randn([m, k], device="cuda", dtype=torch.float16) +y_dyn = torch.randn([k, n], device="cuda", dtype=torch.float16) +result = matmul(x_dyn, y_dyn) + +# Specialized call - mark specific dimensions as compile-time constants +x_opt = torch.randn([64, 128], device="cuda", dtype=torch.float16) +y_opt = torch.randn([128, 56], device="cuda", dtype=torch.float16) +torch._dynamo.mark_static(x_opt, [0, -1]) # specialize dims 0 and -1 (M and K) +torch._dynamo.mark_static(y_opt, 1) # specialize dim 1 (N) +result = matmul(x_opt, y_opt) # generates code with 64, 128, 56 as constants +``` + +This pattern enables a **single kernel definition** to serve both: +- Fully dynamic fallback paths (for rare edge-case shapes) +- Optimized hot paths (with shape constants baked into generated code) + +### Combining Both APIs + +The two APIs form a **union** - you can use `hl.specialize()` for dimensions +that should always be specialized, and `mark_static()` for additional +per-call specialization: + +```python +@helion.kernel(static_shapes=False) +def fn(x: torch.Tensor) -> torch.Tensor: + hl.specialize(x.size(0)) # dim 0 always specialized (internal) + out = torch.empty_like(x) + for tile in hl.tile(x.size()): + out[tile] = x[tile] * 2 + return out + +# mark_static on dim 1 combines with hl.specialize on dim 0 +x = torch.randn([320, 640], device="cuda") +torch._dynamo.mark_static(x, -1) # specialize dim 1 (external) +result = fn(x) # both 320 and 640 become constants +``` + +### Cache Behavior + +Each unique combination of specialized dimension values creates a separate +cache entry: +- Unspecialized calls share one dynamic cache entry +- Calls with `mark_static()` create entries keyed by the specialized values +- Different specialized values (e.g., `[64, 128]` vs `[48, 96]`) create separate entries + ## Advanced Manual Deployment Some teams prefer to skip all runtime selection, using Helion only as From c1abf78406ff2ae9ff92aa9551b003aa69cd4162 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 12 Dec 2025 12:00:46 -0800 Subject: [PATCH 4/5] fix after rebase --- helion/_compiler/device_function.py | 5 ----- test/test_specialize.expected | 2 +- test/test_specialize.py | 14 +++++++++++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index d7f59dcf7..2cdc414ef 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -617,11 +617,6 @@ def tensor_stride(self, fake_value: torch.Tensor, dim: int) -> Argument: if isinstance(v, int): if env.settings.static_shapes: return StaticShape(v) - else: - # Check if all free symbols are specialized - syms = v._sympy_().free_symbols - if syms and syms <= env.specialized_vars: - return StaticShape(int(v)) return self._tensor_property(TensorStrideArg, fake_value, dim, "stride") def sorted_args(self) -> list[Argument]: diff --git a/test/test_specialize.expected b/test/test_specialize.expected index 7a113c6a4..c34513b23 100644 --- a/test/test_specialize.expected +++ b/test/test_specialize.expected @@ -459,7 +459,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 32 # src[test_specialize.py:N]: for tile in hl.tile(n): # src[test_specialize.py:N]: out[tile] = x[tile] + 1 - _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1) + _launcher(_helion_fn, (triton.cdiv(137, _BLOCK_SIZE_0),), x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out diff --git a/test/test_specialize.py b/test/test_specialize.py index 197890517..50cf4bd52 100644 --- a/test/test_specialize.py +++ b/test/test_specialize.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +import re import unittest import torch @@ -16,6 +17,11 @@ import helion.language as hl +def _strip_source_comments(code: str) -> str: + """Remove source line references (# src[...]) that contain line numbers.""" + return re.sub(r"# src\[.*?\].*", "", code) + + @skipIfCpu("needs to be debugged") class TestSpecialize(RefEagerTestBase, TestCase): maxDiff = 163842 @@ -476,9 +482,11 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: matmul, (x, y), block_sizes=[32, 32, 32] ) torch.testing.assert_close(result_no_spec, x @ y, rtol=1e-2, atol=1e-2) - self.assertNotIn("64", code_no_spec) - self.assertNotIn("128", code_no_spec) - self.assertNotIn("56", code_no_spec) + # Strip source line comments to avoid matching line numbers like "464" + code_stripped = _strip_source_comments(code_no_spec) + self.assertNotIn("64", code_stripped) + self.assertNotIn("128", code_stripped) + self.assertNotIn("56", code_stripped) # Now, run WITH mark_static - dimensions SHOULD be constants x_static = torch.randn([m, k], device=DEVICE, dtype=torch.float16) From 1d1e9ccbc12633e7f68a871bc248cf41aa31497c Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 12 Dec 2025 13:29:58 -0800 Subject: [PATCH 5/5] simplify --- test/test_specialize.expected | 8 ++++---- test/test_specialize.py | 20 ++++++-------------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/test/test_specialize.expected b/test/test_specialize.expected index c34513b23..839368543 100644 --- a/test/test_specialize.expected +++ b/test/test_specialize.expected @@ -12,15 +12,15 @@ from helion.runtime import default_launcher as _default_launcher @triton.jit def _helion_matmul(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]): - num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0) + num_blocks_0 = tl.cdiv(96, _BLOCK_SIZE_0) pid_0 = tl.program_id(0) % num_blocks_0 pid_1 = tl.program_id(0) // num_blocks_0 offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) - mask_0 = indices_0 < 64 + mask_0 = indices_0 < 96 offset_1 = pid_1 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) - mask_1 = indices_1 < 56 + mask_1 = indices_1 < 48 # src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) # src[test_specialize.py:N]: for tile_k in hl.tile(k): @@ -56,7 +56,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): # src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_specialize.py:N]: for tile_k in hl.tile(k): # src[test_specialize.py:N-N]: ... - _launcher(_helion_matmul, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(56, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + _launcher(_helion_matmul, (triton.cdiv(96, _BLOCK_SIZE_0) * triton.cdiv(48, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out diff --git a/test/test_specialize.py b/test/test_specialize.py index 50cf4bd52..57a1eadf8 100644 --- a/test/test_specialize.py +++ b/test/test_specialize.py @@ -1,7 +1,6 @@ from __future__ import annotations import math -import re import unittest import torch @@ -17,11 +16,6 @@ import helion.language as hl -def _strip_source_comments(code: str) -> str: - """Remove source line references (# src[...]) that contain line numbers.""" - return re.sub(r"# src\[.*?\].*", "", code) - - @skipIfCpu("needs to be debugged") class TestSpecialize(RefEagerTestBase, TestCase): maxDiff = 163842 @@ -473,7 +467,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: out[tile_m, tile_n] = acc.to(x.dtype) return out - m, k, n = 64, 128, 56 + m, k, n = 96, 128, 48 # First, run WITHOUT mark_static - dimensions should NOT be constants x = torch.randn([m, k], device=DEVICE, dtype=torch.float16) @@ -482,11 +476,9 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: matmul, (x, y), block_sizes=[32, 32, 32] ) torch.testing.assert_close(result_no_spec, x @ y, rtol=1e-2, atol=1e-2) - # Strip source line comments to avoid matching line numbers like "464" - code_stripped = _strip_source_comments(code_no_spec) - self.assertNotIn("64", code_stripped) - self.assertNotIn("128", code_stripped) - self.assertNotIn("56", code_stripped) + self.assertNotIn("96", code_no_spec) + self.assertNotIn("128", code_no_spec) + self.assertNotIn("48", code_no_spec) # Now, run WITH mark_static - dimensions SHOULD be constants x_static = torch.randn([m, k], device=DEVICE, dtype=torch.float16) @@ -498,9 +490,9 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: matmul, (x_static, y_static), block_sizes=[32, 32, 32] ) torch.testing.assert_close(result, x_static @ y_static, rtol=1e-2, atol=1e-2) - self.assertIn("64", code) + self.assertIn("96", code) self.assertIn("128", code) - self.assertIn("56", code) + self.assertIn("48", code) self.assertExpectedJournal(code) # Cache hit: same tensors