diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-06-21 16:40:31 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-21 16:40:31 -0400 |
commit | a6d080015a4fba344868f3ca464cecdb7bc5f5aa (patch) | |
tree | 1ecb15004973555389988b8dca2f631384fb435f /numpy/core | |
parent | 9c43b1767e41726801e8ff70a23be7a0148d504c (diff) | |
parent | 77cc1609b5d7446d6c1b2d68d4ff779d107d2f69 (diff) | |
download | numpy-a6d080015a4fba344868f3ca464cecdb7bc5f5aa.tar.gz |
Merge pull request #11405 from mhvk/correct_scalar_comparison
BUG: Ensure comparisons on scalar strings pass without warning.
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/arrayobject.c | 13 | ||||
-rw-r--r-- | numpy/core/tests/test_deprecations.py | 8 | ||||
-rw-r--r-- | numpy/core/tests/test_ufunc.py | 10 |
3 files changed, 22 insertions, 9 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index e536de66a..368f5ded7 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -1285,6 +1285,7 @@ _failed_comparison_workaround(PyArrayObject *self, PyObject *other, int cmp_op) PyObject *exc, *val, *tb; PyArrayObject *array_other; int other_is_flexible, ndim_other; + int self_is_flexible = PyTypeNum_ISFLEXIBLE(PyArray_DESCR(self)->type_num); PyErr_Fetch(&exc, &val, &tb); /* @@ -1305,8 +1306,11 @@ _failed_comparison_workaround(PyArrayObject *self, PyObject *other, int cmp_op) ndim_other = 0; } if (cmp_op == Py_EQ || cmp_op == Py_NE) { - /* note: for == and !=, a flexible self cannot get here */ - if (other_is_flexible) { + /* + * note: for == and !=, a structured dtype self cannot get here, + * but a string can. Other can be string or structured. + */ + if (other_is_flexible || self_is_flexible) { /* * For scalars, returning NotImplemented is correct. * For arrays, we emit a future deprecation warning. @@ -1325,7 +1329,7 @@ _failed_comparison_workaround(PyArrayObject *self, PyObject *other, int cmp_op) } else { /* - * If other did not have a flexible dtype, the error cannot + * If neither self nor other had a flexible dtype, the error cannot * have been caused by a lack of implementation in the ufunc. * * 2015-05-14, 1.10 @@ -1342,8 +1346,7 @@ _failed_comparison_workaround(PyArrayObject *self, PyObject *other, int cmp_op) Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } - else if (other_is_flexible || - PyTypeNum_ISFLEXIBLE(PyArray_DESCR(self)->type_num)) { + else if (other_is_flexible || self_is_flexible) { /* * For LE, LT, GT, GE and a flexible self or other, we return * NotImplemented, which is the correct answer since the ufuncs do diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py index 285b2de3c..8eb258666 100644 --- a/numpy/core/tests/test_deprecations.py +++ b/numpy/core/tests/test_deprecations.py @@ -190,10 +190,10 @@ class TestComparisonDeprecations(_DeprecationTestCase): b = np.array(['a', 'b', 'c']) assert_raises(ValueError, lambda x, y: x == y, a, b) - # The empty list is not cast to string, as this is only to document - # that fact (it likely should be changed). This means that the - # following works (and returns False) due to dtype mismatch: - a == [] + # The empty list is not cast to string, and this used to pass due + # to dtype mismatch; now (2018-06-21) it correctly leads to a + # FutureWarning. + assert_warns(FutureWarning, lambda: a == []) def test_void_dtype_equality_failures(self): class NotArray(object): diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index 49a4dbbc9..0e564e305 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -1643,6 +1643,16 @@ class TestUfunc(object): target = np.array([ True, False, False, False], dtype=bool) assert_equal(np.all(target == (mra == ra[0])), True) + def test_scalar_equal(self): + # Scalar comparisons should always work, without deprecation warnings. + # even when the ufunc fails. + a = np.array(0.) + b = np.array('a') + assert_(a != b) + assert_(b != a) + assert_(not (a == b)) + assert_(not (b == a)) + def test_NotImplemented_not_returned(self): # See gh-5964 and gh-2091. Some of these functions are not operator # related and were fixed for other reasons in the past. |