diff --git a/dnn/nnet.c b/dnn/nnet.c index 2887b45e..8825d68a 100644 --- a/dnn/nnet.c +++ b/dnn/nnet.c @@ -173,6 +173,11 @@ void compute_gru(const GRULayer *gru, float *state, const float *input) /* Compute update gate. */ for (i=0;ibias[i]; + if (gru->reset_after) + { + for (i=0;ibias[3*N + i]; + } gemm_accum(z, gru->input_weights, N, M, stride, input); gemm_accum(z, gru->recurrent_weights, N, N, stride, state); compute_activation(z, z, N, ACTIVATION_SIGMOID); @@ -180,6 +185,11 @@ void compute_gru(const GRULayer *gru, float *state, const float *input) /* Compute reset gate. */ for (i=0;ibias[N + i]; + if (gru->reset_after) + { + for (i=0;ibias[4*N + i]; + } gemm_accum(r, &gru->input_weights[N], N, M, stride, input); gemm_accum(r, &gru->recurrent_weights[N], N, N, stride, state); compute_activation(r, r, N, ACTIVATION_SIGMOID); @@ -189,8 +199,8 @@ void compute_gru(const GRULayer *gru, float *state, const float *input) h[i] = gru->bias[2*N + i]; if (gru->reset_after) { - /* WARNING: The reset_after version was never tested. */ - RNN_CLEAR(tmp, N); + for (i=0;ibias[5*N + i]; gemm_accum(tmp, &gru->recurrent_weights[2*N], N, N, stride, state); for (i=0;i