summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorAllan Haldane <allan.haldane@gmail.com>2017-09-09 10:53:09 -0400
committerAllan Haldane <allan.haldane@gmail.com>2017-09-21 22:26:00 -0400
commitd4387da79f685d29f007b0c894fee729804048d2 (patch)
treecbfb23f5f2c6e3e579cdc3aa771848d26433d85c /numpy/core
parent1ccfa62d8ac3f25deab22e415ccc5fe8a7d75ab4 (diff)
downloadnumpy-d4387da79f685d29f007b0c894fee729804048d2.tar.gz
BUG: dot/matmul 'out' arg should accept any ndarray subclass
Fixes #9641
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/cblasfuncs.c37
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c58
-rw-r--r--numpy/core/tests/test_multiarray.py12
3 files changed, 61 insertions, 46 deletions
diff --git a/numpy/core/src/multiarray/cblasfuncs.c b/numpy/core/src/multiarray/cblasfuncs.c
index 8432ae5cf..7cb1652bb 100644
--- a/numpy/core/src/multiarray/cblasfuncs.c
+++ b/numpy/core/src/multiarray/cblasfuncs.c
@@ -250,8 +250,6 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
npy_intp ap1stride = 0;
npy_intp dimensions[NPY_MAXDIMS];
npy_intp numbytes;
- double prior1, prior2;
- PyTypeObject *subtype;
MatrixShape ap1shape, ap2shape;
if (_bad_strides(ap1)) {
@@ -381,29 +379,17 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
}
}
- /* Choose which subtype to return */
- if (Py_TYPE(ap1) != Py_TYPE(ap2)) {
- prior2 = PyArray_GetPriority((PyObject *)ap2, 0.0);
- prior1 = PyArray_GetPriority((PyObject *)ap1, 0.0);
- subtype = (prior2 > prior1 ? Py_TYPE(ap2) : Py_TYPE(ap1));
- }
- else {
- prior1 = prior2 = 0.0;
- subtype = Py_TYPE(ap1);
- }
-
if (out != NULL) {
int d;
/* verify that out is usable */
- if (Py_TYPE(out) != subtype ||
- PyArray_NDIM(out) != nd ||
+ if (PyArray_NDIM(out) != nd ||
PyArray_TYPE(out) != typenum ||
!PyArray_ISCARRAY(out)) {
PyErr_SetString(PyExc_ValueError,
- "output array is not acceptable "
- "(must have the right type, nr dimensions, and be a C-Array)");
+ "output array is not acceptable (must have the right datatype, "
+ "number of dimensions, and be a C-Array)");
goto fail;
}
for (d = 0; d < nd; ++d) {
@@ -439,7 +425,22 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
result = out;
}
else {
- PyObject *tmp = (PyObject *)(prior2 > prior1 ? ap2 : ap1);
+ double prior1, prior2;
+ PyTypeObject *subtype;
+ PyObject *tmp;
+
+ /* Choose which subtype to return */
+ if (Py_TYPE(ap1) != Py_TYPE(ap2)) {
+ prior2 = PyArray_GetPriority((PyObject *)ap2, 0.0);
+ prior1 = PyArray_GetPriority((PyObject *)ap1, 0.0);
+ subtype = (prior2 > prior1 ? Py_TYPE(ap2) : Py_TYPE(ap1));
+ }
+ else {
+ prior1 = prior2 = 0.0;
+ subtype = Py_TYPE(ap1);
+ }
+
+ tmp = (PyObject *)(prior2 > prior1 ? ap2 : ap1);
out_buf = (PyArrayObject *)PyArray_New(subtype, nd, dimensions,
typenum, NULL, NULL, 0, 0, tmp);
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 66a076dc6..210882ff0 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -796,32 +796,17 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out,
int nd, npy_intp dimensions[], int typenum, PyArrayObject **result)
{
PyArrayObject *out_buf;
- PyTypeObject *subtype;
- double prior1, prior2;
- /*
- * Need to choose an output array that can hold a sum
- * -- use priority to determine which subtype.
- */
- if (Py_TYPE(ap2) != Py_TYPE(ap1)) {
- prior2 = PyArray_GetPriority((PyObject *)ap2, 0.0);
- prior1 = PyArray_GetPriority((PyObject *)ap1, 0.0);
- subtype = (prior2 > prior1 ? Py_TYPE(ap2) : Py_TYPE(ap1));
- }
- else {
- prior1 = prior2 = 0.0;
- subtype = Py_TYPE(ap1);
- }
+
if (out) {
int d;
/* verify that out is usable */
- if (Py_TYPE(out) != subtype ||
- PyArray_NDIM(out) != nd ||
+ if (PyArray_NDIM(out) != nd ||
PyArray_TYPE(out) != typenum ||
!PyArray_ISCARRAY(out)) {
PyErr_SetString(PyExc_ValueError,
- "output array is not acceptable "
- "(must have the right type, nr dimensions, and be a C-Array)");
+ "output array is not acceptable (must have the right datatype, "
+ "number of dimensions, and be a C-Array)");
return 0;
}
for (d = 0; d < nd; ++d) {
@@ -862,18 +847,35 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out,
return out_buf;
}
+ else {
+ PyTypeObject *subtype;
+ double prior1, prior2;
+ /*
+ * Need to choose an output array that can hold a sum
+ * -- use priority to determine which subtype.
+ */
+ if (Py_TYPE(ap2) != Py_TYPE(ap1)) {
+ prior2 = PyArray_GetPriority((PyObject *)ap2, 0.0);
+ prior1 = PyArray_GetPriority((PyObject *)ap1, 0.0);
+ subtype = (prior2 > prior1 ? Py_TYPE(ap2) : Py_TYPE(ap1));
+ }
+ else {
+ prior1 = prior2 = 0.0;
+ subtype = Py_TYPE(ap1);
+ }
- out_buf = (PyArrayObject *)PyArray_New(subtype, nd, dimensions,
- typenum, NULL, NULL, 0, 0,
- (PyObject *)
- (prior2 > prior1 ? ap2 : ap1));
+ out_buf = (PyArrayObject *)PyArray_New(subtype, nd, dimensions,
+ typenum, NULL, NULL, 0, 0,
+ (PyObject *)
+ (prior2 > prior1 ? ap2 : ap1));
- if (out_buf != NULL && result) {
- Py_INCREF(out_buf);
- *result = out_buf;
- }
+ if (out_buf != NULL && result) {
+ Py_INCREF(out_buf);
+ *result = out_buf;
+ }
- return out_buf;
+ return out_buf;
+ }
}
/* Could perhaps be redone to not make contiguous arrays */
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index bbdf4dbfa..a7629240e 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2498,6 +2498,18 @@ class TestMethods(object):
assert_raises(ValueError, np.dot, a, b, out=b[::2])
assert_raises(ValueError, np.dot, a, b, out=b.T)
+ def test_dot_matmul_out(self):
+ # gh-9641
+ class Sub(np.ndarray):
+ pass
+ a = np.ones((2, 2)).view(Sub)
+ b = np.ones((2, 2)).view(Sub)
+ out = np.ones((2, 2))
+
+ # make sure out can be any ndarray (not only subclass of inputs)
+ np.dot(a, b, out=out)
+ np.matmul(a, b, out=out)
+
def test_diagonal(self):
a = np.arange(12).reshape((3, 4))
assert_equal(a.diagonal(), [0, 5, 10])