summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2022-11-22 15:26:11 +0100
committerSebastian Berg <sebastianb@nvidia.com>2022-12-02 00:29:32 +0100
commit6f3bc1ec10e4e4270c458e4224669d325a233bfb (patch)
tree4314b845c004cc8efc83b2b4f84166b16d2823b9
parentbb59cd8da569837335c67091caa50291e74032a3 (diff)
downloadnumpy-6f3bc1ec10e4e4270c458e4224669d325a233bfb.tar.gz
ENH: Implement matmul using the nuclear options
This uses the `axes` argument. Arguably the most correct version since we use the full ufunc machinery, so we can't mess up with random conversions (which the test suite actually did notice, at least for an array subclass). OTOH, for now the errors are ridiculously unhelpful, some of which could be fixed in the gufunc code (at least made _better_).
-rw-r--r--numpy/core/src/multiarray/number.c63
1 files changed, 40 insertions, 23 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 2c037d918..a96b49a08 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -349,35 +349,52 @@ array_matrix_multiply(PyObject *m1, PyObject *m2)
}
static PyObject *
-array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2)
+array_inplace_matrix_multiply(PyArrayObject *self, PyObject *other)
{
- PyArrayObject *m2_array;
- int m1_ndim;
- int m2_ndim;
-
- INPLACE_GIVE_UP_IF_NEEDED(m1, m2_array,
+ INPLACE_GIVE_UP_IF_NEEDED(self, other,
nb_inplace_matrix_multiply, array_inplace_matrix_multiply);
- /* Explicitly raise a ValueError when the output would
- * otherwise be broadcasted to `m1`. Three conditions must be met:
- * * `m1.ndim in [1, 2]`
- * * `m2.ndim == 1`
- * * `m1.shape[-1] == m2.shape[0]`
+ /*
+ * Unlike `matmul(a, b, out=a)` we ensure that the result is not broadcast
+ * if the result without `out` would have less dimensions than `a`.
+ * Since the signature of matmul is '(n?,k),(k,m?)->(n?,m?)' this is the
+ * case exactly when the second operand has both core dimensions.
+ *
+ * The error here will be confusing, but for now, we enforce this by
+ * passing the correct `axes=`.
*/
- m2_array = (PyArrayObject *)PyArray_FromAny(m2, NULL, 0, 0, 0, NULL);
- m1_ndim = PyArray_NDIM(m1);
- m2_ndim = PyArray_NDIM(m2_array);
- if (((m1_ndim == 1) || (m1_ndim == 2)) && (m2_ndim == 1)
- && (PyArray_DIMS(m1)[m1_ndim - 1] == PyArray_DIMS(m2_array)[0])) {
- PyErr_Format(PyExc_ValueError,
- "output parameter has the wrong number of dimensions: "
- "Found %d but expected %d. Certain broadcasts may be "
- "accepted for `np.matmul(a, b, out=a)`.",
- m1_ndim - 1, m1_ndim);
- return NULL;
+ static PyObject *axes_1d_obj_kwargs = NULL;
+ static PyObject *axes_2d_obj_kwargs = NULL;
+ if (NPY_UNLIKELY(axes_1d_obj_kwargs == NULL)) {
+ axes_1d_obj_kwargs = Py_BuildValue(
+ "{s, [(i), (i, i), (i)]}", "axes", -1, -2, -1, -1);
+ if (axes_1d_obj_kwargs == NULL) {
+ return NULL;
+ }
+ }
+ if (NPY_UNLIKELY(axes_2d_obj_kwargs == NULL)) {
+ axes_2d_obj_kwargs = Py_BuildValue(
+ "{s, [(i, i), (i, i), (i, i)]}", "axes", -2, -1, -2, -1, -2, -1);
+ if (axes_2d_obj_kwargs == NULL) {
+ return NULL;
+ }
}
- return PyArray_GenericInplaceBinaryFunction(m1, (PyObject *)m2_array, n_ops.matmul);
+ PyObject *args = PyTuple_Pack(3, self, other, self);
+ if (args == NULL) {
+ return NULL;
+ }
+ PyObject *kwargs;
+ if (PyArray_NDIM(self) == 1) {
+ kwargs = axes_1d_obj_kwargs;
+ }
+ else {
+ kwargs = axes_2d_obj_kwargs;
+ }
+ PyObject *res = PyObject_Call(n_ops.matmul, args, kwargs);
+ Py_DECREF(args);
+ assert(res == (PyObject *)self);
+ return res;
}
/*