diff options
Diffstat (limited to 'numpy/random/bit_generator.pyi')
-rw-r--r-- | numpy/random/bit_generator.pyi | 45 |
1 files changed, 21 insertions, 24 deletions
diff --git a/numpy/random/bit_generator.pyi b/numpy/random/bit_generator.pyi index 80a2e829b..7f066dbfa 100644 --- a/numpy/random/bit_generator.pyi +++ b/numpy/random/bit_generator.pyi @@ -18,8 +18,8 @@ from typing import ( overload, ) -from numpy import dtype, ndarray, uint32, uint64, unsignedinteger -from numpy.typing import DTypeLike, _ArrayLikeInt_co, _DTypeLikeUInt, _ShapeLike, _SupportsDType +from numpy import dtype, ndarray, uint32, uint64 +from numpy.typing import _ArrayLikeInt_co, _ShapeLike, _SupportsDType, _UInt64Codes, _UInt32Codes if sys.version_info >= (3, 8): from typing import Literal @@ -28,11 +28,17 @@ else: _T = TypeVar("_T") -_UIntType = TypeVar("_UIntType", uint64, uint32) -_DTypeLike = Union[ - Type[_UIntType], - dtype[_UIntType], - _SupportsDType[dtype[_UIntType]], +_DTypeLikeUint32 = Union[ + dtype[uint32], + _SupportsDType[dtype[uint32]], + Type[uint32], + _UInt32Codes, +] +_DTypeLikeUint64 = Union[ + dtype[uint64], + _SupportsDType[dtype[uint64]], + Type[uint64], + _UInt64Codes, ] class _SeedSeqState(TypedDict): @@ -50,30 +56,19 @@ class _Interface(NamedTuple): bit_generator: Any class ISeedSequence(abc.ABC): - @overload - @abc.abstractmethod - def generate_state( - self, n_words: int, dtype: _DTypeLike[_UIntType] = ... - ) -> ndarray[Any, dtype[_UIntType]]: ... - @overload @abc.abstractmethod def generate_state( - self, n_words: int, dtype: _DTypeLikeUInt = ... - ) -> ndarray[Any, dtype[unsignedinteger[Any]]]: ... + self, n_words: int, dtype: Union[_DTypeLikeUint32, _DTypeLikeUint64] = ... + ) -> ndarray[Any, dtype[Union[uint32, uint64]]]: ... class ISpawnableSeedSequence(ISeedSequence): @abc.abstractmethod def spawn(self: _T, n_children: int) -> List[_T]: ... class SeedlessSeedSequence(ISpawnableSeedSequence): - @overload - def generate_state( - self, n_words: int, dtype: _DTypeLike[_UIntType] = ... - ) -> ndarray[Any, dtype[_UIntType]]: ... - @overload def generate_state( - self, n_words: int, dtype: _DTypeLikeUInt = ... - ) -> ndarray[Any, dtype[unsignedinteger[Any]]]: ... + self, n_words: int, dtype: Union[_DTypeLikeUint32, _DTypeLikeUint64] = ... + ) -> ndarray[Any, dtype[Union[uint32, uint64]]]: ... def spawn(self: _T, n_children: int) -> List[_T]: ... class SeedSequence(ISpawnableSeedSequence): @@ -84,7 +79,7 @@ class SeedSequence(ISpawnableSeedSequence): pool: ndarray[Any, dtype[uint32]] def __init__( self, - entropy: Union[None, int, Sequence[int]] = ..., + entropy: Union[None, int, Sequence[int], _ArrayLikeInt_co] = ..., *, spawn_key: Sequence[int] = ..., pool_size: int = ..., @@ -95,7 +90,9 @@ class SeedSequence(ISpawnableSeedSequence): def state( self, ) -> _SeedSeqState: ... - def generate_state(self, n_words: int, dtype: DTypeLike = ...) -> ndarray[Any, Any]: ... + def generate_state( + self, n_words: int, dtype: Union[_DTypeLikeUint32, _DTypeLikeUint64] = ... + ) -> ndarray[Any, dtype[Union[uint32, uint64]]]: ... def spawn(self, n_children: int) -> List[SeedSequence]: ... class BitGenerator(abc.ABC): |