Faster activation functions for AVX

Using rational function approximation for tanh() and sigmoid.
This commit is contained in:
Jean-Marc Valin 2021-06-29 04:05:48 -04:00
parent 5571ef1b8e
commit e35441f2cc
3 changed files with 220 additions and 13 deletions

View file

@ -80,8 +80,9 @@ void compute_activation(float *output, float *input, int N, int activation)
output[i] = relu(input[i]);
} else if (activation == ACTIVATION_SOFTMAX) {
#ifdef SOFTMAX_HACK
for (i=0;i<N;i++)
output[i] = input[i];
RNN_COPY(output, input, N);
/*for (i=0;i<N;i++)
output[i] = input[i];*/
#else
float sum = 0;
softmax(output, input, N);

70
dnn/training_tf2/pade.py Normal file
View file

@ -0,0 +1,70 @@
# Optimizing a rational function to optimize a tanh() approximation
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation
import tensorflow.keras.backend as K
from tensorflow.keras.optimizers import Adam, SGD
def my_loss1(y_true, y_pred):
return 1*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
def my_loss2(y_true, y_pred):
return .1*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
def my_loss3(y_true, y_pred):
return .01*K.mean(K.square(y_true-y_pred)) + 1*K.max(K.square(y_true-y_pred), axis=1)
# Using these initializers to seed the approximation
# with a reasonable starting point
def num_init(shape, dtype=None):
rr = tf.constant([[945], [105], [1]], dtype=dtype)
#rr = tf.constant([[946.56757], [98.01368], [0.66841]], dtype=dtype)
print(rr)
return rr
def den_init(shape, dtype=None):
rr = tf.constant([[945], [420], [15]], dtype=dtype)
#rr = tf.constant([[946.604], [413.342], [12.465]], dtype=dtype)
print(rr)
return rr
x = np.arange(-10, 10, .01)
N = len(x)
x = np.reshape(x, (1, -1, 1))
x2 = x*x
x2in = np.concatenate([x2*0 + 1, x2, x2*x2], axis=2)
yout = np.tanh(x)
model_x = Input(shape=(None, 1,))
model_x2 = Input(shape=(None, 3,))
num = Dense(1, name='num', use_bias=False, kernel_initializer=num_init)
den = Dense(1, name='den', use_bias=False, kernel_initializer=den_init)
def ratio(x):
return tf.minimum(1., tf.maximum(-1., x[0]*x[1]/x[2]))
out_layer = Lambda(ratio)
output = out_layer([model_x, num(model_x2), den(model_x2)])
model = Model([model_x, model_x2], output)
model.summary()
model.compile(Adam(0.05, beta_1=0.9, beta_2=0.9, decay=2e-5), loss='mean_squared_error')
model.fit([x, x2in], yout, batch_size=1, epochs=500000, validation_split=0.0)
model.compile(Adam(0.001, beta_2=0.9, decay=1e-4), loss=my_loss1)
model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
model.compile(Adam(0.0001, beta_2=0.9, decay=1e-4), loss=my_loss2)
model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
model.compile(Adam(0.00001, beta_2=0.9, decay=1e-4), loss=my_loss3)
model.fit([x, x2in], yout, batch_size=1, epochs=50000, validation_split=0.0)
model.save_weights('tanh.h5')

View file

@ -39,6 +39,11 @@
#define USE_SU_BIAS
#endif
#ifndef __FMA__
#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
#endif
#ifdef __AVX2__
static inline __m256 exp8_approx(__m256 X)
{
@ -61,9 +66,66 @@ static inline __m256 exp8_approx(__m256 X)
Y = _mm256_castsi256_ps(_mm256_add_epi32(I, _mm256_castps_si256(Y)));
return Y;
}
/* Approximating tanh() using a Padé-like rational function:
tanh(x) ~= x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
subject to the +/- 1 bounds.
The coefficients were determined by gradient descent trying to minimize
the maximum deviation over the whole range (this is only possible because
of the bounds). The max error is around 3e-4 and is dominated by the
reciprocal approximation (the max error of the rational function is
around 6e-5).
*/
static inline __m256 tanh8_approx(__m256 X)
{
const __m256 N0 = _mm256_set1_ps(952.52801514f);
const __m256 N1 = _mm256_set1_ps(96.39235687f);
const __m256 N2 = _mm256_set1_ps(0.60863042f);
const __m256 D0 = _mm256_set1_ps(952.72399902f);
const __m256 D1 = _mm256_set1_ps(413.36801147f);
const __m256 D2 = _mm256_set1_ps(11.88600922f);
const __m256 max_out = _mm256_set1_ps(1.f);
const __m256 min_out = _mm256_set1_ps(-1.f);
__m256 X2, num, den;
X2 = _mm256_mul_ps(X, X);
num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
num = _mm256_mul_ps(num, X);
den = _mm256_rcp_ps(den);
num = _mm256_mul_ps(num, den);
return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
}
/* Sigmoid approximation using a Padé-like rational function:
1/(1+exp(-x)) ~= 0.5 + x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
subject to the [0, 1] bounds.
The coefficients are directly derived by dividing the tanh() coefficients
by powers of two to get the correct scaling. The max error is around 1.5e-4
and is dominated by the reciprocal approximation (the max error of the
rational function is around 3e-5).
*/
static inline __m256 sigmoid8_approx(__m256 X)
{
const __m256 N0 = _mm256_set1_ps(238.13200378f);
const __m256 N1 = _mm256_set1_ps(6.02452230f);
const __m256 N2 = _mm256_set1_ps(0.00950985f);
const __m256 D0 = _mm256_set1_ps(952.72399902f);
const __m256 D1 = _mm256_set1_ps(103.34200287f);
const __m256 D2 = _mm256_set1_ps(0.74287558f);
const __m256 half = _mm256_set1_ps(0.5);
const __m256 max_out = _mm256_set1_ps(1.f);
const __m256 min_out = _mm256_set1_ps(0.f);
__m256 X2, num, den;
X2 = _mm256_mul_ps(X, X);
num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
num = _mm256_mul_ps(num, X);
den = _mm256_rcp_ps(den);
num = _mm256_fmadd_ps(num, den, half);
return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
}
#else
#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
static inline __m128 exp4_approx(__m128 X)
{
const __m128 K0 = _mm_set1_ps(0.99992522f);
@ -98,6 +160,48 @@ static inline __m256 exp8_approx(__m256 X)
Y = _mm256_insertf128_ps(Y, Ylo, 0);
return Y;
}
static inline __m128 tanh4_approx(__m128 X)
{
const __m128 N0 = _mm_set1_ps(952.52801514f);
const __m128 N1 = _mm_set1_ps(96.39235687f);
const __m128 N2 = _mm_set1_ps(0.60863042f);
const __m128 D0 = _mm_set1_ps(952.72399902f);
const __m128 D1 = _mm_set1_ps(413.36801147f);
const __m128 D2 = _mm_set1_ps(11.88600922f);
const __m128 max_out = _mm_set1_ps(1.f);
const __m128 min_out = _mm_set1_ps(-1.f);
__m128 X2, num, den;
X2 = _mm_mul_ps(X, X);
num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
num = _mm_mul_ps(num, X);
den = _mm_rcp_ps(den);
num = _mm_mul_ps(num, den);
return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
}
static inline __m128 sigmoid4_approx(__m128 X)
{
const __m128 N0 = _mm_set1_ps(238.13200378f);
const __m128 N1 = _mm_set1_ps(6.02452230f);
const __m128 N2 = _mm_set1_ps(0.00950985f);
const __m128 D0 = _mm_set1_ps(952.72399902f);
const __m128 D1 = _mm_set1_ps(103.34200287f);
const __m128 D2 = _mm_set1_ps(0.74287558f);
const __m128 half = _mm_set1_ps(0.5);
const __m128 max_out = _mm_set1_ps(1.f);
const __m128 min_out = _mm_set1_ps(0.f);
__m128 X2, num, den;
X2 = _mm_mul_ps(X, X);
num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
num = _mm_mul_ps(num, X);
den = _mm_rcp_ps(den);
num = _mm_fmadd_ps(num, den, half);
return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
}
#endif
static inline float celt_exp(float x)
@ -124,18 +228,15 @@ static inline void softmax(float *y, const float *x, int N)
y[i] = celt_exp(x[i]);
}
#ifdef __AVX2__
static inline void vec_tanh(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-7;i+=8)
{
const __m256 two = _mm256_set1_ps(2.f);
const __m256 one = _mm256_set1_ps(1.f);
__m256 X, Y;
X = _mm256_loadu_ps(&x[i]);
X = _mm256_mul_ps(X, two);
Y = exp8_approx(X);
Y = _mm256_mul_ps(_mm256_sub_ps(Y, one), _mm256_rcp_ps(_mm256_add_ps(Y, one)));
Y = tanh8_approx(X);
_mm256_storeu_ps(&y[i], Y);
}
for (;i<N;i++)
@ -151,12 +252,9 @@ static inline void vec_sigmoid(float *y, const float *x, int N)
int i;
for (i=0;i<N-7;i+=8)
{
const __m256 one = _mm256_set1_ps(1.f);
__m256 X, Y;
X = _mm256_loadu_ps(&x[i]);
Y = exp8_approx(X);
/* Compute as 1-1/(1+e^x) to avoid >1 values caused by the reciprocal approximation. */
Y = _mm256_sub_ps(one, _mm256_rcp_ps(_mm256_add_ps(Y, one)));
Y = sigmoid8_approx(X);
_mm256_storeu_ps(&y[i], Y);
}
for (;i<N;i++)
@ -166,6 +264,44 @@ static inline void vec_sigmoid(float *y, const float *x, int N)
y[i] = (ex)/(ex+1);
}
}
#else
static inline void vec_tanh(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-3;i+=4)
{
__m128 X, Y;
X = _mm_loadu_ps(&x[i]);
Y = tanh4_approx(X);
_mm_storeu_ps(&y[i], Y);
}
for (;i<N;i++)
{
float ex2;
ex2 = celt_exp(2*x[i]);
y[i] = (ex2-1)/(ex2+1);
}
}
static inline void vec_sigmoid(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-3;i+=4)
{
__m128 X, Y;
X = _mm_loadu_ps(&x[i]);
Y = sigmoid4_approx(X);
_mm_storeu_ps(&y[i], Y);
}
for (;i<N;i++)
{
float ex;
ex = celt_exp(x[i]);
y[i] = (ex)/(ex+1);
}
}
#endif
static inline void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
{