Support for plain AVX with no FMA
This commit is contained in:
parent
91d90676e1
commit
771cc7868a
1 changed files with 42 additions and 2 deletions
44
dnn/nnet.c
44
dnn/nnet.c
|
@ -41,8 +41,11 @@
|
||||||
|
|
||||||
#define SOFTMAX_HACK
|
#define SOFTMAX_HACK
|
||||||
|
|
||||||
#ifdef __AVX2__
|
#ifdef __AVX__
|
||||||
#include <immintrin.h>
|
#include <immintrin.h>
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef __AVX2__
|
||||||
static __m256 exp8_approx(__m256 X)
|
static __m256 exp8_approx(__m256 X)
|
||||||
{
|
{
|
||||||
const __m256 K0 = _mm256_set1_ps(0.99992522f);
|
const __m256 K0 = _mm256_set1_ps(0.99992522f);
|
||||||
|
@ -65,7 +68,44 @@ static __m256 exp8_approx(__m256 X)
|
||||||
Y = _mm256_castsi256_ps(_mm256_and_si256(mask, _mm256_add_epi32(I, _mm256_castps_si256(Y))));
|
Y = _mm256_castsi256_ps(_mm256_and_si256(mask, _mm256_add_epi32(I, _mm256_castps_si256(Y))));
|
||||||
return Y;
|
return Y;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
|
||||||
|
#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
|
||||||
|
static __m128 exp4_approx(__m128 X)
|
||||||
|
{
|
||||||
|
const __m128 K0 = _mm_set1_ps(0.99992522f);
|
||||||
|
const __m128 K1 = _mm_set1_ps(0.69583354f);
|
||||||
|
const __m128 K2 = _mm_set1_ps(0.22606716f);
|
||||||
|
const __m128 K3 = _mm_set1_ps(0.078024523f);
|
||||||
|
const __m128 log2_E = _mm_set1_ps(1.44269504);
|
||||||
|
const __m128 max_in = _mm_set1_ps(50.f);
|
||||||
|
const __m128 min_in = _mm_set1_ps(-50.f);
|
||||||
|
const __m128i mask = _mm_set1_epi32(0x7fffffff);
|
||||||
|
__m128 XF, Y;
|
||||||
|
__m128i I;
|
||||||
|
X = _mm_mul_ps(X, log2_E);
|
||||||
|
X = _mm_max_ps(min_in, _mm_min_ps(max_in, X));
|
||||||
|
XF = _mm_floor_ps(X);
|
||||||
|
I = _mm_cvtps_epi32(XF);
|
||||||
|
X = _mm_sub_ps(X, XF);
|
||||||
|
Y = _mm_fmadd_ps(_mm_fmadd_ps(_mm_fmadd_ps(K3, X, K2), X, K1), X, K0);
|
||||||
|
I = _mm_slli_epi32(I, 23);
|
||||||
|
Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
|
||||||
|
return Y;
|
||||||
|
}
|
||||||
|
static __m256 exp8_approx(__m256 X)
|
||||||
|
{
|
||||||
|
__m256 Y;
|
||||||
|
__m128 Xhi, Xlo, Yhi, Ylo;
|
||||||
|
Xhi = _mm256_extractf128_ps(X, 1);
|
||||||
|
Xlo = _mm256_extractf128_ps(X, 0);
|
||||||
|
Yhi = exp4_approx(Xhi);
|
||||||
|
Ylo = exp4_approx(Xlo);
|
||||||
|
Y = _mm256_insertf128_ps(_mm256_setzero_ps(), Yhi, 1);
|
||||||
|
Y = _mm256_insertf128_ps(Y, Ylo, 0);
|
||||||
|
return Y;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
static float celt_exp(float x)
|
static float celt_exp(float x)
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue