Skip to content

Commit dc19935

Browse files
committed
up
1 parent 377873c commit dc19935

File tree

2 files changed

+156
-23
lines changed

2 files changed

+156
-23
lines changed

helion/runtime/kernel.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -352,26 +352,65 @@ def reset(self) -> None:
352352

353353
def specialize_args(self, **kwargs: list[int]) -> Kernel[_R]:
354354
"""
355-
Returns a kernel that will specialize on the given argument dimensions.
356-
This allows specialization decisions to be made outside the kernel,
355+
Returns a new kernel that will specialize on the given argument dimensions.
356+
The original kernel is not mutated - you can call the original kernel before
357+
or after this method and it will behave identically.
358+
359+
This allows specialization decisions to be made outside the kernel definition,
357360
binding to argument names via kwargs.
358361
359362
Args:
360-
**kwargs: Mapping of argument name -> dims to specialize on
361-
e.g., specialize_args(q_in=[-1], k_in=[-1])
363+
**kwargs: Mapping of argument name -> list of dimension indices to specialize.
364+
Supports negative indexing (e.g., -1 for last dimension).
365+
Example: specialize_args(x=[0, -1], y=[1])
362366
363367
Returns:
364-
Kernel: A new kernel with same settings and configs, adding the given
365-
specializations to any existing ones.
368+
Kernel: A new kernel with the same settings and configs, adding the given
369+
specializations to any existing ones. Can be chained for multiple arguments.
370+
371+
Examples:
372+
Basic usage - specialize specific dimensions:
373+
374+
@helion.kernel(static_shapes=False)
375+
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
376+
m, k = x.size()
377+
k2, n = y.size()
378+
...
379+
380+
# Original kernel - dimensions are dynamic (symbolic)
381+
result1 = matmul(x, y)
382+
383+
# Specialized kernel - m and k are compiled as constants
384+
specialized = matmul.specialize_args(x=[0, 1])
385+
result2 = specialized(x, y)
386+
387+
# Original kernel is unaffected by specialize_args
388+
result3 = matmul(x, y) # still uses dynamic dimensions
389+
390+
Calling original kernel before specialize_args:
391+
392+
# This works fine - original kernel is independent
393+
result1 = matmul(x, y) # dynamic dimensions
394+
395+
# Create specialized version - does NOT affect prior calls
396+
specialized = matmul.specialize_args(x=[0])
397+
result2 = specialized(x, y) # m is now a constant within the `specialized` kernel
398+
399+
Chaining specializations for multiple arguments:
400+
401+
# Specialize x's first dim and y's second dim
402+
chained = matmul.specialize_args(x=[0]).specialize_args(y=[1])
403+
result = chained(x, y)
404+
405+
Combining with hl.specialize() within the kernel:
366406
367-
Example:
368-
@helion.kernel
369-
def attention(q_in, k_in, v_in):
370-
head_dim = q_in.size(0) # Specialized if specified externally
371-
seq_len = k_in.size(1) # Specialized if specified externally
372-
...
407+
@helion.kernel(static_shapes=False)
408+
def fn(x: torch.Tensor) -> torch.Tensor:
409+
hl.specialize(x.size(0)) # Always specialize dim 0
410+
...
373411
374-
result = attention.specialize_args(q_in=[0], k_in=[1])(q, k, v)
412+
# Adds dim 1 specialization to the existing dim 0
413+
both_dims = fn.specialize_args(x=[1])
375414
"""
376415
if not kwargs:
377416
return self

test/test_specialize.py

Lines changed: 104 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ class TestSpecializeArgs(RefEagerTestBase, TestCase):
334334
maxDiff = 163842
335335

336336
def test_specialize_args(self):
337-
"""Test specialize_args: multiple tensors, multiple dims, negative indexing."""
337+
"""Test specialize_args(): multiple tensors, multiple dims, negative indexing."""
338338

339339
@helion.kernel(autotune_effort="none", static_shapes=False)
340340
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@@ -352,7 +352,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
352352
x = torch.randn([m, k], device=DEVICE, dtype=torch.float16)
353353
y = torch.randn([k, n], device=DEVICE, dtype=torch.float16)
354354

355-
# First, run WITHOUT specialize_args - dimensions should NOT be constants
355+
# First, run WITHOUT specialize_args() - dimensions should NOT be constants
356356
code_no_spec, result_no_spec = code_and_output(
357357
matmul,
358358
(x, y),
@@ -363,7 +363,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
363363
self.assertNotIn("128", code_no_spec) # x dim -1 = k should NOT be specialized
364364
self.assertNotIn("56", code_no_spec) # y dim 1 = n should NOT be specialized
365365

366-
# Now, run WITH specialize_args - dimensions SHOULD be constants
366+
# Now, run WITH specialize_args() - dimensions SHOULD be constants
367367
code, result = code_and_output(
368368
matmul.specialize_args(x=[0, -1], y=[1]),
369369
(x, y),
@@ -386,7 +386,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
386386
)
387387

388388
def test_specialize_args_and_hl_specialize(self):
389-
"""Test that external specialize_args and internal hl.specialize form a union."""
389+
"""Test that external specialize_args() and internal hl.specialize() form a union."""
390390

391391
@helion.kernel(autotune_effort="none", static_shapes=False)
392392
def dual_specialize(x: torch.Tensor) -> torch.Tensor:
@@ -399,7 +399,7 @@ def dual_specialize(x: torch.Tensor) -> torch.Tensor:
399399

400400
x = torch.randn([320, 640], device=DEVICE)
401401

402-
# First, run WITHOUT external specialize_args - only dim 0 should be specialized
402+
# First, run WITHOUT external specialize_args() - only dim 0 should be specialized
403403
code_no_spec, result_no_spec = code_and_output(
404404
dual_specialize,
405405
(x,),
@@ -409,7 +409,7 @@ def dual_specialize(x: torch.Tensor) -> torch.Tensor:
409409
self.assertIn("320", code_no_spec) # dim 0 from internal specialize
410410
self.assertNotIn("640", code_no_spec) # dim 1 should NOT be specialized
411411

412-
# Now, run WITH external specialize_args on dim -1 (dim 1)
412+
# Now, run WITH external specialize_args() on dim -1 (dim 1)
413413
# Result: both dim 0 AND dim 1 are specialized (union)
414414
code, result = code_and_output(
415415
dual_specialize.specialize_args(x=[-1]),
@@ -429,7 +429,7 @@ def dual_specialize(x: torch.Tensor) -> torch.Tensor:
429429

430430
@skipIfRefEager("Error checking not available in ref eager mode")
431431
def test_specialize_args_errors(self):
432-
"""Test error handling for invalid specialize_args usage."""
432+
"""Test error handling for invalid specialize_args() usage."""
433433

434434
@helion.kernel(autotune_effort="none", static_shapes=False)
435435
def fn(x: torch.Tensor) -> torch.Tensor:
@@ -450,7 +450,7 @@ def fn(x: torch.Tensor) -> torch.Tensor:
450450
self.assertIn("Unknown argument", str(cm.exception))
451451

452452
def test_specialize_args_chaining(self):
453-
"""Test that chained specialize_args calls merge specializations."""
453+
"""Test that chained specialize_args() calls merge specializations."""
454454

455455
@helion.kernel(autotune_effort="none", static_shapes=False)
456456
def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@@ -464,13 +464,13 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
464464
x = torch.randn([37, 64], device=DEVICE)
465465
y = torch.randn([48, 127], device=DEVICE)
466466

467-
# First, run WITHOUT specialize_args - dimensions should NOT be constants
467+
# First, run WITHOUT specialize_args() - dimensions should NOT be constants
468468
code_no_spec, result_no_spec = code_and_output(fn, (x, y), block_sizes=[16, 16])
469469
torch.testing.assert_close(result_no_spec, x * 127)
470470
self.assertNotIn("37", code_no_spec) # x dim 0 should NOT be specialized
471471
self.assertNotIn("127", code_no_spec) # y dim 1 should NOT be specialized
472472

473-
# Now, chain two specialize_args calls - both should be preserved
473+
# Now, chain two specialize_args() calls - both should be preserved
474474
chained = fn.specialize_args(x=[0]).specialize_args(y=[1])
475475

476476
code, result = code_and_output(chained, (x, y), block_sizes=[16, 16])
@@ -485,6 +485,100 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
485485
y2 = torch.randn([48, 256], device=DEVICE) # different dim 1
486486
self.assertIsNot(chained.bind((x, y)), chained.bind((x2, y2)))
487487

488+
def test_specialize_args_does_not_mutate_original(self):
489+
"""
490+
Test that specialize_args() returns a new kernel and does not mutate the original.
491+
This test explicitly verifies:
492+
1. Calling original kernel before specialize_args() works normally
493+
2. specialize_args() returns a different kernel object
494+
3. Original kernel remains unspecialized after specialize_args() is called
495+
4. Both kernels produce correct results independently
496+
"""
497+
config = helion.Config(block_sizes=[16, 16])
498+
499+
@helion.kernel(config=config, static_shapes=False)
500+
def kernel_fn(x: torch.Tensor) -> torch.Tensor:
501+
m, n = x.size()
502+
out = torch.empty_like(x)
503+
for tile in hl.tile(x.size()):
504+
out[tile] = x[tile] * 2.0
505+
return out
506+
507+
x = torch.randn([64, 128], device=DEVICE)
508+
509+
# Step 1: Call original kernel BEFORE specialize_args()
510+
code_before, result_before = code_and_output(
511+
kernel_fn, (x,), block_sizes=[16, 16]
512+
)
513+
torch.testing.assert_close(result_before, x * 2.0)
514+
# Original should NOT have specialized dimensions
515+
self.assertNotIn("64", code_before)
516+
self.assertNotIn("128", code_before)
517+
518+
# Step 2: Create specialized version
519+
specialized_kernel_fn = kernel_fn.specialize_args(x=[0, 1])
520+
521+
# Verify it's a different kernel object
522+
self.assertIsNot(kernel_fn, specialized_kernel_fn)
523+
524+
# Step 3: Call specialized kernel
525+
code_spec, result_spec = code_and_output(
526+
specialized_kernel_fn, (x,), block_sizes=[16, 16]
527+
)
528+
torch.testing.assert_close(result_spec, x * 2.0)
529+
# Specialized should have constant dimensions
530+
self.assertIn("64", code_spec)
531+
self.assertIn("128", code_spec)
532+
533+
# Step 4: Call original kernel AFTER specialize_args() - should still be unspecialized
534+
kernel_fn.reset() # Clear cache to force recompilation
535+
code_after, result_after = code_and_output(
536+
kernel_fn, (x,), block_sizes=[16, 16]
537+
)
538+
torch.testing.assert_close(result_after, x * 2.0)
539+
# Original should STILL NOT have specialized dimensions
540+
self.assertNotIn("64", code_after)
541+
self.assertNotIn("128", code_after)
542+
543+
# Verify that specialize_args() creates a true copy without shared mutable state.
544+
mutable_attrs = [
545+
"_bound_kernels",
546+
"_specialize_extra",
547+
"_specialized_args",
548+
"_arg_name_to_index",
549+
"_annotations",
550+
]
551+
for attr in mutable_attrs:
552+
self.assertIsNot(
553+
getattr(kernel_fn, attr),
554+
getattr(specialized_kernel_fn, attr),
555+
f"Attribute '{attr}' is shared between original and specialized kernel",
556+
)
557+
558+
# These objects are shared between original and specialized kernel.
559+
self.assertIs(
560+
kernel_fn.settings,
561+
specialized_kernel_fn.settings,
562+
)
563+
# Config objects inside the configs list are shared (list itself is copied)
564+
self.assertIsNot(
565+
kernel_fn.configs,
566+
specialized_kernel_fn.configs,
567+
)
568+
for i, orig_config in enumerate(kernel_fn.configs):
569+
self.assertIs(
570+
orig_config,
571+
specialized_kernel_fn.configs[i],
572+
)
573+
self.assertIs(
574+
kernel_fn.fn,
575+
specialized_kernel_fn.fn,
576+
)
577+
self.assertIs(
578+
kernel_fn._key_fn,
579+
specialized_kernel_fn._key_fn,
580+
)
581+
488582

489583
if __name__ == "__main__":
490584
unittest.main()

0 commit comments

Comments
 (0)