|
1 | 1 | from typing import Any, Optional, Tuple, Union |
2 | 2 |
|
| 3 | +import array_api_compat.numpy as np_array_api |
3 | 4 | import numpy as np |
4 | 5 | from onnx import FunctionProto, ModelProto, NodeProto, TensorProto |
5 | 6 | from onnx.helper import np_dtype_to_tensor_dtype |
6 | 7 | from onnx.numpy_helper import from_array |
7 | 8 |
|
8 | 9 | from .npx_constants import FUNCTION_DOMAIN |
9 | | -from .npx_core_api import cst, make_tuple, npxapi_inline, var |
| 10 | +from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var |
10 | 11 | from .npx_tensors import ArrayApi |
11 | 12 | from .npx_types import ( |
| 13 | + DType, |
12 | 14 | ElemType, |
13 | 15 | OptParType, |
14 | 16 | ParType, |
@@ -397,6 +399,17 @@ def identity(n: ParType[int], dtype=None) -> TensorType[ElemType.numerics, "T"]: |
397 | 399 | return v |
398 | 400 |
|
399 | 401 |
|
| 402 | +@npxapi_no_inline |
| 403 | +def isdtype( |
| 404 | + dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]] |
| 405 | +) -> bool: |
| 406 | + """ |
| 407 | + See :epkg:`ArrayAPI:isdtype`. |
| 408 | + This function is not converted into an onnx graph. |
| 409 | + """ |
| 410 | + return np_array_api.isdtype(dtype, kind) |
| 411 | + |
| 412 | + |
400 | 413 | @npxapi_inline |
401 | 414 | def isnan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T"]: |
402 | 415 | "See :func:`numpy.isnan`." |
@@ -460,9 +473,23 @@ def relu(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, |
460 | 473 |
|
461 | 474 | @npxapi_inline |
462 | 475 | def reshape( |
463 | | - x: TensorType[ElemType.numerics, "T"], shape: TensorType[ElemType.int64, "I"] |
| 476 | + x: TensorType[ElemType.numerics, "T"], |
| 477 | + shape: TensorType[ElemType.int64, "I", (None,)], |
464 | 478 | ) -> TensorType[ElemType.numerics, "T"]: |
465 | | - "See :func:`numpy.reshape`." |
| 479 | + """ |
| 480 | + See :func:`numpy.reshape`. |
| 481 | +
|
| 482 | + .. warning:: |
| 483 | +
|
| 484 | + Numpy definition is tricky because onnxruntime does not handle well |
| 485 | + dimensions with an undefined number of dimensions. |
| 486 | + However the array API defines a more stricly signature for |
| 487 | + `reshape <https://data-apis.org/array-api/2022.12/ |
| 488 | + API_specification/generated/array_api.reshape.html>`_. |
| 489 | + :epkg:`scikit-learn` updated its code to follow the Array API in |
| 490 | + `PR 26030 ENH Forces shape to be tuple when using Array API's reshape |
| 491 | + <https://github.com/scikit-learn/scikit-learn/pull/26030>`_. |
| 492 | + """ |
466 | 493 | if isinstance(shape, int): |
467 | 494 | shape = cst(np.array([shape], dtype=np.int64)) |
468 | 495 | shape_reshaped = var(shape, cst(np.array([-1], dtype=np.int64)), op="Reshape") |
|
0 commit comments