diff options
author | Nathaniel J. Smith <njs@pobox.com> | 2015-06-21 18:41:34 -0700 |
---|---|---|
committer | Nathaniel J. Smith <njs@pobox.com> | 2015-06-27 19:21:39 -0700 |
commit | 1adcdf7aa5b20a9afd778290105ec327b705c93e (patch) | |
tree | b8117375595d7c0c05ef332b298891294c1b79f8 /numpy/core | |
parent | f0c898b6017cc011561d7f1e611e08283ecfdb08 (diff) | |
download | numpy-1adcdf7aa5b20a9afd778290105ec327b705c93e.tar.gz |
BUG: Make a @= b error out
Before this change, we defined a nb_matrix_multiply slot but not a
nb_inplace_matrix_multiply slot, which means that a statement like
a @= b
would be silently expanded by the CPython interpreter to become
a = a @ b
This is undesireable, because it produces unexpected memory
allocations, breaks view relationships, and so forth.
This commit adds a nb_inplace_matrix_multiply slot which simply errors
out, and suggests that users write 'a = a @ b' explicitly if that's
what they want.
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/number.c | 11 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 13 |
2 files changed, 23 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 3e7521582..953a84eef 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -405,6 +405,15 @@ array_matrix_multiply(PyArrayObject *m1, PyObject *m2) 0, nb_matrix_multiply); return PyArray_GenericBinaryFunction(m1, m2, matmul); } + +static PyObject * +array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2) +{ + PyErr_SetString(PyExc_TypeError, + "In-place matrix multiplication is not (yet) supported. " + "Use 'a = a @ b' instead of 'a @= b'."); + return NULL; +} #endif /* Determine if object is a scalar and if so, convert the object @@ -1092,6 +1101,6 @@ NPY_NO_EXPORT PyNumberMethods array_as_number = { (unaryfunc)array_index, /*nb_index */ #if PY_VERSION_HEX >= 0x03050000 (binaryfunc)array_matrix_multiply, /*nb_matrix_multiply*/ - (binaryfunc)NULL, /*nb_inplacematrix_multiply*/ + (binaryfunc)array_inplace_matrix_multiply, /*nb_inplace_matrix_multiply*/ #endif }; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index ac645f013..9822d7dfc 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -4322,6 +4322,19 @@ if sys.version_info[:2] >= (3, 5): assert_equal(self.matmul(a, b), "A") assert_equal(self.matmul(b, a), "A") + def test_matmul_inplace(): + # It would be nice to support in-place matmul eventually, but for now + # we don't have a working implementation, so better just to error out + # and nudge people to writing "a = a @ b". + a = np.eye(3) + b = np.eye(3) + assert_raises(TypeError, a.__imatmul__, b) + import operator + assert_raises(TypeError, operator.imatmul, a, b) + # we avoid writing the token `exec` so as not to crash python 2's + # parser + exec_ = getattr(builtins, "exec") + assert_raises(TypeError, exec_, "a @= b", globals(), locals()) class TestInner(TestCase): |