Skip to content

Commit 069ae43

Browse files
committed
test
1 parent a2cb804 commit 069ae43

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed

test/test_specialize.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,5 +327,164 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
327327
self.assertExpectedJournal(code)
328328

329329

330+
@skipIfCpu("needs to be debugged")
331+
class TestSpecializeArgs(RefEagerTestBase, TestCase):
332+
"""Tests for kernel.specialize_args() external specialization API."""
333+
334+
maxDiff = 163842
335+
336+
def test_specialize_args(self):
337+
"""Test specialize_args: multiple tensors, multiple dims, negative indexing."""
338+
339+
@helion.kernel(autotune_effort="none", static_shapes=False)
340+
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
341+
m, k = x.size()
342+
k2, n = y.size()
343+
out = torch.empty([m, n], device=x.device, dtype=x.dtype)
344+
for tile_m, tile_n in hl.tile([m, n]):
345+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
346+
for tile_k in hl.tile(k):
347+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
348+
out[tile_m, tile_n] = acc.to(x.dtype)
349+
return out
350+
351+
m, k, n = 64, 128, 56
352+
x = torch.randn([m, k], device=DEVICE, dtype=torch.float16)
353+
y = torch.randn([k, n], device=DEVICE, dtype=torch.float16)
354+
355+
# First, run WITHOUT specialize_args - dimensions should NOT be constants
356+
code_no_spec, result_no_spec = code_and_output(
357+
matmul,
358+
(x, y),
359+
block_sizes=[32, 32, 32],
360+
)
361+
torch.testing.assert_close(result_no_spec, x @ y, rtol=1e-2, atol=1e-2)
362+
self.assertNotIn("64", code_no_spec) # x dim 0 = m should NOT be specialized
363+
self.assertNotIn("128", code_no_spec) # x dim -1 = k should NOT be specialized
364+
self.assertNotIn("56", code_no_spec) # y dim 1 = n should NOT be specialized
365+
366+
# Now, run WITH specialize_args - dimensions SHOULD be constants
367+
code, result = code_and_output(
368+
matmul.specialize_args(x=[0, -1], y=[1]),
369+
(x, y),
370+
block_sizes=[32, 32, 32],
371+
)
372+
torch.testing.assert_close(result, x @ y, rtol=1e-2, atol=1e-2)
373+
self.assertIn("64", code) # x dim 0 = m
374+
self.assertIn("128", code) # x dim -1 = k
375+
self.assertIn("56", code) # y dim 1 = n
376+
self.assertExpectedJournal(code)
377+
378+
# Verify cache behavior: same specialized values hit cache
379+
specialized_kernel = matmul.specialize_args(x=[0, -1], y=[1])
380+
self.assertIs(specialized_kernel.bind((x, y)), specialized_kernel.bind((x, y)))
381+
# Verify cache behavior: different specialized values produce different bound kernels
382+
x2 = torch.randn([48, 96], device=DEVICE, dtype=torch.float16)
383+
y2 = torch.randn([96, 24], device=DEVICE, dtype=torch.float16)
384+
self.assertIsNot(
385+
specialized_kernel.bind((x, y)), specialized_kernel.bind((x2, y2))
386+
)
387+
388+
def test_specialize_args_and_hl_specialize(self):
389+
"""Test that external specialize_args and internal hl.specialize form a union."""
390+
391+
@helion.kernel(autotune_effort="none", static_shapes=False)
392+
def dual_specialize(x: torch.Tensor) -> torch.Tensor:
393+
# Internal specialize on dim 0
394+
hl.specialize(x.size(0))
395+
out = torch.empty_like(x)
396+
for tile in hl.tile(x.size()):
397+
out[tile] = x[tile] * 2
398+
return out
399+
400+
x = torch.randn([320, 640], device=DEVICE)
401+
402+
# First, run WITHOUT external specialize_args - only dim 0 should be specialized
403+
code_no_spec, result_no_spec = code_and_output(
404+
dual_specialize,
405+
(x,),
406+
block_sizes=[16, 16],
407+
)
408+
torch.testing.assert_close(result_no_spec, x * 2)
409+
self.assertIn("320", code_no_spec) # dim 0 from internal specialize
410+
self.assertNotIn("640", code_no_spec) # dim 1 should NOT be specialized
411+
412+
# Now, run WITH external specialize_args on dim -1 (dim 1)
413+
# Result: both dim 0 AND dim 1 are specialized (union)
414+
code, result = code_and_output(
415+
dual_specialize.specialize_args(x=[-1]),
416+
(x,),
417+
block_sizes=[16, 16],
418+
)
419+
torch.testing.assert_close(result, x * 2)
420+
# Both dimensions should appear as constants
421+
self.assertIn("320", code) # dim 0 from internal specialize
422+
self.assertIn("640", code) # dim 1 from external specialize
423+
self.assertExpectedJournal(code)
424+
425+
# Verify cache behavior: changing dim 1 (external) produces different bound kernel
426+
x2 = torch.randn([320, 128], device=DEVICE) # same dim 0, different dim 1
427+
specialized_kernel = dual_specialize.specialize_args(x=[-1])
428+
self.assertIsNot(specialized_kernel.bind((x,)), specialized_kernel.bind((x2,)))
429+
430+
@skipIfRefEager("Error checking not available in ref eager mode")
431+
def test_specialize_args_errors(self):
432+
"""Test error handling for invalid specialize_args usage."""
433+
434+
@helion.kernel(autotune_effort="none", static_shapes=False)
435+
def fn(x: torch.Tensor) -> torch.Tensor:
436+
out = torch.empty_like(x)
437+
for tile in hl.tile(x.size()):
438+
out[tile] = x[tile]
439+
return out
440+
441+
x = torch.randn([32, 64], device=DEVICE) # 2D tensor
442+
443+
# Error: dim out of range
444+
with self.assertRaises((IndexError, ValueError)):
445+
fn.specialize_args(x=[5])(x)
446+
447+
# Error: unknown argument name
448+
with self.assertRaises(ValueError) as cm:
449+
fn.specialize_args(z=[-1])
450+
self.assertIn("Unknown argument", str(cm.exception))
451+
452+
def test_specialize_args_chaining(self):
453+
"""Test that chained specialize_args calls merge specializations."""
454+
455+
@helion.kernel(autotune_effort="none", static_shapes=False)
456+
def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
457+
m, n = x.size()
458+
p = y.size(1) # use y's dim 1 as a scalar
459+
out = x.new_empty([m, n])
460+
for tile_m, tile_n in hl.tile([m, n]):
461+
out[tile_m, tile_n] = x[tile_m, tile_n] * p
462+
return out
463+
464+
x = torch.randn([37, 64], device=DEVICE)
465+
y = torch.randn([48, 127], device=DEVICE)
466+
467+
# First, run WITHOUT specialize_args - dimensions should NOT be constants
468+
code_no_spec, result_no_spec = code_and_output(fn, (x, y), block_sizes=[16, 16])
469+
torch.testing.assert_close(result_no_spec, x * 127)
470+
self.assertNotIn("37", code_no_spec) # x dim 0 should NOT be specialized
471+
self.assertNotIn("127", code_no_spec) # y dim 1 should NOT be specialized
472+
473+
# Now, chain two specialize_args calls - both should be preserved
474+
chained = fn.specialize_args(x=[0]).specialize_args(y=[1])
475+
476+
code, result = code_and_output(chained, (x, y), block_sizes=[16, 16])
477+
torch.testing.assert_close(result, x * 127)
478+
# Both specializations should be present
479+
self.assertIn("37", code) # x dim 0
480+
self.assertIn("127", code) # y dim 1
481+
self.assertExpectedJournal(code)
482+
483+
# Verify cache behavior: changing specialized values produces different bound kernels
484+
x2 = torch.randn([48, 64], device=DEVICE) # different dim 0
485+
y2 = torch.randn([48, 256], device=DEVICE) # different dim 1
486+
self.assertIsNot(chained.bind((x, y)), chained.bind((x2, y2)))
487+
488+
330489
if __name__ == "__main__":
331490
unittest.main()

0 commit comments

Comments
 (0)