diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/_typing/_array_like.py | 6 | ||||
-rw-r--r-- | numpy/_typing/_dtype_like.py | 2 | ||||
-rw-r--r-- | numpy/_typing/_nested_sequence.py | 2 | ||||
-rw-r--r-- | numpy/typing/tests/test_runtime.py | 33 |
4 files changed, 40 insertions, 3 deletions
diff --git a/numpy/_typing/_array_like.py b/numpy/_typing/_array_like.py index 2e5684b0b..67d67ce19 100644 --- a/numpy/_typing/_array_like.py +++ b/numpy/_typing/_array_like.py @@ -3,7 +3,7 @@ from __future__ import annotations # NOTE: Import `Sequence` from `typing` as we it is needed for a type-alias, # not an annotation from collections.abc import Collection, Callable -from typing import Any, Sequence, Protocol, Union, TypeVar +from typing import Any, Sequence, Protocol, Union, TypeVar, runtime_checkable from numpy import ( ndarray, dtype, @@ -33,10 +33,12 @@ _DType_co = TypeVar("_DType_co", covariant=True, bound="dtype[Any]") # array. # Concrete implementations of the protocol are responsible for adding # any and all remaining overloads +@runtime_checkable class _SupportsArray(Protocol[_DType_co]): def __array__(self) -> ndarray[Any, _DType_co]: ... +@runtime_checkable class _SupportsArrayFunc(Protocol): """A protocol class representing `~class.__array_function__`.""" def __array_function__( @@ -146,7 +148,7 @@ _ArrayLikeInt = _DualArrayLike[ # Used as the first overload, should only match NDArray[Any], # not any actual types. # https://github.com/numpy/numpy/pull/22193 -class _UnknownType: +class _UnknownType: ... diff --git a/numpy/_typing/_dtype_like.py b/numpy/_typing/_dtype_like.py index b705d82fd..e92e17dd2 100644 --- a/numpy/_typing/_dtype_like.py +++ b/numpy/_typing/_dtype_like.py @@ -8,6 +8,7 @@ from typing import ( TypeVar, Protocol, TypedDict, + runtime_checkable, ) import numpy as np @@ -80,6 +81,7 @@ class _DTypeDict(_DTypeDictBase, total=False): # A protocol for anything with the dtype attribute +@runtime_checkable class _SupportsDType(Protocol[_DType_co]): @property def dtype(self) -> _DType_co: ... diff --git a/numpy/_typing/_nested_sequence.py b/numpy/_typing/_nested_sequence.py index 360c0f1b2..789bf3844 100644 --- a/numpy/_typing/_nested_sequence.py +++ b/numpy/_typing/_nested_sequence.py @@ -8,6 +8,7 @@ from typing import ( overload, TypeVar, Protocol, + runtime_checkable, ) __all__ = ["_NestedSequence"] @@ -15,6 +16,7 @@ __all__ = ["_NestedSequence"] _T_co = TypeVar("_T_co", covariant=True) +@runtime_checkable class _NestedSequence(Protocol[_T_co]): """A protocol for representing nested sequences. diff --git a/numpy/typing/tests/test_runtime.py b/numpy/typing/tests/test_runtime.py index 5b5df49dc..44d069006 100644 --- a/numpy/typing/tests/test_runtime.py +++ b/numpy/typing/tests/test_runtime.py @@ -3,11 +3,19 @@ from __future__ import annotations import sys -from typing import get_type_hints, Union, NamedTuple, get_args, get_origin +from typing import ( + get_type_hints, + Union, + NamedTuple, + get_args, + get_origin, + Any, +) import pytest import numpy as np import numpy.typing as npt +import numpy._typing as _npt class TypeTup(NamedTuple): @@ -80,3 +88,26 @@ def test_keys() -> None: keys = TYPES.keys() ref = set(npt.__all__) assert keys == ref + + +PROTOCOLS: dict[str, tuple[type[Any], object]] = { + "_SupportsDType": (_npt._SupportsDType, np.int64(1)), + "_SupportsArray": (_npt._SupportsArray, np.arange(10)), + "_SupportsArrayFunc": (_npt._SupportsArrayFunc, np.arange(10)), + "_NestedSequence": (_npt._NestedSequence, [1]), +} + + +@pytest.mark.parametrize("cls,obj", PROTOCOLS.values(), ids=PROTOCOLS.keys()) +class TestRuntimeProtocol: + def test_isinstance(self, cls: type[Any], obj: object) -> None: + assert isinstance(obj, cls) + assert not isinstance(None, cls) + + def test_issubclass(self, cls: type[Any], obj: object) -> None: + if cls is _npt._SupportsDType: + pytest.xfail( + "Protocols with non-method members don't support issubclass()" + ) + assert issubclass(type(obj), cls) + assert not issubclass(type(None), cls) |