Skip to content

Commit 55d6aa0

Browse files
authored
Pad to next power of 2 for hl.specialize'ed shape value used in device tensor creation (#804)
1 parent 0cf4232 commit 55d6aa0

File tree

5 files changed

+807
-8
lines changed

5 files changed

+807
-8
lines changed

helion/_compiler/host_function.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .output_header import SOURCE_MODULE
2525
from .source_location import SourceLocation
2626
from .source_location import UnknownLocation
27+
from .tensor_utils import patch_tensor_factories
2728
from .type_printer import print_ast
2829
from .variable_origin import AttributeOrigin
2930
from .variable_origin import GlobalOrigin
@@ -112,7 +113,8 @@ def __init__(
112113
unroll_static_loops(self)
113114
propagate_types(self)
114115
env.finalize_config_spec()
115-
self.device_ir = lower_to_device_ir(self)
116+
with patch_tensor_factories():
117+
self.device_ir = lower_to_device_ir(self)
116118

117119
@staticmethod
118120
def validate_ast(root: ast.FunctionDef) -> None:

helion/_compiler/tensor_utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
from typing import ClassVar
5+
6+
import torch
7+
from torch.utils._python_dispatch import TorchDispatchMode
8+
from torch.utils._pytree import tree_map
9+
from triton import next_power_of_2
10+
11+
12+
class _PadTensorFactoryMode(TorchDispatchMode):
13+
"""Dispatch mode that pads tensor factory size arguments."""
14+
15+
_SIZE_ARG_INDEX: ClassVar[dict[Callable[..., torch.Tensor], int]] = {
16+
torch.ops.aten.zeros.default: 0, # pyright: ignore[reportAttributeAccessIssue]
17+
torch.ops.aten.ones.default: 0, # pyright: ignore[reportAttributeAccessIssue]
18+
torch.ops.aten.empty.memory_format: 0, # pyright: ignore[reportAttributeAccessIssue]
19+
torch.ops.aten.full.default: 0, # pyright: ignore[reportAttributeAccessIssue]
20+
torch.ops.aten.new_empty.default: 1, # pyright: ignore[reportAttributeAccessIssue]
21+
torch.ops.aten.new_full.default: 1, # pyright: ignore[reportAttributeAccessIssue]
22+
torch.ops.aten.new_zeros.default: 1, # pyright: ignore[reportAttributeAccessIssue]
23+
torch.ops.aten.new_ones.default: 1, # pyright: ignore[reportAttributeAccessIssue]
24+
}
25+
26+
def __torch_dispatch__(
27+
self,
28+
func: Callable[..., torch.Tensor],
29+
types: tuple[type, ...],
30+
args: tuple[object, ...] = (),
31+
kwargs: dict[str, object] | None = None,
32+
) -> torch.Tensor:
33+
def _pad_shape(shape: object) -> object:
34+
"""Pad positive integer dimension sizes to the next power of 2."""
35+
36+
def _pad_dim(dim_size: object) -> object:
37+
if isinstance(dim_size, int) and dim_size > 0:
38+
return next_power_of_2(dim_size)
39+
return dim_size
40+
41+
return tree_map(_pad_dim, shape)
42+
43+
kwargs = dict(kwargs or {})
44+
size_index = self._SIZE_ARG_INDEX.get(func)
45+
if size_index is not None:
46+
if "size" in kwargs:
47+
kwargs["size"] = _pad_shape(kwargs["size"])
48+
elif size_index < len(args):
49+
args_list = list(args)
50+
args_list[size_index] = _pad_shape(args_list[size_index])
51+
args = tuple(args_list)
52+
return func(*args, **kwargs)
53+
54+
55+
patch_tensor_factories = _PadTensorFactoryMode
56+
57+
58+
__all__ = ["patch_tensor_factories"]

helion/_compiler/type_propagation.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .host_function import SymbolOrigin
4343
from .output_header import library_imports
4444
from .source_location import current_location
45+
from .tensor_utils import patch_tensor_factories
4546
from .utils import compute_slice_size
4647
from .variable_origin import ArgumentOrigin
4748
from .variable_origin import AttributeOrigin
@@ -1042,7 +1043,8 @@ def proxy(self) -> object:
10421043
torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue]
10431044
)
10441045
try:
1045-
return Tile(self.block_id)
1046+
with torch._C._DisableTorchDispatch(): # pyright: ignore[reportAttributeAccessIssue]
1047+
return Tile(self.block_id)
10461048
finally:
10471049
assert fake_mode is not None
10481050
torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue]
@@ -2191,12 +2193,18 @@ def visit_For(self, node: ast.For) -> TypeInfo:
21912193
raise exc.NestedGridLoop
21922194

21932195
self.device_loop_depth += device_loop
2194-
body = self._loop_body(node.body)
2195-
with self.swap_scope(body):
2196-
# second pass for fixed point
2197-
body.merge(self._loop_body(node.body))
2198-
orelse = self._body(node.orelse)
2199-
self.scope.merge_if_else(body, orelse)
2196+
_maybe_patch_tensor_factories = (
2197+
patch_tensor_factories
2198+
if self.device_loop_depth > 0
2199+
else contextlib.nullcontext
2200+
)
2201+
with _maybe_patch_tensor_factories():
2202+
body = self._loop_body(node.body)
2203+
with self.swap_scope(body):
2204+
# second pass for fixed point
2205+
body.merge(self._loop_body(node.body))
2206+
orelse = self._body(node.orelse)
2207+
self.scope.merge_if_else(body, orelse)
22002208
self.device_loop_depth -= device_loop
22012209
return NoType(origin=self.origin())
22022210

0 commit comments

Comments
 (0)