summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChunlin <fangchunlin@huawei.com>2020-07-03 21:06:04 +0800
committerGitHub <noreply@github.com>2020-07-03 16:06:04 +0300
commit9298eeb4f6c73ca1259f627860abe98b63f89da4 (patch)
tree399bdcbe7b708ffb5e891713e1788affe614c03c
parent6e3be61d6113e63a1d1be54e6c8341760ded5842 (diff)
downloadnumpy-9298eeb4f6c73ca1259f627860abe98b63f89da4.tar.gz
SIMD: Optimize the performace of np.packbits in ARM-based machine. (#16482)
* optimize pack_bit function for NEON.
-rw-r--r--numpy/core/src/multiarray/compiled_base.c40
1 files changed, 40 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/compiled_base.c b/numpy/core/src/multiarray/compiled_base.c
index 7a232b5d9..a8e4aa789 100644
--- a/numpy/core/src/multiarray/compiled_base.c
+++ b/numpy/core/src/multiarray/compiled_base.c
@@ -1504,6 +1504,17 @@ arr_add_docstring(PyObject *NPY_UNUSED(dummy), PyObject *args)
#include <emmintrin.h>
#endif
+#ifdef NPY_HAVE_NEON
+ typedef npy_uint64 uint64_unaligned __attribute__((aligned(16)));
+ static NPY_INLINE int32_t
+ sign_mask(uint8x16_t input)
+ {
+ int8x8_t m0 = vcreate_s8(0x0706050403020100ULL);
+ uint8x16_t v0 = vshlq_u8(vshrq_n_u8(input, 7), vcombine_s8(m0, m0));
+ uint64x2_t v1 = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v0)));
+ return (int)vgetq_lane_u64(v1, 0) + ((int)vgetq_lane_u64(v1, 1) << 8);
+ }
+#endif
/*
* This function packs boolean values in the input array into the bits of a
* byte array. Truth values are determined as usual: 0 is false, everything
@@ -1543,6 +1554,7 @@ pack_inner(const char *inptr,
a = npy_bswap8(a);
b = npy_bswap8(b);
}
+
/* note x86 can load unaligned */
__m128i v = _mm_set_epi64(_m_from_int64(b), _m_from_int64(a));
/* false -> 0x00 and true -> 0xFF (there is no cmpneq) */
@@ -1558,6 +1570,34 @@ pack_inner(const char *inptr,
inptr += 16;
}
}
+#elif defined NPY_HAVE_NEON
+ if (in_stride == 1 && element_size == 1 && n_out > 2) {
+ /* don't handle non-full 8-byte remainder */
+ npy_intp vn_out = n_out - (remain ? 1 : 0);
+ vn_out -= (vn_out & 1);
+ for (index = 0; index < vn_out; index += 2) {
+ unsigned int r;
+ npy_uint64 a = *((uint64_unaligned*)inptr);
+ npy_uint64 b = *((uint64_unaligned*)(inptr + 8));
+ if (order == 'b') {
+ a = npy_bswap8(a);
+ b = npy_bswap8(b);
+ }
+ uint64x2_t v = vcombine_u64(vcreate_u64(a), vcreate_u64(b));
+ uint64x2_t zero = vdupq_n_u64(0);
+ /* false -> 0x00 and true -> 0xFF */
+ v = vreinterpretq_u64_u8(vmvnq_u8(vceqq_u8(vreinterpretq_u8_u64(v), vreinterpretq_u8_u64(zero))));
+ /* extract msb of 16 bytes and pack it into 16 bit */
+ uint8x16_t input = vreinterpretq_u8_u64(v);
+ r = sign_mask(input);
+ /* store result */
+ memcpy(outptr, &r, 1);
+ outptr += out_stride;
+ memcpy(outptr, (char*)&r + 1, 1);
+ outptr += out_stride;
+ inptr += 16;
+ }
+ }
#endif
if (remain == 0) { /* assumes n_in > 0 */