summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src40
-rw-r--r--numpy/core/tests/test_scalarmath.py23
2 files changed, 33 insertions, 30 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index f9af33491..dca6e3840 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -238,44 +238,34 @@ gentype_@name@(PyObject *m1, PyObject *m2)
/**end repeat**/
#endif
+/* Get a nested slot, or NULL if absent */
+#define GET_NESTED_SLOT(type, group, slot) \
+ ((type)->group == NULL ? NULL : (type)->group->slot)
+
static PyObject *
gentype_multiply(PyObject *m1, PyObject *m2)
{
- npy_intp repeat;
-
/*
* If the other object supports sequence repeat and not number multiply
- * we should call sequence repeat to support e.g. list repeat by numpy
- * scalars (they may be converted to ndarray otherwise).
+ * we fall back on the python builtin to invoke the sequence repeat, rather
+ * than promoting both arguments to ndarray.
+ * This covers a list repeat by numpy scalars.
* A python defined class will always only have the nb_multiply slot and
* some classes may have neither defined. For the latter we want need
* to give the normal case a chance to convert the object to ndarray.
* Probably no class has both defined, but if they do, prefer number.
*/
if (!PyArray_IsScalar(m1, Generic) &&
- ((Py_TYPE(m1)->tp_as_sequence != NULL) &&
- (Py_TYPE(m1)->tp_as_sequence->sq_repeat != NULL)) &&
- ((Py_TYPE(m1)->tp_as_number == NULL) ||
- (Py_TYPE(m1)->tp_as_number->nb_multiply == NULL))) {
- /* Try to convert m2 to an int and try sequence repeat */
- repeat = PyArray_PyIntAsIntp(m2);
- if (error_converting(repeat)) {
- return NULL;
- }
- /* Note that npy_intp is compatible to Py_Ssize_t */
- return PySequence_Repeat(m1, repeat);
+ GET_NESTED_SLOT(Py_TYPE(m1), tp_as_sequence, sq_repeat) != NULL &&
+ GET_NESTED_SLOT(Py_TYPE(m1), tp_as_number, nb_multiply) == NULL) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
}
if (!PyArray_IsScalar(m2, Generic) &&
- ((Py_TYPE(m2)->tp_as_sequence != NULL) &&
- (Py_TYPE(m2)->tp_as_sequence->sq_repeat != NULL)) &&
- ((Py_TYPE(m2)->tp_as_number == NULL) ||
- (Py_TYPE(m2)->tp_as_number->nb_multiply == NULL))) {
- /* Try to convert m1 to an int and try sequence repeat */
- repeat = PyArray_PyIntAsIntp(m1);
- if (error_converting(repeat)) {
- return NULL;
- }
- return PySequence_Repeat(m2, repeat);
+ GET_NESTED_SLOT(Py_TYPE(m2), tp_as_sequence, sq_repeat) != NULL &&
+ GET_NESTED_SLOT(Py_TYPE(m2), tp_as_number, nb_multiply) == NULL) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
}
/* All normal cases are handled by PyArray's multiply */
BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_multiply, gentype_multiply);
diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py
index 53b67327b..50824da41 100644
--- a/numpy/core/tests/test_scalarmath.py
+++ b/numpy/core/tests/test_scalarmath.py
@@ -10,7 +10,7 @@ from numpy.testing import (
run_module_suite,
assert_, assert_equal, assert_raises,
assert_almost_equal, assert_allclose, assert_array_equal,
- IS_PYPY, suppress_warnings, dec, _gen_alignment_data,
+ IS_PYPY, suppress_warnings, dec, _gen_alignment_data, assert_warns
)
types = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc, np.uintc,
@@ -561,16 +561,29 @@ class TestMultiply(object):
# numpy integers. And errors are raised when multiplied with others.
# Some of this behaviour may be controversial and could be open for
# change.
+ accepted_types = set(np.typecodes["AllInteger"])
+ deprecated_types = set('?')
+ forbidden_types = (
+ set(np.typecodes["All"]) - accepted_types - deprecated_types)
+ forbidden_types -= set('V') # can't default-construct void scalars
+
for seq_type in (list, tuple):
seq = seq_type([1, 2, 3])
- for numpy_type in np.typecodes["AllInteger"]:
+ for numpy_type in accepted_types:
i = np.dtype(numpy_type).type(2)
assert_equal(seq * i, seq * int(i))
assert_equal(i * seq, int(i) * seq)
- for numpy_type in np.typecodes["All"].replace("V", ""):
- if numpy_type in np.typecodes["AllInteger"]:
- continue
+ for numpy_type in deprecated_types:
+ i = np.dtype(numpy_type).type()
+ assert_equal(
+ assert_warns(DeprecationWarning, operator.mul, seq, i),
+ seq * int(i))
+ assert_equal(
+ assert_warns(DeprecationWarning, operator.mul, i, seq),
+ int(i) * seq)
+
+ for numpy_type in forbidden_types:
i = np.dtype(numpy_type).type()
assert_raises(TypeError, operator.mul, seq, i)
assert_raises(TypeError, operator.mul, i, seq)