Rework 32-bit SSE loads yet again.

The existing code in vec_avx.h produced
  warning: dereferencing type-punned pointer will break
   strict-aliasing rules
 with gcc 6.4.0.
We already had a macro to work around this within the rules of the
 C standard, but trying to use that here does not get optimized
 into a single MOVD like we were hoping.
Replacing it with memcpy() instead does get optimized correctly,
 but requires switching from a macro to an inline function in order
 to be able to declare a local variable and return a value.
We already have such an inline function in NSQ_del_dec_avx2.c, so
 hoist that out and use it everywhere, and then convert vec_avx.h
 to use it also.
This commit is contained in:
Timothy B. Terriberry 2024-02-22 05:48:18 -08:00 committed by Jean-Marc Valin
parent 1186fb8ea4
commit 59dc75fa97
No known key found for this signature in database
GPG key ID: 5E5DD9A36F9189C8
3 changed files with 41 additions and 32 deletions

View file

@ -60,18 +60,33 @@
int opus_select_arch(void); int opus_select_arch(void);
# endif # endif
# if defined(OPUS_X86_MAY_HAVE_SSE2)
# include "opus_defines.h"
/*MOVD should not impose any alignment restrictions, but the C standard does, /*MOVD should not impose any alignment restrictions, but the C standard does,
and UBSan will report errors if we actually make unaligned accesses. and UBSan will report errors if we actually make unaligned accesses.
Use this to work around those restrictions (which should hopefully all get Use this to work around those restrictions (which should hopefully all get
optimized to a single MOVD instruction).*/ optimized to a single MOVD instruction).
#define OP_LOADU_EPI32(x) \ GCC implemented _mm_loadu_si32() since GCC 11; HOWEVER, there is a bug!
(int)((*(unsigned char *)(x) | *((unsigned char *)(x) + 1) << 8U |\ https://gcc.gnu.org/bugzilla/show_bug.cgi?id=99754 */
*((unsigned char *)(x) + 2) << 16U | (opus_uint32)*((unsigned char *)(x) + 3) << 24U)) # if !defined(_MSC_VER) && !OPUS_GNUC_PREREQ(11,3) && !(defined(__clang__) && (__clang_major__ >= 8))
# include <string.h>
# include <emmintrin.h>
#define OP_CVTEPI8_EPI32_M32(x) \ # define _mm_loadu_si32 WORKAROUND_mm_loadu_si32
(_mm_cvtepi8_epi32(_mm_cvtsi32_si128(OP_LOADU_EPI32(x)))) static inline __m128i WORKAROUND_mm_loadu_si32(void const* mem_addr) {
int val;
memcpy(&val, mem_addr, sizeof(val));
return _mm_cvtsi32_si128(val);
}
# endif
#define OP_CVTEPI16_EPI32_M64(x) \ # define OP_CVTEPI8_EPI32_M32(x) \
(_mm_cvtepi8_epi32(_mm_loadu_si32(x)))
# define OP_CVTEPI16_EPI32_M64(x) \
(_mm_cvtepi16_epi32(_mm_loadl_epi64((__m128i *)(void*)(x)))) (_mm_cvtepi16_epi32(_mm_loadl_epi64((__m128i *)(void*)(x))))
# endif
#endif #endif

View file

@ -34,7 +34,7 @@
#include <immintrin.h> #include <immintrin.h>
#include <math.h> #include <math.h>
#include "celt/x86/x86cpu.h"
#define MAX_INPUTS (2048) #define MAX_INPUTS (2048)
@ -196,6 +196,14 @@ static inline void mm256_storeu_si256(mm256i_emu *dst, mm256i_emu src) {
#define _mm256_storeu_si256(dst, src) mm256_storeu_si256(dst, src) #define _mm256_storeu_si256(dst, src) mm256_storeu_si256(dst, src)
static inline mm256i_emu mm256_broadcastd_epi32(__m128i x) {
mm256i_emu ret;
ret.hi = ret.lo = _mm_shuffle_epi32(x, 0);
return ret;
}
#define _mm256_broadcastd_epi32(x) mm256_broadcastd_epi32(x)
static inline mm256i_emu mm256_set1_epi32(int x) { static inline mm256i_emu mm256_set1_epi32(int x) {
mm256i_emu ret; mm256i_emu ret;
ret.lo = _mm_set1_epi32(x); ret.lo = _mm_set1_epi32(x);
@ -786,19 +794,19 @@ static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *i
{ {
__m256i vxj; __m256i vxj;
__m256i vw; __m256i vw;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[*idx++]); vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w); vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[*idx++]); vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w); vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[*idx++]); vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w); vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[*idx++]); vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w); vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
@ -808,9 +816,7 @@ static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *i
{ {
__m256i vxj; __m256i vxj;
__m256i vw; __m256i vw;
int pos; vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
pos = (*idx++);
vxj = _mm256_set1_epi32(*(int*)(void*)&x[pos]);
vw = _mm256_loadu_si256((const __m256i *)(void*)w); vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
@ -837,19 +843,19 @@ static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale,
{ {
__m256i vxj; __m256i vxj;
__m256i vw; __m256i vw;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[j]); vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w); vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[j+4]); vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+4]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w); vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[j+8]); vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+8]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w); vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[j+12]); vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+12]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w); vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
@ -859,7 +865,7 @@ static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale,
{ {
__m256i vxj; __m256i vxj;
__m256i vw; __m256i vw;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[j]); vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w); vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;

View file

@ -86,18 +86,6 @@ static inline int __builtin_ctz(unsigned int x)
} }
#endif #endif
/*
* GCC implemented _mm_loadu_si32() since GCC 11; HOWEVER, there is a bug!
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=99754
*/
#if !defined(_MSC_VER) && !OPUS_GNUC_PREREQ(11,3) && !(defined(__clang__) && (__clang_major__ >= 8))
#define _mm_loadu_si32 WORKAROUND_mm_loadu_si32
static inline __m128i WORKAROUND_mm_loadu_si32(void const* mem_addr)
{
return _mm_cvtsi32_si128(OP_LOADU_EPI32(mem_addr));
}
#endif
static OPUS_INLINE __m128i silk_cvtepi64_epi32_high(__m256i num) static OPUS_INLINE __m128i silk_cvtepi64_epi32_high(__m256i num)
{ {
return _mm256_castsi256_si128(_mm256_permutevar8x32_epi32(num, _mm256_set_epi32(0, 0, 0, 0, 7, 5, 3, 1))); return _mm256_castsi256_si128(_mm256_permutevar8x32_epi32(num, _mm256_set_epi32(0, 0, 0, 0, 7, 5, 3, 1)));