summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorRaghuveer Devulapalli <raghuveer.devulapalli@intel.com>2022-10-12 14:02:08 -0700
committerRaghuveer Devulapalli <raghuveer.devulapalli@intel.com>2023-01-30 13:38:39 -0800
commite91610af8ed4b9ba200086c7edea2f9a1a4ca280 (patch)
treec298f96a8408208b6e01d48aff6e81b78151606e /numpy
parentc71352232164ab7ddc4142ebc1db694493b34ff9 (diff)
downloadnumpy-e91610af8ed4b9ba200086c7edea2f9a1a4ca280.tar.gz
MAINT: Use loadu intrinsic instead of set1_epi16
gcc-8 is missing the _mm512_set1_epi16 intrinsic
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/npysort/x86-simd-sort/src/avx512-16bit-qsort.hpp170
1 files changed, 74 insertions, 96 deletions
diff --git a/numpy/core/src/npysort/x86-simd-sort/src/avx512-16bit-qsort.hpp b/numpy/core/src/npysort/x86-simd-sort/src/avx512-16bit-qsort.hpp
index 51cb4dbb0..5fcb8902d 100644
--- a/numpy/core/src/npysort/x86-simd-sort/src/avx512-16bit-qsort.hpp
+++ b/numpy/core/src/npysort/x86-simd-sort/src/avx512-16bit-qsort.hpp
@@ -15,24 +15,20 @@
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
*/
// ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0
-#define NETWORK_16BIT_1 \
- 24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23, 8, 9, 10, \
- 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7
-#define NETWORK_16BIT_2 \
- 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, \
- 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
-#define NETWORK_16BIT_3 \
- 27, 26, 25, 24, 31, 30, 29, 28, 19, 18, 17, 16, 23, 22, 21, 20, 11, 10, 9, \
- 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4
-#define NETWORK_16BIT_4 \
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, \
- 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
-#define NETWORK_16BIT_5 \
- 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24, 7, 6, 5, \
- 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8
-#define NETWORK_16BIT_6 \
- 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 31, 30, 29, 28, 27, \
- 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16
+static const uint16_t network[6][32]
+ = {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8,
+ 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24},
+ {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
+ 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16},
+ {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11,
+ 20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27},
+ {31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,
+ 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0},
+ {8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7,
+ 24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23},
+ {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}};
+
template <>
struct vector<int16_t> {
@@ -42,6 +38,10 @@ struct vector<int16_t> {
using opmask_t = __mmask32;
static const uint8_t numlanes = 32;
+ static zmm_t get_network(int index)
+ {
+ return _mm512_loadu_si512(&network[index-1][0]);
+ }
static type_t type_max()
{
return X86_SIMD_SORT_MAX_INT16;
@@ -54,20 +54,15 @@ struct vector<int16_t> {
{
return _mm512_set1_epi16(type_max());
}
-
static opmask_t knot_opmask(opmask_t x)
{
return npyv_not_b16(x);
}
+
static opmask_t ge(zmm_t x, zmm_t y)
{
return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT);
}
- //template <int scale>
- //static zmm_t i64gather(__m512i index, void const *base)
- //{
- // return _mm512_i64gather_epi64(index, base, scale);
- //}
static zmm_t loadu(void const *mem)
{
return _mm512_loadu_si512(mem);
@@ -141,6 +136,10 @@ struct vector<uint16_t> {
using opmask_t = __mmask32;
static const uint8_t numlanes = 32;
+ static zmm_t get_network(int index)
+ {
+ return _mm512_loadu_si512(&network[index-1][0]);
+ }
static type_t type_max()
{
return X86_SIMD_SORT_MAX_UINT16;
@@ -152,13 +151,8 @@ struct vector<uint16_t> {
static zmm_t zmm_max()
{
return _mm512_set1_epi16(type_max());
- } // TODO: this should broadcast bits as is?
+ }
- //template<int scale>
- //static zmm_t i64gather(__m512i index, void const *base)
- //{
- // return _mm512_i64gather_epi64(index, base, scale);
- //}
static opmask_t knot_opmask(opmask_t x)
{
return npyv_not_b16(x);
@@ -254,9 +248,7 @@ NPY_FINLINE zmm_t sort_zmm_16bit(zmm_t zmm)
0xAAAAAAAA);
// Level 3
zmm = cmp_merge<vtype>(
- zmm,
- vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_1), zmm),
- 0xF0F0F0F0);
+ zmm, vtype::permutexvar(vtype::get_network(1), zmm), 0xF0F0F0F0);
zmm = cmp_merge<vtype>(
zmm,
vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm),
@@ -267,13 +259,9 @@ NPY_FINLINE zmm_t sort_zmm_16bit(zmm_t zmm)
0xAAAAAAAA);
// Level 4
zmm = cmp_merge<vtype>(
- zmm,
- vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_2), zmm),
- 0xFF00FF00);
+ zmm, vtype::permutexvar(vtype::get_network(2), zmm), 0xFF00FF00);
zmm = cmp_merge<vtype>(
- zmm,
- vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_3), zmm),
- 0xF0F0F0F0);
+ zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0);
zmm = cmp_merge<vtype>(
zmm,
vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm),
@@ -284,17 +272,11 @@ NPY_FINLINE zmm_t sort_zmm_16bit(zmm_t zmm)
0xAAAAAAAA);
// Level 5
zmm = cmp_merge<vtype>(
- zmm,
- vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4), zmm),
- 0xFFFF0000);
+ zmm, vtype::permutexvar(vtype::get_network(4), zmm), 0xFFFF0000);
zmm = cmp_merge<vtype>(
- zmm,
- vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_5), zmm),
- 0xFF00FF00);
+ zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00);
zmm = cmp_merge<vtype>(
- zmm,
- vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_3), zmm),
- 0xF0F0F0F0);
+ zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0);
zmm = cmp_merge<vtype>(
zmm,
vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm),
@@ -312,19 +294,13 @@ NPY_FINLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
{
// 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc ..
zmm = cmp_merge<vtype>(
- zmm,
- vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_6), zmm),
- 0xFFFF0000);
+ zmm, vtype::permutexvar(vtype::get_network(6), zmm), 0xFFFF0000);
// 2) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc ..
zmm = cmp_merge<vtype>(
- zmm,
- vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_5), zmm),
- 0xFF00FF00);
+ zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00);
// 3) half_cleaner[8]
zmm = cmp_merge<vtype>(
- zmm,
- vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_3), zmm),
- 0xF0F0F0F0);
+ zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0);
// 3) half_cleaner[4]
zmm = cmp_merge<vtype>(
zmm,
@@ -343,7 +319,7 @@ template <typename vtype, typename zmm_t = typename vtype::zmm_t>
NPY_FINLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2)
{
// 1) First step of a merging network: coex of zmm1 and zmm2 reversed
- zmm2 = vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4), zmm2);
+ zmm2 = vtype::permutexvar(vtype::get_network(4), zmm2);
zmm_t zmm3 = vtype::min(zmm1, zmm2);
zmm_t zmm4 = vtype::max(zmm1, zmm2);
// 2) Recursive half cleaner for each
@@ -356,13 +332,13 @@ NPY_FINLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2)
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
NPY_FINLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm)
{
- zmm_t zmm2r = vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4), zmm[2]);
- zmm_t zmm3r = vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4), zmm[3]);
+ zmm_t zmm2r = vtype::permutexvar(vtype::get_network(4), zmm[2]);
+ zmm_t zmm3r = vtype::permutexvar(vtype::get_network(4), zmm[3]);
zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r);
zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r);
- zmm_t zmm_t3 = vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4),
+ zmm_t zmm_t3 = vtype::permutexvar(vtype::get_network(4),
vtype::max(zmm[1], zmm2r));
- zmm_t zmm_t4 = vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4),
+ zmm_t zmm_t4 = vtype::permutexvar(vtype::get_network(4),
vtype::max(zmm[0], zmm3r));
zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2);
zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2);
@@ -436,43 +412,45 @@ NPY_FINLINE void sort_128_16bit(type_t *arr, int32_t N)
}
template <typename vtype, typename type_t>
-NPY_FINLINE type_t
-get_pivot_16bit(type_t *arr, const int64_t left, const int64_t right)
+NPY_FINLINE type_t get_pivot_16bit(type_t *arr,
+ const int64_t left,
+ const int64_t right)
{
// median of 32
int64_t size = (right - left) / 32;
- __m512i rand_vec = _mm512_set_epi16(arr[left],
- arr[left + size],
- arr[left + 2 * size],
- arr[left + 3 * size],
- arr[left + 4 * size],
- arr[left + 5 * size],
- arr[left + 6 * size],
- arr[left + 7 * size],
- arr[left + 8 * size],
- arr[left + 9 * size],
- arr[left + 10 * size],
- arr[left + 11 * size],
- arr[left + 12 * size],
- arr[left + 13 * size],
- arr[left + 14 * size],
- arr[left + 15 * size],
- arr[left + 16 * size],
- arr[left + 17 * size],
- arr[left + 18 * size],
- arr[left + 19 * size],
- arr[left + 20 * size],
- arr[left + 21 * size],
- arr[left + 22 * size],
- arr[left + 23 * size],
- arr[left + 24 * size],
- arr[left + 25 * size],
- arr[left + 26 * size],
- arr[left + 27 * size],
- arr[left + 28 * size],
- arr[left + 29 * size],
- arr[left + 30 * size],
- arr[left + 31 * size]);
+ type_t vec_arr[32] = {arr[left],
+ arr[left + size],
+ arr[left + 2 * size],
+ arr[left + 3 * size],
+ arr[left + 4 * size],
+ arr[left + 5 * size],
+ arr[left + 6 * size],
+ arr[left + 7 * size],
+ arr[left + 8 * size],
+ arr[left + 9 * size],
+ arr[left + 10 * size],
+ arr[left + 11 * size],
+ arr[left + 12 * size],
+ arr[left + 13 * size],
+ arr[left + 14 * size],
+ arr[left + 15 * size],
+ arr[left + 16 * size],
+ arr[left + 17 * size],
+ arr[left + 18 * size],
+ arr[left + 19 * size],
+ arr[left + 20 * size],
+ arr[left + 21 * size],
+ arr[left + 22 * size],
+ arr[left + 23 * size],
+ arr[left + 24 * size],
+ arr[left + 25 * size],
+ arr[left + 26 * size],
+ arr[left + 27 * size],
+ arr[left + 28 * size],
+ arr[left + 29 * size],
+ arr[left + 30 * size],
+ arr[left + 31 * size]};
+ __m512i rand_vec = _mm512_loadu_si512(vec_arr);
__m512i sort = sort_zmm_16bit<vtype>(rand_vec);
return ((type_t *)&sort)[16];
}