Fixes vnni macro redefinition with clang

This commit is contained in:
Michael Klingbeil 2023-09-01 23:18:21 -04:00 committed by Jean-Marc Valin
parent 4a47b1a15b
commit d431c321f1
No known key found for this signature in database
GPG key ID: 531A52533318F00A

View file

@ -621,27 +621,27 @@ static inline void vec_sigmoid(float *y, const float *x, int N)
#if defined(__AVXVNNI__) || defined(__AVX512VNNI__) #if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
#define opus_mm256_dpbusds_epi32(src, a, b) _mm256_dpbusds_epi32(src, a, b)
#elif defined(__AVX2__) #elif defined(__AVX2__)
static inline __m256i mm256_dpbusds_epi32(__m256i src, __m256i a, __m256i b) { static inline __m256i opus_mm256_dpbusds_epi32(__m256i src, __m256i a, __m256i b) {
__m256i ones, tmp; __m256i ones, tmp;
ones = _mm256_set1_epi16(1); ones = _mm256_set1_epi16(1);
tmp = _mm256_maddubs_epi16(a, b); tmp = _mm256_maddubs_epi16(a, b);
tmp = _mm256_madd_epi16(tmp, ones); tmp = _mm256_madd_epi16(tmp, ones);
return _mm256_add_epi32(src, tmp); return _mm256_add_epi32(src, tmp);
} }
#define _mm256_dpbusds_epi32(src, a, b) mm256_dpbusds_epi32(src, a, b)
#elif defined(__SSSE3__) #elif defined(__SSSE3__)
static inline mm256i_emu mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) { static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) {
mm256i_emu ones, tmp; mm256i_emu ones, tmp;
ones = _mm256_set1_epi16(1); ones = _mm256_set1_epi16(1);
tmp = _mm256_maddubs_epi16(a, b); tmp = _mm256_maddubs_epi16(a, b);
tmp = _mm256_madd_epi16(tmp, ones); tmp = _mm256_madd_epi16(tmp, ones);
return _mm256_add_epi32(src, tmp); return _mm256_add_epi32(src, tmp);
} }
#define _mm256_dpbusds_epi32(src, a, b) mm256_dpbusds_epi32(src, a, b)
#elif defined(__SSE2__) #elif defined(__SSE2__)
@ -655,13 +655,12 @@ static inline __m128i mm_dpbusds_epi32(__m128i src, __m128i a, __m128i b) {
return _mm_add_epi32(src, tmp); return _mm_add_epi32(src, tmp);
} }
static inline mm256i_emu mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) { static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) {
mm256i_emu res; mm256i_emu res;
res.hi = mm_dpbusds_epi32(src.hi, a.hi, b.hi); res.hi = mm_dpbusds_epi32(src.hi, a.hi, b.hi);
res.lo = mm_dpbusds_epi32(src.lo, a.lo, b.lo); res.lo = mm_dpbusds_epi32(src.lo, a.lo, b.lo);
return res; return res;
} }
#define _mm256_dpbusds_epi32(src, a, b) mm256_dpbusds_epi32(src, a, b)
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma message ("Only SSE and SSE2 are available. On newer machines, enable SSSE3/AVX/AVX2 to get better performance") #pragma message ("Only SSE and SSE2 are available. On newer machines, enable SSSE3/AVX/AVX2 to get better performance")
@ -797,19 +796,19 @@ static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *i
__m256i vw; __m256i vw;
vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
vw = _mm256_loadu_si256((const __m256i *)w); vw = _mm256_loadu_si256((const __m256i *)w);
vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
vw = _mm256_loadu_si256((const __m256i *)w); vw = _mm256_loadu_si256((const __m256i *)w);
vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
vw = _mm256_loadu_si256((const __m256i *)w); vw = _mm256_loadu_si256((const __m256i *)w);
vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vxj = _mm256_set1_epi32(*(int*)&x[*idx++]);
vw = _mm256_loadu_si256((const __m256i *)w); vw = _mm256_loadu_si256((const __m256i *)w);
vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
} }
#endif #endif
@ -821,7 +820,7 @@ static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *i
pos = (*idx++); pos = (*idx++);
vxj = _mm256_set1_epi32(*(int*)&x[pos]); vxj = _mm256_set1_epi32(*(int*)&x[pos]);
vw = _mm256_loadu_si256((const __m256i *)w); vw = _mm256_loadu_si256((const __m256i *)w);
vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
} }
vout = _mm256_cvtepi32_ps(vy0); vout = _mm256_cvtepi32_ps(vy0);
@ -848,19 +847,19 @@ static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale,
__m256i vw; __m256i vw;
vxj = _mm256_set1_epi32(*(int*)&x[j]); vxj = _mm256_set1_epi32(*(int*)&x[j]);
vw = _mm256_loadu_si256((const __m256i *)w); vw = _mm256_loadu_si256((const __m256i *)w);
vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)&x[j+4]); vxj = _mm256_set1_epi32(*(int*)&x[j+4]);
vw = _mm256_loadu_si256((const __m256i *)w); vw = _mm256_loadu_si256((const __m256i *)w);
vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)&x[j+8]); vxj = _mm256_set1_epi32(*(int*)&x[j+8]);
vw = _mm256_loadu_si256((const __m256i *)w); vw = _mm256_loadu_si256((const __m256i *)w);
vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
vxj = _mm256_set1_epi32(*(int*)&x[j+12]); vxj = _mm256_set1_epi32(*(int*)&x[j+12]);
vw = _mm256_loadu_si256((const __m256i *)w); vw = _mm256_loadu_si256((const __m256i *)w);
vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
} }
#endif #endif
@ -870,7 +869,7 @@ static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale,
__m256i vw; __m256i vw;
vxj = _mm256_set1_epi32(*(int*)&x[j]); vxj = _mm256_set1_epi32(*(int*)&x[j]);
vw = _mm256_loadu_si256((const __m256i *)w); vw = _mm256_loadu_si256((const __m256i *)w);
vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32; w += 32;
} }
vout = _mm256_cvtepi32_ps(vy0); vout = _mm256_cvtepi32_ps(vy0);