diff options
-rw-r--r-- | numpy/core/src/multiarray/_multiarray_tests.c.src | 13 | ||||
-rw-r--r-- | numpy/core/src/multiarray/arraytypes.c.src | 12 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 40 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 24 |
4 files changed, 71 insertions, 18 deletions
diff --git a/numpy/core/src/multiarray/_multiarray_tests.c.src b/numpy/core/src/multiarray/_multiarray_tests.c.src index b22b2c14d..1fd28e721 100644 --- a/numpy/core/src/multiarray/_multiarray_tests.c.src +++ b/numpy/core/src/multiarray/_multiarray_tests.c.src @@ -177,8 +177,14 @@ test_neighborhood_iterator(PyObject* NPY_UNUSED(self), PyObject* args) return NULL; } - typenum = PyArray_ObjectType(x, 0); + typenum = PyArray_ObjectType(x, NPY_NOTYPE); + if (typenum == NPY_NOTYPE) { + return NULL; + } typenum = PyArray_ObjectType(fill, typenum); + if (typenum == NPY_NOTYPE) { + return NULL; + } ax = (PyArrayObject*)PyArray_FromObject(x, typenum, 1, 10); if (ax == NULL) { @@ -343,7 +349,10 @@ test_neighborhood_iterator_oob(PyObject* NPY_UNUSED(self), PyObject* args) return NULL; } - typenum = PyArray_ObjectType(x, 0); + typenum = PyArray_ObjectType(x, NPY_NOTYPE); + if (typenum == NPY_NOTYPE) { + return NULL; + } ax = (PyArrayObject*)PyArray_FromObject(x, typenum, 1, 10); if (ax == NULL) { diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index b7c339aa2..34694aac6 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -3803,17 +3803,23 @@ BOOL_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n, *((npy_bool *)op) = tmp; } +/* + * `dot` does not make sense for times, for DATETIME it never worked. + * For timedelta it does/did , but should probably also just be removed. + */ +#define DATETIME_dot NULL + /**begin repeat * * #name = BYTE, UBYTE, SHORT, USHORT, INT, UINT, * LONG, ULONG, LONGLONG, ULONGLONG, - * LONGDOUBLE, DATETIME, TIMEDELTA# + * LONGDOUBLE, TIMEDELTA# * #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint, * npy_long, npy_ulong, npy_longlong, npy_ulonglong, - * npy_longdouble, npy_datetime, npy_timedelta# + * npy_longdouble, npy_timedelta# * #out = npy_long, npy_ulong, npy_long, npy_ulong, npy_long, npy_ulong, * npy_long, npy_ulong, npy_longlong, npy_ulonglong, - * npy_longdouble, npy_datetime, npy_timedelta# + * npy_longdouble, npy_timedelta# */ static void @name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n, diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index dda8831c5..b2925f758 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -899,11 +899,15 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2) int i; PyObject* ret = NULL; - typenum = PyArray_ObjectType(op1, 0); - if (typenum == NPY_NOTYPE && PyErr_Occurred()) { + typenum = PyArray_ObjectType(op1, NPY_NOTYPE); + if (typenum == NPY_NOTYPE) { return NULL; } typenum = PyArray_ObjectType(op2, typenum); + if (typenum == NPY_NOTYPE) { + return NULL; + } + typec = PyArray_DescrFromType(typenum); if (typec == NULL) { if (!PyErr_Occurred()) { @@ -991,11 +995,15 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out) PyArray_Descr *typec = NULL; NPY_BEGIN_THREADS_DEF; - typenum = PyArray_ObjectType(op1, 0); - if (typenum == NPY_NOTYPE && PyErr_Occurred()) { + typenum = PyArray_ObjectType(op1, NPY_NOTYPE); + if (typenum == NPY_NOTYPE) { return NULL; } typenum = PyArray_ObjectType(op2, typenum); + if (typenum == NPY_NOTYPE) { + return NULL; + } + typec = PyArray_DescrFromType(typenum); if (typec == NULL) { if (!PyErr_Occurred()) { @@ -1373,8 +1381,14 @@ PyArray_Correlate2(PyObject *op1, PyObject *op2, int mode) int inverted; int st; - typenum = PyArray_ObjectType(op1, 0); + typenum = PyArray_ObjectType(op1, NPY_NOTYPE); + if (typenum == NPY_NOTYPE) { + return NULL; + } typenum = PyArray_ObjectType(op2, typenum); + if (typenum == NPY_NOTYPE) { + return NULL; + } typec = PyArray_DescrFromType(typenum); Py_INCREF(typec); @@ -1440,8 +1454,14 @@ PyArray_Correlate(PyObject *op1, PyObject *op2, int mode) int unused; PyArray_Descr *typec; - typenum = PyArray_ObjectType(op1, 0); + typenum = PyArray_ObjectType(op1, NPY_NOTYPE); + if (typenum == NPY_NOTYPE) { + return NULL; + } typenum = PyArray_ObjectType(op2, typenum); + if (typenum == NPY_NOTYPE) { + return NULL; + } typec = PyArray_DescrFromType(typenum); Py_INCREF(typec); @@ -2541,8 +2561,14 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) * Conjugating dot product using the BLAS for vectors. * Flattens both op1 and op2 before dotting. */ - typenum = PyArray_ObjectType(op1, 0); + typenum = PyArray_ObjectType(op1, NPY_NOTYPE); + if (typenum == NPY_NOTYPE) { + return NULL; + } typenum = PyArray_ObjectType(op2, typenum); + if (typenum == NPY_NOTYPE) { + return NULL; + } type = PyArray_DescrFromType(typenum); Py_INCREF(type); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 027384fba..15619bcb3 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1257,9 +1257,9 @@ class TestStructured: # The main importance is that it does not return True: with pytest.raises(TypeError): x == y - + def test_empty_structured_array_comparison(self): - # Check that comparison works on empty arrays with nontrivially + # Check that comparison works on empty arrays with nontrivially # shaped fields a = np.zeros(0, [('a', '<f8', (1, 1))]) assert_equal(a, a) @@ -2232,7 +2232,7 @@ class TestMethods: assert_c(a.copy('C')) assert_fortran(a.copy('F')) assert_c(a.copy('A')) - + @pytest.mark.parametrize("dtype", ['O', np.int32, 'i,O']) def test__deepcopy__(self, dtype): # Force the entry of NULLs into array @@ -2441,7 +2441,7 @@ class TestMethods: np.array([0, 1, np.nan]), ]) def test_searchsorted_floats(self, a): - # test for floats arrays containing nans. Explicitly test + # test for floats arrays containing nans. Explicitly test # half, single, and double precision floats to verify that # the NaN-handling is correct. msg = "Test real (%s) searchsorted with nans, side='l'" % a.dtype @@ -2457,7 +2457,7 @@ class TestMethods: assert_equal(y, 2) def test_searchsorted_complex(self): - # test for complex arrays containing nans. + # test for complex arrays containing nans. # The search sorted routines use the compare functions for the # array type, so this checks if that is consistent with the sort # order. @@ -2479,7 +2479,7 @@ class TestMethods: a = np.array([0, 128], dtype='>i4') b = a.searchsorted(np.array(128, dtype='>i4')) assert_equal(b, 1, msg) - + def test_searchsorted_n_elements(self): # Check 0 elements a = np.ones(0) @@ -6731,6 +6731,18 @@ class TestDot: res = np.dot(data, data) assert res == 2**30+100 + def test_dtype_discovery_fails(self): + # See gh-14247, error checking was missing for failed dtype discovery + class BadObject(object): + def __array__(self): + raise TypeError("just this tiny mint leaf") + + with pytest.raises(TypeError): + np.dot(BadObject(), BadObject()) + + with pytest.raises(TypeError): + np.dot(3.0, BadObject()) + class MatmulCommon: """Common tests for '@' operator and numpy.matmul. |