Defer calls to run_frame_network() to save CPU

Calls are deferred to the actual loss and we only process the minimum
required.
This commit is contained in:
Jean-Marc Valin 2023-05-19 18:12:18 -04:00
parent 87f9fbc50c
commit 0098fe70ac
3 changed files with 35 additions and 12 deletions

View file

@ -119,6 +119,30 @@ void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b
if (lpcnet->frame_count < 1000) lpcnet->frame_count++; if (lpcnet->frame_count < 1000) lpcnet->frame_count++;
} }
void run_frame_network_deferred(LPCNetState *lpcnet, const float *features)
{
int max_buffer_size = lpcnet->model.feature_conv1.kernel_size + lpcnet->model.feature_conv2.kernel_size - 2;
celt_assert(max_buffer_size <= MAX_FEATURE_BUFFER_SIZE);
if (lpcnet->feature_buffer_fill == max_buffer_size) {
RNN_MOVE(lpcnet->feature_buffer, &lpcnet->feature_buffer[NB_FEATURES], (max_buffer_size-1)*NB_FEATURES);
} else {
lpcnet->feature_buffer_fill++;
}
RNN_COPY(&lpcnet->feature_buffer[(lpcnet->feature_buffer_fill-1)*NB_FEATURES], features, NB_FEATURES);
}
void run_frame_network_flush(LPCNetState *lpcnet)
{
int i;
for (i=0;i<lpcnet->feature_buffer_fill;i++) {
float lpc[LPC_ORDER];
float gru_a_condition[3*GRU_A_STATE_SIZE];
float gru_b_condition[3*GRU_B_STATE_SIZE];
run_frame_network(lpcnet, gru_a_condition, gru_b_condition, lpc, &lpcnet->feature_buffer[i*NB_FEATURES]);
}
lpcnet->feature_buffer_fill = 0;
}
int run_sample_network(LPCNetState *lpcnet, const float *gru_a_condition, const float *gru_b_condition, int last_exc, int last_sig, int pred, const float *sampling_logit_table, kiss99_ctx *rng) int run_sample_network(LPCNetState *lpcnet, const float *gru_a_condition, const float *gru_b_condition, int last_exc, int last_sig, int pred, const float *sampling_logit_table, kiss99_ctx *rng)
{ {
NNetState *net; NNetState *net;

View file

@ -189,11 +189,8 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) {
st->plc_net = st->plc_copy[FEATURES_DELAY]; st->plc_net = st->plc_copy[FEATURES_DELAY];
compute_plc_pred(&st->plc_net, st->features, zeros); compute_plc_pred(&st->plc_net, st->features, zeros);
for (i=0;i<FEATURES_DELAY;i++) { for (i=0;i<FEATURES_DELAY;i++) {
float lpc[LPC_ORDER];
float gru_a_condition[3*GRU_A_STATE_SIZE];
float gru_b_condition[3*GRU_B_STATE_SIZE];
/* FIXME: backtrack state, replace features. */ /* FIXME: backtrack state, replace features. */
run_frame_network(&st->lpcnet, gru_a_condition, gru_b_condition, lpc, st->features); run_frame_network_deferred(&st->lpcnet, st->features);
} }
copy = st->lpcnet; copy = st->lpcnet;
lpcnet_synthesize_impl(&st->lpcnet, &st->features[0], tmp, FRAME_SIZE-TRAINING_OFFSET, 0); lpcnet_synthesize_impl(&st->lpcnet, &st->features[0], tmp, FRAME_SIZE-TRAINING_OFFSET, 0);
@ -238,11 +235,8 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) {
} }
if (st->skip_analysis) { if (st->skip_analysis) {
if (st->enable_blending) { if (st->enable_blending) {
float lpc[LPC_ORDER];
float gru_a_condition[3*GRU_A_STATE_SIZE];
float gru_b_condition[3*GRU_B_STATE_SIZE];
/* FIXME: backtrack state, replace features. */ /* FIXME: backtrack state, replace features. */
run_frame_network(&st->lpcnet, gru_a_condition, gru_b_condition, lpc, st->enc.features[0]); run_frame_network_deferred(&st->lpcnet, st->enc.features[0]);
} }
st->skip_analysis--; st->skip_analysis--;
} else { } else {
@ -250,10 +244,7 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) {
RNN_COPY(output, &st->pcm[0], FRAME_SIZE); RNN_COPY(output, &st->pcm[0], FRAME_SIZE);
#ifdef PLC_SKIP_UPDATES #ifdef PLC_SKIP_UPDATES
{ {
float lpc[LPC_ORDER]; run_frame_network_deferred(&st->lpcnet, st->enc.features[0]);
float gru_a_condition[3*GRU_A_STATE_SIZE];
float gru_b_condition[3*GRU_B_STATE_SIZE];
run_frame_network(&st->lpcnet, gru_a_condition, gru_b_condition, lpc, st->enc.features[0]);
} }
#else #else
lpcnet_synthesize_impl(&st->lpcnet, st->enc.features[0], output, FRAME_SIZE, FRAME_SIZE); lpcnet_synthesize_impl(&st->lpcnet, st->enc.features[0], output, FRAME_SIZE, FRAME_SIZE);
@ -274,6 +265,7 @@ static const float att_table[10] = {0, 0, -.2, -.2, -.4, -.4, -.8, -.8, -1.6,
static int lpcnet_plc_conceal_causal(LPCNetPLCState *st, short *pcm) { static int lpcnet_plc_conceal_causal(LPCNetPLCState *st, short *pcm) {
int i; int i;
short output[FRAME_SIZE]; short output[FRAME_SIZE];
run_frame_network_flush(&st->lpcnet);
st->enc.pcount = 0; st->enc.pcount = 0;
/* If we concealed the previous frame, finish synthesizing the rest of the samples. */ /* If we concealed the previous frame, finish synthesizing the rest of the samples. */
/* FIXME: Copy/predict features. */ /* FIXME: Copy/predict features. */

View file

@ -23,11 +23,14 @@
#define FORBIDDEN_INTERP 7 #define FORBIDDEN_INTERP 7
#define PLC_MAX_FEC 100 #define PLC_MAX_FEC 100
#define MAX_FEATURE_BUFFER_SIZE 4
struct LPCNetState { struct LPCNetState {
NNetState nnet; NNetState nnet;
int last_exc; int last_exc;
float last_sig[LPC_ORDER]; float last_sig[LPC_ORDER];
float feature_buffer[NB_FEATURES*MAX_FEATURE_BUFFER_SIZE];
int feature_buffer_fill;
float last_features[NB_FEATURES]; float last_features[NB_FEATURES];
#if FEATURES_DELAY>0 #if FEATURES_DELAY>0
float old_lpc[FEATURES_DELAY][LPC_ORDER]; float old_lpc[FEATURES_DELAY][LPC_ORDER];
@ -114,6 +117,10 @@ void decode_packet(float features[4][NB_TOTAL_FEATURES], float *vq_mem, const un
void lpcnet_reset_signal(LPCNetState *lpcnet); void lpcnet_reset_signal(LPCNetState *lpcnet);
void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features); void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features);
void run_frame_network_deferred(LPCNetState *lpcnet, const float *features);
void run_frame_network_flush(LPCNetState *lpcnet);
void lpcnet_synthesize_tail_impl(LPCNetState *lpcnet, short *output, int N, int preload); void lpcnet_synthesize_tail_impl(LPCNetState *lpcnet, short *output, int N, int preload);
void lpcnet_synthesize_impl(LPCNetState *lpcnet, const float *features, short *output, int N, int preload); void lpcnet_synthesize_impl(LPCNetState *lpcnet, const float *features, short *output, int N, int preload);
void lpcnet_synthesize_blend_impl(LPCNetState *lpcnet, const short *pcm_in, short *output, int N); void lpcnet_synthesize_blend_impl(LPCNetState *lpcnet, const short *pcm_in, short *output, int N);