@@ -335,6 +335,87 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
335335 # src[test_specialize.py:N]: return out
336336 return out
337337
338+ --- assertExpectedJournal(TestSpecialize.test_specialize_stride_basic)
339+ from __future__ import annotations
340+
341+ import torch
342+ import triton
343+ import triton.language as tl
344+ from helion.runtime import default_launcher as _default_launcher
345+
346+ @triton.jit
347+ def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
348+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
349+ num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
350+ pid_0 = tl.program_id(0) % num_blocks_0
351+ pid_1 = tl.program_id(0) // num_blocks_0
352+ offset_0 = pid_0 * _BLOCK_SIZE_0
353+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
354+ mask_0 = indices_0 < x_size_0
355+ offset_1 = pid_1 * _BLOCK_SIZE_1
356+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
357+ mask_1 = indices_1 < x_size_1
358+ # src[test_specialize.py:N]: out[tile] = x[tile] + stride
359+ load = tl.load(x + (indices_0[:, None] * 137 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
360+ v_0 = 137.0
361+ v_1 = load + v_0
362+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
363+
364+ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
365+ # src[test_specialize.py:N]: out = torch.empty_like(x)
366+ out = torch.empty_like(x)
367+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
368+ _BLOCK_SIZE_0 = 32
369+ _BLOCK_SIZE_1 = 32
370+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
371+ # src[test_specialize.py:N]: # Use stride in computation to verify it's a constant
372+ # src[test_specialize.py:N]: out[tile] = x[tile] + stride
373+ _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
374+ # src[test_specialize.py:N]: return out
375+ return out
376+
377+ --- assertExpectedJournal(TestSpecialize.test_specialize_stride_tuple)
378+ from __future__ import annotations
379+
380+ import torch
381+ import triton
382+ import triton.language as tl
383+ from helion.runtime import default_launcher as _default_launcher
384+
385+ @triton.jit
386+ def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
387+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
388+ num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
389+ pid_0 = tl.program_id(0) % num_blocks_0
390+ pid_1 = tl.program_id(0) // num_blocks_0
391+ offset_0 = pid_0 * _BLOCK_SIZE_0
392+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
393+ mask_0 = indices_0 < x_size_0
394+ offset_1 = pid_1 * _BLOCK_SIZE_1
395+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
396+ mask_1 = indices_1 < x_size_1
397+ # src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
398+ load = tl.load(x + (indices_0[:, None] * 311 + indices_1[None, :] * 131), mask_0[:, None] & mask_1[None, :], other=0)
399+ v_0 = 311.0
400+ v_1 = load + v_0
401+ v_2 = 131.0
402+ v_3 = v_1 + v_2
403+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
404+
405+ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
406+ # src[test_specialize.py:N]: stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
407+ stride0, stride1 = (311, 131)
408+ # src[test_specialize.py:N]: out = torch.empty_like(x)
409+ out = torch.empty_like(x)
410+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
411+ _BLOCK_SIZE_0 = 32
412+ _BLOCK_SIZE_1 = 32
413+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
414+ # src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
415+ _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
416+ # src[test_specialize.py:N]: return out
417+ return out
418+
338419--- assertExpectedJournal(TestSpecialize.test_specialize_tuple_element)
339420from __future__ import annotations
340421
0 commit comments