diff --git a/dnn/lpcnet.c b/dnn/lpcnet.c index acadbdc2..30451284 100644 --- a/dnn/lpcnet.c +++ b/dnn/lpcnet.c @@ -121,10 +121,10 @@ void run_sample_network(NNetState *net, float *pdf, const float *condition, int compute_embedding(&embed_sig, &in_a[EMBED_SIG_OUT_SIZE], pred); compute_embedding(&embed_exc, &in_a[2*EMBED_SIG_OUT_SIZE], last_exc); RNN_COPY(&in_a[2*EMBED_SIG_OUT_SIZE + EMBED_EXC_OUT_SIZE], condition, FEATURE_DENSE2_OUT_SIZE); - compute_gru(&gru_a, net->gru_a_state, in_a); + compute_gru2(&gru_a, net->gru_a_state, in_a); RNN_COPY(in_b, net->gru_a_state, GRU_A_STATE_SIZE); RNN_COPY(&in_b[GRU_A_STATE_SIZE], condition, FEATURE_DENSE2_OUT_SIZE); - compute_gru(&gru_b, net->gru_b_state, in_b); + compute_gru2(&gru_b, net->gru_b_state, in_b); compute_mdense(&dual_fc, pdf, net->gru_b_state); } diff --git a/dnn/nnet.c b/dnn/nnet.c index 4f5491b3..a8fa704f 100644 --- a/dnn/nnet.c +++ b/dnn/nnet.c @@ -218,6 +218,44 @@ void compute_gru(const GRULayer *gru, float *state, const float *input) state[i] = h[i]; } +void compute_gru2(const GRULayer *gru, float *state, const float *input) +{ + int i; + int N, M; + int stride; + float zrh[3*MAX_RNN_NEURONS]; + float recur[3*MAX_RNN_NEURONS]; + float *z; + float *r; + float *h; + M = gru->nb_inputs; + N = gru->nb_neurons; + z = zrh; + r = &zrh[N]; + h = &zrh[2*N]; + celt_assert(gru->nb_neurons <= MAX_RNN_NEURONS); + celt_assert(input != state); + celt_assert(gru->reset_after); + stride = 3*N; + /* Compute update gate. */ + for (i=0;i<3*N;i++) + zrh[i] = gru->bias[i]; + gemm_accum(zrh, gru->input_weights, 3*N, M, stride, input); + for (i=0;i<3*N;i++) + recur[i] = gru->bias[3*N + i]; + gemm_accum(recur, gru->recurrent_weights, 3*N, N, stride, state); + for (i=0;i<2*N;i++) + zrh[i] += recur[i]; + compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID); + for (i=0;iactivation); + for (i=0;i