diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2021-02-03 00:50:40 +0100 |
---|---|---|
committer | Bas van Beek <b.f.van.beek@vu.nl> | 2021-02-05 17:55:25 +0100 |
commit | af7106f76e3eb2ab6ada000ced512951f378b5fc (patch) | |
tree | fc83b8a7fd3d7a76571e83506a5da79bcd6228fb /numpy/typing/mypy_plugin.py | |
parent | a1640ad416c427d397695f51011000a1d7583f22 (diff) | |
download | numpy-af7106f76e3eb2ab6ada000ced512951f378b5fc.tar.gz |
ENH: Add a plugin for exposing platform-specific extended-precision `np.number`s
Diffstat (limited to 'numpy/typing/mypy_plugin.py')
-rw-r--r-- | numpy/typing/mypy_plugin.py | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/numpy/typing/mypy_plugin.py b/numpy/typing/mypy_plugin.py index 813ff16d2..901bf4fb1 100644 --- a/numpy/typing/mypy_plugin.py +++ b/numpy/typing/mypy_plugin.py @@ -10,6 +10,9 @@ try: import mypy.types from mypy.types import Type from mypy.plugin import Plugin, AnalyzeTypeContext + from mypy.nodes import MypyFile, ImportFrom, Statement + from mypy.build import PRI_MED + _HookFunc = t.Callable[[AnalyzeTypeContext], Type] MYPY_EX: t.Optional[ModuleNotFoundError] = None except ModuleNotFoundError as ex: @@ -39,10 +42,32 @@ def _get_precision_dict() -> t.Dict[str, str]: return ret +def _get_extended_precision_list() -> t.List[str]: + extended_types = [np.ulonglong, np.longlong, np.longdouble, np.clongdouble] + extended_names = { + "uint128", + "uint256", + "int128", + "int256", + "float80", + "float96", + "float128", + "float256", + "complex160", + "complex192", + "complex256", + "complex512", + } + return [i.__name__ for i in extended_types if i.__name__ in extended_names] + + #: A dictionary mapping type-aliases in `numpy.typing._nbit` to #: concrete `numpy.typing.NBitBase` subclasses. _PRECISION_DICT: t.Final = _get_precision_dict() +#: A list with the names of all extended precision `np.number` subclasses. +_EXTENDED_PRECISION_LIST: t.Final = _get_extended_precision_list() + def _hook(ctx: AnalyzeTypeContext) -> Type: """Replace a type-alias with a concrete ``NBitBase`` subclass.""" @@ -53,14 +78,49 @@ def _hook(ctx: AnalyzeTypeContext) -> Type: if t.TYPE_CHECKING or MYPY_EX is None: + def _index(iterable: t.Iterable[Statement], id: str) -> int: + """Identify the first ``ImportFrom`` instance the specified `id`.""" + for i, value in enumerate(iterable): + if getattr(value, "id", None) == id: + return i + else: + raise ValueError("Failed to identify a `ImportFrom` instance " + f"with the following id: {id!r}") + class _NumpyPlugin(Plugin): """A plugin for assigning platform-specific `numpy.number` precisions.""" def get_type_analyze_hook(self, fullname: str) -> t.Optional[_HookFunc]: + """Set the precision of platform-specific `numpy.number` subclasses. + + For example: `numpy.int_`, `numpy.longlong` and `numpy.longdouble`. + """ if fullname in _PRECISION_DICT: return _hook return None + def get_additional_deps(self, file: MypyFile) -> t.List[t.Tuple[int, str, int]]: + """Import platform-specific extended-precision `numpy.number` subclasses. + + For example: `numpy.float96`, `numpy.float128` and `numpy.complex256`. + """ + ret = [(PRI_MED, file.fullname, -1)] + if file.fullname == "numpy": + # Import ONLY the extended precision types available to the + # platform in question + imports = ImportFrom( + "numpy.typing._extended_precision", 0, + names=[(v, v) for v in _EXTENDED_PRECISION_LIST], + ) + imports.is_top_level = True + + # Replace the much broader extended-precision import + # (defined in `numpy/__init__.pyi`) with a more specific one + for lst in [file.defs, file.imports]: # type: t.List[Statement] + i = _index(lst, "numpy.typing._extended_precision") + lst[i] = imports + return ret + def plugin(version: str) -> t.Type[_NumpyPlugin]: """An entry-point for mypy.""" return _NumpyPlugin |