diff options
author | Sebastian Berg <sebastianb@nvidia.com> | 2022-11-22 15:26:11 +0100 |
---|---|---|
committer | Sebastian Berg <sebastianb@nvidia.com> | 2022-12-02 00:29:32 +0100 |
commit | 6f3bc1ec10e4e4270c458e4224669d325a233bfb (patch) | |
tree | 4314b845c004cc8efc83b2b4f84166b16d2823b9 | |
parent | bb59cd8da569837335c67091caa50291e74032a3 (diff) | |
download | numpy-6f3bc1ec10e4e4270c458e4224669d325a233bfb.tar.gz |
ENH: Implement matmul using the nuclear options
This uses the `axes` argument. Arguably the most correct version since
we use the full ufunc machinery, so we can't mess up with random conversions
(which the test suite actually did notice, at least for an array subclass).
OTOH, for now the errors are ridiculously unhelpful, some of which could
be fixed in the gufunc code (at least made _better_).
-rw-r--r-- | numpy/core/src/multiarray/number.c | 63 |
1 files changed, 40 insertions, 23 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 2c037d918..a96b49a08 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -349,35 +349,52 @@ array_matrix_multiply(PyObject *m1, PyObject *m2) } static PyObject * -array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2) +array_inplace_matrix_multiply(PyArrayObject *self, PyObject *other) { - PyArrayObject *m2_array; - int m1_ndim; - int m2_ndim; - - INPLACE_GIVE_UP_IF_NEEDED(m1, m2_array, + INPLACE_GIVE_UP_IF_NEEDED(self, other, nb_inplace_matrix_multiply, array_inplace_matrix_multiply); - /* Explicitly raise a ValueError when the output would - * otherwise be broadcasted to `m1`. Three conditions must be met: - * * `m1.ndim in [1, 2]` - * * `m2.ndim == 1` - * * `m1.shape[-1] == m2.shape[0]` + /* + * Unlike `matmul(a, b, out=a)` we ensure that the result is not broadcast + * if the result without `out` would have less dimensions than `a`. + * Since the signature of matmul is '(n?,k),(k,m?)->(n?,m?)' this is the + * case exactly when the second operand has both core dimensions. + * + * The error here will be confusing, but for now, we enforce this by + * passing the correct `axes=`. */ - m2_array = (PyArrayObject *)PyArray_FromAny(m2, NULL, 0, 0, 0, NULL); - m1_ndim = PyArray_NDIM(m1); - m2_ndim = PyArray_NDIM(m2_array); - if (((m1_ndim == 1) || (m1_ndim == 2)) && (m2_ndim == 1) - && (PyArray_DIMS(m1)[m1_ndim - 1] == PyArray_DIMS(m2_array)[0])) { - PyErr_Format(PyExc_ValueError, - "output parameter has the wrong number of dimensions: " - "Found %d but expected %d. Certain broadcasts may be " - "accepted for `np.matmul(a, b, out=a)`.", - m1_ndim - 1, m1_ndim); - return NULL; + static PyObject *axes_1d_obj_kwargs = NULL; + static PyObject *axes_2d_obj_kwargs = NULL; + if (NPY_UNLIKELY(axes_1d_obj_kwargs == NULL)) { + axes_1d_obj_kwargs = Py_BuildValue( + "{s, [(i), (i, i), (i)]}", "axes", -1, -2, -1, -1); + if (axes_1d_obj_kwargs == NULL) { + return NULL; + } + } + if (NPY_UNLIKELY(axes_2d_obj_kwargs == NULL)) { + axes_2d_obj_kwargs = Py_BuildValue( + "{s, [(i, i), (i, i), (i, i)]}", "axes", -2, -1, -2, -1, -2, -1); + if (axes_2d_obj_kwargs == NULL) { + return NULL; + } } - return PyArray_GenericInplaceBinaryFunction(m1, (PyObject *)m2_array, n_ops.matmul); + PyObject *args = PyTuple_Pack(3, self, other, self); + if (args == NULL) { + return NULL; + } + PyObject *kwargs; + if (PyArray_NDIM(self) == 1) { + kwargs = axes_1d_obj_kwargs; + } + else { + kwargs = axes_2d_obj_kwargs; + } + PyObject *res = PyObject_Call(n_ops.matmul, args, kwargs); + Py_DECREF(args); + assert(res == (PyObject *)self); + return res; } /* |