Skip to content

Commit ca30352

Browse files
committed
test
1 parent 7aada66 commit ca30352

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed

test/test_specialize.expected

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,87 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
335335
# src[test_specialize.py:N]: return out
336336
return out
337337

338+
--- assertExpectedJournal(TestSpecialize.test_specialize_stride_basic)
339+
from __future__ import annotations
340+
341+
import torch
342+
import triton
343+
import triton.language as tl
344+
from helion.runtime import default_launcher as _default_launcher
345+
346+
@triton.jit
347+
def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
348+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
349+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
350+
pid_0 = tl.program_id(0) % num_blocks_0
351+
pid_1 = tl.program_id(0) // num_blocks_0
352+
offset_0 = pid_0 * _BLOCK_SIZE_0
353+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
354+
mask_0 = indices_0 < x_size_0
355+
offset_1 = pid_1 * _BLOCK_SIZE_1
356+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
357+
mask_1 = indices_1 < x_size_1
358+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride
359+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
360+
v_0 = 137.0
361+
v_1 = load + v_0
362+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
363+
364+
def fn(x: torch.Tensor, *, _launcher=_default_launcher):
365+
# src[test_specialize.py:N]: out = torch.empty_like(x)
366+
out = torch.empty_like(x)
367+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
368+
_BLOCK_SIZE_0 = 32
369+
_BLOCK_SIZE_1 = 32
370+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
371+
# src[test_specialize.py:N]: # Use stride in computation to verify it's a constant
372+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride
373+
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
374+
# src[test_specialize.py:N]: return out
375+
return out
376+
377+
--- assertExpectedJournal(TestSpecialize.test_specialize_stride_tuple)
378+
from __future__ import annotations
379+
380+
import torch
381+
import triton
382+
import triton.language as tl
383+
from helion.runtime import default_launcher as _default_launcher
384+
385+
@triton.jit
386+
def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
387+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
388+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
389+
pid_0 = tl.program_id(0) % num_blocks_0
390+
pid_1 = tl.program_id(0) // num_blocks_0
391+
offset_0 = pid_0 * _BLOCK_SIZE_0
392+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
393+
mask_0 = indices_0 < x_size_0
394+
offset_1 = pid_1 * _BLOCK_SIZE_1
395+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
396+
mask_1 = indices_1 < x_size_1
397+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
398+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
399+
v_0 = 311.0
400+
v_1 = load + v_0
401+
v_2 = 131.0
402+
v_3 = v_1 + v_2
403+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
404+
405+
def fn(x: torch.Tensor, *, _launcher=_default_launcher):
406+
# src[test_specialize.py:N]: stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
407+
stride0, stride1 = (311, 131)
408+
# src[test_specialize.py:N]: out = torch.empty_like(x)
409+
out = torch.empty_like(x)
410+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
411+
_BLOCK_SIZE_0 = 32
412+
_BLOCK_SIZE_1 = 32
413+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
414+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
415+
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
416+
# src[test_specialize.py:N]: return out
417+
return out
418+
338419
--- assertExpectedJournal(TestSpecialize.test_specialize_tuple_element)
339420
from __future__ import annotations
340421

test/test_specialize.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,102 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
326326
self.assertIn("65536", code)
327327
self.assertExpectedJournal(code)
328328

329+
def test_specialize_stride_basic(self):
330+
"""Test that hl.specialize works with tensor strides."""
331+
332+
@helion.kernel(static_shapes=False, autotune_effort="none")
333+
def fn(x: torch.Tensor) -> torch.Tensor:
334+
stride = hl.specialize(x.stride(0))
335+
out = torch.empty_like(x)
336+
for tile in hl.tile(x.size()):
337+
# Use stride in computation to verify it's a constant
338+
out[tile] = x[tile] + stride
339+
return out
340+
341+
# Use empty_strided to create tensor with a unique stride value (137)
342+
# that won't be confused with shape values
343+
size = (64, 64)
344+
stride0 = 137 # Distinctive prime number for stride(0)
345+
stride1 = 1
346+
# Need storage size to fit: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1
347+
storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1
348+
storage = torch.randn(storage_size, device=DEVICE)
349+
x = torch.as_strided(storage, size, (stride0, stride1))
350+
351+
code, result = code_and_output(fn, (x,))
352+
torch.testing.assert_close(result, x + x.stride(0))
353+
# Verify the unique stride value 137 is inlined as a constant
354+
self.assertIn("137", code)
355+
self.assertExpectedJournal(code)
356+
357+
def test_specialize_stride_creates_different_variants(self):
358+
"""Test that different stride patterns create different kernel variants."""
359+
360+
@helion.kernel(static_shapes=False, autotune_effort="none")
361+
def fn(x: torch.Tensor) -> torch.Tensor:
362+
stride = hl.specialize(x.stride(0))
363+
out = torch.empty_like(x)
364+
for tile in hl.tile(x.size()):
365+
out[tile] = x[tile] + stride
366+
return out
367+
368+
# Create two tensors with different unique stride values using empty_strided
369+
size = (64, 64)
370+
371+
# First tensor with stride(0) = 173 (distinctive prime)
372+
stride0_a = 173
373+
storage_size_a = (size[0] - 1) * stride0_a + (size[1] - 1) * 1 + 1
374+
storage_a = torch.randn(storage_size_a, device=DEVICE)
375+
x_a = torch.as_strided(storage_a, size, (stride0_a, 1))
376+
377+
# Second tensor with stride(0) = 257 (different distinctive prime)
378+
stride0_b = 257
379+
storage_size_b = (size[0] - 1) * stride0_b + (size[1] - 1) * 1 + 1
380+
storage_b = torch.randn(storage_size_b, device=DEVICE)
381+
x_b = torch.as_strided(storage_b, size, (stride0_b, 1))
382+
383+
# These should create different bound kernels due to different strides
384+
bound1 = fn.bind((x_a,))
385+
bound2 = fn.bind((x_b,))
386+
387+
# Verify different variants are used
388+
self.assertTrueIfInNormalMode(bound1 is not bound2)
389+
390+
# Verify correctness
391+
result1 = fn(x_a)
392+
result2 = fn(x_b)
393+
torch.testing.assert_close(result1, x_a + stride0_a)
394+
torch.testing.assert_close(result2, x_b + stride0_b)
395+
396+
def test_specialize_stride_tuple(self):
397+
"""Test that hl.specialize works with tuple of strides."""
398+
399+
@helion.kernel(static_shapes=False, autotune_effort="none")
400+
def fn(x: torch.Tensor) -> torch.Tensor:
401+
stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
402+
out = torch.empty_like(x)
403+
for tile in hl.tile(x.size()):
404+
out[tile] = x[tile] + stride0 + stride1
405+
return out
406+
407+
# Create tensor with unique stride values using empty_strided
408+
# stride0 = 311, stride1 = 131 (distinctive primes unlikely to appear elsewhere)
409+
size = (64, 64)
410+
stride0 = 311
411+
stride1 = 131
412+
# Storage must fit the largest offset: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1
413+
storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1
414+
storage = torch.randn(storage_size, device=DEVICE)
415+
x = torch.as_strided(storage, size, (stride0, stride1))
416+
417+
code, result = code_and_output(fn, (x,))
418+
expected = x + stride0 + stride1
419+
torch.testing.assert_close(result, expected)
420+
# Verify both unique stride values appear in the generated code
421+
self.assertIn("311", code)
422+
self.assertIn("131", code)
423+
self.assertExpectedJournal(code)
424+
329425

330426
if __name__ == "__main__":
331427
unittest.main()

0 commit comments

Comments
 (0)