summaryrefslogtreecommitdiff
path: root/numpy/core/src/common/simd/avx512/arithmetic.h
blob: a63da87d0c4a53ee02753ff5ed27210c771228c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
#ifndef NPY_SIMD
    #error "Not a standalone header"
#endif

#ifndef _NPY_SIMD_AVX512_ARITHMETIC_H
#define _NPY_SIMD_AVX512_ARITHMETIC_H

#include "../avx2/utils.h"
#include "../sse/utils.h"
/***************************
 * Addition
 ***************************/
// non-saturated
#ifdef NPY_HAVE_AVX512BW
    #define npyv_add_u8  _mm512_add_epi8
    #define npyv_add_u16 _mm512_add_epi16
#else
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_add_u8,  _mm256_add_epi8)
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_add_u16, _mm256_add_epi16)
#endif
#define npyv_add_s8  npyv_add_u8
#define npyv_add_s16 npyv_add_u16
#define npyv_add_u32 _mm512_add_epi32
#define npyv_add_s32 _mm512_add_epi32
#define npyv_add_u64 _mm512_add_epi64
#define npyv_add_s64 _mm512_add_epi64
#define npyv_add_f32 _mm512_add_ps
#define npyv_add_f64 _mm512_add_pd

// saturated
#ifdef NPY_HAVE_AVX512BW
    #define npyv_adds_u8  _mm512_adds_epu8
    #define npyv_adds_s8  _mm512_adds_epi8
    #define npyv_adds_u16 _mm512_adds_epu16
    #define npyv_adds_s16 _mm512_adds_epi16
#else
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_adds_u8,  _mm256_adds_epu8)
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_adds_s8,  _mm256_adds_epi8)
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_adds_u16, _mm256_adds_epu16)
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_adds_s16, _mm256_adds_epi16)
#endif
// TODO: rest, after implement Packs intrins

/***************************
 * Subtraction
 ***************************/
// non-saturated
#ifdef NPY_HAVE_AVX512BW
    #define npyv_sub_u8  _mm512_sub_epi8
    #define npyv_sub_u16 _mm512_sub_epi16
#else
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_sub_u8,  _mm256_sub_epi8)
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_sub_u16, _mm256_sub_epi16)
#endif
#define npyv_sub_s8  npyv_sub_u8
#define npyv_sub_s16 npyv_sub_u16
#define npyv_sub_u32 _mm512_sub_epi32
#define npyv_sub_s32 _mm512_sub_epi32
#define npyv_sub_u64 _mm512_sub_epi64
#define npyv_sub_s64 _mm512_sub_epi64
#define npyv_sub_f32 _mm512_sub_ps
#define npyv_sub_f64 _mm512_sub_pd

// saturated
#ifdef NPY_HAVE_AVX512BW
    #define npyv_subs_u8  _mm512_subs_epu8
    #define npyv_subs_s8  _mm512_subs_epi8
    #define npyv_subs_u16 _mm512_subs_epu16
    #define npyv_subs_s16 _mm512_subs_epi16
#else
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_subs_u8,  _mm256_subs_epu8)
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_subs_s8,  _mm256_subs_epi8)
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_subs_u16, _mm256_subs_epu16)
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_subs_s16, _mm256_subs_epi16)
#endif
// TODO: rest, after implement Packs intrins

/***************************
 * Multiplication
 ***************************/
// non-saturated
#ifdef NPY_HAVE_AVX512BW
NPY_FINLINE __m512i npyv_mul_u8(__m512i a, __m512i b)
{
    __m512i even = _mm512_mullo_epi16(a, b);
    __m512i odd  = _mm512_mullo_epi16(_mm512_srai_epi16(a, 8), _mm512_srai_epi16(b, 8));
            odd  = _mm512_slli_epi16(odd, 8);
    return _mm512_mask_blend_epi8(0xAAAAAAAAAAAAAAAA, even, odd);
}
#else
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_mul_u8, npyv256_mul_u8)
#endif

#ifdef NPY_HAVE_AVX512BW
    #define npyv_mul_u16 _mm512_mullo_epi16
#else
    NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_mul_u16, _mm256_mullo_epi16)
#endif
#define npyv_mul_s8  npyv_mul_u8
#define npyv_mul_s16 npyv_mul_u16
#define npyv_mul_u32 _mm512_mullo_epi32
#define npyv_mul_s32 _mm512_mullo_epi32
#define npyv_mul_f32 _mm512_mul_ps
#define npyv_mul_f64 _mm512_mul_pd

// saturated
// TODO: after implement Packs intrins

/***************************
 * Integer Division
 ***************************/
// See simd/intdiv.h for more clarification
// divide each unsigned 8-bit element by divisor
NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor)
{
    const __m128i shf1  = _mm512_castsi512_si128(divisor.val[1]);
    const __m128i shf2  = _mm512_castsi512_si128(divisor.val[2]);
#ifdef NPY_HAVE_AVX512BW
    const __m512i bmask = _mm512_set1_epi32(0x00FF00FF);
    const __m512i shf1b = _mm512_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf1));
    const __m512i shf2b = _mm512_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf2));
    // high part of unsigned multiplication
    __m512i mulhi_even  = _mm512_mullo_epi16(_mm512_and_si512(a, bmask), divisor.val[0]);
            mulhi_even  = _mm512_srli_epi16(mulhi_even, 8);
    __m512i mulhi_odd   = _mm512_mullo_epi16(_mm512_srli_epi16(a, 8), divisor.val[0]);
    __m512i mulhi       = _mm512_mask_mov_epi8(mulhi_even, 0xAAAAAAAAAAAAAAAA, mulhi_odd);
    // floor(a/d)       = (mulhi + ((a-mulhi) >> sh1)) >> sh2
    __m512i q           = _mm512_sub_epi8(a, mulhi);
            q           = _mm512_and_si512(_mm512_srl_epi16(q, shf1), shf1b);
            q           = _mm512_add_epi8(mulhi, q);
            q           = _mm512_and_si512(_mm512_srl_epi16(q, shf2), shf2b);
    return  q;
#else
    const __m256i bmask = _mm256_set1_epi32(0x00FF00FF);
    const __m256i shf1b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf1));
    const __m256i shf2b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf2));
    const __m512i shf2bw= npyv512_combine_si256(shf2b, shf2b);
    const __m256i mulc  = npyv512_lower_si256(divisor.val[0]);
    //// lower 256-bit
    __m256i lo_a        = npyv512_lower_si256(a);
    // high part of unsigned multiplication
    __m256i mulhi_even  = _mm256_mullo_epi16(_mm256_and_si256(lo_a, bmask), mulc);
            mulhi_even  = _mm256_srli_epi16(mulhi_even, 8);
    __m256i mulhi_odd   = _mm256_mullo_epi16(_mm256_srli_epi16(lo_a, 8), mulc);
    __m256i mulhi       = _mm256_blendv_epi8(mulhi_odd, mulhi_even, bmask);
    // floor(a/d)       = (mulhi + ((a-mulhi) >> sh1)) >> sh2
    __m256i lo_q        = _mm256_sub_epi8(lo_a, mulhi);
            lo_q        = _mm256_and_si256(_mm256_srl_epi16(lo_q, shf1), shf1b);
            lo_q        = _mm256_add_epi8(mulhi, lo_q);
            lo_q        = _mm256_srl_epi16(lo_q, shf2); // no sign extend

    //// higher 256-bit
    __m256i hi_a        = npyv512_higher_si256(a);
    // high part of unsigned multiplication
            mulhi_even  = _mm256_mullo_epi16(_mm256_and_si256(hi_a, bmask), mulc);
            mulhi_even  = _mm256_srli_epi16(mulhi_even, 8);
            mulhi_odd   = _mm256_mullo_epi16(_mm256_srli_epi16(hi_a, 8), mulc);
            mulhi       = _mm256_blendv_epi8(mulhi_odd, mulhi_even, bmask);
    // floor(a/d)       = (mulhi + ((a-mulhi) >> sh1)) >> sh2
    __m256i hi_q        = _mm256_sub_epi8(hi_a, mulhi);
            hi_q        = _mm256_and_si256(_mm256_srl_epi16(hi_q, shf1), shf1b);
            hi_q        = _mm256_add_epi8(mulhi, hi_q);
            hi_q        = _mm256_srl_epi16(hi_q, shf2); // no sign extend
    return _mm512_and_si512(npyv512_combine_si256(lo_q, hi_q), shf2bw); // extend sign
#endif
}
// divide each signed 8-bit element by divisor (round towards zero)
NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor);
NPY_FINLINE npyv_s8 npyv_divc_s8(npyv_s8 a, const npyv_s8x3 divisor)
{
    __m512i divc_even = npyv_divc_s16(npyv_shri_s16(npyv_shli_s16(a, 8), 8), divisor);
    __m512i divc_odd  = npyv_divc_s16(npyv_shri_s16(a, 8), divisor);
            divc_odd  = npyv_shli_s16(divc_odd, 8);
#ifdef NPY_HAVE_AVX512BW
    return _mm512_mask_mov_epi8(divc_even, 0xAAAAAAAAAAAAAAAA, divc_odd);
#else
    const __m512i bmask = _mm512_set1_epi32(0x00FF00FF);
    return npyv_select_u8(bmask, divc_even, divc_odd);
#endif
}
// divide each unsigned 16-bit element by divisor
NPY_FINLINE npyv_u16 npyv_divc_u16(npyv_u16 a, const npyv_u16x3 divisor)
{
    const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
    const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]);
    // floor(a/d)      = (mulhi + ((a-mulhi) >> sh1)) >> sh2
    #define NPYV__DIVC_U16(RLEN, A, MULC, R)      \
        mulhi = _mm##RLEN##_mulhi_epu16(A, MULC); \
        R     = _mm##RLEN##_sub_epi16(A, mulhi);  \
        R     = _mm##RLEN##_srl_epi16(R, shf1);   \
        R     = _mm##RLEN##_add_epi16(mulhi, R);  \
        R     = _mm##RLEN##_srl_epi16(R, shf2);

#ifdef NPY_HAVE_AVX512BW
    __m512i mulhi, q;
    NPYV__DIVC_U16(512, a, divisor.val[0], q)
    return q;
#else
    const __m256i m = npyv512_lower_si256(divisor.val[0]);
    __m256i lo_a    = npyv512_lower_si256(a);
    __m256i hi_a    = npyv512_higher_si256(a);

    __m256i mulhi, lo_q, hi_q;
    NPYV__DIVC_U16(256, lo_a, m, lo_q)
    NPYV__DIVC_U16(256, hi_a, m, hi_q)
    return npyv512_combine_si256(lo_q, hi_q);
#endif
    #undef NPYV__DIVC_U16
}
// divide each signed 16-bit element by divisor (round towards zero)
NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor)
{
    const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
    // q               = ((a + mulhi) >> sh1) - XSIGN(a)
    // trunc(a/d)      = (q ^ dsign) - dsign
    #define NPYV__DIVC_S16(RLEN, A, MULC, DSIGN, R)                       \
        mulhi  = _mm##RLEN##_mulhi_epi16(A, MULC);                        \
        R = _mm##RLEN##_sra_epi16(_mm##RLEN##_add_epi16(A, mulhi), shf1); \
        R = _mm##RLEN##_sub_epi16(R, _mm##RLEN##_srai_epi16(A, 15));      \
        R = _mm##RLEN##_sub_epi16(_mm##RLEN##_xor_si##RLEN(R, DSIGN), DSIGN);

#ifdef NPY_HAVE_AVX512BW
    __m512i mulhi, q;
    NPYV__DIVC_S16(512, a, divisor.val[0], divisor.val[2], q)
    return q;
#else
    const __m256i m     = npyv512_lower_si256(divisor.val[0]);
    const __m256i dsign = npyv512_lower_si256(divisor.val[2]);
    __m256i lo_a        = npyv512_lower_si256(a);
    __m256i hi_a        = npyv512_higher_si256(a);

    __m256i mulhi, lo_q, hi_q;
    NPYV__DIVC_S16(256, lo_a, m, dsign, lo_q)
    NPYV__DIVC_S16(256, hi_a, m, dsign, hi_q)
    return npyv512_combine_si256(lo_q, hi_q);
#endif
    #undef NPYV__DIVC_S16
}
// divide each unsigned 32-bit element by divisor
NPY_FINLINE npyv_u32 npyv_divc_u32(npyv_u32 a, const npyv_u32x3 divisor)
{
    const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
    const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]);
    // high part of unsigned multiplication
    __m512i mulhi_even = _mm512_srli_epi64(_mm512_mul_epu32(a, divisor.val[0]), 32);
    __m512i mulhi_odd  = _mm512_mul_epu32(_mm512_srli_epi64(a, 32), divisor.val[0]);
    __m512i mulhi      = _mm512_mask_mov_epi32(mulhi_even, 0xAAAA, mulhi_odd);
    // floor(a/d)      = (mulhi + ((a-mulhi) >> sh1)) >> sh2
    __m512i q          = _mm512_sub_epi32(a, mulhi);
            q          = _mm512_srl_epi32(q, shf1);
            q          = _mm512_add_epi32(mulhi, q);
            q          = _mm512_srl_epi32(q, shf2);
    return  q;
}
// divide each signed 32-bit element by divisor (round towards zero)
NPY_FINLINE npyv_s32 npyv_divc_s32(npyv_s32 a, const npyv_s32x3 divisor)
{
    const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
    // high part of signed multiplication
    __m512i mulhi_even = _mm512_srli_epi64(_mm512_mul_epi32(a, divisor.val[0]), 32);
    __m512i mulhi_odd  = _mm512_mul_epi32(_mm512_srli_epi64(a, 32), divisor.val[0]);
    __m512i mulhi      = _mm512_mask_mov_epi32(mulhi_even, 0xAAAA, mulhi_odd);
    // q               = ((a + mulhi) >> sh1) - XSIGN(a)
    // trunc(a/d)      = (q ^ dsign) - dsign
    __m512i q          = _mm512_sra_epi32(_mm512_add_epi32(a, mulhi), shf1);
            q          = _mm512_sub_epi32(q, _mm512_srai_epi32(a, 31));
            q          = _mm512_sub_epi32(_mm512_xor_si512(q, divisor.val[2]), divisor.val[2]);
    return  q;
}
// returns the high 64 bits of unsigned 64-bit multiplication
// xref https://stackoverflow.com/a/28827013
NPY_FINLINE npyv_u64 npyv__mullhi_u64(npyv_u64 a, npyv_u64 b)
{
    __m512i lomask = npyv_setall_s64(0xffffffff);
    __m512i a_hi   = _mm512_srli_epi64(a, 32);        // a0l, a0h, a1l, a1h
    __m512i b_hi   = _mm512_srli_epi64(b, 32);        // b0l, b0h, b1l, b1h
    // compute partial products
    __m512i w0     = _mm512_mul_epu32(a, b);          // a0l*b0l, a1l*b1l
    __m512i w1     = _mm512_mul_epu32(a, b_hi);       // a0l*b0h, a1l*b1h
    __m512i w2     = _mm512_mul_epu32(a_hi, b);       // a0h*b0l, a1h*b0l
    __m512i w3     = _mm512_mul_epu32(a_hi, b_hi);    // a0h*b0h, a1h*b1h
    // sum partial products
    __m512i w0h    = _mm512_srli_epi64(w0, 32);
    __m512i s1     = _mm512_add_epi64(w1, w0h);
    __m512i s1l    = _mm512_and_si512(s1, lomask);
    __m512i s1h    = _mm512_srli_epi64(s1, 32);

    __m512i s2     = _mm512_add_epi64(w2, s1l);
    __m512i s2h    = _mm512_srli_epi64(s2, 32);

    __m512i hi     = _mm512_add_epi64(w3, s1h);
            hi     = _mm512_add_epi64(hi, s2h);
    return hi;
}
// divide each unsigned 64-bit element by a divisor
NPY_FINLINE npyv_u64 npyv_divc_u64(npyv_u64 a, const npyv_u64x3 divisor)
{
    const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
    const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]);
    // high part of unsigned multiplication
    __m512i mulhi      = npyv__mullhi_u64(a, divisor.val[0]);
    // floor(a/d)      = (mulhi + ((a-mulhi) >> sh1)) >> sh2
    __m512i q          = _mm512_sub_epi64(a, mulhi);
            q          = _mm512_srl_epi64(q, shf1);
            q          = _mm512_add_epi64(mulhi, q);
            q          = _mm512_srl_epi64(q, shf2);
    return  q;
}
// divide each unsigned 64-bit element by a divisor (round towards zero)
NPY_FINLINE npyv_s64 npyv_divc_s64(npyv_s64 a, const npyv_s64x3 divisor)
{
    const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
    // high part of unsigned multiplication
    __m512i mulhi      = npyv__mullhi_u64(a, divisor.val[0]);
    // convert unsigned to signed high multiplication
    // mulhi - ((a < 0) ? m : 0) - ((m < 0) ? a : 0);
    __m512i asign      = _mm512_srai_epi64(a, 63);
    __m512i msign      = _mm512_srai_epi64(divisor.val[0], 63);
    __m512i m_asign    = _mm512_and_si512(divisor.val[0], asign);
    __m512i a_msign    = _mm512_and_si512(a, msign);
            mulhi      = _mm512_sub_epi64(mulhi, m_asign);
            mulhi      = _mm512_sub_epi64(mulhi, a_msign);
    // q               = ((a + mulhi) >> sh1) - XSIGN(a)
    // trunc(a/d)      = (q ^ dsign) - dsign
    __m512i q          = _mm512_sra_epi64(_mm512_add_epi64(a, mulhi), shf1);
            q          = _mm512_sub_epi64(q, asign);
            q          = _mm512_sub_epi64(_mm512_xor_si512(q, divisor.val[2]), divisor.val[2]);
    return  q;
}
/***************************
 * Division
 ***************************/
// TODO: emulate integer division
#define npyv_div_f32 _mm512_div_ps
#define npyv_div_f64 _mm512_div_pd

/***************************
 * FUSED
 ***************************/
// multiply and add, a*b + c
#define npyv_muladd_f32 _mm512_fmadd_ps
#define npyv_muladd_f64 _mm512_fmadd_pd
// multiply and subtract, a*b - c
#define npyv_mulsub_f32 _mm512_fmsub_ps
#define npyv_mulsub_f64 _mm512_fmsub_pd
// negate multiply and add, -(a*b) + c
#define npyv_nmuladd_f32 _mm512_fnmadd_ps
#define npyv_nmuladd_f64 _mm512_fnmadd_pd
// negate multiply and subtract, -(a*b) - c
#define npyv_nmulsub_f32 _mm512_fnmsub_ps
#define npyv_nmulsub_f64 _mm512_fnmsub_pd
// multiply, add for odd elements and subtract even elements.
// (a * b) -+ c
#define npyv_muladdsub_f32 _mm512_fmaddsub_ps
#define npyv_muladdsub_f64 _mm512_fmaddsub_pd

/***************************
 * Summation: Calculates the sum of all vector elements.
 * there are three ways to implement reduce sum for AVX512:
 * 1- split(256) /add /split(128) /add /hadd /hadd /extract
 * 2- shuff(cross) /add /shuff(cross) /add /shuff /add /shuff /add /extract
 * 3- _mm512_reduce_add_ps/pd
 * The first one is been widely used by many projects
 *
 * the second one is used by Intel Compiler, maybe because the
 * latency of hadd increased by (2-3) starting from Skylake-X which makes two
 * extra shuffles(non-cross) cheaper. check https://godbolt.org/z/s3G9Er for more info.
 *
 * The third one is almost the same as the second one but only works for
 * intel compiler/GCC 7.1/Clang 4, we still need to support older GCC.
 ***************************/
// reduce sum across vector
#ifdef NPY_HAVE_AVX512F_REDUCE
    #define npyv_sum_u32 _mm512_reduce_add_epi32
    #define npyv_sum_u64 _mm512_reduce_add_epi64
    #define npyv_sum_f32 _mm512_reduce_add_ps
    #define npyv_sum_f64 _mm512_reduce_add_pd
#else
    NPY_FINLINE npy_uint32 npyv_sum_u32(npyv_u32 a)
    {
        __m256i half = _mm256_add_epi32(npyv512_lower_si256(a), npyv512_higher_si256(a));
        __m128i quarter = _mm_add_epi32(_mm256_castsi256_si128(half), _mm256_extracti128_si256(half, 1));
        quarter = _mm_hadd_epi32(quarter, quarter);
        return _mm_cvtsi128_si32(_mm_hadd_epi32(quarter, quarter));
    }

    NPY_FINLINE npy_uint64 npyv_sum_u64(npyv_u64 a)
    {
        __m256i four = _mm256_add_epi64(npyv512_lower_si256(a), npyv512_higher_si256(a));
        __m256i two = _mm256_add_epi64(four, _mm256_shuffle_epi32(four, _MM_SHUFFLE(1, 0, 3, 2)));
        __m128i one = _mm_add_epi64(_mm256_castsi256_si128(two), _mm256_extracti128_si256(two, 1));
        return (npy_uint64)npyv128_cvtsi128_si64(one);
    }

    NPY_FINLINE float npyv_sum_f32(npyv_f32 a)
    {
        __m512 h64   = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2));
        __m512 sum32 = _mm512_add_ps(a, h64);
        __m512 h32   = _mm512_shuffle_f32x4(sum32, sum32, _MM_SHUFFLE(1, 0, 3, 2));
        __m512 sum16 = _mm512_add_ps(sum32, h32);
        __m512 h16   = _mm512_permute_ps(sum16, _MM_SHUFFLE(1, 0, 3, 2));
        __m512 sum8  = _mm512_add_ps(sum16, h16);
        __m512 h4    = _mm512_permute_ps(sum8, _MM_SHUFFLE(2, 3, 0, 1));
        __m512 sum4  = _mm512_add_ps(sum8, h4);
        return _mm_cvtss_f32(_mm512_castps512_ps128(sum4));
    }

    NPY_FINLINE double npyv_sum_f64(npyv_f64 a)
    {
        __m512d h64   = _mm512_shuffle_f64x2(a, a, _MM_SHUFFLE(3, 2, 3, 2));
        __m512d sum32 = _mm512_add_pd(a, h64);
        __m512d h32   = _mm512_permutex_pd(sum32, _MM_SHUFFLE(1, 0, 3, 2));
        __m512d sum16 = _mm512_add_pd(sum32, h32);
        __m512d h16   = _mm512_permute_pd(sum16, _MM_SHUFFLE(2, 3, 0, 1));
        __m512d sum8  = _mm512_add_pd(sum16, h16);
        return _mm_cvtsd_f64(_mm512_castpd512_pd128(sum8));
    }

#endif

// expand the source vector and performs sum reduce
NPY_FINLINE npy_uint16 npyv_sumup_u8(npyv_u8 a)
{
#ifdef NPY_HAVE_AVX512BW
    __m512i eight = _mm512_sad_epu8(a, _mm512_setzero_si512());
    __m256i four  = _mm256_add_epi16(npyv512_lower_si256(eight), npyv512_higher_si256(eight));
#else
    __m256i lo_four = _mm256_sad_epu8(npyv512_lower_si256(a), _mm256_setzero_si256());
    __m256i hi_four = _mm256_sad_epu8(npyv512_higher_si256(a), _mm256_setzero_si256());
    __m256i four    = _mm256_add_epi16(lo_four, hi_four);
#endif
    __m128i two     = _mm_add_epi16(_mm256_castsi256_si128(four), _mm256_extracti128_si256(four, 1));
    __m128i one     = _mm_add_epi16(two, _mm_unpackhi_epi64(two, two));
    return (npy_uint16)_mm_cvtsi128_si32(one);
}

NPY_FINLINE npy_uint32 npyv_sumup_u16(npyv_u16 a)
{
    const npyv_u16 even_mask = _mm512_set1_epi32(0x0000FFFF);
    __m512i even = _mm512_and_si512(a, even_mask);
    __m512i odd  = _mm512_srli_epi32(a, 16);
    __m512i ff   = _mm512_add_epi32(even, odd);
    return npyv_sum_u32(ff);
}

#endif // _NPY_SIMD_AVX512_ARITHMETIC_H