|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import Callable |
| 4 | +from typing import ClassVar |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch.utils._python_dispatch import TorchDispatchMode |
| 8 | +from torch.utils._pytree import tree_map |
| 9 | +from triton import next_power_of_2 |
| 10 | + |
| 11 | + |
| 12 | +class _PadTensorFactoryMode(TorchDispatchMode): |
| 13 | + """Dispatch mode that pads tensor factory size arguments.""" |
| 14 | + |
| 15 | + _SIZE_ARG_INDEX: ClassVar[dict[Callable[..., torch.Tensor], int]] = { |
| 16 | + torch.ops.aten.zeros.default: 0, # pyright: ignore[reportAttributeAccessIssue] |
| 17 | + torch.ops.aten.ones.default: 0, # pyright: ignore[reportAttributeAccessIssue] |
| 18 | + torch.ops.aten.empty.memory_format: 0, # pyright: ignore[reportAttributeAccessIssue] |
| 19 | + torch.ops.aten.full.default: 0, # pyright: ignore[reportAttributeAccessIssue] |
| 20 | + torch.ops.aten.new_empty.default: 1, # pyright: ignore[reportAttributeAccessIssue] |
| 21 | + torch.ops.aten.new_full.default: 1, # pyright: ignore[reportAttributeAccessIssue] |
| 22 | + torch.ops.aten.new_zeros.default: 1, # pyright: ignore[reportAttributeAccessIssue] |
| 23 | + torch.ops.aten.new_ones.default: 1, # pyright: ignore[reportAttributeAccessIssue] |
| 24 | + } |
| 25 | + |
| 26 | + def __torch_dispatch__( |
| 27 | + self, |
| 28 | + func: Callable[..., torch.Tensor], |
| 29 | + types: tuple[type, ...], |
| 30 | + args: tuple[object, ...] = (), |
| 31 | + kwargs: dict[str, object] | None = None, |
| 32 | + ) -> torch.Tensor: |
| 33 | + def _pad_shape(shape: object) -> object: |
| 34 | + """Pad positive integer dimension sizes to the next power of 2.""" |
| 35 | + |
| 36 | + def _pad_dim(dim_size: object) -> object: |
| 37 | + if isinstance(dim_size, int) and dim_size > 0: |
| 38 | + return next_power_of_2(dim_size) |
| 39 | + return dim_size |
| 40 | + |
| 41 | + return tree_map(_pad_dim, shape) |
| 42 | + |
| 43 | + kwargs = dict(kwargs or {}) |
| 44 | + size_index = self._SIZE_ARG_INDEX.get(func) |
| 45 | + if size_index is not None: |
| 46 | + if "size" in kwargs: |
| 47 | + kwargs["size"] = _pad_shape(kwargs["size"]) |
| 48 | + elif size_index < len(args): |
| 49 | + args_list = list(args) |
| 50 | + args_list[size_index] = _pad_shape(args_list[size_index]) |
| 51 | + args = tuple(args_list) |
| 52 | + return func(*args, **kwargs) |
| 53 | + |
| 54 | + |
| 55 | +patch_tensor_factories = _PadTensorFactoryMode |
| 56 | + |
| 57 | + |
| 58 | +__all__ = ["patch_tensor_factories"] |
0 commit comments