diff options
author | Pedro Lameiras <87664900+JLameiras@users.noreply.github.com> | 2023-03-28 10:31:54 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-28 11:31:54 +0200 |
commit | 94d4f702017be274acb130cbe9cf163910137fbc (patch) | |
tree | daf4886a782d7987ad56effdbb792f7b40aa6ab4 | |
parent | 08e778407064c37d6f814e0c59065344dea1ec2a (diff) | |
download | numpy-94d4f702017be274acb130cbe9cf163910137fbc.tar.gz |
BUG: Use output when given on numpy.dot C-API branch (#23459)
Updated the dot function C-API so that it now calls `np.multiply` with `out=` and returns it on branch of the function where the correct behaviour was not in place. Added two tests regarding this issue.
Closes #21081.
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 7 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 16 | ||||
-rw-r--r-- | numpy/typing/tests/data/pass/ndarray_misc.py | 2 |
3 files changed, 20 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index adc1558da..98ca15ac4 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -1042,12 +1042,11 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) #endif if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) { - result = (PyArray_NDIM(ap1) == 0 ? ap1 : ap2); - result = (PyArrayObject *)Py_TYPE(result)->tp_as_number->nb_multiply( - (PyObject *)ap1, (PyObject *)ap2); + PyObject *mul_res = PyObject_CallFunctionObjArgs( + n_ops.multiply, ap1, ap2, out, NULL); Py_DECREF(ap1); Py_DECREF(ap2); - return (PyObject *)result; + return mul_res; } l = PyArray_DIMS(ap1)[PyArray_NDIM(ap1) - 1]; if (PyArray_NDIM(ap2) > 1) { diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 3f3c1de8e..4a064827d 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -6675,6 +6675,22 @@ class TestDot: r = np.empty((1024, 32), dtype=int) assert_raises(ValueError, dot, f, v, r) + def test_dot_out_result(self): + x = np.ones((), dtype=np.float16) + y = np.ones((5,), dtype=np.float16) + z = np.zeros((5,), dtype=np.float16) + res = x.dot(y, out=z) + assert np.array_equal(res, y) + assert np.array_equal(z, y) + + def test_dot_out_aliasing(self): + x = np.ones((), dtype=np.float16) + y = np.ones((5,), dtype=np.float16) + z = np.zeros((5,), dtype=np.float16) + res = x.dot(y, out=z) + z[0] = 2 + assert np.array_equal(res, z) + def test_dot_array_order(self): a = np.array([[1, 2], [3, 4]], order='C') b = np.array([[1, 2], [3, 4]], order='F') diff --git a/numpy/typing/tests/data/pass/ndarray_misc.py b/numpy/typing/tests/data/pass/ndarray_misc.py index 19a1af9e2..6beacc5d7 100644 --- a/numpy/typing/tests/data/pass/ndarray_misc.py +++ b/numpy/typing/tests/data/pass/ndarray_misc.py @@ -150,7 +150,7 @@ A.argpartition([0]) A.diagonal() A.dot(1) -A.dot(1, out=B0) +A.dot(1, out=B2) A.nonzero() |