diff options
author | Yunqing Wang <yunqingwang@google.com> | 2023-05-09 15:57:09 +0000 |
---|---|---|
committer | Gerrit Code Review <noreply-gerritcodereview@google.com> | 2023-05-09 15:57:09 +0000 |
commit | cc1b3886f25af47aee03117380451037ad2236a0 (patch) | |
tree | 8e39d2e6d689e33e73531c2294c65f72c4133d28 | |
parent | 19ec57e14938bcb12d87123b7c369212f19792eb (diff) | |
parent | 457b7f59860955415a23c20c535fc13fde51936f (diff) | |
download | libvpx-cc1b3886f25af47aee03117380451037ad2236a0.tar.gz |
Merge "Add AVX2 intrinsic for vpx_comp_avg_pred() function" into main
-rw-r--r-- | test/comp_avg_pred_test.cc | 21 | ||||
-rw-r--r-- | vp9/encoder/vp9_mcomp.c | 8 | ||||
-rw-r--r-- | vp9/encoder/vp9_rdopt.c | 4 | ||||
-rw-r--r-- | vpx_dsp/sad.c | 2 | ||||
-rw-r--r-- | vpx_dsp/variance.c | 2 | ||||
-rw-r--r-- | vpx_dsp/vpx_dsp.mk | 1 | ||||
-rw-r--r-- | vpx_dsp/vpx_dsp_rtcd_defs.pl | 2 | ||||
-rw-r--r-- | vpx_dsp/x86/avg_pred_avx2.c | 111 |
8 files changed, 134 insertions, 17 deletions
diff --git a/test/comp_avg_pred_test.cc b/test/comp_avg_pred_test.cc index f747c3524..d8fabd5be 100644 --- a/test/comp_avg_pred_test.cc +++ b/test/comp_avg_pred_test.cc @@ -81,11 +81,11 @@ void AvgPredTest<bitdepth, Pixel>::TestSizeCombinations() { // Only the reference buffer may have a stride not equal to width. Buffer<Pixel> ref = Buffer<Pixel>(width, height, ref_padding ? 8 : 0); ASSERT_TRUE(ref.Init()); - Buffer<Pixel> pred = Buffer<Pixel>(width, height, 0, 16); + Buffer<Pixel> pred = Buffer<Pixel>(width, height, 0, 32); ASSERT_TRUE(pred.Init()); - Buffer<Pixel> avg_ref = Buffer<Pixel>(width, height, 0, 16); + Buffer<Pixel> avg_ref = Buffer<Pixel>(width, height, 0, 32); ASSERT_TRUE(avg_ref.Init()); - Buffer<Pixel> avg_chk = Buffer<Pixel>(width, height, 0, 16); + Buffer<Pixel> avg_chk = Buffer<Pixel>(width, height, 0, 32); ASSERT_TRUE(avg_chk.Init()); const int bitdepth_mask = (1 << bitdepth) - 1; for (int h = 0; h < height; ++h) { @@ -121,11 +121,11 @@ void AvgPredTest<bitdepth, Pixel>::TestCompareReferenceRandom() { const int height = 32; Buffer<Pixel> ref = Buffer<Pixel>(width, height, 8); ASSERT_TRUE(ref.Init()); - Buffer<Pixel> pred = Buffer<Pixel>(width, height, 0, 16); + Buffer<Pixel> pred = Buffer<Pixel>(width, height, 0, 32); ASSERT_TRUE(pred.Init()); - Buffer<Pixel> avg_ref = Buffer<Pixel>(width, height, 0, 16); + Buffer<Pixel> avg_ref = Buffer<Pixel>(width, height, 0, 32); ASSERT_TRUE(avg_ref.Init()); - Buffer<Pixel> avg_chk = Buffer<Pixel>(width, height, 0, 16); + Buffer<Pixel> avg_chk = Buffer<Pixel>(width, height, 0, 32); ASSERT_TRUE(avg_chk.Init()); for (int i = 0; i < 500; ++i) { @@ -167,9 +167,9 @@ void AvgPredTest<bitdepth, Pixel>::TestSpeed() { const int height = 1 << height_pow; Buffer<Pixel> ref = Buffer<Pixel>(width, height, ref_padding ? 8 : 0); ASSERT_TRUE(ref.Init()); - Buffer<Pixel> pred = Buffer<Pixel>(width, height, 0, 16); + Buffer<Pixel> pred = Buffer<Pixel>(width, height, 0, 32); ASSERT_TRUE(pred.Init()); - Buffer<Pixel> avg = Buffer<Pixel>(width, height, 0, 16); + Buffer<Pixel> avg = Buffer<Pixel>(width, height, 0, 32); ASSERT_TRUE(avg.Init()); const int bitdepth_mask = (1 << bitdepth) - 1; for (int h = 0; h < height; ++h) { @@ -217,6 +217,11 @@ INSTANTIATE_TEST_SUITE_P(SSE2, AvgPredTestLBD, ::testing::Values(&vpx_comp_avg_pred_sse2)); #endif // HAVE_SSE2 +#if HAVE_AVX2 +INSTANTIATE_TEST_SUITE_P(AVX2, AvgPredTestLBD, + ::testing::Values(&vpx_comp_avg_pred_avx2)); +#endif // HAVE_AVX2 + #if HAVE_NEON INSTANTIATE_TEST_SUITE_P(NEON, AvgPredTestLBD, ::testing::Values(&vpx_comp_avg_pred_neon)); diff --git a/vp9/encoder/vp9_mcomp.c b/vp9/encoder/vp9_mcomp.c index 64e9ef0f9..0ea0f85e4 100644 --- a/vp9/encoder/vp9_mcomp.c +++ b/vp9/encoder/vp9_mcomp.c @@ -297,7 +297,7 @@ static unsigned int setup_center_error( besterr = vfp->vf(CONVERT_TO_BYTEPTR(comp_pred16), w, src, src_stride, sse1); } else { - DECLARE_ALIGNED(16, uint8_t, comp_pred[64 * 64]); + DECLARE_ALIGNED(32, uint8_t, comp_pred[64 * 64]); vpx_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, y_stride); besterr = vfp->vf(comp_pred, w, src, src_stride, sse1); } @@ -312,7 +312,7 @@ static unsigned int setup_center_error( uint32_t besterr; (void)xd; if (second_pred != NULL) { - DECLARE_ALIGNED(16, uint8_t, comp_pred[64 * 64]); + DECLARE_ALIGNED(32, uint8_t, comp_pred[64 * 64]); vpx_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, y_stride); besterr = vfp->vf(comp_pred, w, src, src_stride, sse1); } else { @@ -635,7 +635,7 @@ static int accurate_sub_pel_search( vp9_build_inter_predictor(pre_address, y_stride, pred, w, this_mv, sf, w, h, 0, kernel, MV_PRECISION_Q3, 0, 0); if (second_pred != NULL) { - DECLARE_ALIGNED(16, uint8_t, comp_pred[64 * 64]); + DECLARE_ALIGNED(32, uint8_t, comp_pred[64 * 64]); vpx_comp_avg_pred(comp_pred, second_pred, w, h, pred, w); besterr = vfp->vf(comp_pred, w, src_address, src_stride, sse); } else { @@ -654,7 +654,7 @@ static int accurate_sub_pel_search( vp9_build_inter_predictor(pre_address, y_stride, pred, w, this_mv, sf, w, h, 0, kernel, MV_PRECISION_Q3, 0, 0); if (second_pred != NULL) { - DECLARE_ALIGNED(16, uint8_t, comp_pred[64 * 64]); + DECLARE_ALIGNED(32, uint8_t, comp_pred[64 * 64]); vpx_comp_avg_pred(comp_pred, second_pred, w, h, pred, w); besterr = vfp->vf(comp_pred, w, src_address, src_stride, sse); } else { diff --git a/vp9/encoder/vp9_rdopt.c b/vp9/encoder/vp9_rdopt.c index f051c6279..464705a67 100644 --- a/vp9/encoder/vp9_rdopt.c +++ b/vp9/encoder/vp9_rdopt.c @@ -1937,10 +1937,10 @@ static void joint_motion_search(VP9_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize, // Prediction buffer from second frame. #if CONFIG_VP9_HIGHBITDEPTH - DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[64 * 64]); + DECLARE_ALIGNED(32, uint16_t, second_pred_alloc_16[64 * 64]); uint8_t *second_pred; #else - DECLARE_ALIGNED(16, uint8_t, second_pred[64 * 64]); + DECLARE_ALIGNED(32, uint8_t, second_pred[64 * 64]); #endif // CONFIG_VP9_HIGHBITDEPTH // Check number of iterations do not exceed the max diff --git a/vpx_dsp/sad.c b/vpx_dsp/sad.c index 619d7aa95..2a4c81d58 100644 --- a/vpx_dsp/sad.c +++ b/vpx_dsp/sad.c @@ -40,7 +40,7 @@ static INLINE unsigned int sad(const uint8_t *src_ptr, int src_stride, unsigned int vpx_sad##m##x##n##_avg_c( \ const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ int ref_stride, const uint8_t *second_pred) { \ - DECLARE_ALIGNED(16, uint8_t, comp_pred[m * n]); \ + DECLARE_ALIGNED(32, uint8_t, comp_pred[m * n]); \ vpx_comp_avg_pred_c(comp_pred, second_pred, m, n, ref_ptr, ref_stride); \ return sad(src_ptr, src_stride, comp_pred, m, m, n); \ } \ diff --git a/vpx_dsp/variance.c b/vpx_dsp/variance.c index ce1e8382b..a6793efb6 100644 --- a/vpx_dsp/variance.c +++ b/vpx_dsp/variance.c @@ -156,7 +156,7 @@ static void var_filter_block2d_bil_second_pass( const uint8_t *second_pred) { \ uint16_t fdata3[(H + 1) * W]; \ uint8_t temp2[H * W]; \ - DECLARE_ALIGNED(16, uint8_t, temp3[H * W]); \ + DECLARE_ALIGNED(32, uint8_t, temp3[H * W]); \ \ var_filter_block2d_bil_first_pass(src_ptr, fdata3, src_stride, 1, H + 1, \ W, bilinear_filters[x_offset]); \ diff --git a/vpx_dsp/vpx_dsp.mk b/vpx_dsp/vpx_dsp.mk index 67d3fb0e2..04969f37e 100644 --- a/vpx_dsp/vpx_dsp.mk +++ b/vpx_dsp/vpx_dsp.mk @@ -424,6 +424,7 @@ DSP_SRCS-$(HAVE_LSX) += loongarch/avg_pred_lsx.c DSP_SRCS-$(HAVE_MMI) += mips/variance_mmi.c DSP_SRCS-$(HAVE_SSE2) += x86/avg_pred_sse2.c +DSP_SRCS-$(HAVE_AVX2) += x86/avg_pred_avx2.c DSP_SRCS-$(HAVE_SSE2) += x86/variance_sse2.c # Contains SSE2 and SSSE3 DSP_SRCS-$(HAVE_AVX2) += x86/variance_avx2.c DSP_SRCS-$(HAVE_VSX) += ppc/variance_vsx.c diff --git a/vpx_dsp/vpx_dsp_rtcd_defs.pl b/vpx_dsp/vpx_dsp_rtcd_defs.pl index cae4ca811..f20f4e045 100644 --- a/vpx_dsp/vpx_dsp_rtcd_defs.pl +++ b/vpx_dsp/vpx_dsp_rtcd_defs.pl @@ -1321,7 +1321,7 @@ add_proto qw/unsigned int vpx_get4x4sse_cs/, "const unsigned char *src_ptr, int specialize qw/vpx_get4x4sse_cs neon msa vsx/; add_proto qw/void vpx_comp_avg_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride"; - specialize qw/vpx_comp_avg_pred neon sse2 vsx lsx/; + specialize qw/vpx_comp_avg_pred neon sse2 avx2 vsx lsx/; # # Subpixel Variance diff --git a/vpx_dsp/x86/avg_pred_avx2.c b/vpx_dsp/x86/avg_pred_avx2.c new file mode 100644 index 000000000..f4357998c --- /dev/null +++ b/vpx_dsp/x86/avg_pred_avx2.c @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2023 The WebM project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <assert.h> +#include <immintrin.h> + +#include "./vpx_dsp_rtcd.h" + +void vpx_comp_avg_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width, + int height, const uint8_t *ref, int ref_stride) { + int row = 0; + // comp_pred and pred must be 32 byte aligned. + assert(((intptr_t)comp_pred % 32) == 0); + assert(((intptr_t)pred % 32) == 0); + + if (width == 8) { + assert(height % 4 == 0); + do { + const __m256i p = _mm256_load_si256((const __m256i *)pred); + const __m128i r_0 = _mm_loadl_epi64((const __m128i *)ref); + const __m128i r_1 = + _mm_loadl_epi64((const __m128i *)(ref + 2 * ref_stride)); + + const __m128i r1 = _mm_castps_si128(_mm_loadh_pi( + _mm_castsi128_ps(r_0), (const __m64 *)(ref + ref_stride))); + const __m128i r2 = _mm_castps_si128(_mm_loadh_pi( + _mm_castsi128_ps(r_1), (const __m64 *)(ref + 3 * ref_stride))); + + const __m256i ref_0123 = + _mm256_inserti128_si256(_mm256_castsi128_si256(r1), r2, 1); + const __m256i avg = _mm256_avg_epu8(p, ref_0123); + + _mm256_store_si256((__m256i *)comp_pred, avg); + + row += 4; + pred += 32; + comp_pred += 32; + ref += 4 * ref_stride; + } while (row < height); + } else if (width == 16) { + assert(height % 4 == 0); + do { + const __m256i pred_0 = _mm256_load_si256((const __m256i *)pred); + const __m256i pred_1 = _mm256_load_si256((const __m256i *)(pred + 32)); + const __m256i tmp0 = + _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)ref)); + const __m256i ref_0 = _mm256_inserti128_si256( + tmp0, _mm_loadu_si128((const __m128i *)(ref + ref_stride)), 1); + const __m256i tmp1 = _mm256_castsi128_si256( + _mm_loadu_si128((const __m128i *)(ref + 2 * ref_stride))); + const __m256i ref_1 = _mm256_inserti128_si256( + tmp1, _mm_loadu_si128((const __m128i *)(ref + 3 * ref_stride)), 1); + const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0); + const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1); + _mm256_store_si256((__m256i *)comp_pred, average_0); + _mm256_store_si256((__m256i *)(comp_pred + 32), average_1); + + row += 4; + pred += 64; + comp_pred += 64; + ref += 4 * ref_stride; + } while (row < height); + } else if (width == 32) { + assert(height % 2 == 0); + do { + const __m256i pred_0 = _mm256_load_si256((const __m256i *)pred); + const __m256i pred_1 = _mm256_load_si256((const __m256i *)(pred + 32)); + const __m256i ref_0 = _mm256_loadu_si256((const __m256i *)ref); + const __m256i ref_1 = + _mm256_loadu_si256((const __m256i *)(ref + ref_stride)); + const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0); + const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1); + _mm256_store_si256((__m256i *)comp_pred, average_0); + _mm256_store_si256((__m256i *)(comp_pred + 32), average_1); + + row += 2; + pred += 64; + comp_pred += 64; + ref += 2 * ref_stride; + } while (row < height); + } else if (width % 64 == 0) { + do { + int x; + for (x = 0; x < width; x += 64) { + const __m256i pred_0 = _mm256_load_si256((const __m256i *)(pred + x)); + const __m256i pred_1 = + _mm256_load_si256((const __m256i *)(pred + x + 32)); + const __m256i ref_0 = _mm256_loadu_si256((const __m256i *)(ref + x)); + const __m256i ref_1 = + _mm256_loadu_si256((const __m256i *)(ref + x + 32)); + const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0); + const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1); + _mm256_store_si256((__m256i *)(comp_pred + x), average_0); + _mm256_store_si256((__m256i *)(comp_pred + x + 32), average_1); + } + row++; + pred += width; + comp_pred += width; + ref += ref_stride; + } while (row < height); + } else { + vpx_comp_avg_pred_sse2(comp_pred, pred, width, height, ref, ref_stride); + } +} |