diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 40 | ||||
-rw-r--r-- | numpy/core/tests/test_scalarmath.py | 23 |
2 files changed, 33 insertions, 30 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index f9af33491..dca6e3840 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -238,44 +238,34 @@ gentype_@name@(PyObject *m1, PyObject *m2) /**end repeat**/ #endif +/* Get a nested slot, or NULL if absent */ +#define GET_NESTED_SLOT(type, group, slot) \ + ((type)->group == NULL ? NULL : (type)->group->slot) + static PyObject * gentype_multiply(PyObject *m1, PyObject *m2) { - npy_intp repeat; - /* * If the other object supports sequence repeat and not number multiply - * we should call sequence repeat to support e.g. list repeat by numpy - * scalars (they may be converted to ndarray otherwise). + * we fall back on the python builtin to invoke the sequence repeat, rather + * than promoting both arguments to ndarray. + * This covers a list repeat by numpy scalars. * A python defined class will always only have the nb_multiply slot and * some classes may have neither defined. For the latter we want need * to give the normal case a chance to convert the object to ndarray. * Probably no class has both defined, but if they do, prefer number. */ if (!PyArray_IsScalar(m1, Generic) && - ((Py_TYPE(m1)->tp_as_sequence != NULL) && - (Py_TYPE(m1)->tp_as_sequence->sq_repeat != NULL)) && - ((Py_TYPE(m1)->tp_as_number == NULL) || - (Py_TYPE(m1)->tp_as_number->nb_multiply == NULL))) { - /* Try to convert m2 to an int and try sequence repeat */ - repeat = PyArray_PyIntAsIntp(m2); - if (error_converting(repeat)) { - return NULL; - } - /* Note that npy_intp is compatible to Py_Ssize_t */ - return PySequence_Repeat(m1, repeat); + GET_NESTED_SLOT(Py_TYPE(m1), tp_as_sequence, sq_repeat) != NULL && + GET_NESTED_SLOT(Py_TYPE(m1), tp_as_number, nb_multiply) == NULL) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; } if (!PyArray_IsScalar(m2, Generic) && - ((Py_TYPE(m2)->tp_as_sequence != NULL) && - (Py_TYPE(m2)->tp_as_sequence->sq_repeat != NULL)) && - ((Py_TYPE(m2)->tp_as_number == NULL) || - (Py_TYPE(m2)->tp_as_number->nb_multiply == NULL))) { - /* Try to convert m1 to an int and try sequence repeat */ - repeat = PyArray_PyIntAsIntp(m1); - if (error_converting(repeat)) { - return NULL; - } - return PySequence_Repeat(m2, repeat); + GET_NESTED_SLOT(Py_TYPE(m2), tp_as_sequence, sq_repeat) != NULL && + GET_NESTED_SLOT(Py_TYPE(m2), tp_as_number, nb_multiply) == NULL) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; } /* All normal cases are handled by PyArray's multiply */ BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_multiply, gentype_multiply); diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py index 53b67327b..50824da41 100644 --- a/numpy/core/tests/test_scalarmath.py +++ b/numpy/core/tests/test_scalarmath.py @@ -10,7 +10,7 @@ from numpy.testing import ( run_module_suite, assert_, assert_equal, assert_raises, assert_almost_equal, assert_allclose, assert_array_equal, - IS_PYPY, suppress_warnings, dec, _gen_alignment_data, + IS_PYPY, suppress_warnings, dec, _gen_alignment_data, assert_warns ) types = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc, np.uintc, @@ -561,16 +561,29 @@ class TestMultiply(object): # numpy integers. And errors are raised when multiplied with others. # Some of this behaviour may be controversial and could be open for # change. + accepted_types = set(np.typecodes["AllInteger"]) + deprecated_types = set('?') + forbidden_types = ( + set(np.typecodes["All"]) - accepted_types - deprecated_types) + forbidden_types -= set('V') # can't default-construct void scalars + for seq_type in (list, tuple): seq = seq_type([1, 2, 3]) - for numpy_type in np.typecodes["AllInteger"]: + for numpy_type in accepted_types: i = np.dtype(numpy_type).type(2) assert_equal(seq * i, seq * int(i)) assert_equal(i * seq, int(i) * seq) - for numpy_type in np.typecodes["All"].replace("V", ""): - if numpy_type in np.typecodes["AllInteger"]: - continue + for numpy_type in deprecated_types: + i = np.dtype(numpy_type).type() + assert_equal( + assert_warns(DeprecationWarning, operator.mul, seq, i), + seq * int(i)) + assert_equal( + assert_warns(DeprecationWarning, operator.mul, i, seq), + int(i) * seq) + + for numpy_type in forbidden_types: i = np.dtype(numpy_type).type() assert_raises(TypeError, operator.mul, seq, i) assert_raises(TypeError, operator.mul, i, seq) |