Skip to content

Commit 27bd245

Browse files
committed
up
1 parent 41d144f commit 27bd245

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

helion/_compiler/device_function.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -602,11 +602,15 @@ def tensor_size(self, fake_value: torch.Tensor, dim: int) -> Argument:
602602
return self._tensor_property(TensorSizeArg, fake_value, dim, "size")
603603

604604
def tensor_stride(self, fake_value: torch.Tensor, dim: int) -> Argument:
605-
if (
606-
isinstance(v := fake_value.stride(dim), int)
607-
and CompileEnvironment.current().settings.static_shapes
608-
):
609-
return StaticShape(v)
605+
v = fake_value.stride(dim)
606+
env = CompileEnvironment.current()
607+
if isinstance(v, int):
608+
if env.settings.static_shapes:
609+
return StaticShape(v)
610+
elif isinstance(
611+
expr := v._sympy_(), sympy.Integer
612+
) or expr.free_symbols.issubset(env.specialized_vars):
613+
return StaticShape(int(v))
610614
return self._tensor_property(TensorStrideArg, fake_value, dim, "stride")
611615

612616
def sorted_args(self) -> list[Argument]:

helion/_compiler/type_propagation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def propagate_call(
645645
attr = self.attr()
646646
if attr in {"dim", "ndimension"} and not (args or kwargs):
647647
return TypeInfo.from_example(self.tensor.fake_value.ndim, origin)
648-
if attr in {"shape", "size"} and not kwargs:
648+
if attr in {"shape", "size", "stride"} and not kwargs:
649649
fn = getattr(self.tensor.fake_value, attr)
650650
try:
651651
return TypeInfo.from_example(

helion/exc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class SpecializeOnDevice(BaseError):
186186

187187

188188
class SpecializeArgType(BaseError):
189-
message = "hl.specialize() must be called on a size from an input tensor, got: {}"
189+
message = "hl.specialize() must be called on a size or stride from an input tensor, got: {}"
190190

191191

192192
class StackTensorcOnHost(BaseError):

helion/runtime/kernel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,12 +624,14 @@ def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]:
624624

625625
def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
626626
if isinstance(v, TensorPropertySource):
627-
assert v.prop == TensorProperty.SIZE
628627
index = v.idx
629628
assert index is not None
630629
inner = make_extractor(v.base)
631-
632-
return lambda args: cast("torch.Tensor", inner(args)).size(index)
630+
if v.prop == TensorProperty.SIZE:
631+
return lambda args: cast("torch.Tensor", inner(args)).size(index)
632+
if v.prop == TensorProperty.STRIDE:
633+
return lambda args: cast("torch.Tensor", inner(args)).stride(index)
634+
raise exc.SpecializeArgType(v)
633635
if isinstance(v, LocalSource):
634636
index = arg_name_to_index[v.local_name]
635637
return operator.itemgetter(index)

0 commit comments

Comments
 (0)