summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Wright <jonathan.wright@arm.com>2023-05-04 16:33:38 +0100
committerJonathan Wright <jonathan.wright@arm.com>2023-05-13 20:43:20 +0100
commit3e1e38d1176c34f71a87f8402c07cdcc2e20083e (patch)
treef11c4c4b8acb0fad783e0cc37b47a6da0df482c4
parent8ecf58432118b672fe3f4a54725bc63caac262aa (diff)
downloadlibvpx-3e1e38d1176c34f71a87f8402c07cdcc2e20083e.tar.gz
Add 2D-specific Neon horizontal convolution functions
2D 8-tap convolution filtering is performed in two passes - horizontal and vertical. The horizontal pass must produce enough input data for the subsequent vertical pass - 3 rows above and 4 rows below, in addition to the actual block height. At present, all Neon horizontal convolution algorithms process 4 rows at a time, but this means we end up doing at least 1 row too much work in the 2D first pass case where we need h + 7, not h + 8 rows of output. This patch adds additional dot-product (SDOT and USDOT) Neon paths that process h + 7 rows of data exactly, saving the work of the unnecessary extra row. It is impractical to take a similar approach for the Armv8.0 MLA paths since we have to transpose the data block both before and after calling the convolution helper functions. vpx_convolve_neon performance impact: we observe a speedup of ~9% for smaller (and wider) blocks, and a speedup of 0-3% for larger blocks. This is to be expected since the proportion of redundant work decreases as the block height increases. Change-Id: Ie77ad1848707d2d48bb8851345a469aae9d097e1
-rw-r--r--vpx_dsp/arm/mem_neon.h20
-rw-r--r--vpx_dsp/arm/vpx_convolve8_neon.c221
-rw-r--r--vpx_dsp/arm/vpx_convolve8_neon.h9
-rw-r--r--vpx_dsp/arm/vpx_convolve_neon.c55
4 files changed, 301 insertions, 4 deletions
diff --git a/vpx_dsp/arm/mem_neon.h b/vpx_dsp/arm/mem_neon.h
index 1a20da70e..586bfb85a 100644
--- a/vpx_dsp/arm/mem_neon.h
+++ b/vpx_dsp/arm/mem_neon.h
@@ -263,6 +263,16 @@ static INLINE void store_u8(uint8_t *buf, ptrdiff_t stride, const uint8x8_t a) {
vst1_lane_u32((uint32_t *)buf, a_u32, 1);
}
+static INLINE void store_u8_8x3(uint8_t *s, const ptrdiff_t p,
+ const uint8x8_t s0, const uint8x8_t s1,
+ const uint8x8_t s2) {
+ vst1_u8(s, s0);
+ s += p;
+ vst1_u8(s, s1);
+ s += p;
+ vst1_u8(s, s2);
+}
+
static INLINE void load_u8_8x4(const uint8_t *s, const ptrdiff_t p,
uint8x8_t *const s0, uint8x8_t *const s1,
uint8x8_t *const s2, uint8x8_t *const s3) {
@@ -287,6 +297,16 @@ static INLINE void store_u8_8x4(uint8_t *s, const ptrdiff_t p,
vst1_u8(s, s3);
}
+static INLINE void load_u8_16x3(const uint8_t *s, const ptrdiff_t p,
+ uint8x16_t *const s0, uint8x16_t *const s1,
+ uint8x16_t *const s2) {
+ *s0 = vld1q_u8(s);
+ s += p;
+ *s1 = vld1q_u8(s);
+ s += p;
+ *s2 = vld1q_u8(s);
+}
+
static INLINE void load_u8_16x4(const uint8_t *s, const ptrdiff_t p,
uint8x16_t *const s0, uint8x16_t *const s1,
uint8x16_t *const s2, uint8x16_t *const s3) {
diff --git a/vpx_dsp/arm/vpx_convolve8_neon.c b/vpx_dsp/arm/vpx_convolve8_neon.c
index f217a3f35..505d0672f 100644
--- a/vpx_dsp/arm/vpx_convolve8_neon.c
+++ b/vpx_dsp/arm/vpx_convolve8_neon.c
@@ -57,6 +57,111 @@ DECLARE_ALIGNED(16, static const uint8_t, dot_prod_merge_block_tbl[48]) = {
#if defined(__ARM_FEATURE_MATMUL_INT8)
+void vpx_convolve8_2d_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
+ uint8_t *dst, ptrdiff_t dst_stride,
+ const InterpKernel *filter, int x0_q4,
+ int x_step_q4, int y0_q4, int y_step_q4, int w,
+ int h) {
+ const int8x8_t filters = vmovn_s16(vld1q_s16(filter[x0_q4]));
+ uint8x16_t s0, s1, s2, s3;
+
+ assert((intptr_t)dst % 4 == 0);
+ assert(dst_stride % 4 == 0);
+ assert(x_step_q4 == 16);
+ assert(h % 4 == 3);
+
+ (void)x_step_q4;
+ (void)y0_q4;
+ (void)y_step_q4;
+
+ src -= 3;
+
+ if (w == 4) {
+ const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
+ int16x4_t d0, d1, d2, d3;
+ uint8x8_t d01, d23;
+
+ do {
+ load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
+
+ d0 = convolve8_4_usdot(s0, filters, perm_tbl);
+ d1 = convolve8_4_usdot(s1, filters, perm_tbl);
+ d2 = convolve8_4_usdot(s2, filters, perm_tbl);
+ d3 = convolve8_4_usdot(s3, filters, perm_tbl);
+ d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+ d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
+
+ store_u8(dst + 0 * dst_stride, dst_stride, d01);
+ store_u8(dst + 2 * dst_stride, dst_stride, d23);
+
+ src += 4 * src_stride;
+ dst += 4 * dst_stride;
+ h -= 4;
+ } while (h > 3);
+
+ /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for
+ * further details on possible values of block height. */
+ load_u8_16x3(src, src_stride, &s0, &s1, &s2);
+
+ d0 = convolve8_4_usdot(s0, filters, perm_tbl);
+ d1 = convolve8_4_usdot(s1, filters, perm_tbl);
+ d2 = convolve8_4_usdot(s2, filters, perm_tbl);
+ d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+ d23 = vqrshrun_n_s16(vcombine_s16(d2, vdup_n_s16(0)), FILTER_BITS);
+
+ store_u8(dst + 0 * dst_stride, dst_stride, d01);
+ store_u8_4x1(dst + 2 * dst_stride, d23);
+ } else {
+ const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
+ const uint8_t *s;
+ uint8_t *d;
+ int width;
+ uint8x8_t d0, d1, d2, d3;
+
+ do {
+ width = w;
+ s = src;
+ d = dst;
+ do {
+ load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
+
+ d0 = convolve8_8_usdot(s0, filters, perm_tbl);
+ d1 = convolve8_8_usdot(s1, filters, perm_tbl);
+ d2 = convolve8_8_usdot(s2, filters, perm_tbl);
+ d3 = convolve8_8_usdot(s3, filters, perm_tbl);
+
+ store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+ s += 8;
+ d += 8;
+ width -= 8;
+ } while (width > 0);
+ src += 4 * src_stride;
+ dst += 4 * dst_stride;
+ h -= 4;
+ } while (h > 3);
+
+ /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for
+ * further details on possible values of block height. */
+ width = w;
+ s = src;
+ d = dst;
+ do {
+ load_u8_16x3(s, src_stride, &s0, &s1, &s2);
+
+ d0 = convolve8_8_usdot(s0, filters, perm_tbl);
+ d1 = convolve8_8_usdot(s1, filters, perm_tbl);
+ d2 = convolve8_8_usdot(s2, filters, perm_tbl);
+
+ store_u8_8x3(d, dst_stride, d0, d1, d2);
+
+ s += 8;
+ d += 8;
+ width -= 8;
+ } while (width > 0);
+ }
+}
+
void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
uint8_t *dst, ptrdiff_t dst_stride,
const InterpKernel *filter, int x0_q4,
@@ -96,7 +201,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
src += 4 * src_stride;
dst += 4 * dst_stride;
h -= 4;
- } while (h > 0);
+ } while (h != 0);
} else {
const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
const uint8_t *s;
@@ -125,7 +230,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
src += 4 * src_stride;
dst += 4 * dst_stride;
h -= 4;
- } while (h > 0);
+ } while (h != 0);
}
}
@@ -611,6 +716,114 @@ void vpx_convolve8_avg_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
#else // !defined(__ARM_FEATURE_MATMUL_INT8)
+void vpx_convolve8_2d_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
+ uint8_t *dst, ptrdiff_t dst_stride,
+ const InterpKernel *filter, int x0_q4,
+ int x_step_q4, int y0_q4, int y_step_q4, int w,
+ int h) {
+ const int8x8_t filters = vmovn_s16(vld1q_s16(filter[x0_q4]));
+ const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter[x0_q4]), 128);
+ const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp));
+ const uint8x16_t range_limit = vdupq_n_u8(128);
+ uint8x16_t s0, s1, s2, s3;
+
+ assert((intptr_t)dst % 4 == 0);
+ assert(dst_stride % 4 == 0);
+ assert(x_step_q4 == 16);
+ assert(h % 4 == 3);
+
+ (void)x_step_q4;
+ (void)y0_q4;
+ (void)y_step_q4;
+
+ src -= 3;
+
+ if (w == 4) {
+ const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
+ int16x4_t d0, d1, d2, d3;
+ uint8x8_t d01, d23;
+
+ do {
+ load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
+
+ d0 = convolve8_4_sdot(s0, filters, correction, range_limit, perm_tbl);
+ d1 = convolve8_4_sdot(s1, filters, correction, range_limit, perm_tbl);
+ d2 = convolve8_4_sdot(s2, filters, correction, range_limit, perm_tbl);
+ d3 = convolve8_4_sdot(s3, filters, correction, range_limit, perm_tbl);
+ d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+ d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
+
+ store_u8(dst + 0 * dst_stride, dst_stride, d01);
+ store_u8(dst + 2 * dst_stride, dst_stride, d23);
+
+ src += 4 * src_stride;
+ dst += 4 * dst_stride;
+ h -= 4;
+ } while (h > 3);
+
+ /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for
+ * further details on possible values of block height. */
+ load_u8_16x3(src, src_stride, &s0, &s1, &s2);
+
+ d0 = convolve8_4_sdot(s0, filters, correction, range_limit, perm_tbl);
+ d1 = convolve8_4_sdot(s1, filters, correction, range_limit, perm_tbl);
+ d2 = convolve8_4_sdot(s2, filters, correction, range_limit, perm_tbl);
+ d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+ d23 = vqrshrun_n_s16(vcombine_s16(d2, vdup_n_s16(0)), FILTER_BITS);
+
+ store_u8(dst + 0 * dst_stride, dst_stride, d01);
+ store_u8_4x1(dst + 2 * dst_stride, d23);
+ } else {
+ const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
+ const uint8_t *s;
+ uint8_t *d;
+ int width;
+ uint8x8_t d0, d1, d2, d3;
+
+ do {
+ width = w;
+ s = src;
+ d = dst;
+ do {
+ load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
+
+ d0 = convolve8_8_sdot(s0, filters, correction, range_limit, perm_tbl);
+ d1 = convolve8_8_sdot(s1, filters, correction, range_limit, perm_tbl);
+ d2 = convolve8_8_sdot(s2, filters, correction, range_limit, perm_tbl);
+ d3 = convolve8_8_sdot(s3, filters, correction, range_limit, perm_tbl);
+
+ store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+ s += 8;
+ d += 8;
+ width -= 8;
+ } while (width != 0);
+ src += 4 * src_stride;
+ dst += 4 * dst_stride;
+ h -= 4;
+ } while (h > 3);
+
+ /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for
+ * further details on possible values of block height. */
+ width = w;
+ s = src;
+ d = dst;
+ do {
+ load_u8_16x3(s, src_stride, &s0, &s1, &s2);
+
+ d0 = convolve8_8_sdot(s0, filters, correction, range_limit, perm_tbl);
+ d1 = convolve8_8_sdot(s1, filters, correction, range_limit, perm_tbl);
+ d2 = convolve8_8_sdot(s2, filters, correction, range_limit, perm_tbl);
+
+ store_u8_8x3(d, dst_stride, d0, d1, d2);
+
+ s += 8;
+ d += 8;
+ width -= 8;
+ } while (width != 0);
+ }
+}
+
void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
uint8_t *dst, ptrdiff_t dst_stride,
const InterpKernel *filter, int x0_q4,
@@ -653,7 +866,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
src += 4 * src_stride;
dst += 4 * dst_stride;
h -= 4;
- } while (h > 0);
+ } while (h != 0);
} else {
const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
const uint8_t *s;
@@ -682,7 +895,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
src += 4 * src_stride;
dst += 4 * dst_stride;
h -= 4;
- } while (h > 0);
+ } while (h != 0);
}
}
diff --git a/vpx_dsp/arm/vpx_convolve8_neon.h b/vpx_dsp/arm/vpx_convolve8_neon.h
index c838d4047..2f78583af 100644
--- a/vpx_dsp/arm/vpx_convolve8_neon.h
+++ b/vpx_dsp/arm/vpx_convolve8_neon.h
@@ -17,6 +17,15 @@
#include "./vpx_dsp_rtcd.h"
#include "vpx_dsp/vpx_filter.h"
+#if VPX_ARCH_AARCH64 && \
+ (defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8))
+void vpx_convolve8_2d_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
+ uint8_t *dst, ptrdiff_t dst_stride,
+ const InterpKernel *filter, int x0_q4,
+ int x_step_q4, int y0_q4, int y_step_q4, int w,
+ int h);
+#endif
+
#if VPX_ARCH_AARCH64 && defined(__ARM_FEATURE_DOTPROD)
static INLINE int16x4_t convolve8_4_sdot_partial(const int8x16_t samples_lo,
diff --git a/vpx_dsp/arm/vpx_convolve_neon.c b/vpx_dsp/arm/vpx_convolve_neon.c
index 830f3176d..f7db3e6a9 100644
--- a/vpx_dsp/arm/vpx_convolve_neon.c
+++ b/vpx_dsp/arm/vpx_convolve_neon.c
@@ -14,6 +14,57 @@
#include "vpx_dsp/vpx_dsp_common.h"
#include "vpx_ports/mem.h"
+#if VPX_ARCH_AARCH64 && \
+ (defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8))
+#include "vpx_dsp/arm/vpx_convolve8_neon.h"
+
+void vpx_convolve8_neon(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
+ ptrdiff_t dst_stride, const InterpKernel *filter,
+ int x0_q4, int x_step_q4, int y0_q4, int y_step_q4,
+ int w, int h) {
+ /* Given our constraints: w <= 64, h <= 64, taps == 8 we can reduce the
+ * maximum buffer size to 64 * (64 + 7). */
+ uint8_t temp[64 * 71];
+
+ /* Account for the vertical phase needing 3 lines prior and 4 lines post. */
+ const int intermediate_height = h + 7;
+
+ assert(y_step_q4 == 16);
+ assert(x_step_q4 == 16);
+
+ /* Filter starting 3 lines back. */
+ vpx_convolve8_2d_horiz_neon(src - src_stride * 3, src_stride, temp, w, filter,
+ x0_q4, x_step_q4, y0_q4, y_step_q4, w,
+ intermediate_height);
+
+ /* Step into the temp buffer 3 lines to get the actual frame data */
+ vpx_convolve8_vert_neon(temp + w * 3, w, dst, dst_stride, filter, x0_q4,
+ x_step_q4, y0_q4, y_step_q4, w, h);
+}
+
+void vpx_convolve8_avg_neon(const uint8_t *src, ptrdiff_t src_stride,
+ uint8_t *dst, ptrdiff_t dst_stride,
+ const InterpKernel *filter, int x0_q4,
+ int x_step_q4, int y0_q4, int y_step_q4, int w,
+ int h) {
+ uint8_t temp[64 * 71];
+ const int intermediate_height = h + 7;
+
+ assert(y_step_q4 == 16);
+ assert(x_step_q4 == 16);
+
+ vpx_convolve8_2d_horiz_neon(src - src_stride * 3, src_stride, temp, w, filter,
+ x0_q4, x_step_q4, y0_q4, y_step_q4, w,
+ intermediate_height);
+
+ vpx_convolve8_avg_vert_neon(temp + w * 3, w, dst, dst_stride, filter, x0_q4,
+ x_step_q4, y0_q4, y_step_q4, w, h);
+}
+
+#else // !(VPX_ARCH_AARCH64 &&
+ // (defined(__ARM_FEATURE_DOTPROD) ||
+ // defined(__ARM_FEATURE_MATMUL_INT8)))
+
void vpx_convolve8_neon(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
ptrdiff_t dst_stride, const InterpKernel *filter,
int x0_q4, int x_step_q4, int y0_q4, int y_step_q4,
@@ -63,3 +114,7 @@ void vpx_convolve8_avg_neon(const uint8_t *src, ptrdiff_t src_stride,
vpx_convolve8_avg_vert_neon(temp + w * 3, w, dst, dst_stride, filter, x0_q4,
x_step_q4, y0_q4, y_step_q4, w, h);
}
+
+#endif // #if VPX_ARCH_AARCH64 &&
+ // (defined(__ARM_FEATURE_DOTPROD) ||
+ // defined(__ARM_FEATURE_MATMUL_INT8))