summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2023-03-10 13:59:18 -0500
committerGitHub <noreply@github.com>2023-03-10 13:59:18 -0500
commitf14cd445731206d03010aed4a0e4d7f4d8adb693 (patch)
tree1e87c96720deac64cf6cb6a358294935967f8673
parent60be753557d1e44b255bfa0eea492f449e657069 (diff)
parent834490f8efcec7d6d50ab19af2c4cba81203c26d (diff)
downloadnumpy-f14cd445731206d03010aed4a0e4d7f4d8adb693.tar.gz
Merge pull request #23370 from seberg/fixup-like-kwargs
BUG: Ensure like is only stripped for `like=` dispatched functions
-rw-r--r--numpy/core/_asarray.py2
-rw-r--r--numpy/core/numeric.py8
-rw-r--r--numpy/core/overrides.py5
-rw-r--r--numpy/core/src/multiarray/arrayfunction_override.c13
-rw-r--r--numpy/lib/npyio.py4
-rw-r--r--numpy/lib/tests/test_io.py10
-rw-r--r--numpy/lib/twodim_base.py4
7 files changed, 32 insertions, 14 deletions
diff --git a/numpy/core/_asarray.py b/numpy/core/_asarray.py
index cbaab8c3f..28f1fe6fa 100644
--- a/numpy/core/_asarray.py
+++ b/numpy/core/_asarray.py
@@ -136,5 +136,5 @@ def require(a, dtype=None, requirements=None, *, like=None):
_require_with_like = array_function_dispatch(
- _require_dispatcher
+ _require_dispatcher, use_like=True
)(require)
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 5b92972d1..fdd81a20e 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -208,7 +208,7 @@ def ones(shape, dtype=None, order='C', *, like=None):
_ones_with_like = array_function_dispatch(
- _ones_dispatcher
+ _ones_dispatcher, use_like=True
)(ones)
@@ -347,7 +347,7 @@ def full(shape, fill_value, dtype=None, order='C', *, like=None):
_full_with_like = array_function_dispatch(
- _full_dispatcher
+ _full_dispatcher, use_like=True
)(full)
@@ -1867,7 +1867,7 @@ def fromfunction(function, shape, *, dtype=float, like=None, **kwargs):
_fromfunction_with_like = array_function_dispatch(
- _fromfunction_dispatcher
+ _fromfunction_dispatcher, use_like=True
)(fromfunction)
@@ -2188,7 +2188,7 @@ def identity(n, dtype=None, *, like=None):
_identity_with_like = array_function_dispatch(
- _identity_dispatcher
+ _identity_dispatcher, use_like=True
)(identity)
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 450464f89..c567dfefd 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -126,7 +126,7 @@ def set_module(module):
def array_function_dispatch(dispatcher, module=None, verify=True,
- docs_from_dispatcher=False):
+ docs_from_dispatcher=False, use_like=False):
"""Decorator for adding dispatch with the __array_function__ protocol.
See NEP-18 for example usage.
@@ -198,7 +198,8 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
raise TypeError(new_msg) from None
return implement_array_function(
- implementation, public_api, relevant_args, args, kwargs)
+ implementation, public_api, relevant_args, args, kwargs,
+ use_like)
public_api.__code__ = public_api.__code__.replace(
co_name=implementation.__name__,
diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c
index 2bb3fbe28..a3d55bdc2 100644
--- a/numpy/core/src/multiarray/arrayfunction_override.c
+++ b/numpy/core/src/multiarray/arrayfunction_override.c
@@ -334,10 +334,16 @@ array_implement_array_function(
PyObject *NPY_UNUSED(dummy), PyObject *positional_args)
{
PyObject *res, *implementation, *public_api, *relevant_args, *args, *kwargs;
+ /*
+ * Very few functions use **kwargs, only check for like then (note that
+ * this is a backport only change, 1.25.x has been refactored)
+ */
+ PyObject *uses_like;
if (!PyArg_UnpackTuple(
- positional_args, "implement_array_function", 5, 5,
- &implementation, &public_api, &relevant_args, &args, &kwargs)) {
+ positional_args, "implement_array_function", 6, 6,
+ &implementation, &public_api, &relevant_args, &args, &kwargs,
+ &uses_like)) {
return NULL;
}
@@ -346,7 +352,8 @@ array_implement_array_function(
* in downstream libraries. If `like=` is specified but doesn't
* implement `__array_function__`, raise a `TypeError`.
*/
- if (kwargs != NULL && PyDict_Contains(kwargs, npy_ma_str_like)) {
+ if (uses_like == Py_True
+ && kwargs != NULL && PyDict_Contains(kwargs, npy_ma_str_like)) {
PyObject *like_arg = PyDict_GetItem(kwargs, npy_ma_str_like);
if (like_arg != NULL) {
PyObject *tmp_has_override = get_array_function(like_arg);
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 71d600c30..3d6951416 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -1362,7 +1362,7 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None,
_loadtxt_with_like = array_function_dispatch(
- _loadtxt_dispatcher
+ _loadtxt_dispatcher, use_like=True
)(loadtxt)
@@ -2472,7 +2472,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None,
_genfromtxt_with_like = array_function_dispatch(
- _genfromtxt_dispatcher
+ _genfromtxt_dispatcher, use_like=True
)(genfromtxt)
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index 4699935ca..3af2e6f42 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -232,6 +232,16 @@ class TestSavezLoad(RoundtripTest):
assert_equal(a, l['file_a'])
assert_equal(b, l['file_b'])
+ def test_named_arrays_with_like(self):
+ a = np.array([[1, 2], [3, 4]], float)
+ b = np.array([[1 + 2j, 2 + 7j], [3 - 6j, 4 + 12j]], complex)
+ c = BytesIO()
+ np.savez(c, file_a=a, like=b)
+ c.seek(0)
+ l = np.load(c)
+ assert_equal(a, l['file_a'])
+ assert_equal(b, l['like'])
+
def test_BagObj(self):
a = np.array([[1, 2], [3, 4]], float)
b = np.array([[1 + 2j, 2 + 7j], [3 - 6j, 4 + 12j]], complex)
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index 55d8ca896..a5f421e5c 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -229,7 +229,7 @@ def eye(N, M=None, k=0, dtype=float, order='C', *, like=None):
_eye_with_like = array_function_dispatch(
- _eye_dispatcher
+ _eye_dispatcher, use_like=True
)(eye)
@@ -431,7 +431,7 @@ def tri(N, M=None, k=0, dtype=float, *, like=None):
_tri_with_like = array_function_dispatch(
- _tri_dispatcher
+ _tri_dispatcher, use_like=True
)(tri)