diff options
author | Allan Haldane <allan.haldane@gmail.com> | 2017-09-09 10:53:09 -0400 |
---|---|---|
committer | Allan Haldane <allan.haldane@gmail.com> | 2017-09-21 22:26:00 -0400 |
commit | d4387da79f685d29f007b0c894fee729804048d2 (patch) | |
tree | cbfb23f5f2c6e3e579cdc3aa771848d26433d85c /numpy/core | |
parent | 1ccfa62d8ac3f25deab22e415ccc5fe8a7d75ab4 (diff) | |
download | numpy-d4387da79f685d29f007b0c894fee729804048d2.tar.gz |
BUG: dot/matmul 'out' arg should accept any ndarray subclass
Fixes #9641
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/cblasfuncs.c | 37 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 58 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 12 |
3 files changed, 61 insertions, 46 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c index 8432ae5cf..7cb1652bb 100644 --- a/numpy/core/src/multiarray/cblasfuncs.c +++ b/numpy/core/src/multiarray/cblasfuncs.c @@ -250,8 +250,6 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, npy_intp ap1stride = 0; npy_intp dimensions[NPY_MAXDIMS]; npy_intp numbytes; - double prior1, prior2; - PyTypeObject *subtype; MatrixShape ap1shape, ap2shape; if (_bad_strides(ap1)) { @@ -381,29 +379,17 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, } } - /* Choose which subtype to return */ - if (Py_TYPE(ap1) != Py_TYPE(ap2)) { - prior2 = PyArray_GetPriority((PyObject *)ap2, 0.0); - prior1 = PyArray_GetPriority((PyObject *)ap1, 0.0); - subtype = (prior2 > prior1 ? Py_TYPE(ap2) : Py_TYPE(ap1)); - } - else { - prior1 = prior2 = 0.0; - subtype = Py_TYPE(ap1); - } - if (out != NULL) { int d; /* verify that out is usable */ - if (Py_TYPE(out) != subtype || - PyArray_NDIM(out) != nd || + if (PyArray_NDIM(out) != nd || PyArray_TYPE(out) != typenum || !PyArray_ISCARRAY(out)) { PyErr_SetString(PyExc_ValueError, - "output array is not acceptable " - "(must have the right type, nr dimensions, and be a C-Array)"); + "output array is not acceptable (must have the right datatype, " + "number of dimensions, and be a C-Array)"); goto fail; } for (d = 0; d < nd; ++d) { @@ -439,7 +425,22 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2, result = out; } else { - PyObject *tmp = (PyObject *)(prior2 > prior1 ? ap2 : ap1); + double prior1, prior2; + PyTypeObject *subtype; + PyObject *tmp; + + /* Choose which subtype to return */ + if (Py_TYPE(ap1) != Py_TYPE(ap2)) { + prior2 = PyArray_GetPriority((PyObject *)ap2, 0.0); + prior1 = PyArray_GetPriority((PyObject *)ap1, 0.0); + subtype = (prior2 > prior1 ? Py_TYPE(ap2) : Py_TYPE(ap1)); + } + else { + prior1 = prior2 = 0.0; + subtype = Py_TYPE(ap1); + } + + tmp = (PyObject *)(prior2 > prior1 ? ap2 : ap1); out_buf = (PyArrayObject *)PyArray_New(subtype, nd, dimensions, typenum, NULL, NULL, 0, 0, tmp); diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 66a076dc6..210882ff0 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -796,32 +796,17 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out, int nd, npy_intp dimensions[], int typenum, PyArrayObject **result) { PyArrayObject *out_buf; - PyTypeObject *subtype; - double prior1, prior2; - /* - * Need to choose an output array that can hold a sum - * -- use priority to determine which subtype. - */ - if (Py_TYPE(ap2) != Py_TYPE(ap1)) { - prior2 = PyArray_GetPriority((PyObject *)ap2, 0.0); - prior1 = PyArray_GetPriority((PyObject *)ap1, 0.0); - subtype = (prior2 > prior1 ? Py_TYPE(ap2) : Py_TYPE(ap1)); - } - else { - prior1 = prior2 = 0.0; - subtype = Py_TYPE(ap1); - } + if (out) { int d; /* verify that out is usable */ - if (Py_TYPE(out) != subtype || - PyArray_NDIM(out) != nd || + if (PyArray_NDIM(out) != nd || PyArray_TYPE(out) != typenum || !PyArray_ISCARRAY(out)) { PyErr_SetString(PyExc_ValueError, - "output array is not acceptable " - "(must have the right type, nr dimensions, and be a C-Array)"); + "output array is not acceptable (must have the right datatype, " + "number of dimensions, and be a C-Array)"); return 0; } for (d = 0; d < nd; ++d) { @@ -862,18 +847,35 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out, return out_buf; } + else { + PyTypeObject *subtype; + double prior1, prior2; + /* + * Need to choose an output array that can hold a sum + * -- use priority to determine which subtype. + */ + if (Py_TYPE(ap2) != Py_TYPE(ap1)) { + prior2 = PyArray_GetPriority((PyObject *)ap2, 0.0); + prior1 = PyArray_GetPriority((PyObject *)ap1, 0.0); + subtype = (prior2 > prior1 ? Py_TYPE(ap2) : Py_TYPE(ap1)); + } + else { + prior1 = prior2 = 0.0; + subtype = Py_TYPE(ap1); + } - out_buf = (PyArrayObject *)PyArray_New(subtype, nd, dimensions, - typenum, NULL, NULL, 0, 0, - (PyObject *) - (prior2 > prior1 ? ap2 : ap1)); + out_buf = (PyArrayObject *)PyArray_New(subtype, nd, dimensions, + typenum, NULL, NULL, 0, 0, + (PyObject *) + (prior2 > prior1 ? ap2 : ap1)); - if (out_buf != NULL && result) { - Py_INCREF(out_buf); - *result = out_buf; - } + if (out_buf != NULL && result) { + Py_INCREF(out_buf); + *result = out_buf; + } - return out_buf; + return out_buf; + } } /* Could perhaps be redone to not make contiguous arrays */ diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index bbdf4dbfa..a7629240e 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2498,6 +2498,18 @@ class TestMethods(object): assert_raises(ValueError, np.dot, a, b, out=b[::2]) assert_raises(ValueError, np.dot, a, b, out=b.T) + def test_dot_matmul_out(self): + # gh-9641 + class Sub(np.ndarray): + pass + a = np.ones((2, 2)).view(Sub) + b = np.ones((2, 2)).view(Sub) + out = np.ones((2, 2)) + + # make sure out can be any ndarray (not only subclass of inputs) + np.dot(a, b, out=out) + np.matmul(a, b, out=out) + def test_diagonal(self): a = np.arange(12).reshape((3, 4)) assert_equal(a.diagonal(), [0, 5, 10]) |