summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-05-28 14:11:55 -0600
committerGitHub <noreply@github.com>2021-05-28 14:11:55 -0600
commit336cda183332006ce57750af6b14642a79e3b3f4 (patch)
treea83a909b421e771b97c9332289a1e975efc079cc
parentca2880476f2f0f9ff3ba4ca0889ac9911727be9b (diff)
parentf5a5fdb345d3d9be90b2bc6832dc6bd0fe66aa5d (diff)
downloadnumpy-336cda183332006ce57750af6b14642a79e3b3f4.tar.gz
Merge pull request #19062 from BvB93/ctypes-plugin
ENH: Add a mypy plugin for inferring the precision of `np.ctypeslib.c_intp`
-rw-r--r--doc/release/upcoming_changes/19062.new_feature.rst21
-rw-r--r--numpy/core/_internal.pyi11
-rw-r--r--numpy/ctypeslib.pyi9
-rw-r--r--numpy/typing/__init__.py11
-rw-r--r--numpy/typing/mypy_plugin.py62
-rw-r--r--numpy/typing/tests/data/reveal/ctypeslib.py3
-rw-r--r--numpy/typing/tests/data/reveal/ndarray_misc.py4
-rw-r--r--numpy/typing/tests/test_typing.py9
8 files changed, 96 insertions, 34 deletions
diff --git a/doc/release/upcoming_changes/19062.new_feature.rst b/doc/release/upcoming_changes/19062.new_feature.rst
new file mode 100644
index 000000000..171715568
--- /dev/null
+++ b/doc/release/upcoming_changes/19062.new_feature.rst
@@ -0,0 +1,21 @@
+Assign the platform-specific ``c_intp`` precision via a mypy plugin
+-------------------------------------------------------------------
+
+The mypy_ plugin, introduced in `numpy/numpy#17843`_, has again been expanded:
+the plugin now is now responsible for setting the platform-specific precision
+of `numpy.ctypeslib.c_intp`, the latter being used as data type for various
+`numpy.ndarray.ctypes` attributes.
+
+Without the plugin, aforementioned type will default to `ctypes.c_int64`.
+
+To enable the plugin, one must add it to their mypy `configuration file`_:
+
+.. code-block:: ini
+
+ [mypy]
+ plugins = numpy.typing.mypy_plugin
+
+
+.. _mypy: http://mypy-lang.org/
+.. _configuration file: https://mypy.readthedocs.io/en/stable/config_file.html
+.. _`numpy/numpy#17843`: https://github.com/numpy/numpy/pull/17843
diff --git a/numpy/core/_internal.pyi b/numpy/core/_internal.pyi
index 1ef1c9fa1..f4bfd770f 100644
--- a/numpy/core/_internal.pyi
+++ b/numpy/core/_internal.pyi
@@ -2,6 +2,7 @@ from typing import Any, TypeVar, Type, overload, Optional, Generic
import ctypes as ct
from numpy import ndarray
+from numpy.ctypeslib import c_intp
_CastT = TypeVar("_CastT", bound=ct._CanCastTo) # Copied from `ctypes.cast`
_CT = TypeVar("_CT", bound=ct._CData)
@@ -15,18 +16,12 @@ class _ctypes(Generic[_PT]):
def __new__(cls, array: ndarray[Any, Any], ptr: None = ...) -> _ctypes[None]: ...
@overload
def __new__(cls, array: ndarray[Any, Any], ptr: _PT) -> _ctypes[_PT]: ...
-
- # NOTE: In practice `shape` and `strides` return one of the concrete
- # platform dependant array-types (`c_int`, `c_long` or `c_longlong`)
- # corresponding to C's `int_ptr_t`, as determined by `_getintp_ctype`
- # TODO: Hook this in to the mypy plugin so that a more appropiate
- # `ctypes._SimpleCData[int]` sub-type can be returned
@property
def data(self) -> _PT: ...
@property
- def shape(self) -> ct.Array[ct.c_int64]: ...
+ def shape(self) -> ct.Array[c_intp]: ...
@property
- def strides(self) -> ct.Array[ct.c_int64]: ...
+ def strides(self) -> ct.Array[c_intp]: ...
@property
def _as_parameter_(self) -> ct.c_void_p: ...
diff --git a/numpy/ctypeslib.pyi b/numpy/ctypeslib.pyi
index 689ea4164..642017ba7 100644
--- a/numpy/ctypeslib.pyi
+++ b/numpy/ctypeslib.pyi
@@ -1,11 +1,12 @@
from typing import List, Type
-from ctypes import _SimpleCData
+
+# NOTE: Numpy's mypy plugin is used for importing the correct
+# platform-specific `ctypes._SimpleCData[int]` sub-type
+from ctypes import c_int64 as _c_intp
__all__: List[str]
-# TODO: Update the `npt.mypy_plugin` such that it substitutes `c_intp` for
-# a specific `_SimpleCData[int]` subclass (e.g. `ctypes.c_long`)
-c_intp: Type[_SimpleCData[int]]
+c_intp = _c_intp
def load_library(libname, loader_path): ...
def ndpointer(dtype=..., ndim=..., shape=..., flags=...): ...
diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py
index 04d34f0c7..252123a19 100644
--- a/numpy/typing/__init__.py
+++ b/numpy/typing/__init__.py
@@ -23,15 +23,18 @@ Mypy plugin
-----------
A mypy_ plugin is distributed in `numpy.typing` for managing a number of
-platform-specific annotations. Its function can be split into to parts:
+platform-specific annotations. Its functionality can be split into three
+distinct parts:
* Assigning the (platform-dependent) precisions of certain `~numpy.number` subclasses,
including the likes of `~numpy.int_`, `~numpy.intp` and `~numpy.longlong`.
See the documentation on :ref:`scalar types <arrays.scalars.built-in>` for a
- comprehensive overview of the affected classes. without the plugin the precision
- of all relevant classes will be inferred as `~typing.Any`.
+ comprehensive overview of the affected classes. Without the plugin the
+ precision of all relevant classes will be inferred as `~typing.Any`.
+* Assigning the (platform-dependent) precision of `~numpy.ctypeslib.c_intp`.
+ Without the plugin aforementioned type will default to `ctypes.c_int64`.
* Removing all extended-precision `~numpy.number` subclasses that are unavailable
- for the platform in question. Most notable this includes the likes of
+ for the platform in question. Most notably, this includes the likes of
`~numpy.float128` and `~numpy.complex256`. Without the plugin *all*
extended-precision types will, as far as mypy is concerned, be available
to all platforms.
diff --git a/numpy/typing/mypy_plugin.py b/numpy/typing/mypy_plugin.py
index 100e0d957..74dcd7a85 100644
--- a/numpy/typing/mypy_plugin.py
+++ b/numpy/typing/mypy_plugin.py
@@ -61,6 +61,19 @@ def _get_extended_precision_list() -> t.List[str]:
return [i.__name__ for i in extended_types if i.__name__ in extended_names]
+def _get_c_intp_name() -> str:
+ # Adapted from `np.core._internal._getintp_ctype`
+ char = np.dtype('p').char
+ if char == 'i':
+ return "c_int"
+ elif char == 'l':
+ return "c_long"
+ elif char == 'q':
+ return "c_longlong"
+ else:
+ return "c_long"
+
+
#: A dictionary mapping type-aliases in `numpy.typing._nbit` to
#: concrete `numpy.typing.NBitBase` subclasses.
_PRECISION_DICT: t.Final = _get_precision_dict()
@@ -68,6 +81,9 @@ _PRECISION_DICT: t.Final = _get_precision_dict()
#: A list with the names of all extended precision `np.number` subclasses.
_EXTENDED_PRECISION_LIST: t.Final = _get_extended_precision_list()
+#: The name of the ctypes quivalent of `np.intp`
+_C_INTP: t.Final = _get_c_intp_name()
+
def _hook(ctx: AnalyzeTypeContext) -> Type:
"""Replace a type-alias with a concrete ``NBitBase`` subclass."""
@@ -87,8 +103,23 @@ if t.TYPE_CHECKING or MYPY_EX is None:
raise ValueError("Failed to identify a `ImportFrom` instance "
f"with the following id: {id!r}")
+ def _override_imports(
+ file: MypyFile,
+ module: str,
+ imports: t.List[t.Tuple[str, t.Optional[str]]],
+ ) -> None:
+ """Override the first `module`-based import with new `imports`."""
+ # Construct a new `from module import y` statement
+ import_obj = ImportFrom(module, 0, names=imports)
+ import_obj.is_top_level = True
+
+ # Replace the first `module`-based import statement with `import_obj`
+ for lst in [file.defs, file.imports]: # type: t.List[Statement]
+ i = _index(lst, module)
+ lst[i] = import_obj
+
class _NumpyPlugin(Plugin):
- """A plugin for assigning platform-specific `numpy.number` precisions."""
+ """A mypy plugin for handling versus numpy-specific typing tasks."""
def get_type_analyze_hook(self, fullname: str) -> None | _HookFunc:
"""Set the precision of platform-specific `numpy.number` subclasses.
@@ -100,25 +131,26 @@ if t.TYPE_CHECKING or MYPY_EX is None:
return None
def get_additional_deps(self, file: MypyFile) -> t.List[t.Tuple[int, str, int]]:
- """Import platform-specific extended-precision `numpy.number` subclasses.
+ """Handle all import-based overrides.
+
+ * Import platform-specific extended-precision `numpy.number`
+ subclasses (*e.g.* `numpy.float96`, `numpy.float128` and
+ `numpy.complex256`).
+ * Import the appropriate `ctypes` equivalent to `numpy.intp`.
- For example: `numpy.float96`, `numpy.float128` and `numpy.complex256`.
"""
ret = [(PRI_MED, file.fullname, -1)]
+
if file.fullname == "numpy":
- # Import ONLY the extended precision types available to the
- # platform in question
- imports = ImportFrom(
- "numpy.typing._extended_precision", 0,
- names=[(v, v) for v in _EXTENDED_PRECISION_LIST],
+ _override_imports(
+ file, "numpy.typing._extended_precision",
+ imports=[(v, v) for v in _EXTENDED_PRECISION_LIST],
+ )
+ elif file.fullname == "numpy.ctypeslib":
+ _override_imports(
+ file, "ctypes",
+ imports=[(_C_INTP, "_c_intp")],
)
- imports.is_top_level = True
-
- # Replace the much broader extended-precision import
- # (defined in `numpy/__init__.pyi`) with a more specific one
- for lst in [file.defs, file.imports]: # type: t.List[Statement]
- i = _index(lst, "numpy.typing._extended_precision")
- lst[i] = imports
return ret
def plugin(version: str) -> t.Type[_NumpyPlugin]:
diff --git a/numpy/typing/tests/data/reveal/ctypeslib.py b/numpy/typing/tests/data/reveal/ctypeslib.py
new file mode 100644
index 000000000..0c32d70ed
--- /dev/null
+++ b/numpy/typing/tests/data/reveal/ctypeslib.py
@@ -0,0 +1,3 @@
+import numpy as np
+
+reveal_type(np.ctypeslib.c_intp()) # E: {c_intp}
diff --git a/numpy/typing/tests/data/reveal/ndarray_misc.py b/numpy/typing/tests/data/reveal/ndarray_misc.py
index ea01b7aa4..2e198eb6f 100644
--- a/numpy/typing/tests/data/reveal/ndarray_misc.py
+++ b/numpy/typing/tests/data/reveal/ndarray_misc.py
@@ -23,8 +23,8 @@ AR_U: np.ndarray[Any, np.dtype[np.str_]]
ctypes_obj = AR_f8.ctypes
reveal_type(ctypes_obj.data) # E: int
-reveal_type(ctypes_obj.shape) # E: ctypes.Array[ctypes.c_int64]
-reveal_type(ctypes_obj.strides) # E: ctypes.Array[ctypes.c_int64]
+reveal_type(ctypes_obj.shape) # E: ctypes.Array[{c_intp}]
+reveal_type(ctypes_obj.strides) # E: ctypes.Array[{c_intp}]
reveal_type(ctypes_obj._as_parameter_) # E: ctypes.c_void_p
reveal_type(ctypes_obj.data_as(ct.c_void_p)) # E: ctypes.c_void_p
diff --git a/numpy/typing/tests/test_typing.py b/numpy/typing/tests/test_typing.py
index be08c1359..35558c880 100644
--- a/numpy/typing/tests/test_typing.py
+++ b/numpy/typing/tests/test_typing.py
@@ -8,7 +8,11 @@ from typing import Optional, IO, Dict, List
import pytest
import numpy as np
-from numpy.typing.mypy_plugin import _PRECISION_DICT, _EXTENDED_PRECISION_LIST
+from numpy.typing.mypy_plugin import (
+ _PRECISION_DICT,
+ _EXTENDED_PRECISION_LIST,
+ _C_INTP,
+)
try:
from mypy import api
@@ -219,6 +223,9 @@ def _construct_format_dict():
# numpy.typing
"_NBitInt": dct['_NBitInt'],
+
+ # numpy.ctypeslib
+ "c_intp": f"ctypes.{_C_INTP}"
}