summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Wright <jonathan.wright@arm.com>2022-05-18 14:14:56 +0100
committerJonathan Wright <jonathan.wright@arm.com>2023-01-11 12:18:45 +0000
commitf952068691bcc397a17721d004ac84e63e46bb3c (patch)
tree672029ee9dec03e794cbc279b199f835e7f38ac2
parent708c4aa8540ec81aa5f0d93edc2e1e4d6d4581ac (diff)
downloadlibvpx-f952068691bcc397a17721d004ac84e63e46bb3c.tar.gz
Implement horizontal convolutions using Neon USDOT instruction
Add additional AArch64 paths for vpx_convolve8_horiz_neon and vpx_convolve8_avg_horiz_neon that use the Armv8.6-A USDOT (mixed-sign dot-product) instruction. The USDOT instruction takes an 8-bit unsigned operand vector and a signed 8-bit operand vector to produce a signed 32-bit result. This is helpful because convolution filters often have both positive and negative values, while the 8-bit pixel channel data being filtered is all unsigned. As a result, the USDOT convolution paths added here do not have to do the "transform the pixel channel data to [-128, 128) and correct for it later" dance that we have to do with the SDOT paths. The USDOT instruction is optional from Armv8.2 to Armv8.5 but mandatory from Armv8.6 onwards. The availability of the USDOT instruction is indicated by the feature macro __ARM_FEATURE_MATMUL_INT8. The SDOT paths are retained for use on target CPUs that do not implement the USDOT instructions. Change-Id: If19f5872c3453458a8cfb7c7d2be82a2c0eab46a
-rw-r--r--vpx_dsp/arm/vpx_convolve8_neon.c271
-rw-r--r--vpx_dsp/arm/vpx_convolve8_neon.h91
2 files changed, 299 insertions, 63 deletions
diff --git a/vpx_dsp/arm/vpx_convolve8_neon.c b/vpx_dsp/arm/vpx_convolve8_neon.c
index dba436b1a..81ceb518d 100644
--- a/vpx_dsp/arm/vpx_convolve8_neon.c
+++ b/vpx_dsp/arm/vpx_convolve8_neon.c
@@ -31,7 +31,9 @@
// instructions. This optimization is much faster in speed unit test, but slowed
// down the whole decoder by 5%.
-#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+#if defined(__aarch64__) && \
+ (defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8))
+
DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = {
0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
@@ -96,6 +98,175 @@ static INLINE void transpose_concat_8x4(int8x8_t *a0, int8x8_t *a1,
*b1 = vqtbl2q_s8(samples, permute_tbl.val[1]);
}
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+
+void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
+ uint8_t *dst, ptrdiff_t dst_stride,
+ const InterpKernel *filter, int x0_q4,
+ int x_step_q4, int y0_q4, int y_step_q4, int w,
+ int h) {
+ const int8x8_t filters = vmovn_s16(vld1q_s16(filter[x0_q4]));
+ uint8x16_t s0, s1, s2, s3;
+
+ assert(!((intptr_t)dst & 3));
+ assert(!(dst_stride & 3));
+ assert(x_step_q4 == 16);
+
+ (void)x_step_q4;
+ (void)y0_q4;
+ (void)y_step_q4;
+
+ src -= 3;
+
+ if (w == 4) {
+ const uint8x16x2_t permute_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
+ do {
+ int32x4_t t0, t1, t2, t3;
+ int16x8_t t01, t23;
+ uint8x8_t d01, d23;
+
+ load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
+
+ t0 = convolve8_4_usdot(s0, filters, permute_tbl);
+ t1 = convolve8_4_usdot(s1, filters, permute_tbl);
+ t2 = convolve8_4_usdot(s2, filters, permute_tbl);
+ t3 = convolve8_4_usdot(s3, filters, permute_tbl);
+ t01 = vcombine_s16(vqmovn_s32(t0), vqmovn_s32(t1));
+ t23 = vcombine_s16(vqmovn_s32(t2), vqmovn_s32(t3));
+ d01 = vqrshrun_n_s16(t01, 7);
+ d23 = vqrshrun_n_s16(t23, 7);
+
+ store_u8(dst + 0 * dst_stride, dst_stride, d01);
+ store_u8(dst + 2 * dst_stride, dst_stride, d23);
+
+ src += 4 * src_stride;
+ dst += 4 * dst_stride;
+ h -= 4;
+ } while (h > 0);
+ } else {
+ const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
+ const uint8_t *s;
+ uint8_t *d;
+ int width;
+ uint8x8_t d0, d1, d2, d3;
+
+ do {
+ width = w;
+ s = src;
+ d = dst;
+ do {
+ load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
+
+ d0 = convolve8_8_usdot(s0, filters, permute_tbl);
+ d1 = convolve8_8_usdot(s1, filters, permute_tbl);
+ d2 = convolve8_8_usdot(s2, filters, permute_tbl);
+ d3 = convolve8_8_usdot(s3, filters, permute_tbl);
+
+ store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+ s += 8;
+ d += 8;
+ width -= 8;
+ } while (width > 0);
+ src += 4 * src_stride;
+ dst += 4 * dst_stride;
+ h -= 4;
+ } while (h > 0);
+ }
+}
+
+void vpx_convolve8_avg_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
+ uint8_t *dst, ptrdiff_t dst_stride,
+ const InterpKernel *filter, int x0_q4,
+ int x_step_q4, int y0_q4, int y_step_q4,
+ int w, int h) {
+ const int8x8_t filters = vmovn_s16(vld1q_s16(filter[x0_q4]));
+ uint8x16_t s0, s1, s2, s3;
+
+ assert(!((intptr_t)dst & 3));
+ assert(!(dst_stride & 3));
+ assert(x_step_q4 == 16);
+
+ (void)x_step_q4;
+ (void)y0_q4;
+ (void)y_step_q4;
+
+ src -= 3;
+
+ if (w == 4) {
+ const uint8x16x2_t permute_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
+ do {
+ int32x4_t t0, t1, t2, t3;
+ int16x8_t t01, t23;
+ uint8x8_t d01, d23, dd01, dd23;
+ dd01 = vdup_n_u8(0);
+ dd23 = vdup_n_u8(0);
+
+ load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
+
+ t0 = convolve8_4_usdot(s0, filters, permute_tbl);
+ t1 = convolve8_4_usdot(s1, filters, permute_tbl);
+ t2 = convolve8_4_usdot(s2, filters, permute_tbl);
+ t3 = convolve8_4_usdot(s3, filters, permute_tbl);
+ t01 = vcombine_s16(vqmovn_s32(t0), vqmovn_s32(t1));
+ t23 = vcombine_s16(vqmovn_s32(t2), vqmovn_s32(t3));
+ d01 = vqrshrun_n_s16(t01, 7);
+ d23 = vqrshrun_n_s16(t23, 7);
+
+ dd01 = load_u8(dst + 0 * dst_stride, dst_stride);
+ dd23 = load_u8(dst + 2 * dst_stride, dst_stride);
+
+ d01 = vrhadd_u8(d01, dd01);
+ d23 = vrhadd_u8(d23, dd23);
+
+ store_u8(dst + 0 * dst_stride, dst_stride, d01);
+ store_u8(dst + 2 * dst_stride, dst_stride, d23);
+
+ src += 4 * src_stride;
+ dst += 4 * dst_stride;
+ h -= 4;
+ } while (h > 0);
+ } else {
+ const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
+ const uint8_t *s;
+ uint8_t *d;
+ int width;
+ uint8x8_t d0, d1, d2, d3, dd0, dd1, dd2, dd3;
+
+ do {
+ width = w;
+ s = src;
+ d = dst;
+ do {
+ load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
+
+ d0 = convolve8_8_usdot(s0, filters, permute_tbl);
+ d1 = convolve8_8_usdot(s1, filters, permute_tbl);
+ d2 = convolve8_8_usdot(s2, filters, permute_tbl);
+ d3 = convolve8_8_usdot(s3, filters, permute_tbl);
+
+ load_u8_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
+
+ d0 = vrhadd_u8(d0, dd0);
+ d1 = vrhadd_u8(d1, dd1);
+ d2 = vrhadd_u8(d2, dd2);
+ d3 = vrhadd_u8(d3, dd3);
+
+ store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+ s += 8;
+ d += 8;
+ width -= 8;
+ } while (width > 0);
+ src += 4 * src_stride;
+ dst += 4 * dst_stride;
+ h -= 4;
+ } while (h > 0);
+ }
+}
+
+#else // !defined(__ARM_FEATURE_MATMUL_INT8)
+
void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
uint8_t *dst, ptrdiff_t dst_stride,
const InterpKernel *filter, int x0_q4,
@@ -126,10 +297,10 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
- t0 = convolve8_4_dot(s0, filters, correction, range_limit, permute_tbl);
- t1 = convolve8_4_dot(s1, filters, correction, range_limit, permute_tbl);
- t2 = convolve8_4_dot(s2, filters, correction, range_limit, permute_tbl);
- t3 = convolve8_4_dot(s3, filters, correction, range_limit, permute_tbl);
+ t0 = convolve8_4_sdot(s0, filters, correction, range_limit, permute_tbl);
+ t1 = convolve8_4_sdot(s1, filters, correction, range_limit, permute_tbl);
+ t2 = convolve8_4_sdot(s2, filters, correction, range_limit, permute_tbl);
+ t3 = convolve8_4_sdot(s3, filters, correction, range_limit, permute_tbl);
t01 = vcombine_s16(vqmovn_s32(t0), vqmovn_s32(t1));
t23 = vcombine_s16(vqmovn_s32(t2), vqmovn_s32(t3));
d01 = vqrshrun_n_s16(t01, 7);
@@ -156,10 +327,14 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
do {
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
- d0 = convolve8_8_dot(s0, filters, correction, range_limit, permute_tbl);
- d1 = convolve8_8_dot(s1, filters, correction, range_limit, permute_tbl);
- d2 = convolve8_8_dot(s2, filters, correction, range_limit, permute_tbl);
- d3 = convolve8_8_dot(s3, filters, correction, range_limit, permute_tbl);
+ d0 =
+ convolve8_8_sdot(s0, filters, correction, range_limit, permute_tbl);
+ d1 =
+ convolve8_8_sdot(s1, filters, correction, range_limit, permute_tbl);
+ d2 =
+ convolve8_8_sdot(s2, filters, correction, range_limit, permute_tbl);
+ d3 =
+ convolve8_8_sdot(s3, filters, correction, range_limit, permute_tbl);
store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
@@ -206,10 +381,10 @@ void vpx_convolve8_avg_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
- t0 = convolve8_4_dot(s0, filters, correction, range_limit, permute_tbl);
- t1 = convolve8_4_dot(s1, filters, correction, range_limit, permute_tbl);
- t2 = convolve8_4_dot(s2, filters, correction, range_limit, permute_tbl);
- t3 = convolve8_4_dot(s3, filters, correction, range_limit, permute_tbl);
+ t0 = convolve8_4_sdot(s0, filters, correction, range_limit, permute_tbl);
+ t1 = convolve8_4_sdot(s1, filters, correction, range_limit, permute_tbl);
+ t2 = convolve8_4_sdot(s2, filters, correction, range_limit, permute_tbl);
+ t3 = convolve8_4_sdot(s3, filters, correction, range_limit, permute_tbl);
t01 = vcombine_s16(vqmovn_s32(t0), vqmovn_s32(t1));
t23 = vcombine_s16(vqmovn_s32(t2), vqmovn_s32(t3));
d01 = vqrshrun_n_s16(t01, 7);
@@ -242,10 +417,14 @@ void vpx_convolve8_avg_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
do {
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
- d0 = convolve8_8_dot(s0, filters, correction, range_limit, permute_tbl);
- d1 = convolve8_8_dot(s1, filters, correction, range_limit, permute_tbl);
- d2 = convolve8_8_dot(s2, filters, correction, range_limit, permute_tbl);
- d3 = convolve8_8_dot(s3, filters, correction, range_limit, permute_tbl);
+ d0 =
+ convolve8_8_sdot(s0, filters, correction, range_limit, permute_tbl);
+ d1 =
+ convolve8_8_sdot(s1, filters, correction, range_limit, permute_tbl);
+ d2 =
+ convolve8_8_sdot(s2, filters, correction, range_limit, permute_tbl);
+ d3 =
+ convolve8_8_sdot(s3, filters, correction, range_limit, permute_tbl);
load_u8_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -267,6 +446,8 @@ void vpx_convolve8_avg_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
}
}
+#endif // defined(__ARM_FEATURE_MATMUL_INT8)
+
void vpx_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
uint8_t *dst, ptrdiff_t dst_stride,
const InterpKernel *filter, int x0_q4,
@@ -342,10 +523,10 @@ void vpx_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
s5678 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
s6789 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
- d0 = convolve8_4_dot_partial(s0123, s4567, correction, filters);
- d1 = convolve8_4_dot_partial(s1234, s5678, correction, filters);
- d2 = convolve8_4_dot_partial(s2345, s6789, correction, filters);
- d3 = convolve8_4_dot_partial(s3456, s78910, correction, filters);
+ d0 = convolve8_4_sdot_partial(s0123, s4567, correction, filters);
+ d1 = convolve8_4_sdot_partial(s1234, s5678, correction, filters);
+ d2 = convolve8_4_sdot_partial(s2345, s6789, correction, filters);
+ d3 = convolve8_4_sdot_partial(s3456, s78910, correction, filters);
d01 = vqrshrun_n_s16(vcombine_s16(vqmovn_s32(d0), vqmovn_s32(d1)), 7);
d23 = vqrshrun_n_s16(vcombine_s16(vqmovn_s32(d2), vqmovn_s32(d3)), 7);
@@ -437,14 +618,14 @@ void vpx_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
s5678_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
s6789_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
- d0 = convolve8_8_dot_partial(s0123_lo, s4567_lo, s0123_hi, s4567_hi,
- correction, filters);
- d1 = convolve8_8_dot_partial(s1234_lo, s5678_lo, s1234_hi, s5678_hi,
- correction, filters);
- d2 = convolve8_8_dot_partial(s2345_lo, s6789_lo, s2345_hi, s6789_hi,
- correction, filters);
- d3 = convolve8_8_dot_partial(s3456_lo, s78910_lo, s3456_hi, s78910_hi,
- correction, filters);
+ d0 = convolve8_8_sdot_partial(s0123_lo, s4567_lo, s0123_hi, s4567_hi,
+ correction, filters);
+ d1 = convolve8_8_sdot_partial(s1234_lo, s5678_lo, s1234_hi, s5678_hi,
+ correction, filters);
+ d2 = convolve8_8_sdot_partial(s2345_lo, s6789_lo, s2345_hi, s6789_hi,
+ correction, filters);
+ d3 = convolve8_8_sdot_partial(s3456_lo, s78910_lo, s3456_hi, s78910_hi,
+ correction, filters);
store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
@@ -545,10 +726,10 @@ void vpx_convolve8_avg_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
s5678 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
s6789 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
- d0 = convolve8_4_dot_partial(s0123, s4567, correction, filters);
- d1 = convolve8_4_dot_partial(s1234, s5678, correction, filters);
- d2 = convolve8_4_dot_partial(s2345, s6789, correction, filters);
- d3 = convolve8_4_dot_partial(s3456, s78910, correction, filters);
+ d0 = convolve8_4_sdot_partial(s0123, s4567, correction, filters);
+ d1 = convolve8_4_sdot_partial(s1234, s5678, correction, filters);
+ d2 = convolve8_4_sdot_partial(s2345, s6789, correction, filters);
+ d3 = convolve8_4_sdot_partial(s3456, s78910, correction, filters);
d01 = vqrshrun_n_s16(vcombine_s16(vqmovn_s32(d0), vqmovn_s32(d1)), 7);
d23 = vqrshrun_n_s16(vcombine_s16(vqmovn_s32(d2), vqmovn_s32(d3)), 7);
@@ -646,14 +827,14 @@ void vpx_convolve8_avg_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
s5678_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
s6789_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
- d0 = convolve8_8_dot_partial(s0123_lo, s4567_lo, s0123_hi, s4567_hi,
- correction, filters);
- d1 = convolve8_8_dot_partial(s1234_lo, s5678_lo, s1234_hi, s5678_hi,
- correction, filters);
- d2 = convolve8_8_dot_partial(s2345_lo, s6789_lo, s2345_hi, s6789_hi,
- correction, filters);
- d3 = convolve8_8_dot_partial(s3456_lo, s78910_lo, s3456_hi, s78910_hi,
- correction, filters);
+ d0 = convolve8_8_sdot_partial(s0123_lo, s4567_lo, s0123_hi, s4567_hi,
+ correction, filters);
+ d1 = convolve8_8_sdot_partial(s1234_lo, s5678_lo, s1234_hi, s5678_hi,
+ correction, filters);
+ d2 = convolve8_8_sdot_partial(s2345_lo, s6789_lo, s2345_hi, s6789_hi,
+ correction, filters);
+ d3 = convolve8_8_sdot_partial(s3456_lo, s78910_lo, s3456_hi, s78910_hi,
+ correction, filters);
load_u8_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -686,7 +867,9 @@ void vpx_convolve8_avg_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
}
}
-#else // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
+#else // !(defined(__aarch64__) &&
+ // (defined(__ARM_FEATURE_DOTPROD) ||
+ // defined(__ARM_FEATURE_MATMUL_INT8)))
void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
uint8_t *dst, ptrdiff_t dst_stride,
@@ -1528,4 +1711,6 @@ void vpx_convolve8_avg_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
}
}
-#endif // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+#endif // #if defined(__aarch64__) &&
+ // (defined(__ARM_FEATURE_DOTPROD) ||
+ // defined(__ARM_FEATURE_MATMUL_INT8))
diff --git a/vpx_dsp/arm/vpx_convolve8_neon.h b/vpx_dsp/arm/vpx_convolve8_neon.h
index 26a5fa688..a62e4f461 100644
--- a/vpx_dsp/arm/vpx_convolve8_neon.h
+++ b/vpx_dsp/arm/vpx_convolve8_neon.h
@@ -18,10 +18,10 @@
#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
-static INLINE int32x4_t convolve8_4_dot_partial(const int8x16_t samples_lo,
- const int8x16_t samples_hi,
- const int32x4_t correction,
- const int8x8_t filters) {
+static INLINE int32x4_t convolve8_4_sdot_partial(const int8x16_t samples_lo,
+ const int8x16_t samples_hi,
+ const int32x4_t correction,
+ const int8x8_t filters) {
/* Sample range-clamping and permutation are performed by the caller. */
int32x4_t sum;
@@ -33,11 +33,11 @@ static INLINE int32x4_t convolve8_4_dot_partial(const int8x16_t samples_lo,
return sum;
}
-static INLINE int32x4_t convolve8_4_dot(uint8x16_t samples,
- const int8x8_t filters,
- const int32x4_t correction,
- const uint8x16_t range_limit,
- const uint8x16x2_t permute_tbl) {
+static INLINE int32x4_t convolve8_4_sdot(uint8x16_t samples,
+ const int8x8_t filters,
+ const int32x4_t correction,
+ const uint8x16_t range_limit,
+ const uint8x16x2_t permute_tbl) {
int8x16_t clamped_samples, permuted_samples[2];
int32x4_t sum;
@@ -58,12 +58,12 @@ static INLINE int32x4_t convolve8_4_dot(uint8x16_t samples,
return sum;
}
-static INLINE uint8x8_t convolve8_8_dot_partial(const int8x16_t samples0_lo,
- const int8x16_t samples0_hi,
- const int8x16_t samples1_lo,
- const int8x16_t samples1_hi,
- const int32x4_t correction,
- const int8x8_t filters) {
+static INLINE uint8x8_t convolve8_8_sdot_partial(const int8x16_t samples0_lo,
+ const int8x16_t samples0_hi,
+ const int8x16_t samples1_lo,
+ const int8x16_t samples1_hi,
+ const int32x4_t correction,
+ const int8x8_t filters) {
/* Sample range-clamping and permutation are performed by the caller. */
int32x4_t sum0, sum1;
int16x8_t sum;
@@ -81,11 +81,11 @@ static INLINE uint8x8_t convolve8_8_dot_partial(const int8x16_t samples0_lo,
return vqrshrun_n_s16(sum, 7);
}
-static INLINE uint8x8_t convolve8_8_dot(uint8x16_t samples,
- const int8x8_t filters,
- const int32x4_t correction,
- const uint8x16_t range_limit,
- const uint8x16x3_t permute_tbl) {
+static INLINE uint8x8_t convolve8_8_sdot(uint8x16_t samples,
+ const int8x8_t filters,
+ const int32x4_t correction,
+ const uint8x16_t range_limit,
+ const uint8x16x3_t permute_tbl) {
int8x16_t clamped_samples, permuted_samples[3];
int32x4_t sum0, sum1;
int16x8_t sum;
@@ -116,6 +116,57 @@ static INLINE uint8x8_t convolve8_8_dot(uint8x16_t samples,
#endif // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+#if defined(__aarch64__) && defined(__ARM_FEATURE_MATMUL_INT8)
+
+static INLINE int32x4_t convolve8_4_usdot(uint8x16_t samples,
+ const int8x8_t filters,
+ const uint8x16x2_t permute_tbl) {
+ uint8x16_t permuted_samples[2];
+ int32x4_t sum;
+
+ /* Permute samples ready for dot product. */
+ /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */
+ permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
+ /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */
+ permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
+
+ /* Accumulate dot product into 'correction' to account for range clamp. */
+ sum = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0);
+ sum = vusdotq_lane_s32(sum, permuted_samples[1], filters, 1);
+
+ /* Narrowing and packing is performed by the caller. */
+ return sum;
+}
+
+static INLINE uint8x8_t convolve8_8_usdot(uint8x16_t samples,
+ const int8x8_t filters,
+ const uint8x16x3_t permute_tbl) {
+ uint8x16_t permuted_samples[3];
+ int32x4_t sum0, sum1;
+ int16x8_t sum;
+
+ /* Permute samples ready for dot product. */
+ /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */
+ permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
+ /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */
+ permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
+ /* { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+ permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
+
+ /* First 4 output values. */
+ sum0 = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0);
+ sum0 = vusdotq_lane_s32(sum0, permuted_samples[1], filters, 1);
+ /* Second 4 output values. */
+ sum1 = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[1], filters, 0);
+ sum1 = vusdotq_lane_s32(sum1, permuted_samples[2], filters, 1);
+
+ /* Narrow and re-pack. */
+ sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
+ return vqrshrun_n_s16(sum, 7);
+}
+
+#endif // defined(__aarch64__) && defined(__ARM_FEATURE_MATMUL_INT8)
+
static INLINE int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1,
const int16x4_t s2, const int16x4_t s3,
const int16x4_t s4, const int16x4_t s5,