diff --git a/dnn/lpcnet.c b/dnn/lpcnet.c index 86429e5d..2d5d635d 100644 --- a/dnn/lpcnet.c +++ b/dnn/lpcnet.c @@ -119,6 +119,30 @@ void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b if (lpcnet->frame_count < 1000) lpcnet->frame_count++; } +void run_frame_network_deferred(LPCNetState *lpcnet, const float *features) +{ + int max_buffer_size = lpcnet->model.feature_conv1.kernel_size + lpcnet->model.feature_conv2.kernel_size - 2; + celt_assert(max_buffer_size <= MAX_FEATURE_BUFFER_SIZE); + if (lpcnet->feature_buffer_fill == max_buffer_size) { + RNN_MOVE(lpcnet->feature_buffer, &lpcnet->feature_buffer[NB_FEATURES], (max_buffer_size-1)*NB_FEATURES); + } else { + lpcnet->feature_buffer_fill++; + } + RNN_COPY(&lpcnet->feature_buffer[(lpcnet->feature_buffer_fill-1)*NB_FEATURES], features, NB_FEATURES); +} + +void run_frame_network_flush(LPCNetState *lpcnet) +{ + int i; + for (i=0;ifeature_buffer_fill;i++) { + float lpc[LPC_ORDER]; + float gru_a_condition[3*GRU_A_STATE_SIZE]; + float gru_b_condition[3*GRU_B_STATE_SIZE]; + run_frame_network(lpcnet, gru_a_condition, gru_b_condition, lpc, &lpcnet->feature_buffer[i*NB_FEATURES]); + } + lpcnet->feature_buffer_fill = 0; +} + int run_sample_network(LPCNetState *lpcnet, const float *gru_a_condition, const float *gru_b_condition, int last_exc, int last_sig, int pred, const float *sampling_logit_table, kiss99_ctx *rng) { NNetState *net; diff --git a/dnn/lpcnet_plc.c b/dnn/lpcnet_plc.c index c84199bf..b39ca51f 100644 --- a/dnn/lpcnet_plc.c +++ b/dnn/lpcnet_plc.c @@ -189,11 +189,8 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) { st->plc_net = st->plc_copy[FEATURES_DELAY]; compute_plc_pred(&st->plc_net, st->features, zeros); for (i=0;ilpcnet, gru_a_condition, gru_b_condition, lpc, st->features); + run_frame_network_deferred(&st->lpcnet, st->features); } copy = st->lpcnet; lpcnet_synthesize_impl(&st->lpcnet, &st->features[0], tmp, FRAME_SIZE-TRAINING_OFFSET, 0); @@ -238,11 +235,8 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) { } if (st->skip_analysis) { if (st->enable_blending) { - float lpc[LPC_ORDER]; - float gru_a_condition[3*GRU_A_STATE_SIZE]; - float gru_b_condition[3*GRU_B_STATE_SIZE]; /* FIXME: backtrack state, replace features. */ - run_frame_network(&st->lpcnet, gru_a_condition, gru_b_condition, lpc, st->enc.features[0]); + run_frame_network_deferred(&st->lpcnet, st->enc.features[0]); } st->skip_analysis--; } else { @@ -250,10 +244,7 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) { RNN_COPY(output, &st->pcm[0], FRAME_SIZE); #ifdef PLC_SKIP_UPDATES { - float lpc[LPC_ORDER]; - float gru_a_condition[3*GRU_A_STATE_SIZE]; - float gru_b_condition[3*GRU_B_STATE_SIZE]; - run_frame_network(&st->lpcnet, gru_a_condition, gru_b_condition, lpc, st->enc.features[0]); + run_frame_network_deferred(&st->lpcnet, st->enc.features[0]); } #else lpcnet_synthesize_impl(&st->lpcnet, st->enc.features[0], output, FRAME_SIZE, FRAME_SIZE); @@ -274,6 +265,7 @@ static const float att_table[10] = {0, 0, -.2, -.2, -.4, -.4, -.8, -.8, -1.6, static int lpcnet_plc_conceal_causal(LPCNetPLCState *st, short *pcm) { int i; short output[FRAME_SIZE]; + run_frame_network_flush(&st->lpcnet); st->enc.pcount = 0; /* If we concealed the previous frame, finish synthesizing the rest of the samples. */ /* FIXME: Copy/predict features. */ diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index 07f32ac1..3b588c08 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -23,11 +23,14 @@ #define FORBIDDEN_INTERP 7 #define PLC_MAX_FEC 100 +#define MAX_FEATURE_BUFFER_SIZE 4 struct LPCNetState { NNetState nnet; int last_exc; float last_sig[LPC_ORDER]; + float feature_buffer[NB_FEATURES*MAX_FEATURE_BUFFER_SIZE]; + int feature_buffer_fill; float last_features[NB_FEATURES]; #if FEATURES_DELAY>0 float old_lpc[FEATURES_DELAY][LPC_ORDER]; @@ -114,6 +117,10 @@ void decode_packet(float features[4][NB_TOTAL_FEATURES], float *vq_mem, const un void lpcnet_reset_signal(LPCNetState *lpcnet); void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features); +void run_frame_network_deferred(LPCNetState *lpcnet, const float *features); +void run_frame_network_flush(LPCNetState *lpcnet); + + void lpcnet_synthesize_tail_impl(LPCNetState *lpcnet, short *output, int N, int preload); void lpcnet_synthesize_impl(LPCNetState *lpcnet, const float *features, short *output, int N, int preload); void lpcnet_synthesize_blend_impl(LPCNetState *lpcnet, const short *pcm_in, short *output, int N);