summaryrefslogtreecommitdiff
path: root/numpy/lib/shape_base.pyi
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-06-28 12:28:41 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-07-09 16:04:06 +0200
commit1d82bcb65bae28194f2aa80e57728f54b7c3f6ad (patch)
tree870dafa4a5db653ac12a5391393dc55e5a1088d0 /numpy/lib/shape_base.pyi
parentb32b72e3d98d784b98d9c38d4f9905574a60707d (diff)
downloadnumpy-1d82bcb65bae28194f2aa80e57728f54b7c3f6ad.tar.gz
ENH: Add annotations for `np.lib.shape_base`
Diffstat (limited to 'numpy/lib/shape_base.pyi')
-rw-r--r--numpy/lib/shape_base.pyi225
1 files changed, 208 insertions, 17 deletions
diff --git a/numpy/lib/shape_base.pyi b/numpy/lib/shape_base.pyi
index 09edbcb6c..cfb3040b7 100644
--- a/numpy/lib/shape_base.pyi
+++ b/numpy/lib/shape_base.pyi
@@ -1,24 +1,215 @@
-from typing import List
+from typing import List, TypeVar, Callable, Sequence, Any, overload, Tuple
+from typing_extensions import SupportsIndex, Protocol
+
+from numpy import (
+ generic,
+ integer,
+ dtype,
+ ufunc,
+ bool_,
+ unsignedinteger,
+ signedinteger,
+ floating,
+ complexfloating,
+ object_,
+)
+
+from numpy.typing import (
+ ArrayLike,
+ NDArray,
+ _ShapeLike,
+ _NestedSequence,
+ _SupportsDType,
+ _ArrayLikeBool_co,
+ _ArrayLikeUInt_co,
+ _ArrayLikeInt_co,
+ _ArrayLikeFloat_co,
+ _ArrayLikeComplex_co,
+ _ArrayLikeObject_co,
+)
from numpy.core.shape_base import vstack
+_SCT = TypeVar("_SCT", bound=generic)
+
+_ArrayLike = _NestedSequence[_SupportsDType[dtype[_SCT]]]
+
+# The signatures of `__array_wrap__` and `__array_prepare__` are the same;
+# give them unique names for the sake of clarity
+class _ArrayWrap(Protocol):
+ def __call__(
+ self,
+ __array: NDArray[Any],
+ __context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
+ ) -> Any: ...
+
+class _ArrayPrepare(Protocol):
+ def __call__(
+ self,
+ __array: NDArray[Any],
+ __context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
+ ) -> Any: ...
+
+class _SupportsArrayWrap(Protocol):
+ @property
+ def __array_wrap__(self) -> _ArrayWrap: ...
+
+class _SupportsArrayPrepare(Protocol):
+ @property
+ def __array_prepare__(self) -> _ArrayPrepare: ...
+
__all__: List[str]
row_stack = vstack
-def take_along_axis(arr, indices, axis): ...
-def put_along_axis(arr, indices, values, axis): ...
-def apply_along_axis(func1d, axis, arr, *args, **kwargs): ...
-def apply_over_axes(func, a, axes): ...
-def expand_dims(a, axis): ...
-def column_stack(tup): ...
-def dstack(tup): ...
-def array_split(ary, indices_or_sections, axis=...): ...
-def split(ary, indices_or_sections, axis=...): ...
-def hsplit(ary, indices_or_sections): ...
-def vsplit(ary, indices_or_sections): ...
-def dsplit(ary, indices_or_sections): ...
-def get_array_prepare(*args): ...
-def get_array_wrap(*args): ...
-def kron(a, b): ...
-def tile(A, reps): ...
+def take_along_axis(
+ arr: _SCT | NDArray[_SCT],
+ indices: NDArray[integer[Any]],
+ axis: None | int,
+) -> NDArray[_SCT]: ...
+
+def put_along_axis(
+ arr: NDArray[_SCT],
+ indices: NDArray[integer[Any]],
+ values: ArrayLike,
+ axis: None | int,
+) -> None: ...
+
+@overload
+def apply_along_axis(
+ func1d: Callable[..., _ArrayLike[_SCT]],
+ axis: SupportsIndex,
+ arr: ArrayLike,
+ *args: Any,
+ **kwargs: Any,
+) -> NDArray[_SCT]: ...
+@overload
+def apply_along_axis(
+ func1d: Callable[..., ArrayLike],
+ axis: SupportsIndex,
+ arr: ArrayLike,
+ *args: Any,
+ **kwargs: Any,
+) -> NDArray[Any]: ...
+
+def apply_over_axes(
+ func: Callable[[NDArray[Any], int], NDArray[_SCT]],
+ a: ArrayLike,
+ axes: int | Sequence[int],
+) -> NDArray[_SCT]: ...
+
+@overload
+def expand_dims(
+ a: _ArrayLike[_SCT],
+ axis: _ShapeLike,
+) -> NDArray[_SCT]: ...
+@overload
+def expand_dims(
+ a: ArrayLike,
+ axis: _ShapeLike,
+) -> NDArray[Any]: ...
+
+@overload
+def column_stack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
+@overload
+def column_stack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
+
+@overload
+def dstack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
+@overload
+def dstack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
+
+@overload
+def array_split(
+ ary: _ArrayLike[_SCT],
+ indices_or_sections: _ShapeLike,
+ axis: SupportsIndex = ...,
+) -> List[NDArray[_SCT]]: ...
+@overload
+def array_split(
+ ary: ArrayLike,
+ indices_or_sections: _ShapeLike,
+ axis: SupportsIndex = ...,
+) -> List[NDArray[Any]]: ...
+
+@overload
+def split(
+ ary: _ArrayLike[_SCT],
+ indices_or_sections: _ShapeLike,
+ axis: SupportsIndex = ...,
+) -> List[NDArray[_SCT]]: ...
+@overload
+def split(
+ ary: ArrayLike,
+ indices_or_sections: _ShapeLike,
+ axis: SupportsIndex = ...,
+) -> List[NDArray[Any]]: ...
+
+@overload
+def hsplit(
+ ary: _ArrayLike[_SCT],
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[_SCT]]: ...
+@overload
+def hsplit(
+ ary: ArrayLike,
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[Any]]: ...
+
+@overload
+def vsplit(
+ ary: _ArrayLike[_SCT],
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[_SCT]]: ...
+@overload
+def vsplit(
+ ary: ArrayLike,
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[Any]]: ...
+
+@overload
+def dsplit(
+ ary: _ArrayLike[_SCT],
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[_SCT]]: ...
+@overload
+def dsplit(
+ ary: ArrayLike,
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[Any]]: ...
+
+@overload
+def get_array_prepare(*args: _SupportsArrayPrepare) -> _ArrayPrepare: ...
+@overload
+def get_array_prepare(*args: object) -> None | _ArrayPrepare: ...
+
+@overload
+def get_array_wrap(*args: _SupportsArrayWrap) -> _ArrayWrap: ...
+@overload
+def get_array_wrap(*args: object) -> None | _ArrayWrap: ...
+
+@overload
+def kron(a: _ArrayLikeBool_co, b: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
+@overload
+def kron(a: _ArrayLikeUInt_co, b: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[Any]]: ... # type: ignore[misc]
+@overload
+def kron(a: _ArrayLikeInt_co, b: _ArrayLikeInt_co) -> NDArray[signedinteger[Any]]: ... # type: ignore[misc]
+@overload
+def kron(a: _ArrayLikeFloat_co, b: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ... # type: ignore[misc]
+@overload
+def kron(a: _ArrayLikeComplex_co, b: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
+@overload
+def kron(a: _ArrayLikeObject_co, b: Any) -> NDArray[object_]: ...
+@overload
+def kron(a: Any, b: _ArrayLikeObject_co) -> NDArray[object_]: ...
+
+@overload
+def tile(
+ A: _ArrayLike[_SCT],
+ reps: int | Sequence[int],
+) -> NDArray[_SCT]: ...
+@overload
+def tile(
+ A: ArrayLike,
+ reps: int | Sequence[int],
+) -> NDArray[Any]: ...