diff --git a/dnn/Makefile.am b/dnn/Makefile.am index 2f48e613..d4966a76 100644 --- a/dnn/Makefile.am +++ b/dnn/Makefile.am @@ -36,7 +36,8 @@ liblpcnet_la_SOURCES = \ pitch.c \ freq.c \ kiss_fft.c \ - celt_lpc.c + celt_lpc.c \ + lpcnet_plc.c liblpcnet_la_LIBADD = $(DEPS_LIBS) $(lrintf_lib) $(LIBM) liblpcnet_la_LDFLAGS = -no-undefined \ diff --git a/dnn/include/lpcnet.h b/dnn/include/lpcnet.h index 4b36830b..b3a92341 100644 --- a/dnn/include/lpcnet.h +++ b/dnn/include/lpcnet.h @@ -58,6 +58,8 @@ typedef struct LPCNetDecState LPCNetDecState; typedef struct LPCNetEncState LPCNetEncState; +typedef struct LPCNetPLCState LPCNetPLCState; + /** Gets the size of an LPCNetDecState structure. * @returns The size in bytes. @@ -174,4 +176,15 @@ LPCNET_EXPORT void lpcnet_destroy(LPCNetState *st); */ LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *st, const float *features, short *output, int N); + +LPCNET_EXPORT void lpcnet_plc_init(LPCNetPLCState *st); + +LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(); + +LPCNET_EXPORT void lpcnet_plc_destroy(LPCNetPLCState *st); + +LPCNET_EXPORT int lpcnet_plc_update(LPCNetPLCState *st, short *pcm); + +LPCNET_EXPORT int lpcnet_plc_conceal(LPCNetPLCState *st, short *pcm); + #endif diff --git a/dnn/lpcnet.c b/dnn/lpcnet.c index 5fd451ec..020f828c 100644 --- a/dnn/lpcnet.c +++ b/dnn/lpcnet.c @@ -171,13 +171,9 @@ LPCNET_EXPORT void lpcnet_destroy(LPCNetState *lpcnet) } -LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features, short *output, int N) +void lpcnet_synthesize_tail_impl(LPCNetState *lpcnet, short *output, int N, int preload) { int i; - float lpc[LPC_ORDER]; - float gru_a_condition[3*GRU_A_STATE_SIZE]; - float gru_b_condition[3*GRU_B_STATE_SIZE]; - run_frame_network(lpcnet, gru_a_condition, gru_b_condition, lpc, features); if (lpcnet->frame_count <= FEATURES_DELAY) { @@ -192,10 +188,11 @@ LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features, int last_sig_ulaw; int pred_ulaw; float pred = 0; - for (j=0;jlast_sig[j]*lpc[j]; + for (j=0;jlast_sig[j]*lpcnet->lpc[j]; last_sig_ulaw = lin2ulaw(lpcnet->last_sig[0]); pred_ulaw = lin2ulaw(pred); - exc = run_sample_network(&lpcnet->nnet, gru_a_condition, gru_b_condition, lpcnet->last_exc, last_sig_ulaw, pred_ulaw, lpcnet->sampling_logit_table, &lpcnet->rng); + exc = run_sample_network(&lpcnet->nnet, lpcnet->gru_a_condition, lpcnet->gru_b_condition, lpcnet->last_exc, last_sig_ulaw, pred_ulaw, lpcnet->sampling_logit_table, &lpcnet->rng); + if (i < preload) exc = lin2ulaw(output[i]-PREEMPH*lpcnet->deemph_mem - pred); pcm = pred + ulaw2lin(exc); RNN_MOVE(&lpcnet->last_sig[1], &lpcnet->last_sig[0], LPC_ORDER-1); lpcnet->last_sig[0] = pcm; @@ -208,6 +205,15 @@ LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features, } } +void lpcnet_synthesize_impl(LPCNetState *lpcnet, const float *features, short *output, int N, int preload) +{ + run_frame_network(lpcnet, lpcnet->gru_a_condition, lpcnet->gru_b_condition, lpcnet->lpc, features); + lpcnet_synthesize_tail_impl(lpcnet, output, N, preload); +} + +LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features, short *output, int N) { + lpcnet_synthesize_impl(lpcnet, features, output, N, 0); +} LPCNET_EXPORT int lpcnet_decoder_get_size() { diff --git a/dnn/lpcnet_demo.c b/dnn/lpcnet_demo.c index e05755b0..ed004045 100644 --- a/dnn/lpcnet_demo.c +++ b/dnn/lpcnet_demo.c @@ -38,23 +38,30 @@ #define MODE_DECODE 1 #define MODE_FEATURES 2 #define MODE_SYNTHESIS 3 +#define MODE_PLC 4 int main(int argc, char **argv) { int mode; + int plc_percent=0; FILE *fin, *fout; - if (argc != 4) + if (argc != 4 && !(argc == 5 && strcmp(argv[1], "-plc") == 0)) { fprintf(stderr, "usage: lpcnet_demo -encode \n"); fprintf(stderr, " lpcnet_demo -decode \n"); fprintf(stderr, " lpcnet_demo -features \n"); fprintf(stderr, " lpcnet_demo -synthesis \n"); + fprintf(stderr, " lpcnet_demo -plc \n"); return 0; } if (strcmp(argv[1], "-encode") == 0) mode=MODE_ENCODE; else if (strcmp(argv[1], "-decode") == 0) mode=MODE_DECODE; else if (strcmp(argv[1], "-features") == 0) mode=MODE_FEATURES; else if (strcmp(argv[1], "-synthesis") == 0) mode=MODE_SYNTHESIS; - else { + else if (strcmp(argv[1], "-plc") == 0) { + mode=MODE_PLC; + plc_percent = atoi(argv[2]); + argv++; + } else { exit(1); } fin = fopen(argv[2], "rb"); @@ -123,6 +130,23 @@ int main(int argc, char **argv) { fwrite(pcm, sizeof(pcm[0]), LPCNET_FRAME_SIZE, fout); } lpcnet_destroy(net); + } else if (mode == MODE_PLC) { + int count=0; + int loss=0; + LPCNetPLCState *net; + net = lpcnet_plc_create(); + while (1) { + short pcm[FRAME_SIZE]; + size_t ret; + ret = fread(pcm, sizeof(pcm[0]), FRAME_SIZE, fin); + if (feof(fin) || ret != FRAME_SIZE) break; + if (count % 2 == 0) loss = rand() < RAND_MAX*(float)plc_percent/100.f; + if (loss) lpcnet_plc_conceal(net, pcm); + else lpcnet_plc_update(net, pcm); + fwrite(pcm, sizeof(pcm[0]), FRAME_SIZE, fout); + count++; + } + lpcnet_plc_destroy(net); } else { fprintf(stderr, "unknown action\n"); } diff --git a/dnn/lpcnet_enc.c b/dnn/lpcnet_enc.c index 55d5ef28..34a85a24 100644 --- a/dnn/lpcnet_enc.c +++ b/dnn/lpcnet_enc.c @@ -894,7 +894,7 @@ LPCNET_EXPORT int lpcnet_compute_features(LPCNetEncState *st, const short *pcm, return 0; } -LPCNET_EXPORT int lpcnet_compute_single_frame_features(LPCNetEncState *st, const short *pcm, float features[NB_TOTAL_FEATURES]) { +int lpcnet_compute_single_frame_features(LPCNetEncState *st, const short *pcm, float features[NB_TOTAL_FEATURES]) { int i; float x[FRAME_SIZE]; for (i=0;ilpcnet); + lpcnet_encoder_init(&st->enc); + RNN_CLEAR(st->pcm, PLC_BUF_SIZE); + st->pcm_fill = PLC_BUF_SIZE; + st->skip_analysis = 0; + st->blend = 0; +} + +LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create() { + LPCNetPLCState *st; + st = malloc(sizeof(*st)); + lpcnet_plc_init(st); + return st; +} + +LPCNET_EXPORT void lpcnet_plc_destroy(LPCNetPLCState *st) { + free(st); +} + +LPCNET_EXPORT int lpcnet_plc_update(LPCNetPLCState *st, short *pcm) { + int i; + float x[FRAME_SIZE]; + short output[FRAME_SIZE]; + st->enc.pcount = 0; + if (st->skip_analysis) { + //fprintf(stderr, "skip update\n"); + if (st->blend) { + short tmp[FRAME_SIZE-TRAINING_OFFSET]; + lpcnet_synthesize_tail_impl(&st->lpcnet, tmp, FRAME_SIZE-TRAINING_OFFSET, 0); + for (i=0;iblend = 0; + RNN_COPY(st->pcm, &pcm[FRAME_SIZE-TRAINING_OFFSET], TRAINING_OFFSET); + st->pcm_fill = TRAINING_OFFSET; + } else { + RNN_COPY(&st->pcm[st->pcm_fill], pcm, FRAME_SIZE); + st->pcm_fill += FRAME_SIZE; + } + //fprintf(stderr, "fill at %d\n", st->pcm_fill); + } + /* Update state. */ + //fprintf(stderr, "update state\n"); + for (i=0;ienc.mem_preemph, x, PREEMPHASIS, FRAME_SIZE); + compute_frame_features(&st->enc, x); + process_single_frame(&st->enc, NULL); + if (st->skip_analysis) { + float lpc[LPC_ORDER]; + float gru_a_condition[3*GRU_A_STATE_SIZE]; + float gru_b_condition[3*GRU_B_STATE_SIZE]; + /* FIXME: backtrack state, replace features. */ + run_frame_network(&st->lpcnet, gru_a_condition, gru_b_condition, lpc, st->enc.features[0]); + st->skip_analysis--; + } else { + for (i=0;ipcm[PLC_BUF_SIZE+i] = pcm[i]; + RNN_COPY(output, &st->pcm[0], FRAME_SIZE); + lpcnet_synthesize_impl(&st->lpcnet, st->enc.features[0], output, FRAME_SIZE, FRAME_SIZE); + + RNN_MOVE(st->pcm, &st->pcm[FRAME_SIZE], PLC_BUF_SIZE); + } + RNN_COPY(st->features, st->enc.features[0], NB_TOTAL_FEATURES); + return 0; +} + +LPCNET_EXPORT int lpcnet_plc_conceal(LPCNetPLCState *st, short *pcm) { + short output[FRAME_SIZE]; + st->enc.pcount = 0; + /* If we concealed the previous frame, finish synthesizing the rest of the samples. */ + /* FIXME: Copy/predict features. */ + while (st->pcm_fill > 0) { + //fprintf(stderr, "update state for PLC %d\n", st->pcm_fill); + int update_count; + update_count = IMIN(st->pcm_fill, FRAME_SIZE); + RNN_COPY(output, &st->pcm[0], update_count); + + lpcnet_synthesize_impl(&st->lpcnet, &st->features[0], output, update_count, update_count); + RNN_MOVE(st->pcm, &st->pcm[FRAME_SIZE], PLC_BUF_SIZE); + st->pcm_fill -= update_count; + st->skip_analysis++; + } + lpcnet_synthesize_tail_impl(&st->lpcnet, pcm, FRAME_SIZE-TRAINING_OFFSET, 0); + lpcnet_synthesize_impl(&st->lpcnet, &st->features[0], &pcm[FRAME_SIZE-TRAINING_OFFSET], TRAINING_OFFSET, 0); + { + int i; + float x[FRAME_SIZE]; + /* FIXME: Can we do better? */ + for (i=0;ienc.mem_preemph, x, PREEMPHASIS, FRAME_SIZE); + compute_frame_features(&st->enc, x); + process_single_frame(&st->enc, NULL); + } + st->blend = 1; + return 0; +} diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index c7e88f5b..71628f30 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -32,8 +32,11 @@ struct LPCNetState { float old_lpc[FEATURES_DELAY][LPC_ORDER]; #endif float sampling_logit_table[256]; + float gru_a_condition[3*GRU_A_STATE_SIZE]; + float gru_b_condition[3*GRU_B_STATE_SIZE]; int frame_count; float deemph_mem; + float lpc[LPC_ORDER]; kiss99_ctx rng; }; @@ -63,6 +66,16 @@ struct LPCNetEncState{ int exc_mem; }; +#define PLC_BUF_SIZE (FEATURES_DELAY*FRAME_SIZE + TRAINING_OFFSET) +struct LPCNetPLCState { + LPCNetState lpcnet; + LPCNetEncState enc; + short pcm[PLC_BUF_SIZE+FRAME_SIZE]; + int pcm_fill; + int skip_analysis; + int blend; + float features[NB_TOTAL_FEATURES]; +}; extern float ceps_codebook1[]; extern float ceps_codebook2[]; @@ -79,6 +92,13 @@ void compute_frame_features(LPCNetEncState *st, const float *in); void decode_packet(float features[4][NB_TOTAL_FEATURES], float *vq_mem, const unsigned char buf[8]); +void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features); +void lpcnet_synthesize_tail_impl(LPCNetState *lpcnet, short *output, int N, int preload); +void lpcnet_synthesize_impl(LPCNetState *lpcnet, const float *features, short *output, int N, int preload); +void lpcnet_synthesize_blend_impl(LPCNetState *lpcnet, const short *pcm_in, short *output, int N); +void process_single_frame(LPCNetEncState *st, FILE *ffeat); +int lpcnet_compute_single_frame_features(LPCNetEncState *st, const short *pcm, float features[NB_TOTAL_FEATURES]); + void process_single_frame(LPCNetEncState *st, FILE *ffeat); void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features);