Skip to content

Commit 531cbdc

Browse files
authored
Use torch._dynamo.mark_static() API to allow tensor shape specialization outside of the kernel code (#1210)
1 parent 28cc903 commit 531cbdc

File tree

10 files changed

+347
-40
lines changed

10 files changed

+347
-40
lines changed

docs/deployment_autotuning.md

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,122 @@ def my_kernel(x, y):
178178

179179
See {doc}`api/kernel` for the full decorator reference.
180180

181+
## Selective Shape Specialization
182+
183+
The `static_shapes` setting is all-or-nothing: either every dimension is
184+
specialized (`static_shapes=True`) or dimensions are bucketed dynamically
185+
(`static_shapes=False`). Sometimes you want finer control - specializing
186+
only specific dimensions while keeping others dynamic.
187+
188+
Helion provides two APIs for selective shape specialization:
189+
190+
| API | Location | Effect |
191+
|-----|----------|--------|
192+
| `hl.specialize()` | Inside kernel | Dimension always specialized for all calls |
193+
| `torch._dynamo.mark_static()` | Outside kernel | Dimension specialized only for marked tensors |
194+
195+
### `hl.specialize()` - Internal Specialization
196+
197+
Use {func}`~helion.language.specialize` inside the kernel to make specific
198+
dimensions compile-time constants. This applies to **every call** to the kernel:
199+
200+
```python
201+
import torch
202+
import helion
203+
import helion.language as hl
204+
205+
@helion.kernel(static_shapes=False)
206+
def rms_norm_fwd(
207+
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5
208+
) -> torch.Tensor:
209+
m, n = x.size()
210+
hl.specialize(n) # hidden dimension becomes a compile-time constant
211+
out = torch.empty_like(x)
212+
for tile_m in hl.tile(m):
213+
x_tile = x[tile_m, :].to(torch.float32)
214+
x_squared = x_tile * x_tile
215+
mean_x_squared = torch.mean(x_squared, dim=-1)
216+
inv_rms = torch.rsqrt(mean_x_squared + eps)
217+
normalized = x_tile * inv_rms[:, None]
218+
out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype)
219+
return out
220+
221+
# Every call specializes on n - different hidden sizes = different cache entries
222+
weight_4096 = torch.randn([4096], device="cuda")
223+
weight_2048 = torch.randn([2048], device="cuda")
224+
result1 = rms_norm_fwd(torch.randn([2048, 4096], device="cuda"), weight_4096) # compiles for n=4096
225+
result2 = rms_norm_fwd(torch.randn([1024, 4096], device="cuda"), weight_4096) # reuses n=4096
226+
result3 = rms_norm_fwd(torch.randn([2048, 2048], device="cuda"), weight_2048) # compiles for n=2048
227+
```
228+
229+
Use `hl.specialize()` when a dimension is performance-critical and you want
230+
it specialized regardless of how the kernel is called.
231+
232+
### `torch._dynamo.mark_static()` - External Specialization
233+
234+
Use `torch._dynamo.mark_static()` **before** calling the kernel to specialize
235+
dimensions on specific tensors. This is useful when you want the **same kernel**
236+
to serve both dynamic and specialized code paths:
237+
238+
```python
239+
@helion.kernel(static_shapes=False)
240+
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
241+
m, k = x.size()
242+
k2, n = y.size()
243+
out = torch.empty([m, n], device=x.device, dtype=x.dtype)
244+
for tile_m, tile_n in hl.tile([m, n]):
245+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
246+
for tile_k in hl.tile(k):
247+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
248+
out[tile_m, tile_n] = acc.to(x.dtype)
249+
return out
250+
251+
# Dynamic call - all dimensions remain symbolic
252+
x_dyn = torch.randn([m, k], device="cuda", dtype=torch.float16)
253+
y_dyn = torch.randn([k, n], device="cuda", dtype=torch.float16)
254+
result = matmul(x_dyn, y_dyn)
255+
256+
# Specialized call - mark specific dimensions as compile-time constants
257+
x_opt = torch.randn([64, 128], device="cuda", dtype=torch.float16)
258+
y_opt = torch.randn([128, 56], device="cuda", dtype=torch.float16)
259+
torch._dynamo.mark_static(x_opt, [0, -1]) # specialize dims 0 and -1 (M and K)
260+
torch._dynamo.mark_static(y_opt, 1) # specialize dim 1 (N)
261+
result = matmul(x_opt, y_opt) # generates code with 64, 128, 56 as constants
262+
```
263+
264+
This pattern enables a **single kernel definition** to serve both:
265+
- Fully dynamic fallback paths (for rare edge-case shapes)
266+
- Optimized hot paths (with shape constants baked into generated code)
267+
268+
### Combining Both APIs
269+
270+
The two APIs form a **union** - you can use `hl.specialize()` for dimensions
271+
that should always be specialized, and `mark_static()` for additional
272+
per-call specialization:
273+
274+
```python
275+
@helion.kernel(static_shapes=False)
276+
def fn(x: torch.Tensor) -> torch.Tensor:
277+
hl.specialize(x.size(0)) # dim 0 always specialized (internal)
278+
out = torch.empty_like(x)
279+
for tile in hl.tile(x.size()):
280+
out[tile] = x[tile] * 2
281+
return out
282+
283+
# mark_static on dim 1 combines with hl.specialize on dim 0
284+
x = torch.randn([320, 640], device="cuda")
285+
torch._dynamo.mark_static(x, -1) # specialize dim 1 (external)
286+
result = fn(x) # both 320 and 640 become constants
287+
```
288+
289+
### Cache Behavior
290+
291+
Each unique combination of specialized dimension values creates a separate
292+
cache entry:
293+
- Unspecialized calls share one dynamic cache entry
294+
- Calls with `mark_static()` create entries keyed by the specialized values
295+
- Different specialized values (e.g., `[64, 128]` vs `[48, 96]`) create separate entries
296+
181297
## Advanced Manual Deployment
182298

183299
Some teams prefer to skip all runtime selection, using Helion only as

helion/_compiler/compile_environment.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ def __init__(
128128
0 # Track number of loads in all device code for eviction policy tuning
129129
)
130130

131+
def specialize_expr(self, expr: sympy.Expr) -> sympy.Expr:
132+
"""Substitute any specialized vars with their concrete values."""
133+
if subs := {
134+
s: sympy.Integer(self.shape_env.size_hint(s))
135+
for s in expr.free_symbols & self.specialized_vars
136+
}:
137+
# pyrefly: ignore [bad-assignment]
138+
expr = expr.xreplace(subs)
139+
return expr
140+
131141
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
132142
from .device_function import contains_only_block_size_symbols
133143

helion/_compiler/device_function.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ def set_pid(self, pid: ProgramIDs) -> None:
374374
self.pid = pid
375375

376376
def sympy_expr(self, expr: sympy.Expr) -> str:
377-
expr = CompileEnvironment.current().shape_env.simplify(expr)
377+
env = CompileEnvironment.current()
378+
expr = env.specialize_expr(env.shape_env.simplify(expr))
378379
if not expr.free_symbols:
379380
return texpr(expr)
380381
if expr in self.expr_to_var_info:
@@ -394,6 +395,7 @@ def sympy_expr(self, expr: sympy.Expr) -> str:
394395
replacements[sym] = sympy.Symbol(
395396
self._lift_sympy_arg(sym), integer=True
396397
)
398+
# pyrefly: ignore [bad-argument-type]
397399
return texpr(expr.xreplace(replacements))
398400

399401
def _lift_sympy_arg(self, expr: sympy.Expr) -> str:
@@ -615,11 +617,6 @@ def tensor_stride(self, fake_value: torch.Tensor, dim: int) -> Argument:
615617
if isinstance(v, int):
616618
if env.settings.static_shapes:
617619
return StaticShape(v)
618-
else:
619-
# Check if all free symbols are specialized
620-
syms = v._sympy_().free_symbols
621-
if syms and syms <= env.specialized_vars:
622-
return StaticShape(int(v))
623620
return self._tensor_property(TensorStrideArg, fake_value, dim, "stride")
624621

625622
def sorted_args(self) -> list[Argument]:

helion/_compiler/host_function.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,18 @@ def set_local_types(self, local_types: dict[str, TypeInfo]) -> None:
191191
type_info.populate_symbol_origins(NameOrigin(name, fn))
192192

193193
def sympy_expr(self, expr: sympy.Expr) -> str:
194-
expr = CompileEnvironment.current().shape_env.simplify(expr)
194+
env = CompileEnvironment.current()
195+
expr = env.specialize_expr(env.shape_env.simplify(expr))
196+
if not expr.free_symbols:
197+
return pexpr(expr)
195198
if expr in self.expr_to_origin:
196199
return self.expr_to_origin[expr].origin.host_str()
197200
replacements = {}
198201
for sym in sorted(expr.free_symbols, key=lambda x: x.name):
199202
assert isinstance(sym, sympy.Symbol)
200203
origin = self.expr_to_origin[sym].origin
201204
replacements[sym] = sympy.Symbol(origin.host_str(), integer=True)
205+
# pyrefly: ignore [bad-argument-type]
202206
return pexpr(expr.xreplace(replacements))
203207

204208
def literal_expr(self, expr: object) -> str:

helion/_testing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,14 @@ def assertNotIn(
511511
if not self._in_ref_eager_mode:
512512
super().assertNotIn(member, container, msg) # type: ignore[misc]
513513

514+
def assertIs(self, expr1: object, expr2: object, msg: str | None = None) -> None:
515+
if not self._in_ref_eager_mode:
516+
super().assertIs(expr1, expr2, msg) # type: ignore[misc]
517+
518+
def assertIsNot(self, expr1: object, expr2: object, msg: str | None = None) -> None:
519+
if not self._in_ref_eager_mode:
520+
super().assertIsNot(expr1, expr2, msg) # type: ignore[misc]
521+
514522
def assertTrueIfInNormalMode(self, condition: bool, msg: str | None = None) -> None:
515523
if not self._in_ref_eager_mode:
516524
self.assertTrue(condition, msg) # type: ignore[attr-defined]

helion/runtime/kernel.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,9 @@ def __init__(
403403
constexpr_args[name] = arg
404404
else:
405405
self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name)))
406+
407+
self._apply_mark_static(args)
408+
406409
with (
407410
_maybe_skip_dtype_check_in_meta_registrations(),
408411
patch_inductor_lowerings(),
@@ -420,6 +423,20 @@ def __init__(
420423
self.maybe_log_repro(log.warning, args, config=config)
421424
raise
422425

426+
def _apply_mark_static(self, args: tuple[object, ...]) -> None:
427+
"""
428+
Apply torch._dynamo.mark_static() markings from input tensors.
429+
430+
This reads _dynamo_static_indices from each tensor argument and marks
431+
the corresponding dimensions as specialized (constant) in the kernel.
432+
"""
433+
for arg, fake_arg in zip(args, self.fake_args, strict=True):
434+
if isinstance(arg, torch.Tensor) and isinstance(fake_arg, torch.Tensor):
435+
for dim in getattr(arg, "_dynamo_static_indices", ()):
436+
size = fake_arg.size(dim)
437+
if isinstance(size, torch.SymInt):
438+
self.env.specialized_vars.update(size._sympy_().free_symbols)
439+
423440
@property
424441
def settings(self) -> Settings:
425442
"""
@@ -891,12 +908,14 @@ def kernel(
891908
def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
892909
# NOTE: If a machine has two different gpu types on the same machine,
893910
# obj.device.type will incorrectly hit
911+
static_indices = frozenset(getattr(obj, "_dynamo_static_indices", ()))
894912
if fn.settings.static_shapes:
895913
return (
896914
obj.dtype,
897915
obj.device.type,
898916
(*obj.size(),),
899917
(*obj.stride(),),
918+
static_indices,
900919
)
901920
bucketed = tuple([min(s, 2) for s in obj.size()])
902921
if fn.settings.index_dtype is None:
@@ -909,11 +928,13 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
909928
obj.device.type,
910929
bucketed,
911930
needs_int64,
931+
static_indices,
912932
)
913933
return (
914934
obj.dtype,
915935
obj.device.type,
916936
bucketed,
937+
static_indices,
917938
)
918939

919940

test/test_examples.expected

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -460,27 +460,11 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
460460
_RDIM_SIZE_2 = 64
461461
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
462462
_BLOCK_SIZE_0 = 1
463-
# src[attention.py:N]: q = q_view[tile_b, tile_m, :]
464-
_SHAPE_DIM = q_in.size(3)
465-
_SHAPE_DIM_1 = q_in.size(3)
466-
_SHAPE_DIM_2 = q_in.size(3)
467463
# src[attention.py:N]: for tile_n in hl.tile(v_view.size(1)):
468464
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
469465
# src[attention.py:N]: qk = torch.bmm(q, k)
470466
# src[attention.py:N-N]: ...
471467
_BLOCK_SIZE_3 = 32
472-
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
473-
_SHAPE_DIM_3 = q_in.size(3)
474-
_SHAPE_DIM_4 = q_in.size(3)
475-
_SHAPE_DIM_5 = q_in.size(3)
476-
# src[attention.py:N]: v = v_view[tile_b, tile_n, :]
477-
_SHAPE_DIM_6 = q_in.size(3)
478-
_SHAPE_DIM_7 = q_in.size(3)
479-
_SHAPE_DIM_8 = q_in.size(3)
480-
# src[attention.py:N]: out[tile_b, tile_m, :] = acc.to(out.dtype)
481-
_SHAPE_DIM_9 = q_in.size(3)
482-
_SHAPE_DIM_10 = q_in.size(3)
483-
_SHAPE_DIM_11 = q_in.size(3)
484468
# src[attention.py:N]: for tile_b, tile_m in hl.tile([q_view.size(0), m_dim]):
485469
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
486470
# src[attention.py:N]: l_i = torch.full_like(m_i, 1.0)

0 commit comments

Comments
 (0)