@@ -327,5 +327,164 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
327327 self .assertExpectedJournal (code )
328328
329329
330+ @skipIfCpu ("needs to be debugged" )
331+ class TestSpecializeArgs (RefEagerTestBase , TestCase ):
332+ """Tests for kernel.specialize_args() external specialization API."""
333+
334+ maxDiff = 163842
335+
336+ def test_specialize_args (self ):
337+ """Test specialize_args: multiple tensors, multiple dims, negative indexing."""
338+
339+ @helion .kernel (autotune_effort = "none" , static_shapes = False )
340+ def matmul (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
341+ m , k = x .size ()
342+ k2 , n = y .size ()
343+ out = torch .empty ([m , n ], device = x .device , dtype = x .dtype )
344+ for tile_m , tile_n in hl .tile ([m , n ]):
345+ acc = hl .zeros ([tile_m , tile_n ], dtype = torch .float32 )
346+ for tile_k in hl .tile (k ):
347+ acc = torch .addmm (acc , x [tile_m , tile_k ], y [tile_k , tile_n ])
348+ out [tile_m , tile_n ] = acc .to (x .dtype )
349+ return out
350+
351+ m , k , n = 64 , 128 , 56
352+ x = torch .randn ([m , k ], device = DEVICE , dtype = torch .float16 )
353+ y = torch .randn ([k , n ], device = DEVICE , dtype = torch .float16 )
354+
355+ # First, run WITHOUT specialize_args - dimensions should NOT be constants
356+ code_no_spec , result_no_spec = code_and_output (
357+ matmul ,
358+ (x , y ),
359+ block_sizes = [32 , 32 , 32 ],
360+ )
361+ torch .testing .assert_close (result_no_spec , x @ y , rtol = 1e-2 , atol = 1e-2 )
362+ self .assertNotIn ("64" , code_no_spec ) # x dim 0 = m should NOT be specialized
363+ self .assertNotIn ("128" , code_no_spec ) # x dim -1 = k should NOT be specialized
364+ self .assertNotIn ("56" , code_no_spec ) # y dim 1 = n should NOT be specialized
365+
366+ # Now, run WITH specialize_args - dimensions SHOULD be constants
367+ code , result = code_and_output (
368+ matmul .specialize_args (x = [0 , - 1 ], y = [1 ]),
369+ (x , y ),
370+ block_sizes = [32 , 32 , 32 ],
371+ )
372+ torch .testing .assert_close (result , x @ y , rtol = 1e-2 , atol = 1e-2 )
373+ self .assertIn ("64" , code ) # x dim 0 = m
374+ self .assertIn ("128" , code ) # x dim -1 = k
375+ self .assertIn ("56" , code ) # y dim 1 = n
376+ self .assertExpectedJournal (code )
377+
378+ # Verify cache behavior: same specialized values hit cache
379+ specialized_kernel = matmul .specialize_args (x = [0 , - 1 ], y = [1 ])
380+ self .assertIs (specialized_kernel .bind ((x , y )), specialized_kernel .bind ((x , y )))
381+ # Verify cache behavior: different specialized values produce different bound kernels
382+ x2 = torch .randn ([48 , 96 ], device = DEVICE , dtype = torch .float16 )
383+ y2 = torch .randn ([96 , 24 ], device = DEVICE , dtype = torch .float16 )
384+ self .assertIsNot (
385+ specialized_kernel .bind ((x , y )), specialized_kernel .bind ((x2 , y2 ))
386+ )
387+
388+ def test_specialize_args_and_hl_specialize (self ):
389+ """Test that external specialize_args and internal hl.specialize form a union."""
390+
391+ @helion .kernel (autotune_effort = "none" , static_shapes = False )
392+ def dual_specialize (x : torch .Tensor ) -> torch .Tensor :
393+ # Internal specialize on dim 0
394+ hl .specialize (x .size (0 ))
395+ out = torch .empty_like (x )
396+ for tile in hl .tile (x .size ()):
397+ out [tile ] = x [tile ] * 2
398+ return out
399+
400+ x = torch .randn ([320 , 640 ], device = DEVICE )
401+
402+ # First, run WITHOUT external specialize_args - only dim 0 should be specialized
403+ code_no_spec , result_no_spec = code_and_output (
404+ dual_specialize ,
405+ (x ,),
406+ block_sizes = [16 , 16 ],
407+ )
408+ torch .testing .assert_close (result_no_spec , x * 2 )
409+ self .assertIn ("320" , code_no_spec ) # dim 0 from internal specialize
410+ self .assertNotIn ("640" , code_no_spec ) # dim 1 should NOT be specialized
411+
412+ # Now, run WITH external specialize_args on dim -1 (dim 1)
413+ # Result: both dim 0 AND dim 1 are specialized (union)
414+ code , result = code_and_output (
415+ dual_specialize .specialize_args (x = [- 1 ]),
416+ (x ,),
417+ block_sizes = [16 , 16 ],
418+ )
419+ torch .testing .assert_close (result , x * 2 )
420+ # Both dimensions should appear as constants
421+ self .assertIn ("320" , code ) # dim 0 from internal specialize
422+ self .assertIn ("640" , code ) # dim 1 from external specialize
423+ self .assertExpectedJournal (code )
424+
425+ # Verify cache behavior: changing dim 1 (external) produces different bound kernel
426+ x2 = torch .randn ([320 , 128 ], device = DEVICE ) # same dim 0, different dim 1
427+ specialized_kernel = dual_specialize .specialize_args (x = [- 1 ])
428+ self .assertIsNot (specialized_kernel .bind ((x ,)), specialized_kernel .bind ((x2 ,)))
429+
430+ @skipIfRefEager ("Error checking not available in ref eager mode" )
431+ def test_specialize_args_errors (self ):
432+ """Test error handling for invalid specialize_args usage."""
433+
434+ @helion .kernel (autotune_effort = "none" , static_shapes = False )
435+ def fn (x : torch .Tensor ) -> torch .Tensor :
436+ out = torch .empty_like (x )
437+ for tile in hl .tile (x .size ()):
438+ out [tile ] = x [tile ]
439+ return out
440+
441+ x = torch .randn ([32 , 64 ], device = DEVICE ) # 2D tensor
442+
443+ # Error: dim out of range
444+ with self .assertRaises ((IndexError , ValueError )):
445+ fn .specialize_args (x = [5 ])(x )
446+
447+ # Error: unknown argument name
448+ with self .assertRaises (ValueError ) as cm :
449+ fn .specialize_args (z = [- 1 ])
450+ self .assertIn ("Unknown argument" , str (cm .exception ))
451+
452+ def test_specialize_args_chaining (self ):
453+ """Test that chained specialize_args calls merge specializations."""
454+
455+ @helion .kernel (autotune_effort = "none" , static_shapes = False )
456+ def fn (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
457+ m , n = x .size ()
458+ p = y .size (1 ) # use y's dim 1 as a scalar
459+ out = x .new_empty ([m , n ])
460+ for tile_m , tile_n in hl .tile ([m , n ]):
461+ out [tile_m , tile_n ] = x [tile_m , tile_n ] * p
462+ return out
463+
464+ x = torch .randn ([37 , 64 ], device = DEVICE )
465+ y = torch .randn ([48 , 127 ], device = DEVICE )
466+
467+ # First, run WITHOUT specialize_args - dimensions should NOT be constants
468+ code_no_spec , result_no_spec = code_and_output (fn , (x , y ), block_sizes = [16 , 16 ])
469+ torch .testing .assert_close (result_no_spec , x * 127 )
470+ self .assertNotIn ("37" , code_no_spec ) # x dim 0 should NOT be specialized
471+ self .assertNotIn ("127" , code_no_spec ) # y dim 1 should NOT be specialized
472+
473+ # Now, chain two specialize_args calls - both should be preserved
474+ chained = fn .specialize_args (x = [0 ]).specialize_args (y = [1 ])
475+
476+ code , result = code_and_output (chained , (x , y ), block_sizes = [16 , 16 ])
477+ torch .testing .assert_close (result , x * 127 )
478+ # Both specializations should be present
479+ self .assertIn ("37" , code ) # x dim 0
480+ self .assertIn ("127" , code ) # y dim 1
481+ self .assertExpectedJournal (code )
482+
483+ # Verify cache behavior: changing specialized values produces different bound kernels
484+ x2 = torch .randn ([48 , 64 ], device = DEVICE ) # different dim 0
485+ y2 = torch .randn ([48 , 256 ], device = DEVICE ) # different dim 1
486+ self .assertIsNot (chained .bind ((x , y )), chained .bind ((x2 , y2 )))
487+
488+
330489if __name__ == "__main__" :
331490 unittest .main ()
0 commit comments