@@ -334,7 +334,7 @@ class TestSpecializeArgs(RefEagerTestBase, TestCase):
334334 maxDiff = 163842
335335
336336 def test_specialize_args (self ):
337- """Test specialize_args: multiple tensors, multiple dims, negative indexing."""
337+ """Test specialize_args() : multiple tensors, multiple dims, negative indexing."""
338338
339339 @helion .kernel (autotune_effort = "none" , static_shapes = False )
340340 def matmul (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
@@ -352,7 +352,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
352352 x = torch .randn ([m , k ], device = DEVICE , dtype = torch .float16 )
353353 y = torch .randn ([k , n ], device = DEVICE , dtype = torch .float16 )
354354
355- # First, run WITHOUT specialize_args - dimensions should NOT be constants
355+ # First, run WITHOUT specialize_args() - dimensions should NOT be constants
356356 code_no_spec , result_no_spec = code_and_output (
357357 matmul ,
358358 (x , y ),
@@ -363,7 +363,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
363363 self .assertNotIn ("128" , code_no_spec ) # x dim -1 = k should NOT be specialized
364364 self .assertNotIn ("56" , code_no_spec ) # y dim 1 = n should NOT be specialized
365365
366- # Now, run WITH specialize_args - dimensions SHOULD be constants
366+ # Now, run WITH specialize_args() - dimensions SHOULD be constants
367367 code , result = code_and_output (
368368 matmul .specialize_args (x = [0 , - 1 ], y = [1 ]),
369369 (x , y ),
@@ -386,7 +386,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
386386 )
387387
388388 def test_specialize_args_and_hl_specialize (self ):
389- """Test that external specialize_args and internal hl.specialize form a union."""
389+ """Test that external specialize_args() and internal hl.specialize() form a union."""
390390
391391 @helion .kernel (autotune_effort = "none" , static_shapes = False )
392392 def dual_specialize (x : torch .Tensor ) -> torch .Tensor :
@@ -399,7 +399,7 @@ def dual_specialize(x: torch.Tensor) -> torch.Tensor:
399399
400400 x = torch .randn ([320 , 640 ], device = DEVICE )
401401
402- # First, run WITHOUT external specialize_args - only dim 0 should be specialized
402+ # First, run WITHOUT external specialize_args() - only dim 0 should be specialized
403403 code_no_spec , result_no_spec = code_and_output (
404404 dual_specialize ,
405405 (x ,),
@@ -409,7 +409,7 @@ def dual_specialize(x: torch.Tensor) -> torch.Tensor:
409409 self .assertIn ("320" , code_no_spec ) # dim 0 from internal specialize
410410 self .assertNotIn ("640" , code_no_spec ) # dim 1 should NOT be specialized
411411
412- # Now, run WITH external specialize_args on dim -1 (dim 1)
412+ # Now, run WITH external specialize_args() on dim -1 (dim 1)
413413 # Result: both dim 0 AND dim 1 are specialized (union)
414414 code , result = code_and_output (
415415 dual_specialize .specialize_args (x = [- 1 ]),
@@ -429,7 +429,7 @@ def dual_specialize(x: torch.Tensor) -> torch.Tensor:
429429
430430 @skipIfRefEager ("Error checking not available in ref eager mode" )
431431 def test_specialize_args_errors (self ):
432- """Test error handling for invalid specialize_args usage."""
432+ """Test error handling for invalid specialize_args() usage."""
433433
434434 @helion .kernel (autotune_effort = "none" , static_shapes = False )
435435 def fn (x : torch .Tensor ) -> torch .Tensor :
@@ -450,7 +450,7 @@ def fn(x: torch.Tensor) -> torch.Tensor:
450450 self .assertIn ("Unknown argument" , str (cm .exception ))
451451
452452 def test_specialize_args_chaining (self ):
453- """Test that chained specialize_args calls merge specializations."""
453+ """Test that chained specialize_args() calls merge specializations."""
454454
455455 @helion .kernel (autotune_effort = "none" , static_shapes = False )
456456 def fn (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
@@ -464,13 +464,13 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
464464 x = torch .randn ([37 , 64 ], device = DEVICE )
465465 y = torch .randn ([48 , 127 ], device = DEVICE )
466466
467- # First, run WITHOUT specialize_args - dimensions should NOT be constants
467+ # First, run WITHOUT specialize_args() - dimensions should NOT be constants
468468 code_no_spec , result_no_spec = code_and_output (fn , (x , y ), block_sizes = [16 , 16 ])
469469 torch .testing .assert_close (result_no_spec , x * 127 )
470470 self .assertNotIn ("37" , code_no_spec ) # x dim 0 should NOT be specialized
471471 self .assertNotIn ("127" , code_no_spec ) # y dim 1 should NOT be specialized
472472
473- # Now, chain two specialize_args calls - both should be preserved
473+ # Now, chain two specialize_args() calls - both should be preserved
474474 chained = fn .specialize_args (x = [0 ]).specialize_args (y = [1 ])
475475
476476 code , result = code_and_output (chained , (x , y ), block_sizes = [16 , 16 ])
@@ -485,6 +485,92 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
485485 y2 = torch .randn ([48 , 256 ], device = DEVICE ) # different dim 1
486486 self .assertIsNot (chained .bind ((x , y )), chained .bind ((x2 , y2 )))
487487
488+ def test_specialize_args_does_not_mutate_original (self ):
489+ """
490+ Test that specialize_args() returns a new kernel and does not mutate the original.
491+ This test explicitly verifies:
492+ 1. Calling original kernel before specialize_args() works normally
493+ 2. specialize_args() returns a different kernel object
494+ 3. Original kernel remains unspecialized after specialize_args() is called
495+ 4. Both kernels produce correct results independently
496+ """
497+ config = helion .Config (block_sizes = [16 , 16 ])
498+
499+ @helion .kernel (config = config , static_shapes = False )
500+ def kernel_fn (x : torch .Tensor ) -> torch .Tensor :
501+ m , n = x .size ()
502+ out = torch .empty_like (x )
503+ for tile in hl .tile (x .size ()):
504+ out [tile ] = x [tile ] * 2.0
505+ return out
506+
507+ x = torch .randn ([64 , 128 ], device = DEVICE )
508+
509+ # Step 1: Call original kernel BEFORE specialize_args()
510+ code_before , result_before = code_and_output (
511+ kernel_fn , (x ,), block_sizes = [16 , 16 ]
512+ )
513+ torch .testing .assert_close (result_before , x * 2.0 )
514+ # Original should NOT have specialized dimensions
515+ self .assertNotIn ("64" , code_before )
516+ self .assertNotIn ("128" , code_before )
517+
518+ # Step 2: Create specialized version
519+ specialized_kernel_fn = kernel_fn .specialize_args (x = [0 , 1 ])
520+
521+ # Verify it's a different kernel object
522+ self .assertIsNot (kernel_fn , specialized_kernel_fn )
523+
524+ # Step 3: Call specialized kernel
525+ code_spec , result_spec = code_and_output (
526+ specialized_kernel_fn , (x ,), block_sizes = [16 , 16 ]
527+ )
528+ torch .testing .assert_close (result_spec , x * 2.0 )
529+ # Specialized should have constant dimensions
530+ self .assertIn ("64" , code_spec )
531+ self .assertIn ("128" , code_spec )
532+
533+ # Step 4: Call original kernel AFTER specialize_args() - should still be unspecialized
534+ kernel_fn .reset () # Clear cache to force recompilation
535+ code_after , result_after = code_and_output (
536+ kernel_fn , (x ,), block_sizes = [16 , 16 ]
537+ )
538+ torch .testing .assert_close (result_after , x * 2.0 )
539+ # Original should STILL NOT have specialized dimensions
540+ self .assertNotIn ("64" , code_after )
541+ self .assertNotIn ("128" , code_after )
542+
543+ # Verify that specialize_args() creates a true copy without shared mutable state.
544+ mutable_attrs = [
545+ "_bound_kernels" ,
546+ "_specialize_extra" ,
547+ "_specialized_args" ,
548+ "_arg_name_to_index" ,
549+ "_annotations" ,
550+ ]
551+ for attr in mutable_attrs :
552+ self .assertIsNot (
553+ getattr (kernel_fn , attr ),
554+ getattr (specialized_kernel_fn , attr ),
555+ f"Attribute '{ attr } ' is shared between original and specialized kernel" ,
556+ )
557+
558+ # These objects are currently shared between original and specialized kernel.
559+ self .assertIs (
560+ kernel_fn .settings ,
561+ specialized_kernel_fn .settings ,
562+ )
563+ # Config objects inside the configs list are shared (list itself is copied)
564+ self .assertIsNot (
565+ kernel_fn .configs ,
566+ specialized_kernel_fn .configs ,
567+ )
568+ for i , orig_config in enumerate (kernel_fn .configs ):
569+ self .assertIs (
570+ orig_config ,
571+ specialized_kernel_fn .configs [i ],
572+ )
573+
488574
489575if __name__ == "__main__" :
490576 unittest .main ()
0 commit comments