From dd114baf4d6fad4501935b23bc14f79b0fb1cdd6 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Fri, 16 Sep 2022 00:55:39 -0400 Subject: [PATCH] Fix causal PLC for models with non-zero lookahead --- dnn/lpcnet.c | 3 +++ dnn/lpcnet_plc.c | 33 +++++++++++++++++++++------------ dnn/lpcnet_private.h | 3 ++- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/dnn/lpcnet.c b/dnn/lpcnet.c index ead2dcae..2f5c2511 100644 --- a/dnn/lpcnet.c +++ b/dnn/lpcnet.c @@ -89,6 +89,8 @@ void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b float dense1_out[FEATURE_DENSE1_OUT_SIZE]; int pitch; float rc[LPC_ORDER]; + //static float features[NB_FEATURES]; + //RNN_COPY(features, lpcnet->last_features, NB_FEATURES); /* Matches the Python code -- the 0.1 avoids rounding issues. */ pitch = (int)floor(.1 + 50*features[NB_BANDS]+100); pitch = IMIN(255, IMAX(33, pitch)); @@ -116,6 +118,7 @@ void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b #ifdef LPC_GAMMA lpc_weighting(lpc, LPC_GAMMA); #endif + //RNN_COPY(lpcnet->last_features, _features, NB_FEATURES); if (lpcnet->frame_count < 1000) lpcnet->frame_count++; } diff --git a/dnn/lpcnet_plc.c b/dnn/lpcnet_plc.c index 15a6b135..9fd80689 100644 --- a/dnn/lpcnet_plc.c +++ b/dnn/lpcnet_plc.c @@ -37,11 +37,6 @@ LPCNET_EXPORT int lpcnet_plc_get_size() { } LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) { - if (FEATURES_DELAY != 0) { - fprintf(stderr, "PLC cannot work with non-zero FEATURES_DELAY\n"); - fprintf(stderr, "Recompile with a no-lookahead model (see README.md)\n"); - exit(1); - } RNN_CLEAR(st, 1); lpcnet_init(&st->lpcnet); lpcnet_encoder_init(&st->enc); @@ -130,8 +125,15 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) { float zeros[2*NB_BANDS+NB_FEATURES+1] = {0}; RNN_COPY(zeros, plc_features, 2*NB_BANDS); zeros[2*NB_BANDS+NB_FEATURES] = 1; - st->plc_net = st->plc_copy; + st->plc_net = st->plc_copy[FEATURES_DELAY]; compute_plc_pred(&st->plc_net, st->features, zeros); + for (i=0;ilpcnet, gru_a_condition, gru_b_condition, lpc, st->features); + } if (st->enable_blending) { LPCNetState copy; copy = st->lpcnet; @@ -147,14 +149,12 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) { RNN_COPY(tmp, pcm, FRAME_SIZE-TRAINING_OFFSET); lpcnet_synthesize_tail_impl(&st->lpcnet, tmp, FRAME_SIZE-TRAINING_OFFSET, FRAME_SIZE-TRAINING_OFFSET); } - st->blend = 0; RNN_COPY(st->pcm, &pcm[FRAME_SIZE-TRAINING_OFFSET], TRAINING_OFFSET); st->pcm_fill = TRAINING_OFFSET; } else { RNN_COPY(&st->pcm[st->pcm_fill], pcm, FRAME_SIZE); st->pcm_fill += FRAME_SIZE; } - /*fprintf(stderr, "fill at %d\n", st->pcm_fill);*/ } /* Update state. */ /*fprintf(stderr, "update state\n");*/ @@ -162,6 +162,11 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) { preemphasis(x, &st->enc.mem_preemph, x, PREEMPHASIS, FRAME_SIZE); compute_frame_features(&st->enc, x); process_single_frame(&st->enc, NULL); + if (!st->blend) { + RNN_COPY(&plc_features[2*NB_BANDS], st->enc.features[0], NB_FEATURES); + plc_features[2*NB_BANDS+NB_FEATURES] = 1; + compute_plc_pred(&st->plc_net, st->features, plc_features); + } if (st->skip_analysis) { float lpc[LPC_ORDER]; float gru_a_condition[3*GRU_A_STATE_SIZE]; @@ -170,9 +175,6 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) { run_frame_network(&st->lpcnet, gru_a_condition, gru_b_condition, lpc, st->enc.features[0]); st->skip_analysis--; } else { - RNN_COPY(&plc_features[2*NB_BANDS], st->enc.features[0], NB_FEATURES); - plc_features[2*NB_BANDS+NB_FEATURES] = 1; - compute_plc_pred(&st->plc_net, st->features, plc_features); for (i=0;ipcm[PLC_BUF_SIZE+i] = pcm[i]; RNN_COPY(output, &st->pcm[0], FRAME_SIZE); lpcnet_synthesize_impl(&st->lpcnet, st->enc.features[0], output, FRAME_SIZE, FRAME_SIZE); @@ -184,6 +186,7 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) { pcm[i] += lp[i]; } } + st->blend = 0; return 0; } @@ -206,7 +209,8 @@ static int lpcnet_plc_conceal_causal(LPCNetPLCState *st, short *pcm) { st->pcm_fill -= update_count; st->skip_analysis++; } - st->plc_copy = st->plc_net; + RNN_MOVE(&st->plc_copy[1], &st->plc_copy[0], FEATURES_DELAY); + st->plc_copy[0] = st->plc_net; lpcnet_synthesize_tail_impl(&st->lpcnet, pcm, FRAME_SIZE-TRAINING_OFFSET, 0); compute_plc_pred(&st->plc_net, st->features, zeros); if (st->loss_count >= 10) st->features[0] = MAX16(-10, st->features[0]+att_table[9] - 2*(st->loss_count-9)); @@ -250,6 +254,11 @@ static int lpcnet_plc_update_non_causal(LPCNetPLCState *st, short *pcm) { short lp[FRAME_SIZE]={0}; double mem_bak=0; int delta = st->syn_dc; + if (FEATURES_DELAY != 0) { + fprintf(stderr, "Non-causal PLC cannot work with non-zero FEATURES_DELAY\n"); + fprintf(stderr, "Recompile with a no-lookahead model (see README.md)\n"); + exit(1); + } process_queued_update(st); if (st->remove_dc) { st->dc_mem += st->syn_dc; diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index 83496a6b..dba6912a 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -26,6 +26,7 @@ struct LPCNetState { NNetState nnet; int last_exc; float last_sig[LPC_ORDER]; + float last_features[NB_FEATURES]; #if FEATURES_DELAY>0 float old_lpc[FEATURES_DELAY][LPC_ORDER]; #endif @@ -76,7 +77,7 @@ struct LPCNetPLCState { float features[NB_TOTAL_FEATURES]; int loss_count; PLCNetState plc_net; - PLCNetState plc_copy; + PLCNetState plc_copy[FEATURES_DELAY+1]; int enable_blending; int non_causal; double dc_mem;