Fix reset_after GRU

This commit is contained in:
Jean-Marc Valin 2018-11-27 14:37:10 -05:00
parent 3c694db226
commit c7b978b923

View file

@ -173,6 +173,11 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
/* Compute update gate. */
for (i=0;i<N;i++)
z[i] = gru->bias[i];
if (gru->reset_after)
{
for (i=0;i<N;i++)
z[i] += gru->bias[3*N + i];
}
gemm_accum(z, gru->input_weights, N, M, stride, input);
gemm_accum(z, gru->recurrent_weights, N, N, stride, state);
compute_activation(z, z, N, ACTIVATION_SIGMOID);
@ -180,6 +185,11 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
/* Compute reset gate. */
for (i=0;i<N;i++)
r[i] = gru->bias[N + i];
if (gru->reset_after)
{
for (i=0;i<N;i++)
r[i] += gru->bias[4*N + i];
}
gemm_accum(r, &gru->input_weights[N], N, M, stride, input);
gemm_accum(r, &gru->recurrent_weights[N], N, N, stride, state);
compute_activation(r, r, N, ACTIVATION_SIGMOID);
@ -189,8 +199,8 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
h[i] = gru->bias[2*N + i];
if (gru->reset_after)
{
/* WARNING: The reset_after version was never tested. */
RNN_CLEAR(tmp, N);
for (i=0;i<N;i++)
tmp[i] = gru->bias[5*N + i];
gemm_accum(tmp, &gru->recurrent_weights[2*N], N, N, stride, state);
for (i=0;i<N;i++)
h[i] += tmp[i] * r[i];