Skip to content

Commit 2d9af3f

Browse files
authored
🏷️ Annotate einsum_path return type (Issue #724) (#746)
1 parent 08e8118 commit 2d9af3f

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

src/numpy-stubs/@test/static/accept/einsumfunc.pyi

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@ assert_type(np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=bool, casting="unsaf
2727
assert_type(np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f, dtype="c16"), Any)
2828
assert_type(np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=bool, casting="unsafe"), Any)
2929

30-
assert_type(np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_b), tuple[list[Any], str])
31-
assert_type(np.einsum_path("i,i->i", AR_LIKE_u, AR_LIKE_u), tuple[list[Any], str])
32-
assert_type(np.einsum_path("i,i->i", AR_LIKE_i, AR_LIKE_i), tuple[list[Any], str])
33-
assert_type(np.einsum_path("i,i->i", AR_LIKE_f, AR_LIKE_f), tuple[list[Any], str])
34-
assert_type(np.einsum_path("i,i->i", AR_LIKE_c, AR_LIKE_c), tuple[list[Any], str])
35-
assert_type(np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_i), tuple[list[Any], str])
36-
assert_type(np.einsum_path("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c), tuple[list[Any], str])
30+
assert_type(np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_b), tuple[list[str | tuple[int, ...]], str])
31+
assert_type(np.einsum_path("i,i->i", AR_LIKE_u, AR_LIKE_u), tuple[list[str | tuple[int, ...]], str])
32+
assert_type(np.einsum_path("i,i->i", AR_LIKE_i, AR_LIKE_i), tuple[list[str | tuple[int, ...]], str])
33+
assert_type(np.einsum_path("i,i->i", AR_LIKE_f, AR_LIKE_f), tuple[list[str | tuple[int, ...]], str])
34+
assert_type(np.einsum_path("i,i->i", AR_LIKE_c, AR_LIKE_c), tuple[list[str | tuple[int, ...]], str])
35+
assert_type(np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_i), tuple[list[str | tuple[int, ...]], str])
36+
assert_type(
37+
np.einsum_path("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c), tuple[list[str | tuple[int, ...]], str]
38+
)
3739

3840
assert_type(np.einsum([[1, 1], [1, 1]], AR_LIKE_i, AR_LIKE_i), Any)
39-
assert_type(np.einsum_path([[1, 1], [1, 1]], AR_LIKE_i, AR_LIKE_i), tuple[list[Any], str])
41+
assert_type(np.einsum_path([[1, 1], [1, 1]], AR_LIKE_i, AR_LIKE_i), tuple[list[str | tuple[int, ...]], str])

src/numpy-stubs/_core/einsumfunc.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ __all__ = ["einsum", "einsum_path"]
2626
_ArrayT = TypeVar("_ArrayT", bound=_nt.Array[_nt.co_complex])
2727

2828
# TODO (@jorenham): Annotate the `Sequence` value (numpy/numtype#724)
29-
_OptimizeKind: TypeAlias = bool | Literal["greedy", "optimal"] | Sequence[Incomplete]
29+
_OptimizeKind: TypeAlias = bool | Literal["greedy", "optimal"] | Sequence[str | tuple[int, ...]]
3030
_CastingSafe: TypeAlias = Literal["no", "equiv", "safe", "same_kind"]
3131
_CastingUnsafe: TypeAlias = Literal["unsafe"]
3232

@@ -178,4 +178,4 @@ def einsum_path(
178178
*operands: _ArrayLikeComplex_co | _DTypeLikeObject,
179179
optimize: _OptimizeKind = "greedy",
180180
einsum_call: L[False] = False,
181-
) -> tuple[list[Incomplete], str]: ...
181+
) -> tuple[list[str | tuple[int, ...]], str]: ...

0 commit comments

Comments
 (0)