@@ -1132,3 +1132,142 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d
11321132 _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
11331133 # src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype)
11341134 return grad_weight.sum(0).to(x.dtype)
1135+
1136+ --- assertExpectedJournal(TestSpecializeArgs.test_specialize_args)
1137+ from __future__ import annotations
1138+
1139+ import torch
1140+ import triton
1141+ import triton.language as tl
1142+ from helion.runtime import default_launcher as _default_launcher
1143+
1144+ @triton.jit
1145+ def _helion_matmul(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1146+ # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1147+ num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
1148+ pid_0 = tl.program_id(0) % num_blocks_0
1149+ pid_1 = tl.program_id(0) // num_blocks_0
1150+ offset_0 = pid_0 * _BLOCK_SIZE_0
1151+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1152+ mask_0 = indices_0 < 64
1153+ offset_1 = pid_1 * _BLOCK_SIZE_1
1154+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1155+ mask_1 = indices_1 < 56
1156+ # src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
1157+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1158+ # src[test_specialize.py:N]: for tile_k in hl.tile(k):
1159+ # src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
1160+ symnode_0 = 128
1161+ for offset_2 in tl.range(0, symnode_0.to(tl.int32), _BLOCK_SIZE_2):
1162+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1163+ mask_2 = indices_2 < symnode_0
1164+ acc_copy = acc
1165+ acc_copy_0 = acc_copy
1166+ # src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
1167+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1168+ load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0)
1169+ acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
1170+ # src[test_specialize.py:N]: out[tile_m, tile_n] = acc.to(x.dtype)
1171+ v_0 = tl.cast(acc, tl.float16)
1172+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_0, mask_0[:, None] & mask_1[None, :])
1173+
1174+ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1175+ # src[test_specialize.py:N]: m, k = x.size()
1176+ m, k = x.size()
1177+ # src[test_specialize.py:N]: k2, n = y.size()
1178+ k2, n = y.size()
1179+ # src[test_specialize.py:N]: out = torch.empty([m, n], device=x.device, dtype=x.dtype)
1180+ out = torch.empty([m, n], device=x.device, dtype=x.dtype)
1181+ # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1182+ _BLOCK_SIZE_0 = 32
1183+ _BLOCK_SIZE_1 = 32
1184+ # src[test_specialize.py:N]: for tile_k in hl.tile(k):
1185+ # src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
1186+ _BLOCK_SIZE_2 = 32
1187+ # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1188+ # src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
1189+ # src[test_specialize.py:N]: for tile_k in hl.tile(k):
1190+ # src[test_specialize.py:N-N]: ...
1191+ _launcher(_helion_matmul, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(56, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
1192+ # src[test_specialize.py:N]: return out
1193+ return out
1194+
1195+ --- assertExpectedJournal(TestSpecializeArgs.test_specialize_args_and_hl_specialize)
1196+ from __future__ import annotations
1197+
1198+ import torch
1199+ import triton
1200+ import triton.language as tl
1201+ from helion.runtime import default_launcher as _default_launcher
1202+
1203+ @triton.jit
1204+ def _helion_dual_specialize(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1205+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
1206+ num_blocks_0 = tl.cdiv(320, _BLOCK_SIZE_0)
1207+ pid_0 = tl.program_id(0) % num_blocks_0
1208+ pid_1 = tl.program_id(0) // num_blocks_0
1209+ offset_0 = pid_0 * _BLOCK_SIZE_0
1210+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1211+ offset_1 = pid_1 * _BLOCK_SIZE_1
1212+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1213+ mask_1 = indices_1 < 640
1214+ # src[test_specialize.py:N]: out[tile] = x[tile] * 2
1215+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0)
1216+ v_0 = 2.0
1217+ v_1 = load * v_0
1218+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_1[None, :])
1219+
1220+ def dual_specialize(x: torch.Tensor, *, _launcher=_default_launcher):
1221+ # src[test_specialize.py:N]: out = torch.empty_like(x)
1222+ out = torch.empty_like(x)
1223+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
1224+ _BLOCK_SIZE_0 = 16
1225+ _BLOCK_SIZE_1 = 16
1226+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
1227+ # src[test_specialize.py:N]: out[tile] = x[tile] * 2
1228+ _launcher(_helion_dual_specialize, (triton.cdiv(320, _BLOCK_SIZE_0) * triton.cdiv(640, _BLOCK_SIZE_1),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
1229+ # src[test_specialize.py:N]: return out
1230+ return out
1231+
1232+ --- assertExpectedJournal(TestSpecializeArgs.test_specialize_args_chaining)
1233+ from __future__ import annotations
1234+
1235+ import torch
1236+ import triton
1237+ import triton.language as tl
1238+ from helion.runtime import default_launcher as _default_launcher
1239+
1240+ @triton.jit
1241+ def _helion_fn(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1242+ # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1243+ num_blocks_0 = tl.cdiv(37, _BLOCK_SIZE_0)
1244+ pid_0 = tl.program_id(0) % num_blocks_0
1245+ pid_1 = tl.program_id(0) // num_blocks_0
1246+ offset_0 = pid_0 * _BLOCK_SIZE_0
1247+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1248+ mask_0 = indices_0 < 37
1249+ offset_1 = pid_1 * _BLOCK_SIZE_1
1250+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1251+ mask_1 = indices_1 < n
1252+ # src[test_specialize.py:N]: out[tile_m, tile_n] = x[tile_m, tile_n] * p
1253+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1254+ symnode_0 = 127
1255+ v_0 = tl.cast(symnode_0, tl.float32)
1256+ v_1 = load * v_0
1257+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
1258+
1259+ def fn(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1260+ # src[test_specialize.py:N]: m, n = x.size()
1261+ m, n = x.size()
1262+ # src[test_specialize.py:N]: p = y.size(1) # use y's dim 1 as a scalar
1263+ p = y.size(1)
1264+ # src[test_specialize.py:N]: out = x.new_empty([m, n])
1265+ out = x.new_empty([m, n])
1266+ # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1267+ _BLOCK_SIZE_0 = 16
1268+ _BLOCK_SIZE_1 = 16
1269+ # src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1270+ # src[test_specialize.py:N]: out[tile_m, tile_n] = x[tile_m, tile_n] * p
1271+ _launcher(_helion_fn, (triton.cdiv(37, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
1272+ # src[test_specialize.py:N]: return out
1273+ return out
0 commit comments