Defer calls to run_frame_network() to save CPU

Calls are deferred to the actual loss and we only process the minimum
required.
This commit is contained in:
Jean-Marc Valin 2023-05-19 18:12:18 -04:00
parent 87f9fbc50c
commit 0098fe70ac
3 changed files with 35 additions and 12 deletions

View file

@ -119,6 +119,30 @@ void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b
if (lpcnet->frame_count < 1000) lpcnet->frame_count++;
}
void run_frame_network_deferred(LPCNetState *lpcnet, const float *features)
{
int max_buffer_size = lpcnet->model.feature_conv1.kernel_size + lpcnet->model.feature_conv2.kernel_size - 2;
celt_assert(max_buffer_size <= MAX_FEATURE_BUFFER_SIZE);
if (lpcnet->feature_buffer_fill == max_buffer_size) {
RNN_MOVE(lpcnet->feature_buffer, &lpcnet->feature_buffer[NB_FEATURES], (max_buffer_size-1)*NB_FEATURES);
} else {
lpcnet->feature_buffer_fill++;
}
RNN_COPY(&lpcnet->feature_buffer[(lpcnet->feature_buffer_fill-1)*NB_FEATURES], features, NB_FEATURES);
}
void run_frame_network_flush(LPCNetState *lpcnet)
{
int i;
for (i=0;i<lpcnet->feature_buffer_fill;i++) {
float lpc[LPC_ORDER];
float gru_a_condition[3*GRU_A_STATE_SIZE];
float gru_b_condition[3*GRU_B_STATE_SIZE];
run_frame_network(lpcnet, gru_a_condition, gru_b_condition, lpc, &lpcnet->feature_buffer[i*NB_FEATURES]);
}
lpcnet->feature_buffer_fill = 0;
}
int run_sample_network(LPCNetState *lpcnet, const float *gru_a_condition, const float *gru_b_condition, int last_exc, int last_sig, int pred, const float *sampling_logit_table, kiss99_ctx *rng)
{
NNetState *net;