mirror of
https://github.com/xiph/opus.git
synced 2025-05-31 07:37:42 +00:00
Fix reset_after GRU
This commit is contained in:
parent
3c694db226
commit
c7b978b923
1 changed files with 12 additions and 2 deletions
14
dnn/nnet.c
14
dnn/nnet.c
|
@ -173,6 +173,11 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
|
|||
/* Compute update gate. */
|
||||
for (i=0;i<N;i++)
|
||||
z[i] = gru->bias[i];
|
||||
if (gru->reset_after)
|
||||
{
|
||||
for (i=0;i<N;i++)
|
||||
z[i] += gru->bias[3*N + i];
|
||||
}
|
||||
gemm_accum(z, gru->input_weights, N, M, stride, input);
|
||||
gemm_accum(z, gru->recurrent_weights, N, N, stride, state);
|
||||
compute_activation(z, z, N, ACTIVATION_SIGMOID);
|
||||
|
@ -180,6 +185,11 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
|
|||
/* Compute reset gate. */
|
||||
for (i=0;i<N;i++)
|
||||
r[i] = gru->bias[N + i];
|
||||
if (gru->reset_after)
|
||||
{
|
||||
for (i=0;i<N;i++)
|
||||
r[i] += gru->bias[4*N + i];
|
||||
}
|
||||
gemm_accum(r, &gru->input_weights[N], N, M, stride, input);
|
||||
gemm_accum(r, &gru->recurrent_weights[N], N, N, stride, state);
|
||||
compute_activation(r, r, N, ACTIVATION_SIGMOID);
|
||||
|
@ -189,8 +199,8 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
|
|||
h[i] = gru->bias[2*N + i];
|
||||
if (gru->reset_after)
|
||||
{
|
||||
/* WARNING: The reset_after version was never tested. */
|
||||
RNN_CLEAR(tmp, N);
|
||||
for (i=0;i<N;i++)
|
||||
tmp[i] = gru->bias[5*N + i];
|
||||
gemm_accum(tmp, &gru->recurrent_weights[2*N], N, N, stride, state);
|
||||
for (i=0;i<N;i++)
|
||||
h[i] += tmp[i] * r[i];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue