summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Behnel <stefan_ml@behnel.de>2014-08-30 10:13:31 +0200
committerStefan Behnel <stefan_ml@behnel.de>2014-08-30 10:13:31 +0200
commitbe95915ce42d9481b82ee93106b57f620b00895b (patch)
tree5ee5c60f4e94fb2ea1d0f24a8d5a6424e642751a
parent1fed1015ae2f13e0974df8559191c153fb702a7e (diff)
downloadcython-be95915ce42d9481b82ee93106b57f620b00895b.tar.gz
fix error handling in backported matrix multiplication and compare to actual behaviour in Py3.5
-rw-r--r--Cython/Utility/ObjectHandling.c23
-rw-r--r--tests/run/matrix_multiplier.pyx34
2 files changed, 51 insertions, 6 deletions
diff --git a/Cython/Utility/ObjectHandling.c b/Cython/Utility/ObjectHandling.c
index c5baf2c35..cbf1e4afe 100644
--- a/Cython/Utility/ObjectHandling.c
+++ b/Cython/Utility/ObjectHandling.c
@@ -1353,7 +1353,8 @@ static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func) {
#define __Pyx_PyNumber_MatrixMultiply(x,y) PyNumber_MatrixMultiply(x,y)
#define __Pyx_PyNumber_InPlaceMatrixMultiply(x,y) PyNumber_InPlaceMatrixMultiply(x,y)
#else
-static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y);
+#define __Pyx_PyNumber_MatrixMultiply(x,y) __Pyx__PyNumber_MatrixMultiply(x, y, "@")
+static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const char* op_name);
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y);
#endif
@@ -1392,7 +1393,7 @@ bad:
return result;
}
-static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y) {
+static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const char* op_name) {
PyObject *func;
// FIXME: make subtype aware
// see note at https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
@@ -1410,10 +1411,20 @@ static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y) {
func = __Pyx_PyObject_GetAttrStr(y, PYIDENT("__rmatmul__"));
if (func) {
PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, x);
- return result;
+ if (result != Py_NotImplemented)
+ return result;
+ Py_DECREF(result);
+ } else {
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError))
+ return NULL;
+ PyErr_Clear();
}
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
+ PyErr_Format(PyExc_TypeError,
+ "unsupported operand type(s) for %.2s: '%.100s' and '%.100s'",
+ op_name,
+ Py_TYPE(x)->tp_name,
+ Py_TYPE(y)->tp_name);
+ return NULL;
}
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y) {
@@ -1429,6 +1440,6 @@ static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y)
return NULL;
PyErr_Clear();
}
- return __Pyx_PyNumber_MatrixMultiply(x, y);
+ return __Pyx__PyNumber_MatrixMultiply(x, y, "@=");
}
#endif
diff --git a/tests/run/matrix_multiplier.pyx b/tests/run/matrix_multiplier.pyx
index 5c19879ac..83d09a2cb 100644
--- a/tests/run/matrix_multiplier.pyx
+++ b/tests/run/matrix_multiplier.pyx
@@ -17,6 +17,28 @@ ExtMatMult(1) @ 22
ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)')
>>> print(test_imatmul(a, b))
ExtMatMult("ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)') @ ExtMatMult(2)")
+
+>>> x = y = 1
+>>> x @ y
+Traceback (most recent call last):
+TypeError: unsupported operand type(s) for @: 'int' and 'int'
+>>> x @= y
+Traceback (most recent call last):
+TypeError: unsupported operand type(s) for @=: 'int' and 'int'
+
+>>> y = MatMult(22)
+>>> x @= y
+>>> print(x)
+1 @ MatMult(22)
+
+>>> x = MatMult(22)
+>>> print(x @ 1)
+MatMult(22) @ 1
+>>> print(1 @ x)
+1 @ MatMult(22)
+>>> x @= 1
+>>> print(x)
+MatMult('MatMult(22) @ 1')
"""
@@ -71,6 +93,10 @@ def test_matmul(a, b):
11 @ MatMult(2)
>>> print(test_matmul(MatMult('abc'), MatMult('def')))
MatMult('abc') @ MatMult('def')
+
+ >>> test_matmul(1, 2)
+ Traceback (most recent call last):
+ TypeError: unsupported operand type(s) for @: 'int' and 'int'
"""
return a @ b
@@ -81,6 +107,14 @@ def test_imatmul(a, b):
MatMult('MatMult(1) @ MatMult(2)')
>>> print(test_imatmul(MatMult('abc'), MatMult('def')))
MatMult("MatMult('abc') @ MatMult('def')")
+ >>> print(test_imatmul(11, MatMult('def')))
+ 11 @ MatMult('def')
+ >>> print(test_imatmul(MatMult('abc'), 11))
+ MatMult("MatMult('abc') @ 11")
+
+ >>> test_imatmul(1, 2)
+ Traceback (most recent call last):
+ TypeError: unsupported operand type(s) for @=: 'int' and 'int'
"""
a @= b
return a