diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2021-05-21 20:03:44 +0200 |
---|---|---|
committer | Bas van Beek <b.f.van.beek@vu.nl> | 2021-05-21 21:15:16 +0200 |
commit | 869243e50c5a607b792d5102bd0ce360c377e8eb (patch) | |
tree | 4e050cb9c15e4c4afa6369a45528f1290919e0eb /numpy/typing/mypy_plugin.py | |
parent | 7de0fa959e476900725d8a654775e0a38745de08 (diff) | |
download | numpy-869243e50c5a607b792d5102bd0ce360c377e8eb.tar.gz |
ENH: Add a mypy plugin for inferring the precision of `np.ctypeslib.c_intp`
Diffstat (limited to 'numpy/typing/mypy_plugin.py')
-rw-r--r-- | numpy/typing/mypy_plugin.py | 56 |
1 files changed, 41 insertions, 15 deletions
diff --git a/numpy/typing/mypy_plugin.py b/numpy/typing/mypy_plugin.py index 901bf4fb1..2a5e729f3 100644 --- a/numpy/typing/mypy_plugin.py +++ b/numpy/typing/mypy_plugin.py @@ -61,6 +61,13 @@ def _get_extended_precision_list() -> t.List[str]: return [i.__name__ for i in extended_types if i.__name__ in extended_names] +def _get_c_intp_name() -> str: + if np.ctypeslib.c_intp is np.intp: + return "c_int64" # Plan B, in case `ctypes` fails to import + else: + return np.ctypeslib.c_intp.__qualname__ + + #: A dictionary mapping type-aliases in `numpy.typing._nbit` to #: concrete `numpy.typing.NBitBase` subclasses. _PRECISION_DICT: t.Final = _get_precision_dict() @@ -68,6 +75,9 @@ _PRECISION_DICT: t.Final = _get_precision_dict() #: A list with the names of all extended precision `np.number` subclasses. _EXTENDED_PRECISION_LIST: t.Final = _get_extended_precision_list() +#: The name of the ctypes quivalent of `np.intp` +_C_INTP: t.Final = _get_c_intp_name() + def _hook(ctx: AnalyzeTypeContext) -> Type: """Replace a type-alias with a concrete ``NBitBase`` subclass.""" @@ -87,8 +97,23 @@ if t.TYPE_CHECKING or MYPY_EX is None: raise ValueError("Failed to identify a `ImportFrom` instance " f"with the following id: {id!r}") + def _override_imports( + file: MypyFile, + module: str, + imports: t.List[t.Tuple[str, t.Optional[str]]], + ) -> None: + """Override the first `module`-based import with new `imports`.""" + # Construct a new `from module import y` statement + import_obj = ImportFrom(module, 0, names=imports) + import_obj.is_top_level = True + + # Replace the first `module`-based import statement with `import_obj` + for lst in [file.defs, file.imports]: # type: t.List[Statement] + i = _index(lst, module) + lst[i] = import_obj + class _NumpyPlugin(Plugin): - """A plugin for assigning platform-specific `numpy.number` precisions.""" + """A mypy plugin for handling versus numpy-specific typing tasks.""" def get_type_analyze_hook(self, fullname: str) -> t.Optional[_HookFunc]: """Set the precision of platform-specific `numpy.number` subclasses. @@ -100,25 +125,26 @@ if t.TYPE_CHECKING or MYPY_EX is None: return None def get_additional_deps(self, file: MypyFile) -> t.List[t.Tuple[int, str, int]]: - """Import platform-specific extended-precision `numpy.number` subclasses. + """Handle all import-based overrides. + + * Import platform-specific extended-precision `numpy.number` + subclasses (*e.g.* `numpy.float96`, `numpy.float128` and + `numpy.complex256`). + * Import the appropriate `ctypes` equivalent to `numpy.intp`. - For example: `numpy.float96`, `numpy.float128` and `numpy.complex256`. """ ret = [(PRI_MED, file.fullname, -1)] + if file.fullname == "numpy": - # Import ONLY the extended precision types available to the - # platform in question - imports = ImportFrom( - "numpy.typing._extended_precision", 0, - names=[(v, v) for v in _EXTENDED_PRECISION_LIST], + _override_imports( + file, "numpy.typing._extended_precision", + imports=[(v, v) for v in _EXTENDED_PRECISION_LIST], + ) + elif file.fullname == "numpy.ctypeslib": + _override_imports( + file, "ctypes", + imports=[(_C_INTP, "_c_intp")], ) - imports.is_top_level = True - - # Replace the much broader extended-precision import - # (defined in `numpy/__init__.pyi`) with a more specific one - for lst in [file.defs, file.imports]: # type: t.List[Statement] - i = _index(lst, "numpy.typing._extended_precision") - lst[i] = imports return ret def plugin(version: str) -> t.Type[_NumpyPlugin]: |