Simper GRU implementation just for reset_after.

This commit is contained in:
Jean-Marc Valin 2018-11-28 12:37:18 -05:00
parent 6c2f7e58fd
commit 040aa437c3
3 changed files with 42 additions and 2 deletions

View file

@ -218,6 +218,44 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
state[i] = h[i];
}
void compute_gru2(const GRULayer *gru, float *state, const float *input)
{
int i;
int N, M;
int stride;
float zrh[3*MAX_RNN_NEURONS];
float recur[3*MAX_RNN_NEURONS];
float *z;
float *r;
float *h;
M = gru->nb_inputs;
N = gru->nb_neurons;
z = zrh;
r = &zrh[N];
h = &zrh[2*N];
celt_assert(gru->nb_neurons <= MAX_RNN_NEURONS);
celt_assert(input != state);
celt_assert(gru->reset_after);
stride = 3*N;
/* Compute update gate. */
for (i=0;i<3*N;i++)
zrh[i] = gru->bias[i];
gemm_accum(zrh, gru->input_weights, 3*N, M, stride, input);
for (i=0;i<3*N;i++)
recur[i] = gru->bias[3*N + i];
gemm_accum(recur, gru->recurrent_weights, 3*N, N, stride, state);
for (i=0;i<2*N;i++)
zrh[i] += recur[i];
compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID);
for (i=0;i<N;i++)
h[i] += recur[2*N+i]*r[i];
compute_activation(h, h, N, gru->activation);
for (i=0;i<N;i++)
h[i] = z[i]*state[i] + (1-z[i])*h[i];
for (i=0;i<N;i++)
state[i] = h[i];
}
void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input)
{
int i;