Skip to content

Commit 65e1146

Browse files
committed
test
1 parent 7caeaa2 commit 65e1146

File tree

4 files changed

+291
-32
lines changed

4 files changed

+291
-32
lines changed

test/test_examples.expected

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -460,27 +460,11 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
460460
_RDIM_SIZE_2 = 64
461461
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
462462
_BLOCK_SIZE_0 = 1
463-
# src[attention.py:N]: q = q_view[tile_b, tile_m, :]
464-
_SHAPE_DIM = q_in.size(3)
465-
_SHAPE_DIM_1 = q_in.size(3)
466-
_SHAPE_DIM_2 = q_in.size(3)
467463
# src[attention.py:N]: for tile_n in hl.tile(v_view.size(1)):
468464
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
469465
# src[attention.py:N]: qk = torch.bmm(q, k)
470466
# src[attention.py:N-N]: ...
471467
_BLOCK_SIZE_3 = 32
472-
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
473-
_SHAPE_DIM_3 = q_in.size(3)
474-
_SHAPE_DIM_4 = q_in.size(3)
475-
_SHAPE_DIM_5 = q_in.size(3)
476-
# src[attention.py:N]: v = v_view[tile_b, tile_n, :]
477-
_SHAPE_DIM_6 = q_in.size(3)
478-
_SHAPE_DIM_7 = q_in.size(3)
479-
_SHAPE_DIM_8 = q_in.size(3)
480-
# src[attention.py:N]: out[tile_b, tile_m, :] = acc.to(out.dtype)
481-
_SHAPE_DIM_9 = q_in.size(3)
482-
_SHAPE_DIM_10 = q_in.size(3)
483-
_SHAPE_DIM_11 = q_in.size(3)
484468
# src[attention.py:N]: for tile_b, tile_m in hl.tile([q_view.size(0), m_dim]):
485469
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
486470
# src[attention.py:N]: l_i = torch.full_like(m_i, 1.0)

test/test_specialize.expected

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,3 +1132,142 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d
11321132
_launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
11331133
# src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype)
11341134
return grad_weight.sum(0).to(x.dtype)
1135+
1136+
--- assertExpectedJournal(TestMarkStatic.test_mark_static)
1137+
from __future__ import annotations
1138+
1139+
import torch
1140+
import triton
1141+
import triton.language as tl
1142+
from helion.runtime import default_launcher as _default_launcher
1143+
1144+
@triton.jit
1145+
def _helion_matmul(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1146+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1147+
num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
1148+
pid_0 = tl.program_id(0) % num_blocks_0
1149+
pid_1 = tl.program_id(0) // num_blocks_0
1150+
offset_0 = pid_0 * _BLOCK_SIZE_0
1151+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1152+
mask_0 = indices_0 < 64
1153+
offset_1 = pid_1 * _BLOCK_SIZE_1
1154+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1155+
mask_1 = indices_1 < 56
1156+
# src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
1157+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1158+
# src[test_specialize.py:N]: for tile_k in hl.tile(k):
1159+
# src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
1160+
symnode_0 = 128
1161+
for offset_2 in tl.range(0, symnode_0.to(tl.int32), _BLOCK_SIZE_2):
1162+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1163+
mask_2 = indices_2 < symnode_0
1164+
acc_copy = acc
1165+
acc_copy_0 = acc_copy
1166+
# src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
1167+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1168+
load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0)
1169+
acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
1170+
# src[test_specialize.py:N]: out[tile_m, tile_n] = acc.to(x.dtype)
1171+
v_0 = tl.cast(acc, tl.float16)
1172+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_0, mask_0[:, None] & mask_1[None, :])
1173+
1174+
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1175+
# src[test_specialize.py:N]: m, k = x.size()
1176+
m, k = x.size()
1177+
# src[test_specialize.py:N]: k2, n = y.size()
1178+
k2, n = y.size()
1179+
# src[test_specialize.py:N]: out = torch.empty([m, n], device=x.device, dtype=x.dtype)
1180+
out = torch.empty([m, n], device=x.device, dtype=x.dtype)
1181+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1182+
_BLOCK_SIZE_0 = 32
1183+
_BLOCK_SIZE_1 = 32
1184+
# src[test_specialize.py:N]: for tile_k in hl.tile(k):
1185+
# src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
1186+
_BLOCK_SIZE_2 = 32
1187+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1188+
# src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
1189+
# src[test_specialize.py:N]: for tile_k in hl.tile(k):
1190+
# src[test_specialize.py:N-N]: ...
1191+
_launcher(_helion_matmul, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(56, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
1192+
# src[test_specialize.py:N]: return out
1193+
return out
1194+
1195+
--- assertExpectedJournal(TestMarkStatic.test_mark_static_and_hl_specialize)
1196+
from __future__ import annotations
1197+
1198+
import torch
1199+
import triton
1200+
import triton.language as tl
1201+
from helion.runtime import default_launcher as _default_launcher
1202+
1203+
@triton.jit
1204+
def _helion_dual_specialize(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1205+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
1206+
num_blocks_0 = tl.cdiv(320, _BLOCK_SIZE_0)
1207+
pid_0 = tl.program_id(0) % num_blocks_0
1208+
pid_1 = tl.program_id(0) // num_blocks_0
1209+
offset_0 = pid_0 * _BLOCK_SIZE_0
1210+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1211+
offset_1 = pid_1 * _BLOCK_SIZE_1
1212+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1213+
mask_1 = indices_1 < 640
1214+
# src[test_specialize.py:N]: out[tile] = x[tile] * 2
1215+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0)
1216+
v_0 = 2.0
1217+
v_1 = load * v_0
1218+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_1[None, :])
1219+
1220+
def dual_specialize(x: torch.Tensor, *, _launcher=_default_launcher):
1221+
# src[test_specialize.py:N]: out = torch.empty_like(x)
1222+
out = torch.empty_like(x)
1223+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
1224+
_BLOCK_SIZE_0 = 16
1225+
_BLOCK_SIZE_1 = 16
1226+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
1227+
# src[test_specialize.py:N]: out[tile] = x[tile] * 2
1228+
_launcher(_helion_dual_specialize, (triton.cdiv(320, _BLOCK_SIZE_0) * triton.cdiv(640, _BLOCK_SIZE_1),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
1229+
# src[test_specialize.py:N]: return out
1230+
return out
1231+
1232+
--- assertExpectedJournal(TestMarkStatic.test_mark_static_multiple_tensors)
1233+
from __future__ import annotations
1234+
1235+
import torch
1236+
import triton
1237+
import triton.language as tl
1238+
from helion.runtime import default_launcher as _default_launcher
1239+
1240+
@triton.jit
1241+
def _helion_fn(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1242+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1243+
num_blocks_0 = tl.cdiv(37, _BLOCK_SIZE_0)
1244+
pid_0 = tl.program_id(0) % num_blocks_0
1245+
pid_1 = tl.program_id(0) // num_blocks_0
1246+
offset_0 = pid_0 * _BLOCK_SIZE_0
1247+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1248+
mask_0 = indices_0 < 37
1249+
offset_1 = pid_1 * _BLOCK_SIZE_1
1250+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1251+
mask_1 = indices_1 < n
1252+
# src[test_specialize.py:N]: out[tile_m, tile_n] = x[tile_m, tile_n] * p
1253+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1254+
symnode_0 = 127
1255+
v_0 = tl.cast(symnode_0, tl.float32)
1256+
v_1 = load * v_0
1257+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
1258+
1259+
def fn(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1260+
# src[test_specialize.py:N]: m, n = x.size()
1261+
m, n = x.size()
1262+
# src[test_specialize.py:N]: p = y.size(1) # use y's dim 1 as a scalar
1263+
p = y.size(1)
1264+
# src[test_specialize.py:N]: out = x.new_empty([m, n])
1265+
out = x.new_empty([m, n])
1266+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1267+
_BLOCK_SIZE_0 = 16
1268+
_BLOCK_SIZE_1 = 16
1269+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1270+
# src[test_specialize.py:N]: out[tile_m, tile_n] = x[tile_m, tile_n] * p
1271+
_launcher(_helion_fn, (triton.cdiv(37, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
1272+
# src[test_specialize.py:N]: return out
1273+
return out

test/test_specialize.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,5 +327,157 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
327327
self.assertExpectedJournal(code)
328328

329329

330+
@skipIfCpu("needs to be debugged")
331+
class TestMarkStatic(RefEagerTestBase, TestCase):
332+
"""Tests for torch._dynamo.mark_static() external specialization API."""
333+
334+
maxDiff = 163842
335+
336+
def test_mark_static(self):
337+
"""Test mark_static: multiple tensors, multiple dims, negative indexing."""
338+
339+
@helion.kernel(autotune_effort="none", static_shapes=False)
340+
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
341+
m, k = x.size()
342+
k2, n = y.size()
343+
out = torch.empty([m, n], device=x.device, dtype=x.dtype)
344+
for tile_m, tile_n in hl.tile([m, n]):
345+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
346+
for tile_k in hl.tile(k):
347+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
348+
out[tile_m, tile_n] = acc.to(x.dtype)
349+
return out
350+
351+
m, k, n = 64, 128, 56
352+
x = torch.randn([m, k], device=DEVICE, dtype=torch.float16)
353+
y = torch.randn([k, n], device=DEVICE, dtype=torch.float16)
354+
355+
# First, run WITHOUT mark_static - dimensions should NOT be constants
356+
code_no_spec, result_no_spec = code_and_output(
357+
matmul,
358+
(x, y),
359+
block_sizes=[32, 32, 32],
360+
)
361+
torch.testing.assert_close(result_no_spec, x @ y, rtol=1e-2, atol=1e-2)
362+
self.assertNotIn("64", code_no_spec) # x dim 0 = m should NOT be specialized
363+
self.assertNotIn("128", code_no_spec) # x dim -1 = k should NOT be specialized
364+
self.assertNotIn("56", code_no_spec) # y dim 1 = n should NOT be specialized
365+
366+
# Now, run WITH mark_static - dimensions SHOULD be constants
367+
# Create fresh tensors and mark them static
368+
x_static = torch.randn([m, k], device=DEVICE, dtype=torch.float16)
369+
y_static = torch.randn([k, n], device=DEVICE, dtype=torch.float16)
370+
torch._dynamo.mark_static(x_static, [0, -1])
371+
torch._dynamo.mark_static(y_static, 1)
372+
373+
code, result = code_and_output(
374+
matmul,
375+
(x_static, y_static),
376+
block_sizes=[32, 32, 32],
377+
)
378+
torch.testing.assert_close(result, x_static @ y_static, rtol=1e-2, atol=1e-2)
379+
self.assertIn("64", code) # x dim 0 = m
380+
self.assertIn("128", code) # x dim -1 = k
381+
self.assertIn("56", code) # y dim 1 = n
382+
self.assertExpectedJournal(code)
383+
384+
# Verify cache behavior: same specialized values hit cache
385+
self.assertIs(matmul.bind((x_static, y_static)), matmul.bind((x_static, y_static)))
386+
# Verify cache behavior: different specialized values produce different bound kernels
387+
x2 = torch.randn([48, 96], device=DEVICE, dtype=torch.float16)
388+
y2 = torch.randn([96, 24], device=DEVICE, dtype=torch.float16)
389+
torch._dynamo.mark_static(x2, [0, -1])
390+
torch._dynamo.mark_static(y2, 1)
391+
self.assertIsNot(
392+
matmul.bind((x_static, y_static)), matmul.bind((x2, y2))
393+
)
394+
395+
def test_mark_static_and_hl_specialize(self):
396+
"""Test that external mark_static and internal hl.specialize form a union."""
397+
398+
@helion.kernel(autotune_effort="none", static_shapes=False)
399+
def dual_specialize(x: torch.Tensor) -> torch.Tensor:
400+
# Internal specialize on dim 0
401+
hl.specialize(x.size(0))
402+
out = torch.empty_like(x)
403+
for tile in hl.tile(x.size()):
404+
out[tile] = x[tile] * 2
405+
return out
406+
407+
x = torch.randn([320, 640], device=DEVICE)
408+
409+
# First, run WITHOUT external mark_static - only dim 0 should be specialized
410+
code_no_spec, result_no_spec = code_and_output(
411+
dual_specialize,
412+
(x,),
413+
block_sizes=[16, 16],
414+
)
415+
torch.testing.assert_close(result_no_spec, x * 2)
416+
self.assertIn("320", code_no_spec) # dim 0 from internal specialize
417+
self.assertNotIn("640", code_no_spec) # dim 1 should NOT be specialized
418+
419+
# Now, run WITH external mark_static on dim -1 (dim 1)
420+
# Result: both dim 0 AND dim 1 are specialized (union)
421+
x_static = torch.randn([320, 640], device=DEVICE)
422+
torch._dynamo.mark_static(x_static, -1)
423+
424+
code, result = code_and_output(
425+
dual_specialize,
426+
(x_static,),
427+
block_sizes=[16, 16],
428+
)
429+
torch.testing.assert_close(result, x_static * 2)
430+
# Both dimensions should appear as constants
431+
self.assertIn("320", code) # dim 0 from internal specialize
432+
self.assertIn("640", code) # dim 1 from external mark_static
433+
self.assertExpectedJournal(code)
434+
435+
# Verify cache behavior: changing dim 1 (external) produces different bound kernel
436+
x2 = torch.randn([320, 128], device=DEVICE) # same dim 0, different dim 1
437+
torch._dynamo.mark_static(x2, -1)
438+
self.assertIsNot(dual_specialize.bind((x_static,)), dual_specialize.bind((x2,)))
439+
440+
def test_mark_static_multiple_tensors(self):
441+
"""Test mark_static on multiple tensors."""
442+
443+
@helion.kernel(autotune_effort="none", static_shapes=False)
444+
def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
445+
m, n = x.size()
446+
p = y.size(1) # use y's dim 1 as a scalar
447+
out = x.new_empty([m, n])
448+
for tile_m, tile_n in hl.tile([m, n]):
449+
out[tile_m, tile_n] = x[tile_m, tile_n] * p
450+
return out
451+
452+
x = torch.randn([37, 64], device=DEVICE)
453+
y = torch.randn([48, 127], device=DEVICE)
454+
455+
# First, run WITHOUT mark_static - dimensions should NOT be constants
456+
code_no_spec, result_no_spec = code_and_output(fn, (x, y), block_sizes=[16, 16])
457+
torch.testing.assert_close(result_no_spec, x * 127)
458+
self.assertNotIn("37", code_no_spec) # x dim 0 should NOT be specialized
459+
self.assertNotIn("127", code_no_spec) # y dim 1 should NOT be specialized
460+
461+
# Now, mark both tensors static
462+
x_static = torch.randn([37, 64], device=DEVICE)
463+
y_static = torch.randn([48, 127], device=DEVICE)
464+
torch._dynamo.mark_static(x_static, 0)
465+
torch._dynamo.mark_static(y_static, 1)
466+
467+
code, result = code_and_output(fn, (x_static, y_static), block_sizes=[16, 16])
468+
torch.testing.assert_close(result, x_static * 127)
469+
# Both specializations should be present
470+
self.assertIn("37", code) # x dim 0
471+
self.assertIn("127", code) # y dim 1
472+
self.assertExpectedJournal(code)
473+
474+
# Verify cache behavior: changing specialized values produces different bound kernels
475+
x2 = torch.randn([48, 64], device=DEVICE) # different dim 0
476+
y2 = torch.randn([48, 256], device=DEVICE) # different dim 1
477+
torch._dynamo.mark_static(x2, 0)
478+
torch._dynamo.mark_static(y2, 1)
479+
self.assertIsNot(fn.bind((x_static, y_static)), fn.bind((x2, y2)))
480+
481+
330482
if __name__ == "__main__":
331483
unittest.main()

test/test_tensor_descriptor.expected

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -123,27 +123,11 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
123123
_RDIM_SIZE_2 = 64
124124
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
125125
_BLOCK_SIZE_0 = 1
126-
# src[attention.py:N]: q = q_view[tile_b, tile_m, :]
127-
_SHAPE_DIM = q_in.size(3)
128-
_SHAPE_DIM_1 = q_in.size(3)
129-
_SHAPE_DIM_2 = q_in.size(3)
130126
# src[attention.py:N]: for tile_n in hl.tile(v_view.size(1)):
131127
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
132128
# src[attention.py:N]: qk = torch.bmm(q, k)
133129
# src[attention.py:N-N]: ...
134130
_BLOCK_SIZE_3 = 16
135-
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
136-
_SHAPE_DIM_3 = q_in.size(3)
137-
_SHAPE_DIM_4 = q_in.size(3)
138-
_SHAPE_DIM_5 = q_in.size(3)
139-
# src[attention.py:N]: v = v_view[tile_b, tile_n, :]
140-
_SHAPE_DIM_6 = q_in.size(3)
141-
_SHAPE_DIM_7 = q_in.size(3)
142-
_SHAPE_DIM_8 = q_in.size(3)
143-
# src[attention.py:N]: out[tile_b, tile_m, :] = acc.to(out.dtype)
144-
_SHAPE_DIM_9 = q_in.size(3)
145-
_SHAPE_DIM_10 = q_in.size(3)
146-
_SHAPE_DIM_11 = q_in.size(3)
147131
# src[attention.py:N]: for tile_b, tile_m in hl.tile([q_view.size(0), m_dim]):
148132
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
149133
# src[attention.py:N]: l_i = torch.full_like(m_i, 1.0)

0 commit comments

Comments
 (0)