summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2022-06-08 13:39:31 -0700
committerGitHub <noreply@github.com>2022-06-08 13:39:31 -0700
commit4a0e5078be694a893ec173ea7b40bf107eb1cd14 (patch)
tree89f82ede35393f0c55460bbd5da3ca3991965f8f /numpy/core
parent11cc8a2476df22b9998bc90b56048c81b943cca0 (diff)
parentad9a03084919c5be168327180bc6eb9ae186c2dc (diff)
downloadnumpy-4a0e5078be694a893ec173ea7b40bf107eb1cd14.tar.gz
Merge pull request #21687 from rafaelcfsousa/bug_comparison
BUG: switch _CMP_NEQ_OQ to _CMP_NEQ_UQ for npyv_cmpneq_f[32,64]
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/common/simd/avx2/operators.h4
-rw-r--r--numpy/core/src/common/simd/avx512/operators.h4
-rw-r--r--numpy/core/tests/test_simd.py28
3 files changed, 32 insertions, 4 deletions
diff --git a/numpy/core/src/common/simd/avx2/operators.h b/numpy/core/src/common/simd/avx2/operators.h
index 99ef76dcb..7682b24cb 100644
--- a/numpy/core/src/common/simd/avx2/operators.h
+++ b/numpy/core/src/common/simd/avx2/operators.h
@@ -208,8 +208,8 @@ NPY_FINLINE __m256i npyv_cmpge_u32(__m256i a, __m256i b)
// precision comparison
#define npyv_cmpeq_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_EQ_OQ))
#define npyv_cmpeq_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_EQ_OQ))
-#define npyv_cmpneq_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_NEQ_OQ))
-#define npyv_cmpneq_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_NEQ_OQ))
+#define npyv_cmpneq_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_NEQ_UQ))
+#define npyv_cmpneq_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_NEQ_UQ))
#define npyv_cmplt_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_LT_OQ))
#define npyv_cmplt_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_LT_OQ))
#define npyv_cmple_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_LE_OQ))
diff --git a/numpy/core/src/common/simd/avx512/operators.h b/numpy/core/src/common/simd/avx512/operators.h
index b856b345a..804cd24e8 100644
--- a/numpy/core/src/common/simd/avx512/operators.h
+++ b/numpy/core/src/common/simd/avx512/operators.h
@@ -319,8 +319,8 @@
// precision comparison
#define npyv_cmpeq_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_EQ_OQ)
#define npyv_cmpeq_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_EQ_OQ)
-#define npyv_cmpneq_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_NEQ_OQ)
-#define npyv_cmpneq_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_NEQ_OQ)
+#define npyv_cmpneq_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_NEQ_UQ)
+#define npyv_cmpneq_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_NEQ_UQ)
#define npyv_cmplt_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_LT_OQ)
#define npyv_cmplt_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_LT_OQ)
#define npyv_cmple_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_LE_OQ)
diff --git a/numpy/core/tests/test_simd.py b/numpy/core/tests/test_simd.py
index f33db95fc..324948cf2 100644
--- a/numpy/core/tests/test_simd.py
+++ b/numpy/core/tests/test_simd.py
@@ -501,6 +501,34 @@ class _SIMD_FP(_Test_Utility):
nnan = self.notnan(self.setall(self._nan()))
assert nnan == [0]*self.nlanes
+ import operator
+
+ @pytest.mark.parametrize('py_comp,np_comp', [
+ (operator.lt, "cmplt"),
+ (operator.le, "cmple"),
+ (operator.gt, "cmpgt"),
+ (operator.ge, "cmpge"),
+ (operator.eq, "cmpeq"),
+ (operator.ne, "cmpneq")
+ ])
+ def test_comparison_with_nan(self, py_comp, np_comp):
+ pinf, ninf, nan = self._pinfinity(), self._ninfinity(), self._nan()
+ mask_true = self._true_mask()
+
+ def to_bool(vector):
+ return [lane == mask_true for lane in vector]
+
+ intrin = getattr(self, np_comp)
+ cmp_cases = ((0, nan), (nan, 0), (nan, nan), (pinf, nan), (ninf, nan))
+ for case_operand1, case_operand2 in cmp_cases:
+ data_a = [case_operand1]*self.nlanes
+ data_b = [case_operand2]*self.nlanes
+ vdata_a = self.setall(case_operand1)
+ vdata_b = self.setall(case_operand2)
+ vcmp = to_bool(intrin(vdata_a, vdata_b))
+ data_cmp = [py_comp(a, b) for a, b in zip(data_a, data_b)]
+ assert vcmp == data_cmp
+
class _SIMD_ALL(_Test_Utility):
"""
To test all vector types at once