From c7b978b923f4d243d8f67c9d865ba23c18c89ae9 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Tue, 27 Nov 2018 14:37:10 -0500 Subject: [PATCH] Fix reset_after GRU --- dnn/nnet.c | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/dnn/nnet.c b/dnn/nnet.c index 2887b45e..8825d68a 100644 --- a/dnn/nnet.c +++ b/dnn/nnet.c @@ -173,6 +173,11 @@ void compute_gru(const GRULayer *gru, float *state, const float *input) /* Compute update gate. */ for (i=0;ibias[i]; + if (gru->reset_after) + { + for (i=0;ibias[3*N + i]; + } gemm_accum(z, gru->input_weights, N, M, stride, input); gemm_accum(z, gru->recurrent_weights, N, N, stride, state); compute_activation(z, z, N, ACTIVATION_SIGMOID); @@ -180,6 +185,11 @@ void compute_gru(const GRULayer *gru, float *state, const float *input) /* Compute reset gate. */ for (i=0;ibias[N + i]; + if (gru->reset_after) + { + for (i=0;ibias[4*N + i]; + } gemm_accum(r, &gru->input_weights[N], N, M, stride, input); gemm_accum(r, &gru->recurrent_weights[N], N, N, stride, state); compute_activation(r, r, N, ACTIVATION_SIGMOID); @@ -189,8 +199,8 @@ void compute_gru(const GRULayer *gru, float *state, const float *input) h[i] = gru->bias[2*N + i]; if (gru->reset_after) { - /* WARNING: The reset_after version was never tested. */ - RNN_CLEAR(tmp, N); + for (i=0;ibias[5*N + i]; gemm_accum(tmp, &gru->recurrent_weights[2*N], N, N, stride, state); for (i=0;i