summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2021-12-07 20:17:17 -0600
committerSebastian Berg <sebastianb@nvidia.com>2023-01-17 18:40:44 +0100
commit60a858a372b14b73547baacf4a472eccfade1073 (patch)
tree1061a985383ad6ab2a8dc56f144ec25cfbff071e
parent9b6a7b4f874f5502112f36d485b12d92889eb808 (diff)
downloadnumpy-60a858a372b14b73547baacf4a472eccfade1073.tar.gz
ENH: Improve array function overhead by using vectorcall
This moves dispatching for `__array_function__` into a C-wrapper. This helps speed for multiple reasons: * Avoids one additional dispatching function call to C * Avoids the use of `*args, **kwargs` which is slower. * For simple NumPy calls we can stay in the faster "vectorcall" world This speeds up things generally a little, but can speed things up a lot when keyword arguments are used on lightweight functions, for example:: np.can_cast(arr, dtype, casting="same_kind") is more than twice as fast with this. There is one alternative in principle to get best speed: We could inline the "relevant argument"/dispatcher extraction. That changes behavior in an acceptable but larger way (passes default arguments). Unless the C-entry point seems unwanted, this should be a decent step in the right direction even if we want to do that eventually, though. Closes gh-20790 Closes gh-18547 (although not quite sure why)
-rw-r--r--numpy/core/_asarray.py10
-rw-r--r--numpy/core/numeric.py37
-rw-r--r--numpy/core/overrides.py94
-rw-r--r--numpy/core/src/multiarray/arrayfunction_override.c552
-rw-r--r--numpy/core/src/multiarray/arrayfunction_override.h4
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c9
-rw-r--r--numpy/lib/npyio.py33
-rw-r--r--numpy/lib/twodim_base.py20
8 files changed, 413 insertions, 346 deletions
diff --git a/numpy/core/_asarray.py b/numpy/core/_asarray.py
index cbaab8c3f..a9abc5a88 100644
--- a/numpy/core/_asarray.py
+++ b/numpy/core/_asarray.py
@@ -24,10 +24,6 @@ POSSIBLE_FLAGS = {
}
-def _require_dispatcher(a, dtype=None, requirements=None, *, like=None):
- return (like,)
-
-
@set_array_function_like_doc
@set_module('numpy')
def require(a, dtype=None, requirements=None, *, like=None):
@@ -100,10 +96,10 @@ def require(a, dtype=None, requirements=None, *, like=None):
"""
if like is not None:
return _require_with_like(
+ like,
a,
dtype=dtype,
requirements=requirements,
- like=like,
)
if not requirements:
@@ -135,6 +131,4 @@ def require(a, dtype=None, requirements=None, *, like=None):
return arr
-_require_with_like = array_function_dispatch(
- _require_dispatcher
-)(require)
+_require_with_like = array_function_dispatch()(require)
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 577f8e7cd..91e35c684 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -130,10 +130,6 @@ def zeros_like(a, dtype=None, order='K', subok=True, shape=None):
return res
-def _ones_dispatcher(shape, dtype=None, order=None, *, like=None):
- return(like,)
-
-
@set_array_function_like_doc
@set_module('numpy')
def ones(shape, dtype=None, order='C', *, like=None):
@@ -187,16 +183,13 @@ def ones(shape, dtype=None, order='C', *, like=None):
"""
if like is not None:
- return _ones_with_like(shape, dtype=dtype, order=order, like=like)
+ return _ones_with_like(like, shape, dtype=dtype, order=order)
a = empty(shape, dtype, order)
multiarray.copyto(a, 1, casting='unsafe')
return a
-
-_ones_with_like = array_function_dispatch(
- _ones_dispatcher
-)(ones)
+_ones_with_like = array_function_dispatch()(ones)
def _ones_like_dispatcher(a, dtype=None, order=None, subok=None, shape=None):
@@ -323,7 +316,7 @@ def full(shape, fill_value, dtype=None, order='C', *, like=None):
"""
if like is not None:
- return _full_with_like(shape, fill_value, dtype=dtype, order=order, like=like)
+ return _full_with_like(like, shape, fill_value, dtype=dtype, order=order)
if dtype is None:
fill_value = asarray(fill_value)
@@ -333,9 +326,7 @@ def full(shape, fill_value, dtype=None, order='C', *, like=None):
return a
-_full_with_like = array_function_dispatch(
- _full_dispatcher
-)(full)
+_full_with_like = array_function_dispatch()(full)
def _full_like_dispatcher(a, fill_value, dtype=None, order=None, subok=None, shape=None):
@@ -1778,10 +1769,6 @@ def indices(dimensions, dtype=int, sparse=False):
return res
-def _fromfunction_dispatcher(function, shape, *, dtype=None, like=None, **kwargs):
- return (like,)
-
-
@set_array_function_like_doc
@set_module('numpy')
def fromfunction(function, shape, *, dtype=float, like=None, **kwargs):
@@ -1847,15 +1834,13 @@ def fromfunction(function, shape, *, dtype=float, like=None, **kwargs):
"""
if like is not None:
- return _fromfunction_with_like(function, shape, dtype=dtype, like=like, **kwargs)
+ return _fromfunction_with_like(like, function, shape, dtype=dtype, **kwargs)
args = indices(shape, dtype=dtype)
return function(*args, **kwargs)
-_fromfunction_with_like = array_function_dispatch(
- _fromfunction_dispatcher
-)(fromfunction)
+_fromfunction_with_like = array_function_dispatch()(fromfunction)
def _frombuffer(buf, dtype, shape, order):
@@ -2130,10 +2115,6 @@ def _maketup(descr, val):
return tuple(res)
-def _identity_dispatcher(n, dtype=None, *, like=None):
- return (like,)
-
-
@set_array_function_like_doc
@set_module('numpy')
def identity(n, dtype=None, *, like=None):
@@ -2168,15 +2149,13 @@ def identity(n, dtype=None, *, like=None):
"""
if like is not None:
- return _identity_with_like(n, dtype=dtype, like=like)
+ return _identity_with_like(like, n, dtype=dtype)
from numpy import eye
return eye(n, dtype=dtype, like=like)
-_identity_with_like = array_function_dispatch(
- _identity_dispatcher
-)(identity)
+_identity_with_like = array_function_dispatch()(identity)
def _allclose_dispatcher(a, b, rtol=None, atol=None, equal_nan=None):
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 46e1fbe2c..25892d5de 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -6,7 +6,7 @@ import os
from .._utils import set_module
from .._utils._inspect import getargspec
from numpy.core._multiarray_umath import (
- add_docstring, implement_array_function, _get_implementing_args)
+ add_docstring, _get_implementing_args, _ArrayFunctionDispatcher)
ARRAY_FUNCTIONS = set()
@@ -33,40 +33,33 @@ def set_array_function_like_doc(public_api):
add_docstring(
- implement_array_function,
+ _ArrayFunctionDispatcher,
"""
- Implement a function with checks for __array_function__ overrides.
+ Class to wrap functions with checks for __array_function__ overrides.
All arguments are required, and can only be passed by position.
Parameters
----------
+ dispatcher : function or None
+ The dispatcher function that returns a single sequence-like object
+ of all arguments relevant. It must have the same signature (except
+ the default values) as the actual implementation.
+ If ``None``, this is a ``like=`` dispatcher and the
+ ``_ArrayFunctionDispatcher`` must be called with ``like`` as the
+ first (additional and positional) argument.
implementation : function
Function that implements the operation on NumPy array without
- overrides when called like ``implementation(*args, **kwargs)``.
- public_api : function
- Function exposed by NumPy's public API originally called like
- ``public_api(*args, **kwargs)`` on which arguments are now being
- checked.
- relevant_args : iterable
- Iterable of arguments to check for __array_function__ methods.
- args : tuple
- Arbitrary positional arguments originally passed into ``public_api``.
- kwargs : dict
- Arbitrary keyword arguments originally passed into ``public_api``.
+ overrides when called like.
- Returns
- -------
- Result from calling ``implementation()`` or an ``__array_function__``
- method, as appropriate.
-
- Raises
- ------
- TypeError : if no implementation is found.
+ Attributes
+ ----------
+ _implementation : function
+ The original implementation passed in.
""")
-# exposed for testing purposes; used internally by implement_array_function
+# exposed for testing purposes; used internally by _ArrayFunctionDispatcher
add_docstring(
_get_implementing_args,
"""
@@ -110,7 +103,7 @@ def verify_matching_signatures(implementation, dispatcher):
'default argument values')
-def array_function_dispatch(dispatcher, module=None, verify=True,
+def array_function_dispatch(dispatcher=None, module=None, verify=True,
docs_from_dispatcher=False):
"""Decorator for adding dispatch with the __array_function__ protocol.
@@ -118,10 +111,14 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
Parameters
----------
- dispatcher : callable
+ dispatcher : callable or None
Function that when called like ``dispatcher(*args, **kwargs)`` with
arguments from the NumPy function call returns an iterable of
array-like arguments to check for ``__array_function__``.
+
+ If `None`, the first argument is used as the single `like=` argument
+ and not passed on. A function implementing `like=` must call its
+ dispatcher with `like` as the first non-keyword argument.
module : str, optional
__module__ attribute to set on new function, e.g., ``module='numpy'``.
By default, module is copied from the decorated function.
@@ -154,45 +151,28 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
def decorator(implementation):
if verify:
- verify_matching_signatures(implementation, dispatcher)
+ if dispatcher is not None:
+ verify_matching_signatures(implementation, dispatcher)
+ else:
+ # Using __code__ directly similar to verify_matching_signature
+ co = implementation.__code__
+ last_arg = co.co_argcount + co.co_kwonlyargcount - 1
+ last_arg = co.co_varnames[last_arg]
+ if last_arg != "like" or co.co_kwonlyargcount == 0:
+ raise RuntimeError(
+ "__array_function__ expects `like=` to be the last "
+ "argument and a keyword-only argument. "
+ f"{implementation} does not seem to comply.")
if docs_from_dispatcher:
add_docstring(implementation, dispatcher.__doc__)
- @functools.wraps(implementation)
- def public_api(*args, **kwargs):
- try:
- relevant_args = dispatcher(*args, **kwargs)
- except TypeError as exc:
- # Try to clean up a signature related TypeError. Such an
- # error will be something like:
- # dispatcher.__name__() got an unexpected keyword argument
- #
- # So replace the dispatcher name in this case. In principle
- # TypeErrors may be raised from _within_ the dispatcher, so
- # we check that the traceback contains a string that starts
- # with the name. (In principle we could also check the
- # traceback length, as it would be deeper.)
- msg = exc.args[0]
- disp_name = dispatcher.__name__
- if not isinstance(msg, str) or not msg.startswith(disp_name):
- raise
-
- # Replace with the correct name and re-raise:
- new_msg = msg.replace(disp_name, public_api.__name__)
- raise TypeError(new_msg) from None
-
- return implement_array_function(
- implementation, public_api, relevant_args, args, kwargs)
-
- public_api.__code__ = public_api.__code__.replace(
- co_name=implementation.__name__,
- co_filename='<__array_function__ internals>')
+ public_api = _ArrayFunctionDispatcher(dispatcher, implementation)
+ public_api = functools.wraps(implementation)(public_api)
+
if module is not None:
public_api.__module__ = module
- public_api._implementation = implementation
-
ARRAY_FUNCTIONS.add(public_api)
return public_api
diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c
index 2bb3fbe28..e27bb516e 100644
--- a/numpy/core/src/multiarray/arrayfunction_override.c
+++ b/numpy/core/src/multiarray/arrayfunction_override.c
@@ -1,11 +1,15 @@
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#define _MULTIARRAYMODULE
+#include <Python.h>
+#include "structmember.h"
+
#include "npy_pycompat.h"
#include "get_attr_string.h"
#include "npy_import.h"
#include "multiarraymodule.h"
+#include "arrayfunction_override.h"
/* Return the ndarray.__array_function__ method. */
static PyObject *
@@ -200,183 +204,67 @@ call_array_function(PyObject* argument, PyObject* method,
}
-/**
- * Internal handler for the array-function dispatching. The helper returns
- * either the result, or NotImplemented (as a borrowed reference).
- *
- * @param public_api The public API symbol used for dispatching
- * @param relevant_args Arguments which may implement __array_function__
- * @param args Original arguments
- * @param kwargs Original keyword arguments
- *
- * @returns The result of the dispatched version, or a borrowed reference
- * to NotImplemented to indicate the default implementation should
- * be used.
+
+/*
+ * Helper to convert from vectorcall convention, since the protocol requires
+ * args and kwargs to be passed as tuple and dict explicitly.
+ * We always pass a dict, so always returns it.
*/
-static PyObject *
-array_implement_array_function_internal(
- PyObject *public_api, PyObject *relevant_args,
- PyObject *args, PyObject *kwargs)
+static int
+get_args_and_kwargs(
+ PyObject *const *fast_args, Py_ssize_t len_args, PyObject *kwnames,
+ PyObject **out_args, PyObject **out_kwargs)
{
- PyObject *implementing_args[NPY_MAXARGS];
- PyObject *array_function_methods[NPY_MAXARGS];
- PyObject *types = NULL;
-
- PyObject *result = NULL;
-
- static PyObject *errmsg_formatter = NULL;
+ len_args = PyVectorcall_NARGS(len_args);
+ PyObject *args = PyTuple_New(len_args);
+ PyObject *kwargs = NULL;
- relevant_args = PySequence_Fast(
- relevant_args,
- "dispatcher for __array_function__ did not return an iterable");
- if (relevant_args == NULL) {
- return NULL;
+ if (args == NULL) {
+ return -1;
}
-
- /* Collect __array_function__ implementations */
- int num_implementing_args = get_implementing_args_and_methods(
- relevant_args, implementing_args, array_function_methods);
- if (num_implementing_args == -1) {
- goto cleanup;
+ for (Py_ssize_t i = 0; i < len_args; i++) {
+ Py_INCREF(fast_args[i]);
+ PyTuple_SET_ITEM(args, i, fast_args[i]);
}
-
- /*
- * Handle the typical case of no overrides. This is merely an optimization
- * if some arguments are ndarray objects, but is also necessary if no
- * arguments implement __array_function__ at all (e.g., if they are all
- * built-in types).
- */
- int any_overrides = 0;
- for (int j = 0; j < num_implementing_args; j++) {
- if (!is_default_array_function(array_function_methods[j])) {
- any_overrides = 1;
- break;
- }
- }
- if (!any_overrides) {
- /*
- * When the default implementation should be called, return
- * `Py_NotImplemented` to indicate this.
- */
- result = Py_NotImplemented;
- goto cleanup;
- }
-
- /*
- * Create a Python object for types.
- * We use a tuple, because it's the fastest Python collection to create
- * and has the bonus of being immutable.
- */
- types = PyTuple_New(num_implementing_args);
- if (types == NULL) {
- goto cleanup;
- }
- for (int j = 0; j < num_implementing_args; j++) {
- PyObject *arg_type = (PyObject *)Py_TYPE(implementing_args[j]);
- Py_INCREF(arg_type);
- PyTuple_SET_ITEM(types, j, arg_type);
- }
-
- /* Call __array_function__ methods */
- for (int j = 0; j < num_implementing_args; j++) {
- PyObject *argument = implementing_args[j];
- PyObject *method = array_function_methods[j];
-
- /*
- * We use `public_api` instead of `implementation` here so
- * __array_function__ implementations can do equality/identity
- * comparisons.
- */
- result = call_array_function(
- argument, method, public_api, types, args, kwargs);
-
- if (result == Py_NotImplemented) {
- /* Try the next one */
- Py_DECREF(result);
- result = NULL;
+ kwargs = PyDict_New();
+ if (kwnames != NULL) {
+ if (kwargs == NULL) {
+ Py_DECREF(args);
+ return -1;
}
- else {
- /* Either a good result, or an exception was raised. */
- goto cleanup;
+ Py_ssize_t nkwargs = PyTuple_GET_SIZE(kwnames);
+ for (Py_ssize_t i = 0; i < nkwargs; i++) {
+ PyObject *key = PyTuple_GET_ITEM(kwnames, i);
+ PyObject *value = fast_args[i+len_args];
+ if (PyDict_SetItem(kwargs, key, value) < 0) {
+ Py_DECREF(args);
+ Py_DECREF(kwargs);
+ return -1;
+ }
}
}
+ *out_args = args;
+ *out_kwargs = kwargs;
+ return 0;
+}
+
+static void
+set_no_matching_types_error(PyObject *public_api, PyObject *types)
+{
+ static PyObject *errmsg_formatter = NULL;
/* No acceptable override found, raise TypeError. */
npy_cache_import("numpy.core._internal",
"array_function_errmsg_formatter",
&errmsg_formatter);
if (errmsg_formatter != NULL) {
PyObject *errmsg = PyObject_CallFunctionObjArgs(
- errmsg_formatter, public_api, types, NULL);
+ errmsg_formatter, public_api, types, NULL);
if (errmsg != NULL) {
PyErr_SetObject(PyExc_TypeError, errmsg);
Py_DECREF(errmsg);
}
}
-
-cleanup:
- for (int j = 0; j < num_implementing_args; j++) {
- Py_DECREF(implementing_args[j]);
- Py_DECREF(array_function_methods[j]);
- }
- Py_XDECREF(types);
- Py_DECREF(relevant_args);
- return result;
-}
-
-
-/*
- * Implements the __array_function__ protocol for a Python function, as described in
- * in NEP-18. See numpy.core.overrides for a full docstring.
- */
-NPY_NO_EXPORT PyObject *
-array_implement_array_function(
- PyObject *NPY_UNUSED(dummy), PyObject *positional_args)
-{
- PyObject *res, *implementation, *public_api, *relevant_args, *args, *kwargs;
-
- if (!PyArg_UnpackTuple(
- positional_args, "implement_array_function", 5, 5,
- &implementation, &public_api, &relevant_args, &args, &kwargs)) {
- return NULL;
- }
-
- /*
- * Remove `like=` kwarg, which is NumPy-exclusive and thus not present
- * in downstream libraries. If `like=` is specified but doesn't
- * implement `__array_function__`, raise a `TypeError`.
- */
- if (kwargs != NULL && PyDict_Contains(kwargs, npy_ma_str_like)) {
- PyObject *like_arg = PyDict_GetItem(kwargs, npy_ma_str_like);
- if (like_arg != NULL) {
- PyObject *tmp_has_override = get_array_function(like_arg);
- if (tmp_has_override == NULL) {
- return PyErr_Format(PyExc_TypeError,
- "The `like` argument must be an array-like that "
- "implements the `__array_function__` protocol.");
- }
- Py_DECREF(tmp_has_override);
- PyDict_DelItem(kwargs, npy_ma_str_like);
-
- /*
- * If `like=` kwarg was removed, `implementation` points to the NumPy
- * public API, as `public_api` is in that case the wrapper dispatcher
- * function. For example, in the `np.full` case, `implementation` is
- * `np.full`, whereas `public_api` is `_full_with_like`. This is done
- * to ensure `__array_function__` implementations can do
- * equality/identity comparisons when `like=` is present.
- */
- public_api = implementation;
- }
- }
-
- res = array_implement_array_function_internal(
- public_api, relevant_args, args, kwargs);
-
- if (res == Py_NotImplemented) {
- return PyObject_Call(implementation, args, kwargs);
- }
- return res;
}
/*
@@ -392,64 +280,48 @@ array_implement_c_array_function_creation(
PyObject *args, PyObject *kwargs,
PyObject *const *fast_args, Py_ssize_t len_args, PyObject *kwnames)
{
- PyObject *relevant_args = NULL;
+ PyObject *dispatch_types = NULL;
PyObject *numpy_module = NULL;
PyObject *public_api = NULL;
PyObject *result = NULL;
/* If `like` doesn't implement `__array_function__`, raise a `TypeError` */
- PyObject *tmp_has_override = get_array_function(like);
- if (tmp_has_override == NULL) {
+ PyObject *method = get_array_function(like);
+ if (method == NULL) {
return PyErr_Format(PyExc_TypeError,
"The `like` argument must be an array-like that "
"implements the `__array_function__` protocol.");
}
- Py_DECREF(tmp_has_override);
-
- if (fast_args != NULL) {
+ if (is_default_array_function(method)) {
/*
- * Convert from vectorcall convention, since the protocol requires
- * the normal convention. We have to do this late to ensure the
- * normal path where NotImplemented is returned is fast.
+ * Return a borrowed reference of Py_NotImplemented to defer back to
+ * the original function.
*/
+ Py_DECREF(method);
+ return Py_NotImplemented;
+ }
+
+ dispatch_types = PyTuple_Pack(1, Py_TYPE(like));
+ if (dispatch_types == NULL) {
+ goto finish;
+ }
+
+ /* We have to call __array_function__ properly, which needs some prep */
+ if (fast_args != NULL) {
assert(args == NULL);
assert(kwargs == NULL);
- args = PyTuple_New(len_args);
- if (args == NULL) {
- return NULL;
- }
- for (Py_ssize_t i = 0; i < len_args; i++) {
- Py_INCREF(fast_args[i]);
- PyTuple_SET_ITEM(args, i, fast_args[i]);
- }
- if (kwnames != NULL) {
- kwargs = PyDict_New();
- if (kwargs == NULL) {
- Py_DECREF(args);
- return NULL;
- }
- Py_ssize_t nkwargs = PyTuple_GET_SIZE(kwnames);
- for (Py_ssize_t i = 0; i < nkwargs; i++) {
- PyObject *key = PyTuple_GET_ITEM(kwnames, i);
- PyObject *value = fast_args[i+len_args];
- if (PyDict_SetItem(kwargs, key, value) < 0) {
- Py_DECREF(args);
- Py_DECREF(kwargs);
- return NULL;
- }
- }
+ if (get_args_and_kwargs(
+ fast_args, len_args, kwnames, &args, &kwargs) < 0) {
+ goto finish;
}
}
- relevant_args = PyTuple_Pack(1, like);
- if (relevant_args == NULL) {
- goto finish;
- }
/* The like argument must be present in the keyword arguments, remove it */
if (PyDict_DelItem(kwargs, npy_ma_str_like) < 0) {
goto finish;
}
+ /* Fetch the actual symbol (the long way right now) */
numpy_module = PyImport_Import(npy_ma_str_numpy);
if (numpy_module == NULL) {
goto finish;
@@ -466,16 +338,20 @@ array_implement_c_array_function_creation(
goto finish;
}
- result = array_implement_array_function_internal(
- public_api, relevant_args, args, kwargs);
+ result = call_array_function(like, method,
+ public_api, dispatch_types, args, kwargs);
- finish:
- if (kwnames != NULL) {
- /* args and kwargs were converted from vectorcall convention */
- Py_XDECREF(args);
- Py_XDECREF(kwargs);
+ if (result == Py_NotImplemented) {
+ Py_DECREF(result);
+ result = NULL;
+ set_no_matching_types_error(public_api, dispatch_types);
}
- Py_XDECREF(relevant_args);
+
+ finish:
+ Py_DECREF(method);
+ Py_XDECREF(args);
+ Py_XDECREF(kwargs);
+ Py_XDECREF(dispatch_types);
Py_XDECREF(public_api);
return result;
}
@@ -530,3 +406,275 @@ cleanup:
Py_DECREF(relevant_args);
return result;
}
+
+
+typedef struct {
+ PyObject_HEAD
+ vectorcallfunc vectorcall;
+ PyObject *dict;
+ PyObject *relevant_arg_func;
+ PyObject *default_impl;
+} PyArray_ArrayFunctionDispatcherObject;
+
+
+static void
+dispatcher_dealloc(PyArray_ArrayFunctionDispatcherObject *self)
+{
+ Py_CLEAR(self->relevant_arg_func);
+ Py_CLEAR(self->default_impl);
+ Py_CLEAR(self->dict);
+ PyObject_FREE(self);
+}
+
+
+static PyObject *
+dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self,
+ PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
+{
+ PyObject *result = NULL;
+ PyObject *types = NULL;
+ PyObject *relevant_args = NULL;
+
+ PyObject *public_api;
+
+ /* __array_function__ passes args, kwargs. These may be filled: */
+ PyObject *packed_args = NULL;
+ PyObject *packed_kwargs = NULL;
+
+ PyObject *implementing_args[NPY_MAXARGS];
+ PyObject *array_function_methods[NPY_MAXARGS];
+
+ int num_implementing_args;
+
+ if (self->relevant_arg_func != NULL) {
+ public_api = (PyObject *)self;
+
+ /* Typical path, need to call the relevant_arg_func and unpack them */
+ relevant_args = PyObject_Vectorcall(
+ self->relevant_arg_func, args, len_args, kwnames);
+ if (relevant_args == NULL) {
+ return NULL;
+ }
+ Py_SETREF(relevant_args, PySequence_Fast(relevant_args,
+ "dispatcher for __array_function__ did not return an iterable"));
+ if (relevant_args == NULL) {
+ return NULL;
+ }
+
+ num_implementing_args = get_implementing_args_and_methods(
+ relevant_args, implementing_args, array_function_methods);
+ if (num_implementing_args < 0) {
+ Py_DECREF(relevant_args);
+ return NULL;
+ }
+ }
+ else {
+ /* For like= dispatching from Python, the public_symbol is the impl */
+ public_api = self->default_impl;
+
+ /*
+ * We are dealing with `like=` from Python. For simplicity, the
+ * Python code passes it on as the first argument.
+ */
+ if (PyVectorcall_NARGS(len_args) == 0) {
+ PyErr_Format(PyExc_TypeError,
+ "`like` argument dispatching, but first argument is not "
+ "positional in call to %S.", self->default_impl);
+ return NULL;
+ }
+
+ array_function_methods[0] = get_array_function(args[0]);
+ if (array_function_methods[0] == NULL) {
+ return PyErr_Format(PyExc_TypeError,
+ "The `like` argument must be an array-like that "
+ "implements the `__array_function__` protocol.");
+ }
+ num_implementing_args = 1;
+ implementing_args[0] = args[0];
+ Py_INCREF(implementing_args[0]);
+
+ /* do not pass the like argument */
+ len_args = PyVectorcall_NARGS(len_args) - 1;
+ len_args |= PY_VECTORCALL_ARGUMENTS_OFFSET;
+ args++;
+ }
+
+ /*
+ * Handle the typical case of no overrides. This is merely an optimization
+ * if some arguments are ndarray objects, but is also necessary if no
+ * arguments implement __array_function__ at all (e.g., if they are all
+ * built-in types).
+ */
+ int any_overrides = 0;
+ for (int j = 0; j < num_implementing_args; j++) {
+ if (!is_default_array_function(array_function_methods[j])) {
+ any_overrides = 1;
+ break;
+ }
+ }
+ if (!any_overrides) {
+ /* Directly call the actual implementation. */
+ result = PyObject_Vectorcall(self->default_impl, args, len_args, kwnames);
+ goto cleanup;
+ }
+
+ /* Find args and kwargs as tuple and dict, as we pass them out: */
+ if (get_args_and_kwargs(
+ args, len_args, kwnames, &packed_args, &packed_kwargs) < 0) {
+ goto cleanup;
+ }
+
+ /*
+ * Create a Python object for types.
+ * We use a tuple, because it's the fastest Python collection to create
+ * and has the bonus of being immutable.
+ */
+ types = PyTuple_New(num_implementing_args);
+ if (types == NULL) {
+ goto cleanup;
+ }
+ for (int j = 0; j < num_implementing_args; j++) {
+ PyObject *arg_type = (PyObject *)Py_TYPE(implementing_args[j]);
+ Py_INCREF(arg_type);
+ PyTuple_SET_ITEM(types, j, arg_type);
+ }
+
+ /* Call __array_function__ methods */
+ for (int j = 0; j < num_implementing_args; j++) {
+ PyObject *argument = implementing_args[j];
+ PyObject *method = array_function_methods[j];
+
+ result = call_array_function(
+ argument, method, public_api, types,
+ packed_args, packed_kwargs);
+
+ if (result == Py_NotImplemented) {
+ /* Try the next one */
+ Py_DECREF(result);
+ result = NULL;
+ }
+ else {
+ /* Either a good result, or an exception was raised. */
+ goto cleanup;
+ }
+ }
+
+ set_no_matching_types_error(public_api, types);
+
+cleanup:
+ for (int j = 0; j < num_implementing_args; j++) {
+ Py_DECREF(implementing_args[j]);
+ Py_DECREF(array_function_methods[j]);
+ }
+ Py_XDECREF(packed_args);
+ Py_XDECREF(packed_kwargs);
+ Py_XDECREF(types);
+ Py_XDECREF(relevant_args);
+ return result;
+}
+
+
+static PyObject *
+dispatcher_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwargs)
+{
+ PyArray_ArrayFunctionDispatcherObject *self;
+
+ self = PyObject_New(
+ PyArray_ArrayFunctionDispatcherObject,
+ &PyArrayFunctionDispatcher_Type);
+ if (self == NULL) {
+ return PyErr_NoMemory();
+ }
+
+ char *kwlist[] = {"", "", NULL};
+ if (!PyArg_ParseTupleAndKeywords(
+ args, kwargs, "OO:_ArrayFunctionDispatcher", kwlist,
+ &self->relevant_arg_func, &self->default_impl)) {
+ Py_DECREF(self);
+ return NULL;
+ }
+
+ self->vectorcall = (vectorcallfunc)dispatcher_vectorcall;
+ if (self->relevant_arg_func == Py_None) {
+ /* NULL in the relevant arg function means we use `like=` */
+ Py_CLEAR(self->relevant_arg_func);
+ }
+ else {
+ Py_INCREF(self->relevant_arg_func);
+ }
+ Py_INCREF(self->default_impl);
+
+ /* Need to be like a Python function that has arbitrary attributes */
+ self->dict = PyDict_New();
+ if (self->dict == NULL) {
+ Py_DECREF(self);
+ return NULL;
+ }
+ return (PyObject *)self;
+}
+
+
+static PyObject *
+dispatcher_str(PyArray_ArrayFunctionDispatcherObject *self)
+{
+ return PyObject_Str(self->default_impl);
+}
+
+
+static PyObject *
+dispatcher_repr(PyObject *self)
+{
+ PyObject *name = PyObject_GetAttrString(self, "__name__");
+ if (name == NULL) {
+ return NULL;
+ }
+ /* Print like a normal function */
+ return PyUnicode_FromFormat("<function %S at %p>", name, self);
+}
+
+static PyObject *
+dispatcher_get_implementation(
+ PyArray_ArrayFunctionDispatcherObject *self, void *NPY_UNUSED(closure))
+{
+ Py_INCREF(self->default_impl);
+ return self->default_impl;
+}
+
+
+static PyObject *
+dispatcher_reduce(PyObject *self, PyObject *NPY_UNUSED(args))
+{
+ return PyObject_GetAttrString(self, "__qualname__");
+}
+
+
+static struct PyMethodDef func_dispatcher_methods[] = {
+ {"__reduce__",
+ (PyCFunction)dispatcher_reduce, METH_NOARGS, NULL},
+ {NULL, NULL, 0, NULL}
+};
+
+
+static struct PyGetSetDef func_dispatcher_getset[] = {
+ {"__dict__", &PyObject_GenericGetDict, 0, NULL, 0},
+ {"_implementation", (getter)&dispatcher_get_implementation, 0, NULL, 0},
+ {0, 0, 0, 0, 0}
+};
+
+
+NPY_NO_EXPORT PyTypeObject PyArrayFunctionDispatcher_Type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ .tp_name = "numpy._ArrayFunctionDispatcher",
+ .tp_basicsize = sizeof(PyArray_ArrayFunctionDispatcherObject),
+ /* We have a dict, so in theory could traverse, but in practice... */
+ .tp_dictoffset = offsetof(PyArray_ArrayFunctionDispatcherObject, dict),
+ .tp_dealloc = (destructor)dispatcher_dealloc,
+ .tp_new = (newfunc)dispatcher_new,
+ .tp_str = (reprfunc)dispatcher_str,
+ .tp_repr = (reprfunc)dispatcher_repr,
+ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_VECTORCALL,
+ .tp_methods = func_dispatcher_methods,
+ .tp_getset = func_dispatcher_getset,
+ .tp_call = &PyVectorcall_Call,
+ .tp_vectorcall_offset = offsetof(PyArray_ArrayFunctionDispatcherObject, vectorcall),
+};
diff --git a/numpy/core/src/multiarray/arrayfunction_override.h b/numpy/core/src/multiarray/arrayfunction_override.h
index 09f7ee548..3b8b88bac 100644
--- a/numpy/core/src/multiarray/arrayfunction_override.h
+++ b/numpy/core/src/multiarray/arrayfunction_override.h
@@ -1,9 +1,7 @@
#ifndef NUMPY_CORE_SRC_MULTIARRAY_ARRAYFUNCTION_OVERRIDE_H_
#define NUMPY_CORE_SRC_MULTIARRAY_ARRAYFUNCTION_OVERRIDE_H_
-NPY_NO_EXPORT PyObject *
-array_implement_array_function(
- PyObject *NPY_UNUSED(dummy), PyObject *positional_args);
+extern NPY_NO_EXPORT PyTypeObject PyArrayFunctionDispatcher_Type;
NPY_NO_EXPORT PyObject *
array__get_implementing_args(
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 94fa2a909..6b1862b18 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -4539,9 +4539,6 @@ static struct PyMethodDef array_module_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"_monotonicity", (PyCFunction)arr__monotonicity,
METH_VARARGS | METH_KEYWORDS, NULL},
- {"implement_array_function",
- (PyCFunction)array_implement_array_function,
- METH_VARARGS, NULL},
{"interp", (PyCFunction)arr_interp,
METH_VARARGS | METH_KEYWORDS, NULL},
{"interp_complex", (PyCFunction)arr_interp_complex,
@@ -5112,6 +5109,12 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
if (set_typeinfo(d) != 0) {
goto err;
}
+ if (PyType_Ready(&PyArrayFunctionDispatcher_Type) < 0) {
+ goto err;
+ }
+ PyDict_SetItemString(
+ d, "_ArrayFunctionDispatcher",
+ (PyObject *)&PyArrayFunctionDispatcher_Type);
if (PyType_Ready(&PyArrayMethod_Type) < 0) {
goto err;
}
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 71d600c30..0c1740df1 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -760,13 +760,6 @@ def _ensure_ndmin_ndarray(a, *, ndmin: int):
_loadtxt_chunksize = 50000
-def _loadtxt_dispatcher(
- fname, dtype=None, comments=None, delimiter=None,
- converters=None, skiprows=None, usecols=None, unpack=None,
- ndmin=None, encoding=None, max_rows=None, *, like=None):
- return (like,)
-
-
def _check_nonneg_int(value, name="argument"):
try:
operator.index(value)
@@ -1331,10 +1324,10 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None,
if like is not None:
return _loadtxt_with_like(
- fname, dtype=dtype, comments=comments, delimiter=delimiter,
+ like, fname, dtype=dtype, comments=comments, delimiter=delimiter,
converters=converters, skiprows=skiprows, usecols=usecols,
unpack=unpack, ndmin=ndmin, encoding=encoding,
- max_rows=max_rows, like=like
+ max_rows=max_rows
)
if isinstance(delimiter, bytes):
@@ -1361,9 +1354,7 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None,
return arr
-_loadtxt_with_like = array_function_dispatch(
- _loadtxt_dispatcher
-)(loadtxt)
+_loadtxt_with_like = array_function_dispatch()(loadtxt)
def _savetxt_dispatcher(fname, X, fmt=None, delimiter=None, newline=None,
@@ -1724,17 +1715,6 @@ def fromregex(file, regexp, dtype, encoding=None):
#####--------------------------------------------------------------------------
-def _genfromtxt_dispatcher(fname, dtype=None, comments=None, delimiter=None,
- skip_header=None, skip_footer=None, converters=None,
- missing_values=None, filling_values=None, usecols=None,
- names=None, excludelist=None, deletechars=None,
- replace_space=None, autostrip=None, case_sensitive=None,
- defaultfmt=None, unpack=None, usemask=None, loose=None,
- invalid_raise=None, max_rows=None, encoding=None,
- *, ndmin=None, like=None):
- return (like,)
-
-
@set_array_function_like_doc
@set_module('numpy')
def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
@@ -1932,7 +1912,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
if like is not None:
return _genfromtxt_with_like(
- fname, dtype=dtype, comments=comments, delimiter=delimiter,
+ like, fname, dtype=dtype, comments=comments, delimiter=delimiter,
skip_header=skip_header, skip_footer=skip_footer,
converters=converters, missing_values=missing_values,
filling_values=filling_values, usecols=usecols, names=names,
@@ -1942,7 +1922,6 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
unpack=unpack, usemask=usemask, loose=loose,
invalid_raise=invalid_raise, max_rows=max_rows, encoding=encoding,
ndmin=ndmin,
- like=like
)
_ensure_ndmin_ndarray_check_param(ndmin)
@@ -2471,9 +2450,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
return output
-_genfromtxt_with_like = array_function_dispatch(
- _genfromtxt_dispatcher
-)(genfromtxt)
+_genfromtxt_with_like = array_function_dispatch()(genfromtxt)
def recfromtxt(fname, **kwargs):
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index dcb4ed46c..ed4f98704 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -155,10 +155,6 @@ def flipud(m):
return m[::-1, ...]
-def _eye_dispatcher(N, M=None, k=None, dtype=None, order=None, *, like=None):
- return (like,)
-
-
@set_array_function_like_doc
@set_module('numpy')
def eye(N, M=None, k=0, dtype=float, order='C', *, like=None):
@@ -209,7 +205,7 @@ def eye(N, M=None, k=0, dtype=float, order='C', *, like=None):
"""
if like is not None:
- return _eye_with_like(N, M=M, k=k, dtype=dtype, order=order, like=like)
+ return _eye_with_like(like, N, M=M, k=k, dtype=dtype, order=order)
if M is None:
M = N
m = zeros((N, M), dtype=dtype, order=order)
@@ -228,9 +224,7 @@ def eye(N, M=None, k=0, dtype=float, order='C', *, like=None):
return m
-_eye_with_like = array_function_dispatch(
- _eye_dispatcher
-)(eye)
+_eye_with_like = array_function_dispatch()(eye)
def _diag_dispatcher(v, k=None):
@@ -369,10 +363,6 @@ def diagflat(v, k=0):
return wrap(res)
-def _tri_dispatcher(N, M=None, k=None, dtype=None, *, like=None):
- return (like,)
-
-
@set_array_function_like_doc
@set_module('numpy')
def tri(N, M=None, k=0, dtype=float, *, like=None):
@@ -416,7 +406,7 @@ def tri(N, M=None, k=0, dtype=float, *, like=None):
"""
if like is not None:
- return _tri_with_like(N, M=M, k=k, dtype=dtype, like=like)
+ return _tri_with_like(like, N, M=M, k=k, dtype=dtype)
if M is None:
M = N
@@ -430,9 +420,7 @@ def tri(N, M=None, k=0, dtype=float, *, like=None):
return m
-_tri_with_like = array_function_dispatch(
- _tri_dispatcher
-)(tri)
+_tri_with_like = array_function_dispatch()(tri)
def _trilu_dispatcher(m, k=None):