summaryrefslogtreecommitdiff
path: root/numpy/typing/_generic_alias.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-06-14 14:07:18 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-06-14 14:07:18 -0600
commit8c78b84968e580f24b3705378fb35705a434cdf1 (patch)
treec9f82beeb5a2c3f0301f7984d4b6d19539c35d23 /numpy/typing/_generic_alias.py
parent8bf3a4618f1de951c7a4ccdb8bc3e36825a1b744 (diff)
parent75f852edf94a7293e7982ad516bee314d7187c2d (diff)
downloadnumpy-8c78b84968e580f24b3705378fb35705a434cdf1.tar.gz
Merge branch 'main' into matrix_rank-doc-fix
Diffstat (limited to 'numpy/typing/_generic_alias.py')
-rw-r--r--numpy/typing/_generic_alias.py216
1 files changed, 216 insertions, 0 deletions
diff --git a/numpy/typing/_generic_alias.py b/numpy/typing/_generic_alias.py
new file mode 100644
index 000000000..8d65ef855
--- /dev/null
+++ b/numpy/typing/_generic_alias.py
@@ -0,0 +1,216 @@
+from __future__ import annotations
+
+import sys
+import types
+from typing import (
+ Any,
+ ClassVar,
+ FrozenSet,
+ Generator,
+ Iterable,
+ Iterator,
+ List,
+ NoReturn,
+ Tuple,
+ Type,
+ TypeVar,
+ TYPE_CHECKING,
+)
+
+import numpy as np
+
+__all__ = ["_GenericAlias", "NDArray"]
+
+_T = TypeVar("_T", bound="_GenericAlias")
+
+
+def _to_str(obj: object) -> str:
+ """Helper function for `_GenericAlias.__repr__`."""
+ if obj is Ellipsis:
+ return '...'
+ elif isinstance(obj, type) and not isinstance(obj, _GENERIC_ALIAS_TYPE):
+ if obj.__module__ == 'builtins':
+ return obj.__qualname__
+ else:
+ return f'{obj.__module__}.{obj.__qualname__}'
+ else:
+ return repr(obj)
+
+
+def _parse_parameters(args: Iterable[Any]) -> Generator[TypeVar, None, None]:
+ """Search for all typevars and typevar-containing objects in `args`.
+
+ Helper function for `_GenericAlias.__init__`.
+
+ """
+ for i in args:
+ if hasattr(i, "__parameters__"):
+ yield from i.__parameters__
+ elif isinstance(i, TypeVar):
+ yield i
+
+
+def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T:
+ """Recursivelly replace all typevars with those from `parameters`.
+
+ Helper function for `_GenericAlias.__getitem__`.
+
+ """
+ args = []
+ for i in alias.__args__:
+ if isinstance(i, TypeVar):
+ value: Any = next(parameters)
+ elif isinstance(i, _GenericAlias):
+ value = _reconstruct_alias(i, parameters)
+ elif hasattr(i, "__parameters__"):
+ prm_tup = tuple(next(parameters) for _ in i.__parameters__)
+ value = i[prm_tup]
+ else:
+ value = i
+ args.append(value)
+
+ cls = type(alias)
+ return cls(alias.__origin__, tuple(args))
+
+
+class _GenericAlias:
+ """A python-based backport of the `types.GenericAlias` class.
+
+ E.g. for ``t = list[int]``, ``t.__origin__`` is ``list`` and
+ ``t.__args__`` is ``(int,)``.
+
+ See Also
+ --------
+ :pep:`585`
+ The PEP responsible for introducing `types.GenericAlias`.
+
+ """
+
+ __slots__ = ("__weakref__", "_origin", "_args", "_parameters", "_hash")
+
+ @property
+ def __origin__(self) -> type:
+ return super().__getattribute__("_origin")
+
+ @property
+ def __args__(self) -> Tuple[object, ...]:
+ return super().__getattribute__("_args")
+
+ @property
+ def __parameters__(self) -> Tuple[TypeVar, ...]:
+ """Type variables in the ``GenericAlias``."""
+ return super().__getattribute__("_parameters")
+
+ def __init__(
+ self,
+ origin: type,
+ args: object | Tuple[object, ...],
+ ) -> None:
+ self._origin = origin
+ self._args = args if isinstance(args, tuple) else (args,)
+ self._parameters = tuple(_parse_parameters(self.__args__))
+
+ @property
+ def __call__(self) -> type:
+ return self.__origin__
+
+ def __reduce__(self: _T) -> Tuple[
+ Type[_T],
+ Tuple[type, Tuple[object, ...]],
+ ]:
+ cls = type(self)
+ return cls, (self.__origin__, self.__args__)
+
+ def __mro_entries__(self, bases: Iterable[object]) -> Tuple[type]:
+ return (self.__origin__,)
+
+ def __dir__(self) -> List[str]:
+ """Implement ``dir(self)``."""
+ cls = type(self)
+ dir_origin = set(dir(self.__origin__))
+ return sorted(cls._ATTR_EXCEPTIONS | dir_origin)
+
+ def __hash__(self) -> int:
+ """Return ``hash(self)``."""
+ # Attempt to use the cached hash
+ try:
+ return super().__getattribute__("_hash")
+ except AttributeError:
+ self._hash: int = hash(self.__origin__) ^ hash(self.__args__)
+ return super().__getattribute__("_hash")
+
+ def __instancecheck__(self, obj: object) -> NoReturn:
+ """Check if an `obj` is an instance."""
+ raise TypeError("isinstance() argument 2 cannot be a "
+ "parameterized generic")
+
+ def __subclasscheck__(self, cls: type) -> NoReturn:
+ """Check if a `cls` is a subclass."""
+ raise TypeError("issubclass() argument 2 cannot be a "
+ "parameterized generic")
+
+ def __repr__(self) -> str:
+ """Return ``repr(self)``."""
+ args = ", ".join(_to_str(i) for i in self.__args__)
+ origin = _to_str(self.__origin__)
+ return f"{origin}[{args}]"
+
+ def __getitem__(self: _T, key: object | Tuple[object, ...]) -> _T:
+ """Return ``self[key]``."""
+ key_tup = key if isinstance(key, tuple) else (key,)
+
+ if len(self.__parameters__) == 0:
+ raise TypeError(f"There are no type variables left in {self}")
+ elif len(key_tup) > len(self.__parameters__):
+ raise TypeError(f"Too many arguments for {self}")
+ elif len(key_tup) < len(self.__parameters__):
+ raise TypeError(f"Too few arguments for {self}")
+
+ key_iter = iter(key_tup)
+ return _reconstruct_alias(self, key_iter)
+
+ def __eq__(self, value: object) -> bool:
+ """Return ``self == value``."""
+ if not isinstance(value, _GENERIC_ALIAS_TYPE):
+ return NotImplemented
+ return (
+ self.__origin__ == value.__origin__ and
+ self.__args__ == value.__args__
+ )
+
+ _ATTR_EXCEPTIONS: ClassVar[FrozenSet[str]] = frozenset({
+ "__origin__",
+ "__args__",
+ "__parameters__",
+ "__mro_entries__",
+ "__reduce__",
+ "__reduce_ex__",
+ })
+
+ def __getattribute__(self, name: str) -> Any:
+ """Return ``getattr(self, name)``."""
+ # Pull the attribute from `__origin__` unless its
+ # name is in `_ATTR_EXCEPTIONS`
+ cls = type(self)
+ if name in cls._ATTR_EXCEPTIONS:
+ return super().__getattribute__(name)
+ return getattr(self.__origin__, name)
+
+
+# See `_GenericAlias.__eq__`
+if sys.version_info >= (3, 9):
+ _GENERIC_ALIAS_TYPE = (_GenericAlias, types.GenericAlias)
+else:
+ _GENERIC_ALIAS_TYPE = (_GenericAlias,)
+
+ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
+
+if TYPE_CHECKING:
+ _DType = np.dtype[ScalarType]
+ NDArray = np.ndarray[Any, np.dtype[ScalarType]]
+elif sys.version_info >= (3, 9):
+ _DType = types.GenericAlias(np.dtype, (ScalarType,))
+ NDArray = types.GenericAlias(np.ndarray, (Any, _DType))
+else:
+ _DType = _GenericAlias(np.dtype, (ScalarType,))
+ NDArray = _GenericAlias(np.ndarray, (Any, _DType))