@@ -178,6 +178,122 @@ def my_kernel(x, y):
178178
179179See {doc}` api/kernel ` for the full decorator reference.
180180
181+ ## Selective Shape Specialization
182+
183+ The ` static_shapes ` setting is all-or-nothing: either every dimension is
184+ specialized (` static_shapes=True ` ) or dimensions are bucketed dynamically
185+ (` static_shapes=False ` ). Sometimes you want finer control - specializing
186+ only specific dimensions while keeping others dynamic.
187+
188+ Helion provides two APIs for selective shape specialization:
189+
190+ | API | Location | Effect |
191+ | -----| ----------| --------|
192+ | ` hl.specialize() ` | Inside kernel | Dimension always specialized for all calls |
193+ | ` torch._dynamo.mark_static() ` | Outside kernel | Dimension specialized only for marked tensors |
194+
195+ ### ` hl.specialize() ` - Internal Specialization
196+
197+ Use {func}` ~helion.language.specialize ` inside the kernel to make specific
198+ dimensions compile-time constants. This applies to ** every call** to the kernel:
199+
200+ ``` python
201+ import torch
202+ import helion
203+ import helion.language as hl
204+
205+ @helion.kernel (static_shapes = False )
206+ def rms_norm_fwd (
207+ x : torch.Tensor, weight : torch.Tensor, eps : float = 1e-5
208+ ) -> torch.Tensor:
209+ m, n = x.size()
210+ hl.specialize(n) # hidden dimension becomes a compile-time constant
211+ out = torch.empty_like(x)
212+ for tile_m in hl.tile(m):
213+ x_tile = x[tile_m, :].to(torch.float32)
214+ x_squared = x_tile * x_tile
215+ mean_x_squared = torch.mean(x_squared, dim = - 1 )
216+ inv_rms = torch.rsqrt(mean_x_squared + eps)
217+ normalized = x_tile * inv_rms[:, None ]
218+ out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype)
219+ return out
220+
221+ # Every call specializes on n - different hidden sizes = different cache entries
222+ weight_4096 = torch.randn([4096 ], device = " cuda" )
223+ weight_2048 = torch.randn([2048 ], device = " cuda" )
224+ result1 = rms_norm_fwd(torch.randn([2048 , 4096 ], device = " cuda" ), weight_4096) # compiles for n=4096
225+ result2 = rms_norm_fwd(torch.randn([1024 , 4096 ], device = " cuda" ), weight_4096) # reuses n=4096
226+ result3 = rms_norm_fwd(torch.randn([2048 , 2048 ], device = " cuda" ), weight_2048) # compiles for n=2048
227+ ```
228+
229+ Use ` hl.specialize() ` when a dimension is performance-critical and you want
230+ it specialized regardless of how the kernel is called.
231+
232+ ### ` torch._dynamo.mark_static() ` - External Specialization
233+
234+ Use ` torch._dynamo.mark_static() ` ** before** calling the kernel to specialize
235+ dimensions on specific tensors. This is useful when you want the ** same kernel**
236+ to serve both dynamic and specialized code paths:
237+
238+ ``` python
239+ @helion.kernel (static_shapes = False )
240+ def matmul (x : torch.Tensor, y : torch.Tensor) -> torch.Tensor:
241+ m, k = x.size()
242+ k2, n = y.size()
243+ out = torch.empty([m, n], device = x.device, dtype = x.dtype)
244+ for tile_m, tile_n in hl.tile([m, n]):
245+ acc = hl.zeros([tile_m, tile_n], dtype = torch.float32)
246+ for tile_k in hl.tile(k):
247+ acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
248+ out[tile_m, tile_n] = acc.to(x.dtype)
249+ return out
250+
251+ # Dynamic call - all dimensions remain symbolic
252+ x_dyn = torch.randn([m, k], device = " cuda" , dtype = torch.float16)
253+ y_dyn = torch.randn([k, n], device = " cuda" , dtype = torch.float16)
254+ result = matmul(x_dyn, y_dyn)
255+
256+ # Specialized call - mark specific dimensions as compile-time constants
257+ x_opt = torch.randn([64 , 128 ], device = " cuda" , dtype = torch.float16)
258+ y_opt = torch.randn([128 , 56 ], device = " cuda" , dtype = torch.float16)
259+ torch._dynamo.mark_static(x_opt, [0 , - 1 ]) # specialize dims 0 and -1 (M and K)
260+ torch._dynamo.mark_static(y_opt, 1 ) # specialize dim 1 (N)
261+ result = matmul(x_opt, y_opt) # generates code with 64, 128, 56 as constants
262+ ```
263+
264+ This pattern enables a ** single kernel definition** to serve both:
265+ - Fully dynamic fallback paths (for rare edge-case shapes)
266+ - Optimized hot paths (with shape constants baked into generated code)
267+
268+ ### Combining Both APIs
269+
270+ The two APIs form a ** union** - you can use ` hl.specialize() ` for dimensions
271+ that should always be specialized, and ` mark_static() ` for additional
272+ per-call specialization:
273+
274+ ``` python
275+ @helion.kernel (static_shapes = False )
276+ def fn (x : torch.Tensor) -> torch.Tensor:
277+ hl.specialize(x.size(0 )) # dim 0 always specialized (internal)
278+ out = torch.empty_like(x)
279+ for tile in hl.tile(x.size()):
280+ out[tile] = x[tile] * 2
281+ return out
282+
283+ # mark_static on dim 1 combines with hl.specialize on dim 0
284+ x = torch.randn([320 , 640 ], device = " cuda" )
285+ torch._dynamo.mark_static(x, - 1 ) # specialize dim 1 (external)
286+ result = fn(x) # both 320 and 640 become constants
287+ ```
288+
289+ ### Cache Behavior
290+
291+ Each unique combination of specialized dimension values creates a separate
292+ cache entry:
293+ - Unspecialized calls share one dynamic cache entry
294+ - Calls with ` mark_static() ` create entries keyed by the specialized values
295+ - Different specialized values (e.g., ` [64, 128] ` vs ` [48, 96] ` ) create separate entries
296+
181297## Advanced Manual Deployment
182298
183299Some teams prefer to skip all runtime selection, using Helion only as
0 commit comments