Skip to content

Commit 4320adc

Browse files
committed
up
1 parent 069ae43 commit 4320adc

File tree

8 files changed

+232
-37
lines changed

8 files changed

+232
-37
lines changed

helion/_compiler/compile_environment.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,16 @@ def __init__(
127127
0 # Track number of loads in all device code for eviction policy tuning
128128
)
129129

130+
def specialize_expr(self, expr: sympy.Expr) -> sympy.Expr:
131+
"""Substitute any specialized vars with their concrete values."""
132+
if subs := {
133+
s: sympy.Integer(self.shape_env.size_hint(s))
134+
for s in expr.free_symbols & self.specialized_vars
135+
}:
136+
# pyrefly: ignore [bad-assignment]
137+
expr = expr.xreplace(subs)
138+
return expr
139+
130140
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
131141
from .device_function import contains_only_block_size_symbols
132142

helion/_compiler/device_function.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,8 @@ def set_pid(self, pid: ProgramIDs) -> None:
373373
self.pid = pid
374374

375375
def sympy_expr(self, expr: sympy.Expr) -> str:
376-
expr = CompileEnvironment.current().shape_env.simplify(expr)
376+
env = CompileEnvironment.current()
377+
expr = env.specialize_expr(env.shape_env.simplify(expr))
377378
if not expr.free_symbols:
378379
return texpr(expr)
379380
if expr in self.expr_to_var_info:
@@ -393,6 +394,7 @@ def sympy_expr(self, expr: sympy.Expr) -> str:
393394
replacements[sym] = sympy.Symbol(
394395
self._lift_sympy_arg(sym), integer=True
395396
)
397+
# pyrefly: ignore [bad-argument-type]
396398
return texpr(expr.xreplace(replacements))
397399

398400
def _lift_sympy_arg(self, expr: sympy.Expr) -> str:

helion/_compiler/host_function.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,18 @@ def set_local_types(self, local_types: dict[str, TypeInfo]) -> None:
191191
type_info.populate_symbol_origins(NameOrigin(name, fn))
192192

193193
def sympy_expr(self, expr: sympy.Expr) -> str:
194-
expr = CompileEnvironment.current().shape_env.simplify(expr)
194+
env = CompileEnvironment.current()
195+
expr = env.specialize_expr(env.shape_env.simplify(expr))
196+
if not expr.free_symbols:
197+
return pexpr(expr)
195198
if expr in self.expr_to_origin:
196199
return self.expr_to_origin[expr].origin.host_str()
197200
replacements = {}
198201
for sym in sorted(expr.free_symbols, key=lambda x: x.name):
199202
assert isinstance(sym, sympy.Symbol)
200203
origin = self.expr_to_origin[sym].origin
201204
replacements[sym] = sympy.Symbol(origin.host_str(), integer=True)
205+
# pyrefly: ignore [bad-argument-type]
202206
return pexpr(expr.xreplace(replacements))
203207

204208
def literal_expr(self, expr: object) -> str:

helion/_testing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,14 @@ def assertNotIn(
499499
if not self._in_ref_eager_mode:
500500
super().assertNotIn(member, container, msg) # type: ignore[misc]
501501

502+
def assertIs(self, expr1: object, expr2: object, msg: str | None = None) -> None:
503+
if not self._in_ref_eager_mode:
504+
super().assertIs(expr1, expr2, msg) # type: ignore[misc]
505+
506+
def assertIsNot(self, expr1: object, expr2: object, msg: str | None = None) -> None:
507+
if not self._in_ref_eager_mode:
508+
super().assertIsNot(expr1, expr2, msg) # type: ignore[misc]
509+
502510
def assertTrueIfInNormalMode(self, condition: bool, msg: str | None = None) -> None:
503511
if not self._in_ref_eager_mode:
504512
self.assertTrue(condition, msg) # type: ignore[attr-defined]

helion/runtime/kernel.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ def __init__(
136136
self._specialize_extra: dict[
137137
Hashable, list[Callable[[Sequence[object]], Hashable]]
138138
] = {}
139+
self._specialized_args: dict[int, tuple[int, ...]] = {}
140+
self._arg_name_to_index: dict[str, int] = {
141+
name: i for i, name in enumerate(self.signature.parameters.keys())
142+
}
139143
if any(
140144
param.kind
141145
in (
@@ -346,6 +350,51 @@ def reset(self) -> None:
346350
"""
347351
self._bound_kernels.clear()
348352

353+
def specialize_args(self, **kwargs: list[int]) -> Kernel[_R]:
354+
"""
355+
Returns a kernel that will specialize on the given argument dimensions.
356+
This allows specialization decisions to be made outside the kernel,
357+
binding to argument names via kwargs.
358+
359+
Args:
360+
**kwargs: Mapping of argument name -> dims to specialize on
361+
e.g., specialize_args(q_in=[-1], k_in=[-1])
362+
363+
Returns:
364+
Kernel: A new kernel with same settings and configs, adding the given
365+
specializations to any existing ones.
366+
367+
Example:
368+
@helion.kernel
369+
def attention(q_in, k_in, v_in):
370+
head_dim = q_in.size(0) # Specialized if specified externally
371+
seq_len = k_in.size(1) # Specialized if specified externally
372+
...
373+
374+
result = attention.specialize_args(q_in=[0], k_in=[1])(q, k, v)
375+
"""
376+
if not kwargs:
377+
return self
378+
try:
379+
specialized_args = {
380+
self._arg_name_to_index[name]: tuple(dims)
381+
for name, dims in kwargs.items()
382+
}
383+
except KeyError as e:
384+
valid_args = ", ".join(self._arg_name_to_index.keys())
385+
raise ValueError(
386+
f"Unknown argument '{e.args[0]}' for kernel '{self.name}'. Valid arguments: {valid_args}"
387+
) from e
388+
389+
specialized = Kernel(
390+
self.fn,
391+
configs=list(self.configs),
392+
settings=self.settings,
393+
key=self._key_fn,
394+
)
395+
specialized._specialized_args = {**self._specialized_args, **specialized_args}
396+
return specialized
397+
349398

350399
class BoundKernel(Generic[_R]):
351400
def __init__(
@@ -403,6 +452,10 @@ def __init__(
403452
constexpr_args[name] = arg
404453
else:
405454
self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name)))
455+
456+
if kernel._specialized_args:
457+
self._apply_specialized_args(kernel._specialized_args)
458+
406459
with (
407460
_maybe_skip_dtype_check_in_meta_registrations(),
408461
patch_inductor_lowerings(),
@@ -420,6 +473,18 @@ def __init__(
420473
self.maybe_log_repro(log.warning, args, config=config)
421474
raise
422475

476+
def _apply_specialized_args(
477+
self, specialized_args: dict[int, tuple[int, ...]]
478+
) -> None:
479+
for arg_idx, dims in specialized_args.items():
480+
fake_tensor = self.fake_args[arg_idx]
481+
if isinstance(fake_tensor, torch.Tensor):
482+
for dim in dims:
483+
size = fake_tensor.size(dim)
484+
if isinstance(size, torch.SymInt):
485+
sym_expr = size._sympy_()
486+
self.env.specialized_vars.update(sym_expr.free_symbols)
487+
423488
@property
424489
def settings(self) -> Settings:
425490
"""
@@ -622,6 +687,8 @@ def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]:
622687
if not self.env.specialized_vars:
623688
return []
624689

690+
arg_name_to_index = self.kernel._arg_name_to_index
691+
625692
def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
626693
if isinstance(v, TensorPropertySource):
627694
assert v.prop == TensorProperty.SIZE
@@ -635,9 +702,6 @@ def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
635702
return operator.itemgetter(index)
636703
raise exc.SpecializeArgType(v)
637704

638-
arg_name_to_index: dict[str, int] = {
639-
n: i for i, n in enumerate(self.kernel.signature.parameters.keys())
640-
}
641705
extractors = []
642706
for v in sorted(self.env.specialized_vars, key=lambda v: v.name):
643707
source = self.env.shape_env.var_to_sources[v][0]

test/test_examples.expected

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -460,27 +460,11 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
460460
_RDIM_SIZE_2 = 64
461461
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
462462
_BLOCK_SIZE_0 = 1
463-
# src[attention.py:N]: q = q_view[tile_b, tile_m, :]
464-
_SHAPE_DIM = q_in.size(3)
465-
_SHAPE_DIM_1 = q_in.size(3)
466-
_SHAPE_DIM_2 = q_in.size(3)
467463
# src[attention.py:N]: for tile_n in hl.tile(v_view.size(1)):
468464
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
469465
# src[attention.py:N]: qk = torch.bmm(q, k)
470466
# src[attention.py:N-N]: ...
471467
_BLOCK_SIZE_3 = 32
472-
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
473-
_SHAPE_DIM_3 = q_in.size(3)
474-
_SHAPE_DIM_4 = q_in.size(3)
475-
_SHAPE_DIM_5 = q_in.size(3)
476-
# src[attention.py:N]: v = v_view[tile_b, tile_n, :]
477-
_SHAPE_DIM_6 = q_in.size(3)
478-
_SHAPE_DIM_7 = q_in.size(3)
479-
_SHAPE_DIM_8 = q_in.size(3)
480-
# src[attention.py:N]: out[tile_b, tile_m, :] = acc.to(out.dtype)
481-
_SHAPE_DIM_9 = q_in.size(3)
482-
_SHAPE_DIM_10 = q_in.size(3)
483-
_SHAPE_DIM_11 = q_in.size(3)
484468
# src[attention.py:N]: for tile_b, tile_m in hl.tile([q_view.size(0), m_dim]):
485469
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
486470
# src[attention.py:N]: l_i = torch.full_like(m_i, 1.0)

test/test_specialize.expected

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,3 +1132,142 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d
11321132
_launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
11331133
# src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype)
11341134
return grad_weight.sum(0).to(x.dtype)
1135+
1136+
--- assertExpectedJournal(TestSpecializeArgs.test_specialize_args)
1137+
from __future__ import annotations
1138+
1139+
import torch
1140+
import triton
1141+
import triton.language as tl
1142+
from helion.runtime import default_launcher as _default_launcher
1143+
1144+
@triton.jit
1145+
def _helion_matmul(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1146+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1147+
num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
1148+
pid_0 = tl.program_id(0) % num_blocks_0
1149+
pid_1 = tl.program_id(0) // num_blocks_0
1150+
offset_0 = pid_0 * _BLOCK_SIZE_0
1151+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1152+
mask_0 = indices_0 < 64
1153+
offset_1 = pid_1 * _BLOCK_SIZE_1
1154+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1155+
mask_1 = indices_1 < 56
1156+
# src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
1157+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1158+
# src[test_specialize.py:N]: for tile_k in hl.tile(k):
1159+
# src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
1160+
symnode_0 = 128
1161+
for offset_2 in tl.range(0, symnode_0.to(tl.int32), _BLOCK_SIZE_2):
1162+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1163+
mask_2 = indices_2 < symnode_0
1164+
acc_copy = acc
1165+
acc_copy_0 = acc_copy
1166+
# src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
1167+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1168+
load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0)
1169+
acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
1170+
# src[test_specialize.py:N]: out[tile_m, tile_n] = acc.to(x.dtype)
1171+
v_0 = tl.cast(acc, tl.float16)
1172+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_0, mask_0[:, None] & mask_1[None, :])
1173+
1174+
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1175+
# src[test_specialize.py:N]: m, k = x.size()
1176+
m, k = x.size()
1177+
# src[test_specialize.py:N]: k2, n = y.size()
1178+
k2, n = y.size()
1179+
# src[test_specialize.py:N]: out = torch.empty([m, n], device=x.device, dtype=x.dtype)
1180+
out = torch.empty([m, n], device=x.device, dtype=x.dtype)
1181+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1182+
_BLOCK_SIZE_0 = 32
1183+
_BLOCK_SIZE_1 = 32
1184+
# src[test_specialize.py:N]: for tile_k in hl.tile(k):
1185+
# src[test_specialize.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
1186+
_BLOCK_SIZE_2 = 32
1187+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1188+
# src[test_specialize.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
1189+
# src[test_specialize.py:N]: for tile_k in hl.tile(k):
1190+
# src[test_specialize.py:N-N]: ...
1191+
_launcher(_helion_matmul, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(56, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
1192+
# src[test_specialize.py:N]: return out
1193+
return out
1194+
1195+
--- assertExpectedJournal(TestSpecializeArgs.test_specialize_args_and_hl_specialize)
1196+
from __future__ import annotations
1197+
1198+
import torch
1199+
import triton
1200+
import triton.language as tl
1201+
from helion.runtime import default_launcher as _default_launcher
1202+
1203+
@triton.jit
1204+
def _helion_dual_specialize(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1205+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
1206+
num_blocks_0 = tl.cdiv(320, _BLOCK_SIZE_0)
1207+
pid_0 = tl.program_id(0) % num_blocks_0
1208+
pid_1 = tl.program_id(0) // num_blocks_0
1209+
offset_0 = pid_0 * _BLOCK_SIZE_0
1210+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1211+
offset_1 = pid_1 * _BLOCK_SIZE_1
1212+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1213+
mask_1 = indices_1 < 640
1214+
# src[test_specialize.py:N]: out[tile] = x[tile] * 2
1215+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0)
1216+
v_0 = 2.0
1217+
v_1 = load * v_0
1218+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_1[None, :])
1219+
1220+
def dual_specialize(x: torch.Tensor, *, _launcher=_default_launcher):
1221+
# src[test_specialize.py:N]: out = torch.empty_like(x)
1222+
out = torch.empty_like(x)
1223+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
1224+
_BLOCK_SIZE_0 = 16
1225+
_BLOCK_SIZE_1 = 16
1226+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
1227+
# src[test_specialize.py:N]: out[tile] = x[tile] * 2
1228+
_launcher(_helion_dual_specialize, (triton.cdiv(320, _BLOCK_SIZE_0) * triton.cdiv(640, _BLOCK_SIZE_1),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
1229+
# src[test_specialize.py:N]: return out
1230+
return out
1231+
1232+
--- assertExpectedJournal(TestSpecializeArgs.test_specialize_args_chaining)
1233+
from __future__ import annotations
1234+
1235+
import torch
1236+
import triton
1237+
import triton.language as tl
1238+
from helion.runtime import default_launcher as _default_launcher
1239+
1240+
@triton.jit
1241+
def _helion_fn(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1242+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1243+
num_blocks_0 = tl.cdiv(37, _BLOCK_SIZE_0)
1244+
pid_0 = tl.program_id(0) % num_blocks_0
1245+
pid_1 = tl.program_id(0) // num_blocks_0
1246+
offset_0 = pid_0 * _BLOCK_SIZE_0
1247+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1248+
mask_0 = indices_0 < 37
1249+
offset_1 = pid_1 * _BLOCK_SIZE_1
1250+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1251+
mask_1 = indices_1 < n
1252+
# src[test_specialize.py:N]: out[tile_m, tile_n] = x[tile_m, tile_n] * p
1253+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1254+
symnode_0 = 127
1255+
v_0 = tl.cast(symnode_0, tl.float32)
1256+
v_1 = load * v_0
1257+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
1258+
1259+
def fn(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1260+
# src[test_specialize.py:N]: m, n = x.size()
1261+
m, n = x.size()
1262+
# src[test_specialize.py:N]: p = y.size(1) # use y's dim 1 as a scalar
1263+
p = y.size(1)
1264+
# src[test_specialize.py:N]: out = x.new_empty([m, n])
1265+
out = x.new_empty([m, n])
1266+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1267+
_BLOCK_SIZE_0 = 16
1268+
_BLOCK_SIZE_1 = 16
1269+
# src[test_specialize.py:N]: for tile_m, tile_n in hl.tile([m, n]):
1270+
# src[test_specialize.py:N]: out[tile_m, tile_n] = x[tile_m, tile_n] * p
1271+
_launcher(_helion_fn, (triton.cdiv(37, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
1272+
# src[test_specialize.py:N]: return out
1273+
return out

test/test_tensor_descriptor.expected

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -123,27 +123,11 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
123123
_RDIM_SIZE_2 = 64
124124
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
125125
_BLOCK_SIZE_0 = 1
126-
# src[attention.py:N]: q = q_view[tile_b, tile_m, :]
127-
_SHAPE_DIM = q_in.size(3)
128-
_SHAPE_DIM_1 = q_in.size(3)
129-
_SHAPE_DIM_2 = q_in.size(3)
130126
# src[attention.py:N]: for tile_n in hl.tile(v_view.size(1)):
131127
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
132128
# src[attention.py:N]: qk = torch.bmm(q, k)
133129
# src[attention.py:N-N]: ...
134130
_BLOCK_SIZE_3 = 16
135-
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
136-
_SHAPE_DIM_3 = q_in.size(3)
137-
_SHAPE_DIM_4 = q_in.size(3)
138-
_SHAPE_DIM_5 = q_in.size(3)
139-
# src[attention.py:N]: v = v_view[tile_b, tile_n, :]
140-
_SHAPE_DIM_6 = q_in.size(3)
141-
_SHAPE_DIM_7 = q_in.size(3)
142-
_SHAPE_DIM_8 = q_in.size(3)
143-
# src[attention.py:N]: out[tile_b, tile_m, :] = acc.to(out.dtype)
144-
_SHAPE_DIM_9 = q_in.size(3)
145-
_SHAPE_DIM_10 = q_in.size(3)
146-
_SHAPE_DIM_11 = q_in.size(3)
147131
# src[attention.py:N]: for tile_b, tile_m in hl.tile([q_view.size(0), m_dim]):
148132
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
149133
# src[attention.py:N]: l_i = torch.full_like(m_i, 1.0)

0 commit comments

Comments
 (0)