summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2022-11-29 17:56:10 +0100
committerGitHub <noreply@github.com>2022-11-29 17:56:10 +0100
commit2cf8c40727b08c029cdaf7fce803d031a60499a2 (patch)
treea59d6662d17e7bfad1e94b36ab4a77930574c3a0
parent7f0f045625022c3f816911cd80f8635ac2a36f21 (diff)
parent1a8d3ca45f0a7294784bc200ec436dc8563f654a (diff)
downloadnumpy-2cf8c40727b08c029cdaf7fce803d031a60499a2.tar.gz
Merge pull request #22533 from ngoldbaum/ufunc-and-function-listing
API: Add numpy.testing.overrides to aid testing of custom array containers
-rw-r--r--doc/source/reference/routines.rst1
-rw-r--r--doc/source/reference/routines.testing.overrides.rst18
-rw-r--r--doc/source/user/basics.dispatch.rst30
-rw-r--r--numpy/core/overrides.py4
-rw-r--r--numpy/testing/overrides.py82
-rw-r--r--numpy/tests/test_public_api.py1
6 files changed, 136 insertions, 0 deletions
diff --git a/doc/source/reference/routines.rst b/doc/source/reference/routines.rst
index 593d017cc..24117895b 100644
--- a/doc/source/reference/routines.rst
+++ b/doc/source/reference/routines.rst
@@ -44,4 +44,5 @@ indentation.
routines.sort
routines.statistics
routines.testing
+ routines.testing.overrides
routines.window
diff --git a/doc/source/reference/routines.testing.overrides.rst b/doc/source/reference/routines.testing.overrides.rst
new file mode 100644
index 000000000..262852633
--- /dev/null
+++ b/doc/source/reference/routines.testing.overrides.rst
@@ -0,0 +1,18 @@
+.. module:: numpy.testing.overrides
+
+Support for testing overrides (:mod:`numpy.testing.overrides`)
+==============================================================
+
+.. currentmodule:: numpy.testing.overrides
+
+Support for testing custom array container implementations.
+
+Utility Functions
+-----------------
+.. autosummary::
+ :toctree: generated/
+
+ allows_array_function_override
+ allows_array_ufunc_override
+ get_overridable_numpy_ufuncs
+ get_overridable_numpy_array_functions
diff --git a/doc/source/user/basics.dispatch.rst b/doc/source/user/basics.dispatch.rst
index 0d0ddfcdb..7c30272ad 100644
--- a/doc/source/user/basics.dispatch.rst
+++ b/doc/source/user/basics.dispatch.rst
@@ -271,6 +271,36 @@ array([[1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
+
+The implementation of ``DiagonalArray`` in this example only handles the
+``np.sum`` and ``np.mean`` functions for brevity. Many other functions in the
+Numpy API are also available to wrap and a full-fledged custom array container
+can explicitly support all functions that Numpy makes available to wrap.
+
+Numpy provides some utilities to aid testing of custom array containers that
+implement the ``__array_ufunc__`` and ``__array_function__`` protocols in the
+``numpy.testing.overrides`` namespace.
+
+To check if a Numpy function can be overriden via ``__array_ufunc__``, you can
+use :func:`~numpy.testing.overrides.allows_array_ufunc_override`:
+
+>>> from np.testing.overrides import allows_array_ufunc_override
+>>> allows_array_ufunc_override(np.add)
+True
+
+Similarly, you can check if a function can be overriden via
+``__array_function__`` using
+:func:`~numpy.testing.overrides.allows_array_function_override`.
+
+Lists of every overridable function in the Numpy API are also available via
+:func:`~numpy.testing.overrides.get_overridable_numpy_array_functions` for
+functions that support the ``__array_function__`` protocol and
+:func:`~numpy.testing.overrides.get_overridable_numpy_ufuncs` for functions that
+support the ``__array_ufunc__`` protocol. Both functions return sets of
+functions that are present in the Numpy public API. User-defined ufuncs or
+ufuncs defined in other libraries that depend on Numpy are not present in
+these sets.
+
Refer to the `dask source code <https://github.com/dask/dask>`_ and
`cupy source code <https://github.com/cupy/cupy>`_ for more fully-worked
examples of custom array containers.
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 450464f89..6d3680915 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -8,6 +8,8 @@ from numpy.core._multiarray_umath import (
from numpy.compat._inspect import getargspec
+ARRAY_FUNCTIONS = set()
+
ARRAY_FUNCTION_ENABLED = bool(
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1)))
@@ -208,6 +210,8 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
public_api._implementation = implementation
+ ARRAY_FUNCTIONS.add(public_api)
+
return public_api
return decorator
diff --git a/numpy/testing/overrides.py b/numpy/testing/overrides.py
new file mode 100644
index 000000000..d20ed60e5
--- /dev/null
+++ b/numpy/testing/overrides.py
@@ -0,0 +1,82 @@
+"""Tools for testing implementations of __array_function__ and ufunc overrides
+
+
+"""
+
+from numpy.core.overrides import ARRAY_FUNCTIONS as _array_functions
+from numpy import ufunc as _ufunc
+import numpy.core.umath as _umath
+
+def get_overridable_numpy_ufuncs():
+ """List all numpy ufuncs overridable via `__array_ufunc__`
+
+ Parameters
+ ----------
+ None
+
+ Returns
+ -------
+ set
+ A set containing all overridable ufuncs in the public numpy API.
+ """
+ ufuncs = {obj for obj in _umath.__dict__.values()
+ if isinstance(obj, _ufunc)}
+
+
+def allows_array_ufunc_override(func):
+ """Determine if a function can be overriden via `__array_ufunc__`
+
+ Parameters
+ ----------
+ func : callable
+ Function that may be overridable via `__array_ufunc__`
+
+ Returns
+ -------
+ bool
+ `True` if `func` is overridable via `__array_ufunc__` and
+ `False` otherwise.
+
+ Note
+ ----
+ This function is equivalent to `isinstance(func, np.ufunc)` and
+ will work correctly for ufuncs defined outside of Numpy.
+
+ """
+ return isinstance(func, np.ufunc)
+
+
+def get_overridable_numpy_array_functions():
+ """List all numpy functions overridable via `__array_function__`
+
+ Parameters
+ ----------
+ None
+
+ Returns
+ -------
+ set
+ A set containing all functions in the public numpy API that are
+ overridable via `__array_function__`.
+
+ """
+ # 'import numpy' doesn't import recfunctions, so make sure it's imported
+ # so ufuncs defined there show up in the ufunc listing
+ from numpy.lib import recfunctions
+ return _array_functions.copy()
+
+def allows_array_function_override(func):
+ """Determine if a Numpy function can be overriden via `__array_function__`
+
+ Parameters
+ ----------
+ func : callable
+ Function that may be overridable via `__array_function__`
+
+ Returns
+ -------
+ bool
+ `True` if `func` is a function in the Numpy API that is
+ overridable via `__array_function__` and `False` otherwise.
+ """
+ return func in _array_functions
diff --git a/numpy/tests/test_public_api.py b/numpy/tests/test_public_api.py
index 35eb4fd00..396375d87 100644
--- a/numpy/tests/test_public_api.py
+++ b/numpy/tests/test_public_api.py
@@ -159,6 +159,7 @@ PUBLIC_MODULES = ['numpy.' + s for s in [
"polynomial.polynomial",
"random",
"testing",
+ "testing.overrides",
"typing",
"typing.mypy_plugin",
"version",