Skip to content

Commit 4a86965

Browse files
committed
test
1 parent a2f5ed1 commit 4a86965

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed

test/test_specialize.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,5 +327,166 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
327327
self.assertExpectedJournal(code)
328328

329329

330+
@skipIfCpu("needs to be debugged")
331+
class TestSpecializeArgs(RefEagerTestBase, TestCase):
332+
"""Tests for kernel.specialize_args() external specialization API."""
333+
334+
maxDiff = 163842
335+
336+
def test_specialize_args(self):
337+
"""Test specialize_args: multiple tensors, multiple dims, negative indexing."""
338+
339+
@helion.kernel(autotune_effort="none", static_shapes=False)
340+
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
341+
m, k = x.size()
342+
k2, n = y.size()
343+
out = torch.empty([m, n], device=x.device, dtype=x.dtype)
344+
for tile_m, tile_n in hl.tile([m, n]):
345+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
346+
for tile_k in hl.tile(k):
347+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
348+
out[tile_m, tile_n] = acc.to(x.dtype)
349+
return out
350+
351+
m, k, n = 64, 128, 56
352+
x = torch.randn([m, k], device=DEVICE, dtype=torch.float16)
353+
y = torch.randn([k, n], device=DEVICE, dtype=torch.float16)
354+
355+
# First, run WITHOUT specialize_args - dimensions should NOT be constants
356+
code_no_spec, result_no_spec = code_and_output(
357+
matmul,
358+
(x, y),
359+
block_sizes=[32, 32, 32],
360+
)
361+
torch.testing.assert_close(result_no_spec, x @ y, rtol=1e-2, atol=1e-2)
362+
self.assertNotIn("64", code_no_spec) # x dim 0 = m should NOT be specialized
363+
self.assertNotIn("128", code_no_spec) # x dim -1 = k should NOT be specialized
364+
self.assertNotIn("56", code_no_spec) # y dim 1 = n should NOT be specialized
365+
366+
# Now, run WITH specialize_args - dimensions SHOULD be constants
367+
code, result = code_and_output(
368+
matmul.specialize_args(x=[0, -1], y=[1]),
369+
(x, y),
370+
block_sizes=[32, 32, 32],
371+
)
372+
torch.testing.assert_close(result, x @ y, rtol=1e-2, atol=1e-2)
373+
self.assertIn("64", code) # x dim 0 = m
374+
self.assertIn("128", code) # x dim -1 = k
375+
self.assertIn("56", code) # y dim 1 = n
376+
self.assertExpectedJournal(code)
377+
378+
# Verify cache behavior: same specialized values hit cache
379+
specialized_kernel = matmul.specialize_args(x=[0, -1], y=[1])
380+
self.assertIs(
381+
specialized_kernel.bind((x, y)), specialized_kernel.bind((x, y))
382+
)
383+
# Verify cache behavior: different specialized values produce different bound kernels
384+
x2 = torch.randn([48, 96], device=DEVICE, dtype=torch.float16)
385+
y2 = torch.randn([96, 24], device=DEVICE, dtype=torch.float16)
386+
self.assertIsNot(
387+
specialized_kernel.bind((x, y)), specialized_kernel.bind((x2, y2))
388+
)
389+
390+
def test_specialize_args_and_hl_specialize(self):
391+
"""Test that external specialize_args and internal hl.specialize form a union."""
392+
393+
@helion.kernel(autotune_effort="none", static_shapes=False)
394+
def dual_specialize(x: torch.Tensor) -> torch.Tensor:
395+
# Internal specialize on dim 0
396+
hl.specialize(x.size(0))
397+
out = torch.empty_like(x)
398+
for tile in hl.tile(x.size()):
399+
out[tile] = x[tile] * 2
400+
return out
401+
402+
x = torch.randn([320, 640], device=DEVICE)
403+
404+
# First, run WITHOUT external specialize_args - only dim 0 should be specialized
405+
code_no_spec, result_no_spec = code_and_output(
406+
dual_specialize,
407+
(x,),
408+
block_sizes=[16, 16],
409+
)
410+
torch.testing.assert_close(result_no_spec, x * 2)
411+
self.assertIn("320", code_no_spec) # dim 0 from internal specialize
412+
self.assertNotIn("640", code_no_spec) # dim 1 should NOT be specialized
413+
414+
# Now, run WITH external specialize_args on dim -1 (dim 1)
415+
# Result: both dim 0 AND dim 1 are specialized (union)
416+
code, result = code_and_output(
417+
dual_specialize.specialize_args(x=[-1]),
418+
(x,),
419+
block_sizes=[16, 16],
420+
)
421+
torch.testing.assert_close(result, x * 2)
422+
# Both dimensions should appear as constants
423+
self.assertIn("320", code) # dim 0 from internal specialize
424+
self.assertIn("640", code) # dim 1 from external specialize
425+
self.assertExpectedJournal(code)
426+
427+
# Verify cache behavior: changing dim 1 (external) produces different bound kernel
428+
x2 = torch.randn([320, 128], device=DEVICE) # same dim 0, different dim 1
429+
specialized_kernel = dual_specialize.specialize_args(x=[-1])
430+
self.assertIsNot(specialized_kernel.bind((x,)), specialized_kernel.bind((x2,)))
431+
432+
@skipIfRefEager("Error checking not available in ref eager mode")
433+
def test_specialize_args_errors(self):
434+
"""Test error handling for invalid specialize_args usage."""
435+
436+
@helion.kernel(autotune_effort="none", static_shapes=False)
437+
def fn(x: torch.Tensor) -> torch.Tensor:
438+
out = torch.empty_like(x)
439+
for tile in hl.tile(x.size()):
440+
out[tile] = x[tile]
441+
return out
442+
443+
x = torch.randn([32, 64], device=DEVICE) # 2D tensor
444+
445+
# Error: dim out of range
446+
with self.assertRaises((IndexError, ValueError)):
447+
fn.specialize_args(x=[5])(x)
448+
449+
# Error: unknown argument name
450+
with self.assertRaises(ValueError) as cm:
451+
fn.specialize_args(z=[-1])
452+
self.assertIn("Unknown argument", str(cm.exception))
453+
454+
def test_specialize_args_chaining(self):
455+
"""Test that chained specialize_args calls merge specializations."""
456+
457+
@helion.kernel(autotune_effort="none", static_shapes=False)
458+
def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
459+
m, n = x.size()
460+
p = y.size(1) # use y's dim 1 as a scalar
461+
out = x.new_empty([m, n])
462+
for tile_m, tile_n in hl.tile([m, n]):
463+
out[tile_m, tile_n] = x[tile_m, tile_n] * p
464+
return out
465+
466+
x = torch.randn([37, 64], device=DEVICE)
467+
y = torch.randn([48, 127], device=DEVICE)
468+
469+
# First, run WITHOUT specialize_args - dimensions should NOT be constants
470+
code_no_spec, result_no_spec = code_and_output(fn, (x, y), block_sizes=[16, 16])
471+
torch.testing.assert_close(result_no_spec, x * 127)
472+
self.assertNotIn("37", code_no_spec) # x dim 0 should NOT be specialized
473+
self.assertNotIn("127", code_no_spec) # y dim 1 should NOT be specialized
474+
475+
# Now, chain two specialize_args calls - both should be preserved
476+
chained = fn.specialize_args(x=[0]).specialize_args(y=[1])
477+
478+
code, result = code_and_output(chained, (x, y), block_sizes=[16, 16])
479+
torch.testing.assert_close(result, x * 127)
480+
# Both specializations should be present
481+
self.assertIn("37", code) # x dim 0
482+
self.assertIn("127", code) # y dim 1
483+
self.assertExpectedJournal(code)
484+
485+
# Verify cache behavior: changing specialized values produces different bound kernels
486+
x2 = torch.randn([48, 64], device=DEVICE) # different dim 0
487+
y2 = torch.randn([48, 256], device=DEVICE) # different dim 1
488+
self.assertIsNot(chained.bind((x, y)), chained.bind((x2, y2)))
489+
490+
330491
if __name__ == "__main__":
331492
unittest.main()

0 commit comments

Comments
 (0)