Skip to content

Commit 6138675

Browse files
committed
up
1 parent 41d144f commit 6138675

File tree

7 files changed

+48
-16
lines changed

7 files changed

+48
-16
lines changed

helion/_compiler/compile_environment.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def __init__(
112112
collections.Counter()
113113
)
114114
self.specialized_vars: set[sympy.Symbol] = set()
115+
# Maps stride symbol -> (param_name, dim) for stride specialization
116+
self.stride_symbols: dict[sympy.Symbol, tuple[str, int]] = {}
117+
self.specialized_strides: set[tuple[str, int]] = set()
115118
self.loop_dependency_checker = LoopDependencyChecker()
116119
self._symint_cache: dict[object, torch.SymInt] = {}
117120
self.device_load_count = (
@@ -469,6 +472,13 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
469472
self.debug_shape_renames[s._sympy_()] = sympy.Symbol(
470473
f"{source.local_name}_size{i}", integer=True
471474
)
475+
# Record stride symbols for specialization tracking
476+
if isinstance(source, LocalSource):
477+
for i in range(result.ndim):
478+
st = result.stride(i)
479+
if isinstance(st, torch.SymInt):
480+
for sym in st._sympy_().free_symbols:
481+
self.stride_symbols[sym] = (source.local_name, i)
472482
return result
473483

474484
def size_hint(self, n: int | torch.SymInt) -> int:

helion/_compiler/device_function.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -602,10 +602,15 @@ def tensor_size(self, fake_value: torch.Tensor, dim: int) -> Argument:
602602
return self._tensor_property(TensorSizeArg, fake_value, dim, "size")
603603

604604
def tensor_stride(self, fake_value: torch.Tensor, dim: int) -> Argument:
605-
if (
606-
isinstance(v := fake_value.stride(dim), int)
607-
and CompileEnvironment.current().settings.static_shapes
608-
):
605+
from torch._dynamo.source import LocalSource
606+
607+
v = fake_value.stride(dim)
608+
env = CompileEnvironment.current()
609+
# Check if this specific stride was specialized
610+
source = env.input_sources.get(fake_value)
611+
if isinstance(source, LocalSource) and (source.local_name, dim) in env.specialized_strides:
612+
return StaticShape(int(v))
613+
if isinstance(v, int) and env.settings.static_shapes:
609614
return StaticShape(v)
610615
return self._tensor_property(TensorStrideArg, fake_value, dim, "stride")
611616

helion/_compiler/type_propagation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def propagate_call(
645645
attr = self.attr()
646646
if attr in {"dim", "ndimension"} and not (args or kwargs):
647647
return TypeInfo.from_example(self.tensor.fake_value.ndim, origin)
648-
if attr in {"shape", "size"} and not kwargs:
648+
if attr in {"shape", "size", "stride"} and not kwargs:
649649
fn = getattr(self.tensor.fake_value, attr)
650650
try:
651651
return TypeInfo.from_example(

helion/exc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class SpecializeOnDevice(BaseError):
186186

187187

188188
class SpecializeArgType(BaseError):
189-
message = "hl.specialize() must be called on a size from an input tensor, got: {}"
189+
message = "hl.specialize() must be called on a size or stride from an input tensor, got: {}"
190190

191191

192192
class StackTensorcOnHost(BaseError):

helion/language/constexpr.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,22 @@ def _(value: TypeInfo, *, origin: Origin) -> TypeInfo:
8787
env = CompileEnvironment.current()
8888

8989
def handle_symint(symint: torch.SymInt) -> int:
90-
env.specialized_vars.update(symint._sympy_().free_symbols)
90+
from torch._dynamo.source import TensorProperty
91+
from torch._dynamo.source import TensorPropertySource
92+
93+
syms = symint._sympy_().free_symbols
94+
env.specialized_vars.update(syms)
95+
# Track which strides were specialized (only stride-only symbols)
96+
for sym in syms:
97+
if sym in env.stride_symbols:
98+
# Check if this symbol is also a size via var_to_sources
99+
sources = env.shape_env.var_to_sources.get(sym, [])
100+
is_size = any(
101+
isinstance(s, TensorPropertySource) and s.prop == TensorProperty.SIZE
102+
for s in sources
103+
)
104+
if not is_size:
105+
env.specialized_strides.add(env.stride_symbols[sym])
91106
return symint.__int__()
92107

93108
specialized = _convert_specializable(proxy, on_symint=handle_symint)

helion/runtime/kernel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,12 +624,14 @@ def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]:
624624

625625
def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
626626
if isinstance(v, TensorPropertySource):
627-
assert v.prop == TensorProperty.SIZE
628627
index = v.idx
629628
assert index is not None
630629
inner = make_extractor(v.base)
631-
632-
return lambda args: cast("torch.Tensor", inner(args)).size(index)
630+
if v.prop == TensorProperty.SIZE:
631+
return lambda args: cast("torch.Tensor", inner(args)).size(index)
632+
if v.prop == TensorProperty.STRIDE:
633+
return lambda args: cast("torch.Tensor", inner(args)).stride(index)
634+
raise exc.SpecializeArgType(v)
633635
if isinstance(v, LocalSource):
634636
index = arg_name_to_index[v.local_name]
635637
return operator.itemgetter(index)

test/test_examples.expected

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,18 +1674,18 @@ def _helion_fused_linear_jsd_kernel(student_logits, teacher_logits, loss, temper
16741674
# src[fused_linear_jsd.py:N]: teacher_div = torch.nn.functional.kl_div(
16751675
# src[fused_linear_jsd.py:N]: torch.log(m), teacher_prob, reduction="none", log_target=True
16761676
# src[fused_linear_jsd.py:N]: ).sum(dim=-1)
1677-
v_17 = teacher_prob_1 - v_16
1678-
v_18 = libdevice.exp(teacher_prob_1)
1679-
v_19 = v_18 * v_17
1677+
v_17 = libdevice.exp(teacher_prob_1)
1678+
v_18 = teacher_prob_1 - v_16
1679+
v_19 = v_17 * v_18
16801680
teacher_div = tl.cast(tl.sum(v_19, 1), tl.float32)
16811681
# src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True
16821682
v_20 = tl_math.log(v_15)
16831683
# src[fused_linear_jsd.py:N]: student_div = torch.nn.functional.kl_div(
16841684
# src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True
16851685
# src[fused_linear_jsd.py:N]: ).sum(dim=-1)
1686-
v_21 = student_prob_1 - v_20
1687-
v_22 = libdevice.exp(student_prob_1)
1688-
v_23 = v_22 * v_21
1686+
v_21 = libdevice.exp(student_prob_1)
1687+
v_22 = student_prob_1 - v_20
1688+
v_23 = v_21 * v_22
16891689
student_div = tl.cast(tl.sum(v_23, 1), tl.float32)
16901690
# src[fused_linear_jsd.py:N]: batch_loss = student_div + beta * (teacher_div - student_div)
16911691
v_24 = teacher_div - student_div

0 commit comments

Comments
 (0)