Skip to content

Commit bc2a3c6

Browse files
committed
wip
1 parent 65e1146 commit bc2a3c6

File tree

5 files changed

+49
-2
lines changed

5 files changed

+49
-2
lines changed

helion/_compiler/compile_environment.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,16 @@ def __init__(
127127
0 # Track number of loads in all device code for eviction policy tuning
128128
)
129129

130+
def specialize_expr(self, expr: sympy.Expr) -> sympy.Expr:
131+
"""Substitute any specialized vars with their concrete values."""
132+
if subs := {
133+
s: sympy.Integer(self.shape_env.size_hint(s))
134+
for s in expr.free_symbols & self.specialized_vars
135+
}:
136+
# pyrefly: ignore [bad-assignment]
137+
expr = expr.xreplace(subs)
138+
return expr
139+
130140
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
131141
from .device_function import contains_only_block_size_symbols
132142

helion/_compiler/device_function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,8 @@ def set_pid(self, pid: ProgramIDs) -> None:
373373
self.pid = pid
374374

375375
def sympy_expr(self, expr: sympy.Expr) -> str:
376-
expr = CompileEnvironment.current().shape_env.simplify(expr)
376+
env = CompileEnvironment.current()
377+
expr = env.specialize_expr(env.shape_env.simplify(expr))
377378
if not expr.free_symbols:
378379
return texpr(expr)
379380
if expr in self.expr_to_var_info:

helion/_compiler/host_function.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,10 @@ def set_local_types(self, local_types: dict[str, TypeInfo]) -> None:
191191
type_info.populate_symbol_origins(NameOrigin(name, fn))
192192

193193
def sympy_expr(self, expr: sympy.Expr) -> str:
194-
expr = CompileEnvironment.current().shape_env.simplify(expr)
194+
env = CompileEnvironment.current()
195+
expr = env.specialize_expr(env.shape_env.simplify(expr))
196+
if not expr.free_symbols:
197+
return pexpr(expr)
195198
if expr in self.expr_to_origin:
196199
return self.expr_to_origin[expr].origin.host_str()
197200
replacements = {}

helion/_testing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,14 @@ def assertNotIn(
499499
if not self._in_ref_eager_mode:
500500
super().assertNotIn(member, container, msg) # type: ignore[misc]
501501

502+
def assertIs(self, expr1: object, expr2: object, msg: str | None = None) -> None:
503+
if not self._in_ref_eager_mode:
504+
super().assertIs(expr1, expr2, msg) # type: ignore[misc]
505+
506+
def assertIsNot(self, expr1: object, expr2: object, msg: str | None = None) -> None:
507+
if not self._in_ref_eager_mode:
508+
super().assertIsNot(expr1, expr2, msg) # type: ignore[misc]
509+
502510
def assertTrueIfInNormalMode(self, condition: bool, msg: str | None = None) -> None:
503511
if not self._in_ref_eager_mode:
504512
self.assertTrue(condition, msg) # type: ignore[attr-defined]

helion/runtime/kernel.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,9 @@ def __init__(
403403
constexpr_args[name] = arg
404404
else:
405405
self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name)))
406+
407+
self._apply_mark_static(args)
408+
406409
with (
407410
_maybe_skip_dtype_check_in_meta_registrations(),
408411
patch_inductor_lowerings(),
@@ -420,6 +423,24 @@ def __init__(
420423
self.maybe_log_repro(log.warning, args, config=config)
421424
raise
422425

426+
def _apply_mark_static(self, args: tuple[object, ...]) -> None:
427+
"""
428+
Apply torch._dynamo.mark_static() markings from input tensors.
429+
430+
This reads _dynamo_static_indices from each tensor argument and marks
431+
the corresponding dimensions as specialized (constant) in the kernel.
432+
"""
433+
for arg_idx, (arg, fake_arg) in enumerate(zip(args, self.fake_args, strict=True)):
434+
if isinstance(arg, torch.Tensor):
435+
static_indices = getattr(arg, "_dynamo_static_indices", None)
436+
if static_indices:
437+
assert isinstance(fake_arg, torch.Tensor)
438+
for dim in static_indices:
439+
size = fake_arg.size(dim)
440+
if isinstance(size, torch.SymInt):
441+
sym_expr = size._sympy_()
442+
self.env.specialized_vars.update(sym_expr.free_symbols)
443+
423444
@property
424445
def settings(self) -> Settings:
425446
"""
@@ -889,12 +910,14 @@ def kernel(
889910
def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
890911
# NOTE: If a machine has two different gpu types on the same machine,
891912
# obj.device.type will incorrectly hit
913+
static_indices = frozenset(getattr(obj, "_dynamo_static_indices", ()))
892914
if fn.settings.static_shapes:
893915
return (
894916
obj.dtype,
895917
obj.device.type,
896918
(*obj.size(),),
897919
(*obj.stride(),),
920+
static_indices,
898921
)
899922
bucketed = tuple([min(s, 2) for s in obj.size()])
900923
if fn.settings.index_dtype is None:
@@ -907,11 +930,13 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
907930
obj.device.type,
908931
bucketed,
909932
needs_int64,
933+
static_indices,
910934
)
911935
return (
912936
obj.dtype,
913937
obj.device.type,
914938
bucketed,
939+
static_indices,
915940
)
916941

917942

0 commit comments

Comments
 (0)