@@ -327,5 +327,166 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
327327 self .assertExpectedJournal (code )
328328
329329
330+ @skipIfCpu ("needs to be debugged" )
331+ class TestSpecializeArgs (RefEagerTestBase , TestCase ):
332+ """Tests for kernel.specialize_args() external specialization API."""
333+
334+ maxDiff = 163842
335+
336+ def test_specialize_args (self ):
337+ """Test specialize_args: multiple tensors, multiple dims, negative indexing."""
338+
339+ @helion .kernel (autotune_effort = "none" , static_shapes = False )
340+ def matmul (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
341+ m , k = x .size ()
342+ k2 , n = y .size ()
343+ out = torch .empty ([m , n ], device = x .device , dtype = x .dtype )
344+ for tile_m , tile_n in hl .tile ([m , n ]):
345+ acc = hl .zeros ([tile_m , tile_n ], dtype = torch .float32 )
346+ for tile_k in hl .tile (k ):
347+ acc = torch .addmm (acc , x [tile_m , tile_k ], y [tile_k , tile_n ])
348+ out [tile_m , tile_n ] = acc .to (x .dtype )
349+ return out
350+
351+ m , k , n = 64 , 128 , 56
352+ x = torch .randn ([m , k ], device = DEVICE , dtype = torch .float16 )
353+ y = torch .randn ([k , n ], device = DEVICE , dtype = torch .float16 )
354+
355+ # First, run WITHOUT specialize_args - dimensions should NOT be constants
356+ code_no_spec , result_no_spec = code_and_output (
357+ matmul ,
358+ (x , y ),
359+ block_sizes = [32 , 32 , 32 ],
360+ )
361+ torch .testing .assert_close (result_no_spec , x @ y , rtol = 1e-2 , atol = 1e-2 )
362+ self .assertNotIn ("64" , code_no_spec ) # x dim 0 = m should NOT be specialized
363+ self .assertNotIn ("128" , code_no_spec ) # x dim -1 = k should NOT be specialized
364+ self .assertNotIn ("56" , code_no_spec ) # y dim 1 = n should NOT be specialized
365+
366+ # Now, run WITH specialize_args - dimensions SHOULD be constants
367+ code , result = code_and_output (
368+ matmul .specialize_args (x = [0 , - 1 ], y = [1 ]),
369+ (x , y ),
370+ block_sizes = [32 , 32 , 32 ],
371+ )
372+ torch .testing .assert_close (result , x @ y , rtol = 1e-2 , atol = 1e-2 )
373+ self .assertIn ("64" , code ) # x dim 0 = m
374+ self .assertIn ("128" , code ) # x dim -1 = k
375+ self .assertIn ("56" , code ) # y dim 1 = n
376+ self .assertExpectedJournal (code )
377+
378+ # Verify cache behavior: same specialized values hit cache
379+ specialized_kernel = matmul .specialize_args (x = [0 , - 1 ], y = [1 ])
380+ self .assertIs (
381+ specialized_kernel .bind ((x , y )), specialized_kernel .bind ((x , y ))
382+ )
383+ # Verify cache behavior: different specialized values produce different bound kernels
384+ x2 = torch .randn ([48 , 96 ], device = DEVICE , dtype = torch .float16 )
385+ y2 = torch .randn ([96 , 24 ], device = DEVICE , dtype = torch .float16 )
386+ self .assertIsNot (
387+ specialized_kernel .bind ((x , y )), specialized_kernel .bind ((x2 , y2 ))
388+ )
389+
390+ def test_specialize_args_and_hl_specialize (self ):
391+ """Test that external specialize_args and internal hl.specialize form a union."""
392+
393+ @helion .kernel (autotune_effort = "none" , static_shapes = False )
394+ def dual_specialize (x : torch .Tensor ) -> torch .Tensor :
395+ # Internal specialize on dim 0
396+ hl .specialize (x .size (0 ))
397+ out = torch .empty_like (x )
398+ for tile in hl .tile (x .size ()):
399+ out [tile ] = x [tile ] * 2
400+ return out
401+
402+ x = torch .randn ([320 , 640 ], device = DEVICE )
403+
404+ # First, run WITHOUT external specialize_args - only dim 0 should be specialized
405+ code_no_spec , result_no_spec = code_and_output (
406+ dual_specialize ,
407+ (x ,),
408+ block_sizes = [16 , 16 ],
409+ )
410+ torch .testing .assert_close (result_no_spec , x * 2 )
411+ self .assertIn ("320" , code_no_spec ) # dim 0 from internal specialize
412+ self .assertNotIn ("640" , code_no_spec ) # dim 1 should NOT be specialized
413+
414+ # Now, run WITH external specialize_args on dim -1 (dim 1)
415+ # Result: both dim 0 AND dim 1 are specialized (union)
416+ code , result = code_and_output (
417+ dual_specialize .specialize_args (x = [- 1 ]),
418+ (x ,),
419+ block_sizes = [16 , 16 ],
420+ )
421+ torch .testing .assert_close (result , x * 2 )
422+ # Both dimensions should appear as constants
423+ self .assertIn ("320" , code ) # dim 0 from internal specialize
424+ self .assertIn ("640" , code ) # dim 1 from external specialize
425+ self .assertExpectedJournal (code )
426+
427+ # Verify cache behavior: changing dim 1 (external) produces different bound kernel
428+ x2 = torch .randn ([320 , 128 ], device = DEVICE ) # same dim 0, different dim 1
429+ specialized_kernel = dual_specialize .specialize_args (x = [- 1 ])
430+ self .assertIsNot (specialized_kernel .bind ((x ,)), specialized_kernel .bind ((x2 ,)))
431+
432+ @skipIfRefEager ("Error checking not available in ref eager mode" )
433+ def test_specialize_args_errors (self ):
434+ """Test error handling for invalid specialize_args usage."""
435+
436+ @helion .kernel (autotune_effort = "none" , static_shapes = False )
437+ def fn (x : torch .Tensor ) -> torch .Tensor :
438+ out = torch .empty_like (x )
439+ for tile in hl .tile (x .size ()):
440+ out [tile ] = x [tile ]
441+ return out
442+
443+ x = torch .randn ([32 , 64 ], device = DEVICE ) # 2D tensor
444+
445+ # Error: dim out of range
446+ with self .assertRaises ((IndexError , ValueError )):
447+ fn .specialize_args (x = [5 ])(x )
448+
449+ # Error: unknown argument name
450+ with self .assertRaises (ValueError ) as cm :
451+ fn .specialize_args (z = [- 1 ])
452+ self .assertIn ("Unknown argument" , str (cm .exception ))
453+
454+ def test_specialize_args_chaining (self ):
455+ """Test that chained specialize_args calls merge specializations."""
456+
457+ @helion .kernel (autotune_effort = "none" , static_shapes = False )
458+ def fn (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
459+ m , n = x .size ()
460+ p = y .size (1 ) # use y's dim 1 as a scalar
461+ out = x .new_empty ([m , n ])
462+ for tile_m , tile_n in hl .tile ([m , n ]):
463+ out [tile_m , tile_n ] = x [tile_m , tile_n ] * p
464+ return out
465+
466+ x = torch .randn ([37 , 64 ], device = DEVICE )
467+ y = torch .randn ([48 , 127 ], device = DEVICE )
468+
469+ # First, run WITHOUT specialize_args - dimensions should NOT be constants
470+ code_no_spec , result_no_spec = code_and_output (fn , (x , y ), block_sizes = [16 , 16 ])
471+ torch .testing .assert_close (result_no_spec , x * 127 )
472+ self .assertNotIn ("37" , code_no_spec ) # x dim 0 should NOT be specialized
473+ self .assertNotIn ("127" , code_no_spec ) # y dim 1 should NOT be specialized
474+
475+ # Now, chain two specialize_args calls - both should be preserved
476+ chained = fn .specialize_args (x = [0 ]).specialize_args (y = [1 ])
477+
478+ code , result = code_and_output (chained , (x , y ), block_sizes = [16 , 16 ])
479+ torch .testing .assert_close (result , x * 127 )
480+ # Both specializations should be present
481+ self .assertIn ("37" , code ) # x dim 0
482+ self .assertIn ("127" , code ) # y dim 1
483+ self .assertExpectedJournal (code )
484+
485+ # Verify cache behavior: changing specialized values produces different bound kernels
486+ x2 = torch .randn ([48 , 64 ], device = DEVICE ) # different dim 0
487+ y2 = torch .randn ([48 , 256 ], device = DEVICE ) # different dim 1
488+ self .assertIsNot (chained .bind ((x , y )), chained .bind ((x2 , y2 )))
489+
490+
330491if __name__ == "__main__" :
331492 unittest .main ()
0 commit comments