|
1 | 1 | This file is automatically generated by assertExpectedJournal calls in test_specialize.py. |
2 | 2 | Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. |
3 | 3 |
|
| 4 | +--- assertExpectedJournal(TestMarkStatic.test_mark_static) |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import torch |
| 8 | +import triton |
| 9 | +import triton.language as tl |
| 10 | +from helion.runtime import default_launcher as _default_launcher |
| 11 | + |
| 12 | +@triton.jit |
| 13 | +def _helion_matmul(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): |
| 14 | + # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]): |
| 15 | + num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0) |
| 16 | + pid_0 = tl.program_id(0) % num_blocks_0 |
| 17 | + pid_1 = tl.program_id(0) // num_blocks_0 |
| 18 | + offset_0 = pid_0 * _BLOCK_SIZE_0 |
| 19 | + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) |
| 20 | + mask_0 = indices_0 < 64 |
| 21 | + offset_1 = pid_1 * _BLOCK_SIZE_1 |
| 22 | + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) |
| 23 | + mask_1 = indices_1 < 56 |
| 24 | + # src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 25 | + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) |
| 26 | + # src[test_specialize.py:N]: for tile_k in hl.tile(k): |
| 27 | + # src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) |
| 28 | + symnode_0 = 128 |
| 29 | + for offset_2 in tl.range(0, symnode_0.to(tl.int32), _BLOCK_SIZE_2): |
| 30 | + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) |
| 31 | + mask_2 = indices_2 < symnode_0 |
| 32 | + acc_copy = acc |
| 33 | + acc_copy_0 = acc_copy |
| 34 | + # src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) |
| 35 | + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) |
| 36 | + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) |
| 37 | + acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) |
| 38 | + # src[test_specialize.py:N]: out[tile_m, tile_n] = acc.to(x.dtype) |
| 39 | + v_0 = tl.cast(acc, tl.float16) |
| 40 | + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_0, mask_0[:, None] & mask_1[None, :]) |
| 41 | + |
| 42 | +def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): |
| 43 | + # src[test_specialize.py:N]: m, k = x.size() |
| 44 | + m, k = x.size() |
| 45 | + # src[test_specialize.py:N]: k2, n = y.size() |
| 46 | + k2, n = y.size() |
| 47 | + # src[test_specialize.py:N]: out = torch.empty([m, n], device=x.device, dtype=x.dtype) |
| 48 | + out = torch.empty([m, n], device=x.device, dtype=x.dtype) |
| 49 | + # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]): |
| 50 | + _BLOCK_SIZE_0 = 32 |
| 51 | + _BLOCK_SIZE_1 = 32 |
| 52 | + # src[test_specialize.py:N]: for tile_k in hl.tile(k): |
| 53 | + # src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) |
| 54 | + _BLOCK_SIZE_2 = 32 |
| 55 | + # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]): |
| 56 | + # src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 57 | + # src[test_specialize.py:N]: for tile_k in hl.tile(k): |
| 58 | + # src[test_specialize.py:N-N]: ... |
| 59 | + _launcher(_helion_matmul, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(56, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) |
| 60 | + # src[test_specialize.py:N]: return out |
| 61 | + return out |
| 62 | + |
| 63 | +--- assertExpectedJournal(TestMarkStatic.test_mark_static_and_hl_specialize) |
| 64 | +from __future__ import annotations |
| 65 | + |
| 66 | +import torch |
| 67 | +import triton |
| 68 | +import triton.language as tl |
| 69 | +from helion.runtime import default_launcher as _default_launcher |
| 70 | + |
| 71 | +@triton.jit |
| 72 | +def _helion_fn(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): |
| 73 | + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): |
| 74 | + num_blocks_0 = tl.cdiv(320, _BLOCK_SIZE_0) |
| 75 | + pid_0 = tl.program_id(0) % num_blocks_0 |
| 76 | + pid_1 = tl.program_id(0) // num_blocks_0 |
| 77 | + offset_0 = pid_0 * _BLOCK_SIZE_0 |
| 78 | + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) |
| 79 | + offset_1 = pid_1 * _BLOCK_SIZE_1 |
| 80 | + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) |
| 81 | + mask_1 = indices_1 < 640 |
| 82 | + # src[test_specialize.py:N]: out[tile] = x[tile] * 2 |
| 83 | + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) |
| 84 | + v_0 = 2.0 |
| 85 | + v_1 = load * v_0 |
| 86 | + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_1[None, :]) |
| 87 | + |
| 88 | +def fn(x: torch.Tensor, *, _launcher=_default_launcher): |
| 89 | + # src[test_specialize.py:N]: out = torch.empty_like(x) |
| 90 | + out = torch.empty_like(x) |
| 91 | + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): |
| 92 | + _BLOCK_SIZE_0 = 16 |
| 93 | + _BLOCK_SIZE_1 = 16 |
| 94 | + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): |
| 95 | + # src[test_specialize.py:N]: out[tile] = x[tile] * 2 |
| 96 | + _launcher(_helion_fn, (triton.cdiv(320, _BLOCK_SIZE_0) * triton.cdiv(640, _BLOCK_SIZE_1),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) |
| 97 | + # src[test_specialize.py:N]: return out |
| 98 | + return out |
| 99 | + |
4 | 100 | --- assertExpectedJournal(TestSpecialize.test_dynamic_size_block_non_power_of_two) |
5 | 101 | from __future__ import annotations |
6 | 102 |
|
|
0 commit comments