summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPedro Lameiras <87664900+JLameiras@users.noreply.github.com>2023-03-28 10:31:54 +0100
committerGitHub <noreply@github.com>2023-03-28 11:31:54 +0200
commit94d4f702017be274acb130cbe9cf163910137fbc (patch)
treedaf4886a782d7987ad56effdbb792f7b40aa6ab4
parent08e778407064c37d6f814e0c59065344dea1ec2a (diff)
downloadnumpy-94d4f702017be274acb130cbe9cf163910137fbc.tar.gz
BUG: Use output when given on numpy.dot C-API branch (#23459)
Updated the dot function C-API so that it now calls `np.multiply` with `out=` and returns it on branch of the function where the correct behaviour was not in place. Added two tests regarding this issue. Closes #21081. Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c7
-rw-r--r--numpy/core/tests/test_multiarray.py16
-rw-r--r--numpy/typing/tests/data/pass/ndarray_misc.py2
3 files changed, 20 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index adc1558da..98ca15ac4 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -1042,12 +1042,11 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
#endif
if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
- result = (PyArray_NDIM(ap1) == 0 ? ap1 : ap2);
- result = (PyArrayObject *)Py_TYPE(result)->tp_as_number->nb_multiply(
- (PyObject *)ap1, (PyObject *)ap2);
+ PyObject *mul_res = PyObject_CallFunctionObjArgs(
+ n_ops.multiply, ap1, ap2, out, NULL);
Py_DECREF(ap1);
Py_DECREF(ap2);
- return (PyObject *)result;
+ return mul_res;
}
l = PyArray_DIMS(ap1)[PyArray_NDIM(ap1) - 1];
if (PyArray_NDIM(ap2) > 1) {
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 3f3c1de8e..4a064827d 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -6675,6 +6675,22 @@ class TestDot:
r = np.empty((1024, 32), dtype=int)
assert_raises(ValueError, dot, f, v, r)
+ def test_dot_out_result(self):
+ x = np.ones((), dtype=np.float16)
+ y = np.ones((5,), dtype=np.float16)
+ z = np.zeros((5,), dtype=np.float16)
+ res = x.dot(y, out=z)
+ assert np.array_equal(res, y)
+ assert np.array_equal(z, y)
+
+ def test_dot_out_aliasing(self):
+ x = np.ones((), dtype=np.float16)
+ y = np.ones((5,), dtype=np.float16)
+ z = np.zeros((5,), dtype=np.float16)
+ res = x.dot(y, out=z)
+ z[0] = 2
+ assert np.array_equal(res, z)
+
def test_dot_array_order(self):
a = np.array([[1, 2], [3, 4]], order='C')
b = np.array([[1, 2], [3, 4]], order='F')
diff --git a/numpy/typing/tests/data/pass/ndarray_misc.py b/numpy/typing/tests/data/pass/ndarray_misc.py
index 19a1af9e2..6beacc5d7 100644
--- a/numpy/typing/tests/data/pass/ndarray_misc.py
+++ b/numpy/typing/tests/data/pass/ndarray_misc.py
@@ -150,7 +150,7 @@ A.argpartition([0])
A.diagonal()
A.dot(1)
-A.dot(1, out=B0)
+A.dot(1, out=B2)
A.nonzero()