Skip to content

Commit a8369d5

Browse files
committed
up
1 parent 5e26f71 commit a8369d5

File tree

7 files changed

+47
-34
lines changed

7 files changed

+47
-34
lines changed

helion/_compiler/compile_environment.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ def __init__(
128128
0 # Track number of loads in all device code for eviction policy tuning
129129
)
130130

131+
def specialize_expr(self, expr: sympy.Expr) -> sympy.Expr:
132+
"""Substitute any specialized vars with their concrete values."""
133+
if subs := {
134+
s: sympy.Integer(self.shape_env.size_hint(s))
135+
for s in expr.free_symbols & self.specialized_vars
136+
}:
137+
# pyrefly: ignore [bad-assignment]
138+
expr = expr.xreplace(subs)
139+
return expr
140+
131141
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
132142
from .device_function import contains_only_block_size_symbols
133143

helion/_compiler/device_function.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ def set_pid(self, pid: ProgramIDs) -> None:
374374
self.pid = pid
375375

376376
def sympy_expr(self, expr: sympy.Expr) -> str:
377-
expr = CompileEnvironment.current().shape_env.simplify(expr)
377+
env = CompileEnvironment.current()
378+
expr = env.specialize_expr(env.shape_env.simplify(expr))
378379
if not expr.free_symbols:
379380
return texpr(expr)
380381
if expr in self.expr_to_var_info:
@@ -394,6 +395,7 @@ def sympy_expr(self, expr: sympy.Expr) -> str:
394395
replacements[sym] = sympy.Symbol(
395396
self._lift_sympy_arg(sym), integer=True
396397
)
398+
# pyrefly: ignore [bad-argument-type]
397399
return texpr(expr.xreplace(replacements))
398400

399401
def _lift_sympy_arg(self, expr: sympy.Expr) -> str:

helion/_compiler/host_function.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,18 @@ def set_local_types(self, local_types: dict[str, TypeInfo]) -> None:
191191
type_info.populate_symbol_origins(NameOrigin(name, fn))
192192

193193
def sympy_expr(self, expr: sympy.Expr) -> str:
194-
expr = CompileEnvironment.current().shape_env.simplify(expr)
194+
env = CompileEnvironment.current()
195+
expr = env.specialize_expr(env.shape_env.simplify(expr))
196+
if not expr.free_symbols:
197+
return pexpr(expr)
195198
if expr in self.expr_to_origin:
196199
return self.expr_to_origin[expr].origin.host_str()
197200
replacements = {}
198201
for sym in sorted(expr.free_symbols, key=lambda x: x.name):
199202
assert isinstance(sym, sympy.Symbol)
200203
origin = self.expr_to_origin[sym].origin
201204
replacements[sym] = sympy.Symbol(origin.host_str(), integer=True)
205+
# pyrefly: ignore [bad-argument-type]
202206
return pexpr(expr.xreplace(replacements))
203207

204208
def literal_expr(self, expr: object) -> str:

helion/_testing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,14 @@ def assertNotIn(
511511
if not self._in_ref_eager_mode:
512512
super().assertNotIn(member, container, msg) # type: ignore[misc]
513513

514+
def assertIs(self, expr1: object, expr2: object, msg: str | None = None) -> None:
515+
if not self._in_ref_eager_mode:
516+
super().assertIs(expr1, expr2, msg) # type: ignore[misc]
517+
518+
def assertIsNot(self, expr1: object, expr2: object, msg: str | None = None) -> None:
519+
if not self._in_ref_eager_mode:
520+
super().assertIsNot(expr1, expr2, msg) # type: ignore[misc]
521+
514522
def assertTrueIfInNormalMode(self, condition: bool, msg: str | None = None) -> None:
515523
if not self._in_ref_eager_mode:
516524
self.assertTrue(condition, msg) # type: ignore[attr-defined]

helion/runtime/kernel.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,9 @@ def __init__(
403403
constexpr_args[name] = arg
404404
else:
405405
self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name)))
406+
407+
self._apply_mark_static(args)
408+
406409
with (
407410
_maybe_skip_dtype_check_in_meta_registrations(),
408411
patch_inductor_lowerings(),
@@ -420,6 +423,20 @@ def __init__(
420423
self.maybe_log_repro(log.warning, args, config=config)
421424
raise
422425

426+
def _apply_mark_static(self, args: tuple[object, ...]) -> None:
427+
"""
428+
Apply torch._dynamo.mark_static() markings from input tensors.
429+
430+
This reads _dynamo_static_indices from each tensor argument and marks
431+
the corresponding dimensions as specialized (constant) in the kernel.
432+
"""
433+
for arg, fake_arg in zip(args, self.fake_args, strict=True):
434+
if isinstance(arg, torch.Tensor) and isinstance(fake_arg, torch.Tensor):
435+
for dim in getattr(arg, "_dynamo_static_indices", ()):
436+
size = fake_arg.size(dim)
437+
if isinstance(size, torch.SymInt):
438+
self.env.specialized_vars.update(size._sympy_().free_symbols)
439+
423440
@property
424441
def settings(self) -> Settings:
425442
"""
@@ -891,12 +908,14 @@ def kernel(
891908
def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
892909
# NOTE: If a machine has two different gpu types on the same machine,
893910
# obj.device.type will incorrectly hit
911+
static_indices = frozenset(getattr(obj, "_dynamo_static_indices", ()))
894912
if fn.settings.static_shapes:
895913
return (
896914
obj.dtype,
897915
obj.device.type,
898916
(*obj.size(),),
899917
(*obj.stride(),),
918+
static_indices,
900919
)
901920
bucketed = tuple([min(s, 2) for s in obj.size()])
902921
if fn.settings.index_dtype is None:
@@ -909,11 +928,13 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
909928
obj.device.type,
910929
bucketed,
911930
needs_int64,
931+
static_indices,
912932
)
913933
return (
914934
obj.dtype,
915935
obj.device.type,
916936
bucketed,
937+
static_indices,
917938
)
918939

919940

test/test_examples.expected

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -460,27 +460,11 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
460460
_RDIM_SIZE_2 = 64
461461
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
462462
_BLOCK_SIZE_0 = 1
463-
# src[attention.py:N]: q = q_view[tile_b, tile_m, :]
464-
_SHAPE_DIM = q_in.size(3)
465-
_SHAPE_DIM_1 = q_in.size(3)
466-
_SHAPE_DIM_2 = q_in.size(3)
467463
# src[attention.py:N]: for tile_n in hl.tile(v_view.size(1)):
468464
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
469465
# src[attention.py:N]: qk = torch.bmm(q, k)
470466
# src[attention.py:N-N]: ...
471467
_BLOCK_SIZE_3 = 32
472-
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
473-
_SHAPE_DIM_3 = q_in.size(3)
474-
_SHAPE_DIM_4 = q_in.size(3)
475-
_SHAPE_DIM_5 = q_in.size(3)
476-
# src[attention.py:N]: v = v_view[tile_b, tile_n, :]
477-
_SHAPE_DIM_6 = q_in.size(3)
478-
_SHAPE_DIM_7 = q_in.size(3)
479-
_SHAPE_DIM_8 = q_in.size(3)
480-
# src[attention.py:N]: out[tile_b, tile_m, :] = acc.to(out.dtype)
481-
_SHAPE_DIM_9 = q_in.size(3)
482-
_SHAPE_DIM_10 = q_in.size(3)
483-
_SHAPE_DIM_11 = q_in.size(3)
484468
# src[attention.py:N]: for tile_b, tile_m in hl.tile([q_view.size(0), m_dim]):
485469
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
486470
# src[attention.py:N]: l_i = torch.full_like(m_i, 1.0)

test/test_tensor_descriptor.expected

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -123,27 +123,11 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
123123
_RDIM_SIZE_2 = 64
124124
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
125125
_BLOCK_SIZE_0 = 1
126-
# src[attention.py:N]: q = q_view[tile_b, tile_m, :]
127-
_SHAPE_DIM = q_in.size(3)
128-
_SHAPE_DIM_1 = q_in.size(3)
129-
_SHAPE_DIM_2 = q_in.size(3)
130126
# src[attention.py:N]: for tile_n in hl.tile(v_view.size(1)):
131127
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
132128
# src[attention.py:N]: qk = torch.bmm(q, k)
133129
# src[attention.py:N-N]: ...
134130
_BLOCK_SIZE_3 = 16
135-
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
136-
_SHAPE_DIM_3 = q_in.size(3)
137-
_SHAPE_DIM_4 = q_in.size(3)
138-
_SHAPE_DIM_5 = q_in.size(3)
139-
# src[attention.py:N]: v = v_view[tile_b, tile_n, :]
140-
_SHAPE_DIM_6 = q_in.size(3)
141-
_SHAPE_DIM_7 = q_in.size(3)
142-
_SHAPE_DIM_8 = q_in.size(3)
143-
# src[attention.py:N]: out[tile_b, tile_m, :] = acc.to(out.dtype)
144-
_SHAPE_DIM_9 = q_in.size(3)
145-
_SHAPE_DIM_10 = q_in.size(3)
146-
_SHAPE_DIM_11 = q_in.size(3)
147131
# src[attention.py:N]: for tile_b, tile_m in hl.tile([q_view.size(0), m_dim]):
148132
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
149133
# src[attention.py:N]: l_i = torch.full_like(m_i, 1.0)

0 commit comments

Comments
 (0)