summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/arrayfunction_override.c17
-rw-r--r--numpy/core/tests/test_overrides.py33
2 files changed, 49 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c
index c9b579ffe..04768504e 100644
--- a/numpy/core/src/multiarray/arrayfunction_override.c
+++ b/numpy/core/src/multiarray/arrayfunction_override.c
@@ -637,6 +637,19 @@ dispatcher_repr(PyObject *self)
return PyUnicode_FromFormat("<function %S at %p>", name, self);
}
+
+static PyObject *
+func_dispatcher___get__(PyObject *self, PyObject *obj, PyObject *cls)
+{
+ if (obj == NULL) {
+ /* Act like a static method, no need to bind */
+ Py_INCREF(self);
+ return self;
+ }
+ return PyMethod_New(self, obj);
+}
+
+
static PyObject *
dispatcher_get_implementation(
PyArray_ArrayFunctionDispatcherObject *self, void *NPY_UNUSED(closure))
@@ -677,9 +690,11 @@ NPY_NO_EXPORT PyTypeObject PyArrayFunctionDispatcher_Type = {
.tp_new = (newfunc)dispatcher_new,
.tp_str = (reprfunc)dispatcher_str,
.tp_repr = (reprfunc)dispatcher_repr,
- .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_VECTORCALL,
+ .tp_flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_VECTORCALL
+ | Py_TPFLAGS_METHOD_DESCRIPTOR),
.tp_methods = func_dispatcher_methods,
.tp_getset = func_dispatcher_getset,
+ .tp_descr_get = func_dispatcher___get__,
.tp_call = &PyVectorcall_Call,
.tp_vectorcall_offset = offsetof(PyArray_ArrayFunctionDispatcherObject, vectorcall),
};
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 47aaca6a9..c27354311 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -690,3 +690,36 @@ class TestArrayLike:
array_like.fill(1)
expected.fill(1)
assert_equal(array_like, expected)
+
+
+@requires_array_function
+def test_function_like():
+ # We provide a `__get__` implementation, make sure it works
+ assert type(np.mean) is np.core._multiarray_umath._ArrayFunctionDispatcher
+
+ class MyClass:
+ def __array__(self):
+ # valid argument to mean:
+ return np.arange(3)
+
+ func1 = staticmethod(np.mean)
+ func2 = np.mean
+ func3 = classmethod(np.mean)
+
+ m = MyClass()
+ assert m.func1([10]) == 10
+ assert m.func2() == 1 # mean of the arange
+ with pytest.raises(TypeError, match="unsupported operand type"):
+ # Tries to operate on the class
+ m.func3()
+
+ # Manual binding also works (the above may shortcut):
+ bound = np.mean.__get__(m, MyClass)
+ assert bound() == 1
+
+ bound = np.mean.__get__(None, MyClass) # unbound actually
+ assert bound([10]) == 10
+
+ bound = np.mean.__get__(MyClass) # classmethod
+ with pytest.raises(TypeError, match="unsupported operand type"):
+ bound()