Adding RTCD for DNN code

Starting with compute_linear()
This commit is contained in:
Jean-Marc Valin 2023-11-13 18:26:31 -05:00
parent b0620c0bf9
commit 2e034f6f31
No known key found for this signature in database
GPG key ID: 531A52533318F00A
31 changed files with 539 additions and 165 deletions

View file

@ -12,7 +12,8 @@
float compute_pitchdnn(
PitchDNNState *st,
const float *if_features,
const float *xcorr_features
const float *xcorr_features,
int arch
)
{
float if1_out[DENSE_IF_UPSAMPLER_1_OUT_SIZE];
@ -28,16 +29,16 @@ float compute_pitchdnn(
float count=0;
PitchDNN *model = &st->model;
/* IF */
compute_generic_dense(&model->dense_if_upsampler_1, if1_out, if_features, ACTIVATION_TANH);
compute_generic_dense(&model->dense_if_upsampler_2, &downsampler_in[NB_XCORR_FEATURES], if1_out, ACTIVATION_TANH);
compute_generic_dense(&model->dense_if_upsampler_1, if1_out, if_features, ACTIVATION_TANH, arch);
compute_generic_dense(&model->dense_if_upsampler_2, &downsampler_in[NB_XCORR_FEATURES], if1_out, ACTIVATION_TANH, arch);
/* xcorr*/
OPUS_COPY(&conv1_tmp1[1], xcorr_features, NB_XCORR_FEATURES);
compute_conv2d(&model->conv2d_1, &conv1_tmp2[1], st->xcorr_mem1, conv1_tmp1, NB_XCORR_FEATURES, NB_XCORR_FEATURES+2, ACTIVATION_TANH);
compute_conv2d(&model->conv2d_2, downsampler_in, st->xcorr_mem2, conv1_tmp2, NB_XCORR_FEATURES, NB_XCORR_FEATURES, ACTIVATION_TANH);
compute_generic_dense(&model->dense_downsampler, downsampler_out, downsampler_in, ACTIVATION_TANH);
compute_generic_gru(&model->gru_1_input, &model->gru_1_recurrent, st->gru_state, downsampler_out);
compute_generic_dense(&model->dense_final_upsampler, output, st->gru_state, ACTIVATION_LINEAR);
compute_generic_dense(&model->dense_downsampler, downsampler_out, downsampler_in, ACTIVATION_TANH, arch);
compute_generic_gru(&model->gru_1_input, &model->gru_1_recurrent, st->gru_state, downsampler_out, arch);
compute_generic_dense(&model->dense_final_upsampler, output, st->gru_state, ACTIVATION_LINEAR, arch);
for (i=0;i<180;i++) {
if (output[i] > maxval) {
pos = i;
@ -65,7 +66,6 @@ void pitchdnn_init(PitchDNNState *st)
ret = 0;
#endif
celt_assert(ret == 0);
/* FIXME: perform arch detection. */
}
int pitchdnn_load_model(PitchDNNState *st, const unsigned char *data, int len) {