From 1a8d3ca45f0a7294784bc200ec436dc8563f654a Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Fri, 4 Nov 2022 12:13:56 -0600 Subject: API: Add numpy.testing.overrides to aid testing of custom array containers Closes #15544 --- doc/source/reference/routines.rst | 1 + .../reference/routines.testing.overrides.rst | 18 +++++ doc/source/user/basics.dispatch.rst | 30 ++++++++ numpy/core/overrides.py | 4 ++ numpy/testing/overrides.py | 82 ++++++++++++++++++++++ numpy/tests/test_public_api.py | 1 + 6 files changed, 136 insertions(+) create mode 100644 doc/source/reference/routines.testing.overrides.rst create mode 100644 numpy/testing/overrides.py diff --git a/doc/source/reference/routines.rst b/doc/source/reference/routines.rst index 593d017cc..24117895b 100644 --- a/doc/source/reference/routines.rst +++ b/doc/source/reference/routines.rst @@ -44,4 +44,5 @@ indentation. routines.sort routines.statistics routines.testing + routines.testing.overrides routines.window diff --git a/doc/source/reference/routines.testing.overrides.rst b/doc/source/reference/routines.testing.overrides.rst new file mode 100644 index 000000000..262852633 --- /dev/null +++ b/doc/source/reference/routines.testing.overrides.rst @@ -0,0 +1,18 @@ +.. module:: numpy.testing.overrides + +Support for testing overrides (:mod:`numpy.testing.overrides`) +============================================================== + +.. currentmodule:: numpy.testing.overrides + +Support for testing custom array container implementations. + +Utility Functions +----------------- +.. autosummary:: + :toctree: generated/ + + allows_array_function_override + allows_array_ufunc_override + get_overridable_numpy_ufuncs + get_overridable_numpy_array_functions diff --git a/doc/source/user/basics.dispatch.rst b/doc/source/user/basics.dispatch.rst index 0d0ddfcdb..7c30272ad 100644 --- a/doc/source/user/basics.dispatch.rst +++ b/doc/source/user/basics.dispatch.rst @@ -271,6 +271,36 @@ array([[1., 0., 0., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]]) + +The implementation of ``DiagonalArray`` in this example only handles the +``np.sum`` and ``np.mean`` functions for brevity. Many other functions in the +Numpy API are also available to wrap and a full-fledged custom array container +can explicitly support all functions that Numpy makes available to wrap. + +Numpy provides some utilities to aid testing of custom array containers that +implement the ``__array_ufunc__`` and ``__array_function__`` protocols in the +``numpy.testing.overrides`` namespace. + +To check if a Numpy function can be overriden via ``__array_ufunc__``, you can +use :func:`~numpy.testing.overrides.allows_array_ufunc_override`: + +>>> from np.testing.overrides import allows_array_ufunc_override +>>> allows_array_ufunc_override(np.add) +True + +Similarly, you can check if a function can be overriden via +``__array_function__`` using +:func:`~numpy.testing.overrides.allows_array_function_override`. + +Lists of every overridable function in the Numpy API are also available via +:func:`~numpy.testing.overrides.get_overridable_numpy_array_functions` for +functions that support the ``__array_function__`` protocol and +:func:`~numpy.testing.overrides.get_overridable_numpy_ufuncs` for functions that +support the ``__array_ufunc__`` protocol. Both functions return sets of +functions that are present in the Numpy public API. User-defined ufuncs or +ufuncs defined in other libraries that depend on Numpy are not present in +these sets. + Refer to the `dask source code `_ and `cupy source code `_ for more fully-worked examples of custom array containers. diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 450464f89..6d3680915 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -8,6 +8,8 @@ from numpy.core._multiarray_umath import ( from numpy.compat._inspect import getargspec +ARRAY_FUNCTIONS = set() + ARRAY_FUNCTION_ENABLED = bool( int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1))) @@ -208,6 +210,8 @@ def array_function_dispatch(dispatcher, module=None, verify=True, public_api._implementation = implementation + ARRAY_FUNCTIONS.add(public_api) + return public_api return decorator diff --git a/numpy/testing/overrides.py b/numpy/testing/overrides.py new file mode 100644 index 000000000..d20ed60e5 --- /dev/null +++ b/numpy/testing/overrides.py @@ -0,0 +1,82 @@ +"""Tools for testing implementations of __array_function__ and ufunc overrides + + +""" + +from numpy.core.overrides import ARRAY_FUNCTIONS as _array_functions +from numpy import ufunc as _ufunc +import numpy.core.umath as _umath + +def get_overridable_numpy_ufuncs(): + """List all numpy ufuncs overridable via `__array_ufunc__` + + Parameters + ---------- + None + + Returns + ------- + set + A set containing all overridable ufuncs in the public numpy API. + """ + ufuncs = {obj for obj in _umath.__dict__.values() + if isinstance(obj, _ufunc)} + + +def allows_array_ufunc_override(func): + """Determine if a function can be overriden via `__array_ufunc__` + + Parameters + ---------- + func : callable + Function that may be overridable via `__array_ufunc__` + + Returns + ------- + bool + `True` if `func` is overridable via `__array_ufunc__` and + `False` otherwise. + + Note + ---- + This function is equivalent to `isinstance(func, np.ufunc)` and + will work correctly for ufuncs defined outside of Numpy. + + """ + return isinstance(func, np.ufunc) + + +def get_overridable_numpy_array_functions(): + """List all numpy functions overridable via `__array_function__` + + Parameters + ---------- + None + + Returns + ------- + set + A set containing all functions in the public numpy API that are + overridable via `__array_function__`. + + """ + # 'import numpy' doesn't import recfunctions, so make sure it's imported + # so ufuncs defined there show up in the ufunc listing + from numpy.lib import recfunctions + return _array_functions.copy() + +def allows_array_function_override(func): + """Determine if a Numpy function can be overriden via `__array_function__` + + Parameters + ---------- + func : callable + Function that may be overridable via `__array_function__` + + Returns + ------- + bool + `True` if `func` is a function in the Numpy API that is + overridable via `__array_function__` and `False` otherwise. + """ + return func in _array_functions diff --git a/numpy/tests/test_public_api.py b/numpy/tests/test_public_api.py index 92a34bf91..d04e24218 100644 --- a/numpy/tests/test_public_api.py +++ b/numpy/tests/test_public_api.py @@ -157,6 +157,7 @@ PUBLIC_MODULES = ['numpy.' + s for s in [ "polynomial.polynomial", "random", "testing", + "testing.overrides", "typing", "typing.mypy_plugin", "version", -- cgit v1.2.1