Skip to content

Commit 9a40f13

Browse files
committed
test
1 parent 601d7dd commit 9a40f13

File tree

2 files changed

+183
-0
lines changed

2 files changed

+183
-0
lines changed

test/test_specialize.expected

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,102 @@
11
This file is automatically generated by assertExpectedJournal calls in test_specialize.py.
22
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
33

4+
--- assertExpectedJournal(TestMarkStatic.test_mark_static)
5+
from __future__ import annotations
6+
7+
import torch
8+
import triton
9+
import triton.language as tl
10+
from helion.runtime import default_launcher as _default_launcher
11+
12+
@triton.jit
13+
def _helion_matmul(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
14+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
15+
num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
16+
pid_0 = tl.program_id(0) % num_blocks_0
17+
pid_1 = tl.program_id(0) // num_blocks_0
18+
offset_0 = pid_0 * _BLOCK_SIZE_0
19+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
20+
mask_0 = indices_0 < 64
21+
offset_1 = pid_1 * _BLOCK_SIZE_1
22+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
23+
mask_1 = indices_1 < 56
24+
# src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
25+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
26+
# src[test_specialize.py:N]: for tile_k in hl.tile(k):
27+
# src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
28+
symnode_0 = 128
29+
for offset_2 in tl.range(0, symnode_0.to(tl.int32), _BLOCK_SIZE_2):
30+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
31+
mask_2 = indices_2 < symnode_0
32+
acc_copy = acc
33+
acc_copy_0 = acc_copy
34+
# src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
35+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
36+
load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0)
37+
acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
38+
# src[test_specialize.py:N]: out[tile_m, tile_n] = acc.to(x.dtype)
39+
v_0 = tl.cast(acc, tl.float16)
40+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_0, mask_0[:, None] & mask_1[None, :])
41+
42+
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
43+
# src[test_specialize.py:N]: m, k = x.size()
44+
m, k = x.size()
45+
# src[test_specialize.py:N]: k2, n = y.size()
46+
k2, n = y.size()
47+
# src[test_specialize.py:N]: out = torch.empty([m, n], device=x.device, dtype=x.dtype)
48+
out = torch.empty([m, n], device=x.device, dtype=x.dtype)
49+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
50+
_BLOCK_SIZE_0 = 32
51+
_BLOCK_SIZE_1 = 32
52+
# src[test_specialize.py:N]: for tile_k in hl.tile(k):
53+
# src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
54+
_BLOCK_SIZE_2 = 32
55+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
56+
# src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
57+
# src[test_specialize.py:N]: for tile_k in hl.tile(k):
58+
# src[test_specialize.py:N-N]: ...
59+
_launcher(_helion_matmul, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(56, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
60+
# src[test_specialize.py:N]: return out
61+
return out
62+
63+
--- assertExpectedJournal(TestMarkStatic.test_mark_static_and_hl_specialize)
64+
from __future__ import annotations
65+
66+
import torch
67+
import triton
68+
import triton.language as tl
69+
from helion.runtime import default_launcher as _default_launcher
70+
71+
@triton.jit
72+
def _helion_fn(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
73+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
74+
num_blocks_0 = tl.cdiv(320, _BLOCK_SIZE_0)
75+
pid_0 = tl.program_id(0) % num_blocks_0
76+
pid_1 = tl.program_id(0) // num_blocks_0
77+
offset_0 = pid_0 * _BLOCK_SIZE_0
78+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
79+
offset_1 = pid_1 * _BLOCK_SIZE_1
80+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
81+
mask_1 = indices_1 < 640
82+
# src[test_specialize.py:N]: out[tile] = x[tile] * 2
83+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0)
84+
v_0 = 2.0
85+
v_1 = load * v_0
86+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_1[None, :])
87+
88+
def fn(x: torch.Tensor, *, _launcher=_default_launcher):
89+
# src[test_specialize.py:N]: out = torch.empty_like(x)
90+
out = torch.empty_like(x)
91+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
92+
_BLOCK_SIZE_0 = 16
93+
_BLOCK_SIZE_1 = 16
94+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
95+
# src[test_specialize.py:N]: out[tile] = x[tile] * 2
96+
_launcher(_helion_fn, (triton.cdiv(320, _BLOCK_SIZE_0) * triton.cdiv(640, _BLOCK_SIZE_1),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
97+
# src[test_specialize.py:N]: return out
98+
return out
99+
4100
--- assertExpectedJournal(TestSpecialize.test_dynamic_size_block_non_power_of_two)
5101
from __future__ import annotations
6102

test/test_specialize.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,5 +327,92 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
327327
self.assertExpectedJournal(code)
328328

329329

330+
@skipIfCpu("needs to be debugged")
331+
class TestMarkStatic(RefEagerTestBase, TestCase):
332+
"""Tests for torch._dynamo.mark_static() external specialization API."""
333+
334+
maxDiff = 163842
335+
336+
def test_mark_static(self):
337+
"""Test mark_static: multiple tensors, multiple dims, negative indexing."""
338+
339+
@helion.kernel(autotune_effort="none", static_shapes=False)
340+
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
341+
m, k = x.size()
342+
k2, n = y.size()
343+
out = torch.empty([m, n], device=x.device, dtype=x.dtype)
344+
for tile_m, tile_n in hl.tile([m, n]):
345+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
346+
for tile_k in hl.tile(k):
347+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
348+
out[tile_m, tile_n] = acc.to(x.dtype)
349+
return out
350+
351+
m, k, n = 64, 128, 56
352+
353+
# First, run WITHOUT mark_static - dimensions should NOT be constants
354+
x = torch.randn([m, k], device=DEVICE, dtype=torch.float16)
355+
y = torch.randn([k, n], device=DEVICE, dtype=torch.float16)
356+
code_no_spec, result_no_spec = code_and_output(
357+
matmul, (x, y), block_sizes=[32, 32, 32]
358+
)
359+
torch.testing.assert_close(result_no_spec, x @ y, rtol=1e-2, atol=1e-2)
360+
self.assertNotIn("64", code_no_spec)
361+
self.assertNotIn("128", code_no_spec)
362+
self.assertNotIn("56", code_no_spec)
363+
364+
# Now, run WITH mark_static - dimensions SHOULD be constants
365+
x_static = torch.randn([m, k], device=DEVICE, dtype=torch.float16)
366+
y_static = torch.randn([k, n], device=DEVICE, dtype=torch.float16)
367+
torch._dynamo.mark_static(x_static, [0, -1]) # test list and negative index
368+
torch._dynamo.mark_static(y_static, 1)
369+
370+
code, result = code_and_output(
371+
matmul, (x_static, y_static), block_sizes=[32, 32, 32]
372+
)
373+
torch.testing.assert_close(result, x_static @ y_static, rtol=1e-2, atol=1e-2)
374+
self.assertIn("64", code)
375+
self.assertIn("128", code)
376+
self.assertIn("56", code)
377+
self.assertExpectedJournal(code)
378+
379+
# Cache hit: same tensors
380+
self.assertIs(
381+
matmul.bind((x_static, y_static)), matmul.bind((x_static, y_static))
382+
)
383+
# Cache miss: different specialized values
384+
x2 = torch.randn([48, 96], device=DEVICE, dtype=torch.float16)
385+
y2 = torch.randn([96, 24], device=DEVICE, dtype=torch.float16)
386+
torch._dynamo.mark_static(x2, [0, -1])
387+
torch._dynamo.mark_static(y2, 1)
388+
self.assertIsNot(matmul.bind((x_static, y_static)), matmul.bind((x2, y2)))
389+
390+
def test_mark_static_and_hl_specialize(self):
391+
"""Test that external mark_static and internal hl.specialize form a union."""
392+
393+
@helion.kernel(autotune_effort="none", static_shapes=False)
394+
def fn(x: torch.Tensor) -> torch.Tensor:
395+
hl.specialize(x.size(0)) # internal specialize on dim 0
396+
out = torch.empty_like(x)
397+
for tile in hl.tile(x.size()):
398+
out[tile] = x[tile] * 2
399+
return out
400+
401+
# mark_static on dim 1 should combine with hl.specialize on dim 0
402+
x = torch.randn([320, 640], device=DEVICE)
403+
torch._dynamo.mark_static(x, -1)
404+
405+
code, result = code_and_output(fn, (x,), block_sizes=[16, 16])
406+
torch.testing.assert_close(result, x * 2)
407+
self.assertIn("320", code) # dim 0 from hl.specialize
408+
self.assertIn("640", code) # dim 1 from mark_static
409+
self.assertExpectedJournal(code)
410+
411+
# Cache miss: changing externally-specialized dim
412+
x2 = torch.randn([320, 128], device=DEVICE)
413+
torch._dynamo.mark_static(x2, -1)
414+
self.assertIsNot(fn.bind((x,)), fn.bind((x2,)))
415+
416+
330417
if __name__ == "__main__":
331418
unittest.main()

0 commit comments

Comments
 (0)