Rework 32-bit SSE loads yet again.

The existing code in vec_avx.h produced
  warning: dereferencing type-punned pointer will break
   strict-aliasing rules
 with gcc 6.4.0.
We already had a macro to work around this within the rules of the
 C standard, but trying to use that here does not get optimized
 into a single MOVD like we were hoping.
Replacing it with memcpy() instead does get optimized correctly,
 but requires switching from a macro to an inline function in order
 to be able to declare a local variable and return a value.
We already have such an inline function in NSQ_del_dec_avx2.c, so
 hoist that out and use it everywhere, and then convert vec_avx.h
 to use it also.
This commit is contained in:
Timothy B. Terriberry 2024-02-22 05:48:18 -08:00 committed by Jean-Marc Valin
parent 1186fb8ea4
commit 59dc75fa97
No known key found for this signature in database
GPG key ID: 5E5DD9A36F9189C8
3 changed files with 41 additions and 32 deletions

View file

@ -34,7 +34,7 @@
#include <immintrin.h>
#include <math.h>
#include "celt/x86/x86cpu.h"
#define MAX_INPUTS (2048)
@ -196,6 +196,14 @@ static inline void mm256_storeu_si256(mm256i_emu *dst, mm256i_emu src) {
#define _mm256_storeu_si256(dst, src) mm256_storeu_si256(dst, src)
static inline mm256i_emu mm256_broadcastd_epi32(__m128i x) {
mm256i_emu ret;
ret.hi = ret.lo = _mm_shuffle_epi32(x, 0);
return ret;
}
#define _mm256_broadcastd_epi32(x) mm256_broadcastd_epi32(x)
static inline mm256i_emu mm256_set1_epi32(int x) {
mm256i_emu ret;
ret.lo = _mm_set1_epi32(x);
@ -786,19 +794,19 @@ static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *i
{
__m256i vxj;
__m256i vw;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[*idx++]);
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[*idx++]);
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[*idx++]);
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[*idx++]);
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
@ -808,9 +816,7 @@ static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *i
{
__m256i vxj;
__m256i vw;
int pos;
pos = (*idx++);
vxj = _mm256_set1_epi32(*(int*)(void*)&x[pos]);
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
@ -837,19 +843,19 @@ static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale,
{
__m256i vxj;
__m256i vw;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[j]);
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[j+4]);
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+4]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[j+8]);
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+8]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[j+12]);
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+12]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
@ -859,7 +865,7 @@ static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale,
{
__m256i vxj;
__m256i vw;
vxj = _mm256_set1_epi32(*(int*)(void*)&x[j]);
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;