summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/__init__.pyi17
-rw-r--r--numpy/core/_add_newdocs.py10
-rw-r--r--numpy/lib/shape_base.pyi225
-rw-r--r--numpy/typing/tests/data/reveal/ndarray_misc.py14
-rw-r--r--numpy/typing/tests/data/reveal/shape_base.py57
5 files changed, 296 insertions, 27 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index fc3b0501d..cafa296eb 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -1219,10 +1219,9 @@ class _ArrayOrScalarCommon:
@property
def __array_interface__(self): ...
@property
- def __array_priority__(self): ...
+ def __array_priority__(self) -> float: ...
@property
def __array_struct__(self): ...
- def __array_wrap__(array, context=...): ...
def __setstate__(self, __state): ...
# a `bool_` is returned when `keepdims=True` and `self` is a 0d array
@@ -1599,6 +1598,7 @@ _FlexDType = TypeVar("_FlexDType", bound=dtype[flexible])
# TODO: Set the `bound` to something more suitable once we
# have proper shape support
_ShapeType = TypeVar("_ShapeType", bound=Any)
+_ShapeType2 = TypeVar("_ShapeType2", bound=Any)
_NumberType = TypeVar("_NumberType", bound=number[Any])
# There is currently no exhaustive way to type the buffer protocol,
@@ -1674,6 +1674,19 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
def __array__(self, __dtype: None = ...) -> ndarray[Any, _DType_co]: ...
@overload
def __array__(self, __dtype: _DType) -> ndarray[Any, _DType]: ...
+
+ def __array_wrap__(
+ self,
+ __array: ndarray[_ShapeType2, _DType],
+ __context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
+ ) -> ndarray[_ShapeType2, _DType]: ...
+
+ def __array_prepare__(
+ self,
+ __array: ndarray[_ShapeType2, _DType],
+ __context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
+ ) -> ndarray[_ShapeType2, _DType]: ...
+
@property
def ctypes(self) -> _ctypes[int]: ...
@property
diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py
index a29e2e8a8..41e4e4d08 100644
--- a/numpy/core/_add_newdocs.py
+++ b/numpy/core/_add_newdocs.py
@@ -1585,7 +1585,7 @@ add_newdoc('numpy.core.multiarray', 'arange',
For integer arguments the function is equivalent to the Python built-in
`range` function, but returns an ndarray rather than a list.
- When using a non-integer step, such as 0.1, it is often better to use
+ When using a non-integer step, such as 0.1, it is often better to use
`numpy.linspace`. See the warnings section below for more information.
Parameters
@@ -2771,13 +2771,17 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('__array__',
add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_prepare__',
- """a.__array_prepare__(obj) -> Object of same type as ndarray object obj.
+ """a.__array_prepare__(array[, context], /)
+
+ Returns a view of `array` with the same type as self.
"""))
add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_wrap__',
- """a.__array_wrap__(obj) -> Object of same type as ndarray object a.
+ """a.__array_wrap__(array[, context], /)
+
+ Returns a view of `array` with the same type as self.
"""))
diff --git a/numpy/lib/shape_base.pyi b/numpy/lib/shape_base.pyi
index 09edbcb6c..cfb3040b7 100644
--- a/numpy/lib/shape_base.pyi
+++ b/numpy/lib/shape_base.pyi
@@ -1,24 +1,215 @@
-from typing import List
+from typing import List, TypeVar, Callable, Sequence, Any, overload, Tuple
+from typing_extensions import SupportsIndex, Protocol
+
+from numpy import (
+ generic,
+ integer,
+ dtype,
+ ufunc,
+ bool_,
+ unsignedinteger,
+ signedinteger,
+ floating,
+ complexfloating,
+ object_,
+)
+
+from numpy.typing import (
+ ArrayLike,
+ NDArray,
+ _ShapeLike,
+ _NestedSequence,
+ _SupportsDType,
+ _ArrayLikeBool_co,
+ _ArrayLikeUInt_co,
+ _ArrayLikeInt_co,
+ _ArrayLikeFloat_co,
+ _ArrayLikeComplex_co,
+ _ArrayLikeObject_co,
+)
from numpy.core.shape_base import vstack
+_SCT = TypeVar("_SCT", bound=generic)
+
+_ArrayLike = _NestedSequence[_SupportsDType[dtype[_SCT]]]
+
+# The signatures of `__array_wrap__` and `__array_prepare__` are the same;
+# give them unique names for the sake of clarity
+class _ArrayWrap(Protocol):
+ def __call__(
+ self,
+ __array: NDArray[Any],
+ __context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
+ ) -> Any: ...
+
+class _ArrayPrepare(Protocol):
+ def __call__(
+ self,
+ __array: NDArray[Any],
+ __context: None | Tuple[ufunc, Tuple[Any, ...], int] = ...,
+ ) -> Any: ...
+
+class _SupportsArrayWrap(Protocol):
+ @property
+ def __array_wrap__(self) -> _ArrayWrap: ...
+
+class _SupportsArrayPrepare(Protocol):
+ @property
+ def __array_prepare__(self) -> _ArrayPrepare: ...
+
__all__: List[str]
row_stack = vstack
-def take_along_axis(arr, indices, axis): ...
-def put_along_axis(arr, indices, values, axis): ...
-def apply_along_axis(func1d, axis, arr, *args, **kwargs): ...
-def apply_over_axes(func, a, axes): ...
-def expand_dims(a, axis): ...
-def column_stack(tup): ...
-def dstack(tup): ...
-def array_split(ary, indices_or_sections, axis=...): ...
-def split(ary, indices_or_sections, axis=...): ...
-def hsplit(ary, indices_or_sections): ...
-def vsplit(ary, indices_or_sections): ...
-def dsplit(ary, indices_or_sections): ...
-def get_array_prepare(*args): ...
-def get_array_wrap(*args): ...
-def kron(a, b): ...
-def tile(A, reps): ...
+def take_along_axis(
+ arr: _SCT | NDArray[_SCT],
+ indices: NDArray[integer[Any]],
+ axis: None | int,
+) -> NDArray[_SCT]: ...
+
+def put_along_axis(
+ arr: NDArray[_SCT],
+ indices: NDArray[integer[Any]],
+ values: ArrayLike,
+ axis: None | int,
+) -> None: ...
+
+@overload
+def apply_along_axis(
+ func1d: Callable[..., _ArrayLike[_SCT]],
+ axis: SupportsIndex,
+ arr: ArrayLike,
+ *args: Any,
+ **kwargs: Any,
+) -> NDArray[_SCT]: ...
+@overload
+def apply_along_axis(
+ func1d: Callable[..., ArrayLike],
+ axis: SupportsIndex,
+ arr: ArrayLike,
+ *args: Any,
+ **kwargs: Any,
+) -> NDArray[Any]: ...
+
+def apply_over_axes(
+ func: Callable[[NDArray[Any], int], NDArray[_SCT]],
+ a: ArrayLike,
+ axes: int | Sequence[int],
+) -> NDArray[_SCT]: ...
+
+@overload
+def expand_dims(
+ a: _ArrayLike[_SCT],
+ axis: _ShapeLike,
+) -> NDArray[_SCT]: ...
+@overload
+def expand_dims(
+ a: ArrayLike,
+ axis: _ShapeLike,
+) -> NDArray[Any]: ...
+
+@overload
+def column_stack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
+@overload
+def column_stack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
+
+@overload
+def dstack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
+@overload
+def dstack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
+
+@overload
+def array_split(
+ ary: _ArrayLike[_SCT],
+ indices_or_sections: _ShapeLike,
+ axis: SupportsIndex = ...,
+) -> List[NDArray[_SCT]]: ...
+@overload
+def array_split(
+ ary: ArrayLike,
+ indices_or_sections: _ShapeLike,
+ axis: SupportsIndex = ...,
+) -> List[NDArray[Any]]: ...
+
+@overload
+def split(
+ ary: _ArrayLike[_SCT],
+ indices_or_sections: _ShapeLike,
+ axis: SupportsIndex = ...,
+) -> List[NDArray[_SCT]]: ...
+@overload
+def split(
+ ary: ArrayLike,
+ indices_or_sections: _ShapeLike,
+ axis: SupportsIndex = ...,
+) -> List[NDArray[Any]]: ...
+
+@overload
+def hsplit(
+ ary: _ArrayLike[_SCT],
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[_SCT]]: ...
+@overload
+def hsplit(
+ ary: ArrayLike,
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[Any]]: ...
+
+@overload
+def vsplit(
+ ary: _ArrayLike[_SCT],
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[_SCT]]: ...
+@overload
+def vsplit(
+ ary: ArrayLike,
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[Any]]: ...
+
+@overload
+def dsplit(
+ ary: _ArrayLike[_SCT],
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[_SCT]]: ...
+@overload
+def dsplit(
+ ary: ArrayLike,
+ indices_or_sections: _ShapeLike,
+) -> List[NDArray[Any]]: ...
+
+@overload
+def get_array_prepare(*args: _SupportsArrayPrepare) -> _ArrayPrepare: ...
+@overload
+def get_array_prepare(*args: object) -> None | _ArrayPrepare: ...
+
+@overload
+def get_array_wrap(*args: _SupportsArrayWrap) -> _ArrayWrap: ...
+@overload
+def get_array_wrap(*args: object) -> None | _ArrayWrap: ...
+
+@overload
+def kron(a: _ArrayLikeBool_co, b: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
+@overload
+def kron(a: _ArrayLikeUInt_co, b: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[Any]]: ... # type: ignore[misc]
+@overload
+def kron(a: _ArrayLikeInt_co, b: _ArrayLikeInt_co) -> NDArray[signedinteger[Any]]: ... # type: ignore[misc]
+@overload
+def kron(a: _ArrayLikeFloat_co, b: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ... # type: ignore[misc]
+@overload
+def kron(a: _ArrayLikeComplex_co, b: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
+@overload
+def kron(a: _ArrayLikeObject_co, b: Any) -> NDArray[object_]: ...
+@overload
+def kron(a: Any, b: _ArrayLikeObject_co) -> NDArray[object_]: ...
+
+@overload
+def tile(
+ A: _ArrayLike[_SCT],
+ reps: int | Sequence[int],
+) -> NDArray[_SCT]: ...
+@overload
+def tile(
+ A: ArrayLike,
+ reps: int | Sequence[int],
+) -> NDArray[Any]: ...
diff --git a/numpy/typing/tests/data/reveal/ndarray_misc.py b/numpy/typing/tests/data/reveal/ndarray_misc.py
index 2e198eb6f..050b82cdc 100644
--- a/numpy/typing/tests/data/reveal/ndarray_misc.py
+++ b/numpy/typing/tests/data/reveal/ndarray_misc.py
@@ -11,14 +11,15 @@ import ctypes as ct
from typing import Any
import numpy as np
+from numpy.typing import NDArray
-class SubClass(np.ndarray): ...
+class SubClass(NDArray[np.object_]): ...
f8: np.float64
B: SubClass
-AR_f8: np.ndarray[Any, np.dtype[np.float64]]
-AR_i8: np.ndarray[Any, np.dtype[np.int64]]
-AR_U: np.ndarray[Any, np.dtype[np.str_]]
+AR_f8: NDArray[np.float64]
+AR_i8: NDArray[np.int64]
+AR_U: NDArray[np.str_]
ctypes_obj = AR_f8.ctypes
@@ -126,7 +127,7 @@ reveal_type(AR_f8.round(out=B)) # E: SubClass
reveal_type(f8.repeat(1)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
reveal_type(AR_f8.repeat(1)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
-reveal_type(B.repeat(1)) # E: numpy.ndarray[Any, Any]
+reveal_type(B.repeat(1)) # E: numpy.ndarray[Any, numpy.dtype[numpy.object_]]
reveal_type(f8.std()) # E: Any
reveal_type(AR_f8.std()) # E: Any
@@ -189,3 +190,6 @@ reveal_type(float(AR_U)) # E: float
reveal_type(complex(AR_f8)) # E: complex
reveal_type(operator.index(AR_i8)) # E: int
+
+reveal_type(AR_f8.__array_prepare__(B)) # E: numpy.ndarray[Any, numpy.dtype[numpy.object_]]
+reveal_type(AR_f8.__array_wrap__(B)) # E: numpy.ndarray[Any, numpy.dtype[numpy.object_]]
diff --git a/numpy/typing/tests/data/reveal/shape_base.py b/numpy/typing/tests/data/reveal/shape_base.py
new file mode 100644
index 000000000..57633defb
--- /dev/null
+++ b/numpy/typing/tests/data/reveal/shape_base.py
@@ -0,0 +1,57 @@
+import numpy as np
+from numpy.typing import NDArray
+from typing import Any, List
+
+i8: np.int64
+f8: np.float64
+
+AR_b: NDArray[np.bool_]
+AR_i8: NDArray[np.int64]
+AR_f8: NDArray[np.float64]
+
+AR_LIKE_f8: List[float]
+
+reveal_type(np.take_along_axis(AR_f8, AR_i8, axis=1)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
+reveal_type(np.take_along_axis(f8, AR_i8, axis=None)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
+
+reveal_type(np.put_along_axis(AR_f8, AR_i8, "1.0", axis=1)) # E: None
+
+reveal_type(np.expand_dims(AR_i8, 2)) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
+reveal_type(np.expand_dims(AR_LIKE_f8, 2)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
+
+reveal_type(np.column_stack([AR_i8])) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
+reveal_type(np.column_stack([AR_LIKE_f8])) # E: numpy.ndarray[Any, numpy.dtype[Any]]
+
+reveal_type(np.dstack([AR_i8])) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
+reveal_type(np.dstack([AR_LIKE_f8])) # E: numpy.ndarray[Any, numpy.dtype[Any]]
+
+reveal_type(np.row_stack([AR_i8])) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
+reveal_type(np.row_stack([AR_LIKE_f8])) # E: numpy.ndarray[Any, numpy.dtype[Any]]
+
+reveal_type(np.array_split(AR_i8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[{int64}]]]
+reveal_type(np.array_split(AR_LIKE_f8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]]
+
+reveal_type(np.split(AR_i8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[{int64}]]]
+reveal_type(np.split(AR_LIKE_f8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]]
+
+reveal_type(np.hsplit(AR_i8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[{int64}]]]
+reveal_type(np.hsplit(AR_LIKE_f8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]]
+
+reveal_type(np.vsplit(AR_i8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[{int64}]]]
+reveal_type(np.vsplit(AR_LIKE_f8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]]
+
+reveal_type(np.dsplit(AR_i8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[{int64}]]]
+reveal_type(np.dsplit(AR_LIKE_f8, [3, 5, 6, 10])) # E: list[numpy.ndarray[Any, numpy.dtype[Any]]]
+
+reveal_type(np.lib.shape_base.get_array_prepare(AR_i8)) # E: numpy.lib.shape_base._ArrayPrepare
+reveal_type(np.lib.shape_base.get_array_prepare(AR_i8, 1)) # E: Union[None, numpy.lib.shape_base._ArrayPrepare]
+
+reveal_type(np.get_array_wrap(AR_i8)) # E: numpy.lib.shape_base._ArrayWrap
+reveal_type(np.get_array_wrap(AR_i8, 1)) # E: Union[None, numpy.lib.shape_base._ArrayWrap]
+
+reveal_type(np.kron(AR_b, AR_b)) # E: numpy.ndarray[Any, numpy.dtype[numpy.bool_]]
+reveal_type(np.kron(AR_b, AR_i8)) # E: numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[Any]]]
+reveal_type(np.kron(AR_f8, AR_f8)) # E: numpy.ndarray[Any, numpy.dtype[numpy.floating[Any]]]
+
+reveal_type(np.tile(AR_i8, 5)) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
+reveal_type(np.tile(AR_LIKE_f8, [2, 2])) # E: numpy.ndarray[Any, numpy.dtype[Any]]