Skip to content

Commit 41d144f

Browse files
committed
test
1 parent 7aada66 commit 41d144f

File tree

2 files changed

+182
-0
lines changed

2 files changed

+182
-0
lines changed

test/test_specialize.expected

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,87 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
335335
# src[test_specialize.py:N]: return out
336336
return out
337337

338+
--- assertExpectedJournal(TestSpecialize.test_specialize_stride_basic)
339+
from __future__ import annotations
340+
341+
import torch
342+
import triton
343+
import triton.language as tl
344+
from helion.runtime import default_launcher as _default_launcher
345+
346+
@triton.jit
347+
def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
348+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
349+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
350+
pid_0 = tl.program_id(0) % num_blocks_0
351+
pid_1 = tl.program_id(0) // num_blocks_0
352+
offset_0 = pid_0 * _BLOCK_SIZE_0
353+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
354+
mask_0 = indices_0 < x_size_0
355+
offset_1 = pid_1 * _BLOCK_SIZE_1
356+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
357+
mask_1 = indices_1 < x_size_1
358+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride
359+
load = tl.load(x + (indices_0[:, None] * 137 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
360+
v_0 = 137.0
361+
v_1 = load + v_0
362+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
363+
364+
def fn(x: torch.Tensor, *, _launcher=_default_launcher):
365+
# src[test_specialize.py:N]: out = torch.empty_like(x)
366+
out = torch.empty_like(x)
367+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
368+
_BLOCK_SIZE_0 = 32
369+
_BLOCK_SIZE_1 = 32
370+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
371+
# src[test_specialize.py:N]: # Use stride in computation to verify it's a constant
372+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride
373+
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
374+
# src[test_specialize.py:N]: return out
375+
return out
376+
377+
--- assertExpectedJournal(TestSpecialize.test_specialize_stride_tuple)
378+
from __future__ import annotations
379+
380+
import torch
381+
import triton
382+
import triton.language as tl
383+
from helion.runtime import default_launcher as _default_launcher
384+
385+
@triton.jit
386+
def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
387+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
388+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
389+
pid_0 = tl.program_id(0) % num_blocks_0
390+
pid_1 = tl.program_id(0) // num_blocks_0
391+
offset_0 = pid_0 * _BLOCK_SIZE_0
392+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
393+
mask_0 = indices_0 < x_size_0
394+
offset_1 = pid_1 * _BLOCK_SIZE_1
395+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
396+
mask_1 = indices_1 < x_size_1
397+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
398+
load = tl.load(x + (indices_0[:, None] * 311 + indices_1[None, :] * 131), mask_0[:, None] & mask_1[None, :], other=0)
399+
v_0 = 311.0
400+
v_1 = load + v_0
401+
v_2 = 131.0
402+
v_3 = v_1 + v_2
403+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
404+
405+
def fn(x: torch.Tensor, *, _launcher=_default_launcher):
406+
# src[test_specialize.py:N]: stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
407+
stride0, stride1 = (311, 131)
408+
# src[test_specialize.py:N]: out = torch.empty_like(x)
409+
out = torch.empty_like(x)
410+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
411+
_BLOCK_SIZE_0 = 32
412+
_BLOCK_SIZE_1 = 32
413+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
414+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
415+
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
416+
# src[test_specialize.py:N]: return out
417+
return out
418+
338419
--- assertExpectedJournal(TestSpecialize.test_specialize_tuple_element)
339420
from __future__ import annotations
340421

test/test_specialize.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,107 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
326326
self.assertIn("65536", code)
327327
self.assertExpectedJournal(code)
328328

329+
def test_specialize_stride_basic(self):
330+
"""Test that hl.specialize works with tensor strides."""
331+
332+
@helion.kernel(static_shapes=False, autotune_effort="none")
333+
def fn(x: torch.Tensor) -> torch.Tensor:
334+
stride = hl.specialize(x.stride(0))
335+
out = torch.empty_like(x)
336+
for tile in hl.tile(x.size()):
337+
# Use stride in computation to verify it's a constant
338+
out[tile] = x[tile] + stride
339+
return out
340+
341+
# Use empty_strided to create tensor with a unique stride value (137)
342+
# that won't be confused with shape values
343+
size = (64, 64)
344+
stride0 = 137 # Distinctive prime number for stride(0)
345+
stride1 = 1
346+
# Need storage size to fit: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1
347+
storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1
348+
storage = torch.randn(storage_size, device=DEVICE)
349+
x = torch.as_strided(storage, size, (stride0, stride1))
350+
351+
code, result = code_and_output(fn, (x,))
352+
torch.testing.assert_close(result, x + x.stride(0))
353+
# Verify the unique stride value 137 is inlined as a constant
354+
self.assertIn("137", code)
355+
# Verify x_stride_0 is NOT passed as an argument (it should be inlined)
356+
self.assertNotIn("x_stride_0", code)
357+
self.assertExpectedJournal(code)
358+
359+
def test_specialize_stride_creates_different_variants(self):
360+
"""Test that different stride patterns create different kernel variants."""
361+
362+
@helion.kernel(static_shapes=False, autotune_effort="none")
363+
def fn(x: torch.Tensor) -> torch.Tensor:
364+
stride = hl.specialize(x.stride(0))
365+
out = torch.empty_like(x)
366+
for tile in hl.tile(x.size()):
367+
out[tile] = x[tile] + stride
368+
return out
369+
370+
# Create two tensors with different unique stride values using empty_strided
371+
size = (64, 64)
372+
373+
# First tensor with stride(0) = 173 (distinctive prime)
374+
stride0_a = 173
375+
storage_size_a = (size[0] - 1) * stride0_a + (size[1] - 1) * 1 + 1
376+
storage_a = torch.randn(storage_size_a, device=DEVICE)
377+
x_a = torch.as_strided(storage_a, size, (stride0_a, 1))
378+
379+
# Second tensor with stride(0) = 257 (different distinctive prime)
380+
stride0_b = 257
381+
storage_size_b = (size[0] - 1) * stride0_b + (size[1] - 1) * 1 + 1
382+
storage_b = torch.randn(storage_size_b, device=DEVICE)
383+
x_b = torch.as_strided(storage_b, size, (stride0_b, 1))
384+
385+
# These should create different bound kernels due to different strides
386+
bound1 = fn.bind((x_a,))
387+
bound2 = fn.bind((x_b,))
388+
389+
# Verify different variants are used
390+
self.assertTrueIfInNormalMode(bound1 is not bound2)
391+
392+
# Verify correctness
393+
result1 = fn(x_a)
394+
result2 = fn(x_b)
395+
torch.testing.assert_close(result1, x_a + stride0_a)
396+
torch.testing.assert_close(result2, x_b + stride0_b)
397+
398+
def test_specialize_stride_tuple(self):
399+
"""Test that hl.specialize works with tuple of strides."""
400+
401+
@helion.kernel(static_shapes=False, autotune_effort="none")
402+
def fn(x: torch.Tensor) -> torch.Tensor:
403+
stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
404+
out = torch.empty_like(x)
405+
for tile in hl.tile(x.size()):
406+
out[tile] = x[tile] + stride0 + stride1
407+
return out
408+
409+
# Create tensor with unique stride values using empty_strided
410+
# stride0 = 311, stride1 = 131 (distinctive primes unlikely to appear elsewhere)
411+
size = (64, 64)
412+
stride0 = 311
413+
stride1 = 131
414+
# Storage must fit the largest offset: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1
415+
storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1
416+
storage = torch.randn(storage_size, device=DEVICE)
417+
x = torch.as_strided(storage, size, (stride0, stride1))
418+
419+
code, result = code_and_output(fn, (x,))
420+
expected = x + stride0 + stride1
421+
torch.testing.assert_close(result, expected)
422+
# Verify both unique stride values appear in the generated code
423+
self.assertIn("311", code)
424+
self.assertIn("131", code)
425+
# Verify both x_stride_0 and x_stride_1 are NOT passed as arguments (they should be inlined)
426+
self.assertNotIn("x_stride_0", code)
427+
self.assertNotIn("x_stride_1", code)
428+
self.assertExpectedJournal(code)
429+
329430

330431
if __name__ == "__main__":
331432
unittest.main()

0 commit comments

Comments
 (0)