summaryrefslogtreecommitdiff
path: root/numpy/typing/mypy_plugin.py
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-05-21 20:03:44 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-05-21 21:15:16 +0200
commit869243e50c5a607b792d5102bd0ce360c377e8eb (patch)
tree4e050cb9c15e4c4afa6369a45528f1290919e0eb /numpy/typing/mypy_plugin.py
parent7de0fa959e476900725d8a654775e0a38745de08 (diff)
downloadnumpy-869243e50c5a607b792d5102bd0ce360c377e8eb.tar.gz
ENH: Add a mypy plugin for inferring the precision of `np.ctypeslib.c_intp`
Diffstat (limited to 'numpy/typing/mypy_plugin.py')
-rw-r--r--numpy/typing/mypy_plugin.py56
1 files changed, 41 insertions, 15 deletions
diff --git a/numpy/typing/mypy_plugin.py b/numpy/typing/mypy_plugin.py
index 901bf4fb1..2a5e729f3 100644
--- a/numpy/typing/mypy_plugin.py
+++ b/numpy/typing/mypy_plugin.py
@@ -61,6 +61,13 @@ def _get_extended_precision_list() -> t.List[str]:
return [i.__name__ for i in extended_types if i.__name__ in extended_names]
+def _get_c_intp_name() -> str:
+ if np.ctypeslib.c_intp is np.intp:
+ return "c_int64" # Plan B, in case `ctypes` fails to import
+ else:
+ return np.ctypeslib.c_intp.__qualname__
+
+
#: A dictionary mapping type-aliases in `numpy.typing._nbit` to
#: concrete `numpy.typing.NBitBase` subclasses.
_PRECISION_DICT: t.Final = _get_precision_dict()
@@ -68,6 +75,9 @@ _PRECISION_DICT: t.Final = _get_precision_dict()
#: A list with the names of all extended precision `np.number` subclasses.
_EXTENDED_PRECISION_LIST: t.Final = _get_extended_precision_list()
+#: The name of the ctypes quivalent of `np.intp`
+_C_INTP: t.Final = _get_c_intp_name()
+
def _hook(ctx: AnalyzeTypeContext) -> Type:
"""Replace a type-alias with a concrete ``NBitBase`` subclass."""
@@ -87,8 +97,23 @@ if t.TYPE_CHECKING or MYPY_EX is None:
raise ValueError("Failed to identify a `ImportFrom` instance "
f"with the following id: {id!r}")
+ def _override_imports(
+ file: MypyFile,
+ module: str,
+ imports: t.List[t.Tuple[str, t.Optional[str]]],
+ ) -> None:
+ """Override the first `module`-based import with new `imports`."""
+ # Construct a new `from module import y` statement
+ import_obj = ImportFrom(module, 0, names=imports)
+ import_obj.is_top_level = True
+
+ # Replace the first `module`-based import statement with `import_obj`
+ for lst in [file.defs, file.imports]: # type: t.List[Statement]
+ i = _index(lst, module)
+ lst[i] = import_obj
+
class _NumpyPlugin(Plugin):
- """A plugin for assigning platform-specific `numpy.number` precisions."""
+ """A mypy plugin for handling versus numpy-specific typing tasks."""
def get_type_analyze_hook(self, fullname: str) -> t.Optional[_HookFunc]:
"""Set the precision of platform-specific `numpy.number` subclasses.
@@ -100,25 +125,26 @@ if t.TYPE_CHECKING or MYPY_EX is None:
return None
def get_additional_deps(self, file: MypyFile) -> t.List[t.Tuple[int, str, int]]:
- """Import platform-specific extended-precision `numpy.number` subclasses.
+ """Handle all import-based overrides.
+
+ * Import platform-specific extended-precision `numpy.number`
+ subclasses (*e.g.* `numpy.float96`, `numpy.float128` and
+ `numpy.complex256`).
+ * Import the appropriate `ctypes` equivalent to `numpy.intp`.
- For example: `numpy.float96`, `numpy.float128` and `numpy.complex256`.
"""
ret = [(PRI_MED, file.fullname, -1)]
+
if file.fullname == "numpy":
- # Import ONLY the extended precision types available to the
- # platform in question
- imports = ImportFrom(
- "numpy.typing._extended_precision", 0,
- names=[(v, v) for v in _EXTENDED_PRECISION_LIST],
+ _override_imports(
+ file, "numpy.typing._extended_precision",
+ imports=[(v, v) for v in _EXTENDED_PRECISION_LIST],
+ )
+ elif file.fullname == "numpy.ctypeslib":
+ _override_imports(
+ file, "ctypes",
+ imports=[(_C_INTP, "_c_intp")],
)
- imports.is_top_level = True
-
- # Replace the much broader extended-precision import
- # (defined in `numpy/__init__.pyi`) with a more specific one
- for lst in [file.defs, file.imports]: # type: t.List[Statement]
- i = _index(lst, "numpy.typing._extended_precision")
- lst[i] = imports
return ret
def plugin(version: str) -> t.Type[_NumpyPlugin]: