summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-06-21 16:40:31 -0400
committerGitHub <noreply@github.com>2018-06-21 16:40:31 -0400
commita6d080015a4fba344868f3ca464cecdb7bc5f5aa (patch)
tree1ecb15004973555389988b8dca2f631384fb435f /numpy/core
parent9c43b1767e41726801e8ff70a23be7a0148d504c (diff)
parent77cc1609b5d7446d6c1b2d68d4ff779d107d2f69 (diff)
downloadnumpy-a6d080015a4fba344868f3ca464cecdb7bc5f5aa.tar.gz
Merge pull request #11405 from mhvk/correct_scalar_comparison
BUG: Ensure comparisons on scalar strings pass without warning.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/arrayobject.c13
-rw-r--r--numpy/core/tests/test_deprecations.py8
-rw-r--r--numpy/core/tests/test_ufunc.py10
3 files changed, 22 insertions, 9 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index e536de66a..368f5ded7 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -1285,6 +1285,7 @@ _failed_comparison_workaround(PyArrayObject *self, PyObject *other, int cmp_op)
PyObject *exc, *val, *tb;
PyArrayObject *array_other;
int other_is_flexible, ndim_other;
+ int self_is_flexible = PyTypeNum_ISFLEXIBLE(PyArray_DESCR(self)->type_num);
PyErr_Fetch(&exc, &val, &tb);
/*
@@ -1305,8 +1306,11 @@ _failed_comparison_workaround(PyArrayObject *self, PyObject *other, int cmp_op)
ndim_other = 0;
}
if (cmp_op == Py_EQ || cmp_op == Py_NE) {
- /* note: for == and !=, a flexible self cannot get here */
- if (other_is_flexible) {
+ /*
+ * note: for == and !=, a structured dtype self cannot get here,
+ * but a string can. Other can be string or structured.
+ */
+ if (other_is_flexible || self_is_flexible) {
/*
* For scalars, returning NotImplemented is correct.
* For arrays, we emit a future deprecation warning.
@@ -1325,7 +1329,7 @@ _failed_comparison_workaround(PyArrayObject *self, PyObject *other, int cmp_op)
}
else {
/*
- * If other did not have a flexible dtype, the error cannot
+ * If neither self nor other had a flexible dtype, the error cannot
* have been caused by a lack of implementation in the ufunc.
*
* 2015-05-14, 1.10
@@ -1342,8 +1346,7 @@ _failed_comparison_workaround(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
- else if (other_is_flexible ||
- PyTypeNum_ISFLEXIBLE(PyArray_DESCR(self)->type_num)) {
+ else if (other_is_flexible || self_is_flexible) {
/*
* For LE, LT, GT, GE and a flexible self or other, we return
* NotImplemented, which is the correct answer since the ufuncs do
diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py
index 285b2de3c..8eb258666 100644
--- a/numpy/core/tests/test_deprecations.py
+++ b/numpy/core/tests/test_deprecations.py
@@ -190,10 +190,10 @@ class TestComparisonDeprecations(_DeprecationTestCase):
b = np.array(['a', 'b', 'c'])
assert_raises(ValueError, lambda x, y: x == y, a, b)
- # The empty list is not cast to string, as this is only to document
- # that fact (it likely should be changed). This means that the
- # following works (and returns False) due to dtype mismatch:
- a == []
+ # The empty list is not cast to string, and this used to pass due
+ # to dtype mismatch; now (2018-06-21) it correctly leads to a
+ # FutureWarning.
+ assert_warns(FutureWarning, lambda: a == [])
def test_void_dtype_equality_failures(self):
class NotArray(object):
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index 49a4dbbc9..0e564e305 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -1643,6 +1643,16 @@ class TestUfunc(object):
target = np.array([ True, False, False, False], dtype=bool)
assert_equal(np.all(target == (mra == ra[0])), True)
+ def test_scalar_equal(self):
+ # Scalar comparisons should always work, without deprecation warnings.
+ # even when the ufunc fails.
+ a = np.array(0.)
+ b = np.array('a')
+ assert_(a != b)
+ assert_(b != a)
+ assert_(not (a == b))
+ assert_(not (b == a))
+
def test_NotImplemented_not_returned(self):
# See gh-5964 and gh-2091. Some of these functions are not operator
# related and were fixed for other reasons in the past.