|
32 | 32 | *_int_dtypes, |
33 | 33 | torch.float32, |
34 | 34 | torch.float64, |
| 35 | + torch.complex64, |
| 36 | + torch.complex128, |
35 | 37 | } |
36 | 38 |
|
37 | 39 | _promotion_table = { |
|
70 | 72 | (torch.float32, torch.float64): torch.float64, |
71 | 73 | (torch.float64, torch.float32): torch.float64, |
72 | 74 | (torch.float64, torch.float64): torch.float64, |
| 75 | + # complexes |
| 76 | + (torch.complex64, torch.complex64): torch.complex64, |
| 77 | + (torch.complex64, torch.complex128): torch.complex128, |
| 78 | + (torch.complex128, torch.complex64): torch.complex128, |
| 79 | + (torch.complex128, torch.complex128): torch.complex128, |
| 80 | + # Mixed float and complex |
| 81 | + (torch.float32, torch.complex64): torch.complex64, |
| 82 | + (torch.float32, torch.complex128): torch.complex128, |
| 83 | + (torch.float64, torch.complex64): torch.complex128, |
| 84 | + (torch.float64, torch.complex128): torch.complex128, |
73 | 85 | } |
74 | 86 |
|
75 | 87 |
|
@@ -129,7 +141,6 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: |
129 | 141 | return torch.can_cast(from_, to) |
130 | 142 |
|
131 | 143 | # Basic renames |
132 | | -permute_dims = torch.permute |
133 | 144 | bitwise_invert = torch.bitwise_not |
134 | 145 |
|
135 | 146 | # Two-arg elementwise functions |
@@ -439,18 +450,26 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: |
439 | 450 | x = torch.squeeze(x, a) |
440 | 451 | return x |
441 | 452 |
|
| 453 | +# torch.broadcast_to uses size instead of shape |
| 454 | +def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array: |
| 455 | + return torch.broadcast_to(x, shape, **kwargs) |
| 456 | + |
| 457 | +# torch.permute uses dims instead of axes |
| 458 | +def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: |
| 459 | + return torch.permute(x, axes) |
| 460 | + |
442 | 461 | # The axis parameter doesn't work for flip() and roll() |
443 | 462 | # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't |
444 | 463 | # accept axis=None |
445 | | -def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: |
| 464 | +def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: |
446 | 465 | if axis is None: |
447 | 466 | axis = tuple(range(x.ndim)) |
448 | 467 | # torch.flip doesn't accept dim as an int but the method does |
449 | 468 | # https://github.com/pytorch/pytorch/issues/18095 |
450 | | - return x.flip(axis) |
| 469 | + return x.flip(axis, **kwargs) |
451 | 470 |
|
452 | | -def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: |
453 | | - return torch.roll(x, shift, axis) |
| 471 | +def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: |
| 472 | + return torch.roll(x, shift, axis, **kwargs) |
454 | 473 |
|
455 | 474 | def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: |
456 | 475 | return torch.nonzero(x, as_tuple=True, **kwargs) |
@@ -662,15 +681,18 @@ def isdtype( |
662 | 681 | else: |
663 | 682 | return dtype == kind |
664 | 683 |
|
| 684 | +def take(x: array, indices: array, /, *, axis: int, **kwargs) -> array: |
| 685 | + return torch.index_select(x, axis, indices, **kwargs) |
| 686 | + |
665 | 687 | __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add', |
666 | 688 | 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', |
667 | 689 | 'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal', |
668 | 690 | 'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal', |
669 | 691 | 'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder', |
670 | 692 | 'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all', |
671 | | - 'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll', |
| 693 | + 'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', |
672 | 694 | 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', |
673 | 695 | 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', |
674 | 696 | 'broadcast_arrays', 'unique_all', 'unique_counts', |
675 | 697 | 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', |
676 | | - 'vecdot', 'tensordot', 'isdtype'] |
| 698 | + 'vecdot', 'tensordot', 'isdtype', 'take'] |
0 commit comments