diff options
author | Raghuveer Devulapalli <raghuveer.devulapalli@intel.com> | 2022-10-12 14:02:08 -0700 |
---|---|---|
committer | Raghuveer Devulapalli <raghuveer.devulapalli@intel.com> | 2023-01-30 13:38:39 -0800 |
commit | e91610af8ed4b9ba200086c7edea2f9a1a4ca280 (patch) | |
tree | c298f96a8408208b6e01d48aff6e81b78151606e /numpy | |
parent | c71352232164ab7ddc4142ebc1db694493b34ff9 (diff) | |
download | numpy-e91610af8ed4b9ba200086c7edea2f9a1a4ca280.tar.gz |
MAINT: Use loadu intrinsic instead of set1_epi16
gcc-8 is missing the _mm512_set1_epi16 intrinsic
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/npysort/x86-simd-sort/src/avx512-16bit-qsort.hpp | 170 |
1 files changed, 74 insertions, 96 deletions
diff --git a/numpy/core/src/npysort/x86-simd-sort/src/avx512-16bit-qsort.hpp b/numpy/core/src/npysort/x86-simd-sort/src/avx512-16bit-qsort.hpp index 51cb4dbb0..5fcb8902d 100644 --- a/numpy/core/src/npysort/x86-simd-sort/src/avx512-16bit-qsort.hpp +++ b/numpy/core/src/npysort/x86-simd-sort/src/avx512-16bit-qsort.hpp @@ -15,24 +15,20 @@ * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) */ // ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 -#define NETWORK_16BIT_1 \ - 24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23, 8, 9, 10, \ - 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_16BIT_2 \ - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, \ - 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 -#define NETWORK_16BIT_3 \ - 27, 26, 25, 24, 31, 30, 29, 28, 19, 18, 17, 16, 23, 22, 21, 20, 11, 10, 9, \ - 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 -#define NETWORK_16BIT_4 \ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, \ - 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 -#define NETWORK_16BIT_5 \ - 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24, 7, 6, 5, \ - 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 -#define NETWORK_16BIT_6 \ - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 31, 30, 29, 28, 27, \ - 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16 +static const uint16_t network[6][32] + = {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}, + {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16}, + {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, + 20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27}, + {31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + {8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23}, + {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}; + template <> struct vector<int16_t> { @@ -42,6 +38,10 @@ struct vector<int16_t> { using opmask_t = __mmask32; static const uint8_t numlanes = 32; + static zmm_t get_network(int index) + { + return _mm512_loadu_si512(&network[index-1][0]); + } static type_t type_max() { return X86_SIMD_SORT_MAX_INT16; @@ -54,20 +54,15 @@ struct vector<int16_t> { { return _mm512_set1_epi16(type_max()); } - static opmask_t knot_opmask(opmask_t x) { return npyv_not_b16(x); } + static opmask_t ge(zmm_t x, zmm_t y) { return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT); } - //template <int scale> - //static zmm_t i64gather(__m512i index, void const *base) - //{ - // return _mm512_i64gather_epi64(index, base, scale); - //} static zmm_t loadu(void const *mem) { return _mm512_loadu_si512(mem); @@ -141,6 +136,10 @@ struct vector<uint16_t> { using opmask_t = __mmask32; static const uint8_t numlanes = 32; + static zmm_t get_network(int index) + { + return _mm512_loadu_si512(&network[index-1][0]); + } static type_t type_max() { return X86_SIMD_SORT_MAX_UINT16; @@ -152,13 +151,8 @@ struct vector<uint16_t> { static zmm_t zmm_max() { return _mm512_set1_epi16(type_max()); - } // TODO: this should broadcast bits as is? + } - //template<int scale> - //static zmm_t i64gather(__m512i index, void const *base) - //{ - // return _mm512_i64gather_epi64(index, base, scale); - //} static opmask_t knot_opmask(opmask_t x) { return npyv_not_b16(x); @@ -254,9 +248,7 @@ NPY_FINLINE zmm_t sort_zmm_16bit(zmm_t zmm) 0xAAAAAAAA); // Level 3 zmm = cmp_merge<vtype>( - zmm, - vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_1), zmm), - 0xF0F0F0F0); + zmm, vtype::permutexvar(vtype::get_network(1), zmm), 0xF0F0F0F0); zmm = cmp_merge<vtype>( zmm, vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm), @@ -267,13 +259,9 @@ NPY_FINLINE zmm_t sort_zmm_16bit(zmm_t zmm) 0xAAAAAAAA); // Level 4 zmm = cmp_merge<vtype>( - zmm, - vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_2), zmm), - 0xFF00FF00); + zmm, vtype::permutexvar(vtype::get_network(2), zmm), 0xFF00FF00); zmm = cmp_merge<vtype>( - zmm, - vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_3), zmm), - 0xF0F0F0F0); + zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); zmm = cmp_merge<vtype>( zmm, vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm), @@ -284,17 +272,11 @@ NPY_FINLINE zmm_t sort_zmm_16bit(zmm_t zmm) 0xAAAAAAAA); // Level 5 zmm = cmp_merge<vtype>( - zmm, - vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4), zmm), - 0xFFFF0000); + zmm, vtype::permutexvar(vtype::get_network(4), zmm), 0xFFFF0000); zmm = cmp_merge<vtype>( - zmm, - vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_5), zmm), - 0xFF00FF00); + zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00); zmm = cmp_merge<vtype>( - zmm, - vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_3), zmm), - 0xF0F0F0F0); + zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); zmm = cmp_merge<vtype>( zmm, vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm), @@ -312,19 +294,13 @@ NPY_FINLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm) { // 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc .. zmm = cmp_merge<vtype>( - zmm, - vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_6), zmm), - 0xFFFF0000); + zmm, vtype::permutexvar(vtype::get_network(6), zmm), 0xFFFF0000); // 2) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc .. zmm = cmp_merge<vtype>( - zmm, - vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_5), zmm), - 0xFF00FF00); + zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00); // 3) half_cleaner[8] zmm = cmp_merge<vtype>( - zmm, - vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_3), zmm), - 0xF0F0F0F0); + zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); // 3) half_cleaner[4] zmm = cmp_merge<vtype>( zmm, @@ -343,7 +319,7 @@ template <typename vtype, typename zmm_t = typename vtype::zmm_t> NPY_FINLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2) { // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - zmm2 = vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4), zmm2); + zmm2 = vtype::permutexvar(vtype::get_network(4), zmm2); zmm_t zmm3 = vtype::min(zmm1, zmm2); zmm_t zmm4 = vtype::max(zmm1, zmm2); // 2) Recursive half cleaner for each @@ -356,13 +332,13 @@ NPY_FINLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2) template <typename vtype, typename zmm_t = typename vtype::zmm_t> NPY_FINLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm) { - zmm_t zmm2r = vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4), zmm[2]); - zmm_t zmm3r = vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4), zmm[3]); + zmm_t zmm2r = vtype::permutexvar(vtype::get_network(4), zmm[2]); + zmm_t zmm3r = vtype::permutexvar(vtype::get_network(4), zmm[3]); zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r); zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r); - zmm_t zmm_t3 = vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4), + zmm_t zmm_t3 = vtype::permutexvar(vtype::get_network(4), vtype::max(zmm[1], zmm2r)); - zmm_t zmm_t4 = vtype::permutexvar(_mm512_set_epi16(NETWORK_16BIT_4), + zmm_t zmm_t4 = vtype::permutexvar(vtype::get_network(4), vtype::max(zmm[0], zmm3r)); zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2); zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2); @@ -436,43 +412,45 @@ NPY_FINLINE void sort_128_16bit(type_t *arr, int32_t N) } template <typename vtype, typename type_t> -NPY_FINLINE type_t -get_pivot_16bit(type_t *arr, const int64_t left, const int64_t right) +NPY_FINLINE type_t get_pivot_16bit(type_t *arr, + const int64_t left, + const int64_t right) { // median of 32 int64_t size = (right - left) / 32; - __m512i rand_vec = _mm512_set_epi16(arr[left], - arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size], - arr[left + 9 * size], - arr[left + 10 * size], - arr[left + 11 * size], - arr[left + 12 * size], - arr[left + 13 * size], - arr[left + 14 * size], - arr[left + 15 * size], - arr[left + 16 * size], - arr[left + 17 * size], - arr[left + 18 * size], - arr[left + 19 * size], - arr[left + 20 * size], - arr[left + 21 * size], - arr[left + 22 * size], - arr[left + 23 * size], - arr[left + 24 * size], - arr[left + 25 * size], - arr[left + 26 * size], - arr[left + 27 * size], - arr[left + 28 * size], - arr[left + 29 * size], - arr[left + 30 * size], - arr[left + 31 * size]); + type_t vec_arr[32] = {arr[left], + arr[left + size], + arr[left + 2 * size], + arr[left + 3 * size], + arr[left + 4 * size], + arr[left + 5 * size], + arr[left + 6 * size], + arr[left + 7 * size], + arr[left + 8 * size], + arr[left + 9 * size], + arr[left + 10 * size], + arr[left + 11 * size], + arr[left + 12 * size], + arr[left + 13 * size], + arr[left + 14 * size], + arr[left + 15 * size], + arr[left + 16 * size], + arr[left + 17 * size], + arr[left + 18 * size], + arr[left + 19 * size], + arr[left + 20 * size], + arr[left + 21 * size], + arr[left + 22 * size], + arr[left + 23 * size], + arr[left + 24 * size], + arr[left + 25 * size], + arr[left + 26 * size], + arr[left + 27 * size], + arr[left + 28 * size], + arr[left + 29 * size], + arr[left + 30 * size], + arr[left + 31 * size]}; + __m512i rand_vec = _mm512_loadu_si512(vec_arr); __m512i sort = sort_zmm_16bit<vtype>(rand_vec); return ((type_t *)&sort)[16]; } |