summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorNathaniel J. Smith <njs@pobox.com>2015-06-21 18:41:34 -0700
committerNathaniel J. Smith <njs@pobox.com>2015-06-27 19:21:39 -0700
commit1adcdf7aa5b20a9afd778290105ec327b705c93e (patch)
treeb8117375595d7c0c05ef332b298891294c1b79f8 /numpy/core
parentf0c898b6017cc011561d7f1e611e08283ecfdb08 (diff)
downloadnumpy-1adcdf7aa5b20a9afd778290105ec327b705c93e.tar.gz
BUG: Make a @= b error out
Before this change, we defined a nb_matrix_multiply slot but not a nb_inplace_matrix_multiply slot, which means that a statement like a @= b would be silently expanded by the CPython interpreter to become a = a @ b This is undesireable, because it produces unexpected memory allocations, breaks view relationships, and so forth. This commit adds a nb_inplace_matrix_multiply slot which simply errors out, and suggests that users write 'a = a @ b' explicitly if that's what they want.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/number.c11
-rw-r--r--numpy/core/tests/test_multiarray.py13
2 files changed, 23 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 3e7521582..953a84eef 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -405,6 +405,15 @@ array_matrix_multiply(PyArrayObject *m1, PyObject *m2)
0, nb_matrix_multiply);
return PyArray_GenericBinaryFunction(m1, m2, matmul);
}
+
+static PyObject *
+array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2)
+{
+ PyErr_SetString(PyExc_TypeError,
+ "In-place matrix multiplication is not (yet) supported. "
+ "Use 'a = a @ b' instead of 'a @= b'.");
+ return NULL;
+}
#endif
/* Determine if object is a scalar and if so, convert the object
@@ -1092,6 +1101,6 @@ NPY_NO_EXPORT PyNumberMethods array_as_number = {
(unaryfunc)array_index, /*nb_index */
#if PY_VERSION_HEX >= 0x03050000
(binaryfunc)array_matrix_multiply, /*nb_matrix_multiply*/
- (binaryfunc)NULL, /*nb_inplacematrix_multiply*/
+ (binaryfunc)array_inplace_matrix_multiply, /*nb_inplace_matrix_multiply*/
#endif
};
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index ac645f013..9822d7dfc 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4322,6 +4322,19 @@ if sys.version_info[:2] >= (3, 5):
assert_equal(self.matmul(a, b), "A")
assert_equal(self.matmul(b, a), "A")
+ def test_matmul_inplace():
+ # It would be nice to support in-place matmul eventually, but for now
+ # we don't have a working implementation, so better just to error out
+ # and nudge people to writing "a = a @ b".
+ a = np.eye(3)
+ b = np.eye(3)
+ assert_raises(TypeError, a.__imatmul__, b)
+ import operator
+ assert_raises(TypeError, operator.imatmul, a, b)
+ # we avoid writing the token `exec` so as not to crash python 2's
+ # parser
+ exec_ = getattr(builtins, "exec")
+ assert_raises(TypeError, exec_, "a @= b", globals(), locals())
class TestInner(TestCase):