Add blob loading for DRED encoder and decoder

This commit is contained in:
Jean-Marc Valin 2023-06-06 17:19:12 -04:00
parent 0dad5e06ab
commit a8cb719d05
No known key found for this signature in database
GPG key ID: 531A52533318F00A
6 changed files with 91 additions and 5 deletions

View file

@ -547,7 +547,18 @@ OPUS_EXPORT int opus_dred_decoder_init(OpusDREDDecoder *dec);
*/
OPUS_EXPORT void opus_dred_decoder_destroy(OpusDREDDecoder *dec);
/** Perform a CTL function on an Opus DRED decoder.
*
* Generally the request and subsequent arguments are generated
* by a convenience macro.
* @param st <tt>OpusDREDDecoder*</tt>: DRED Decoder state.
* @param request This and all remaining parameters should be replaced by one
* of the convenience macros in @ref opus_genericctls or
* @ref opus_decoderctls.
* @see opus_genericctls
* @see opus_decoderctls
*/
OPUS_EXPORT int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...);
/** Gets the size of an <code>OpusDRED</code> structure.
* @returns The size in bytes.

View file

@ -44,6 +44,16 @@
#include "float_cast.h"
#include "os_support.h"
int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len)
{
WeightArray *list;
int ret;
parse_weights(&list, data, len);
ret = init_rdovaeenc(&enc->model, list);
free(list);
return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
}
void dred_encoder_reset(DREDEnc* enc)
{
RNN_CLEAR((char*)&enc->DREDENC_RESET_START,

View file

@ -54,7 +54,7 @@ typedef struct {
RDOVAEEncState rdovae_enc;
} DREDEnc;
int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len);
void dred_encoder_init(DREDEnc* enc, opus_int32 Fs, int channels);
void dred_encoder_reset(DREDEnc* enc);

View file

@ -1141,6 +1141,16 @@ int opus_dred_decoder_get_size(void)
return sizeof(OpusDREDDecoder);
}
int dred_decoder_load_model(OpusDREDDecoder *dec, const unsigned char *data, int len)
{
WeightArray *list;
int ret;
parse_weights(&list, data, len);
ret = init_rdovaedec(&dec->model, list);
free(list);
return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
}
int opus_dred_decoder_init(OpusDREDDecoder *dec)
{
#ifndef USE_WEIGHTS_FILE
@ -1180,7 +1190,47 @@ void opus_dred_decoder_destroy(OpusDREDDecoder *dec)
free(dec);
}
int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...)
{
#ifdef ENABLE_NEURAL_FEC
int ret = OPUS_OK;
va_list ap;
va_start(ap, request);
(void)dred_dec;
switch (request)
{
# ifdef USE_WEIGHTS_FILE
case OPUS_SET_DNN_BLOB_REQUEST:
{
const unsigned char *data = va_arg(ap, const unsigned char *);
opus_int32 len = va_arg(ap, opus_int32);
if(len<0 || data == NULL)
{
goto bad_arg;
}
return dred_decoder_load_model(dred_dec, data, len);
}
break;
# endif
default:
/*fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);*/
ret = OPUS_UNIMPLEMENTED;
break;
}
va_end(ap);
return ret;
# ifdef USE_WEIGHTS_FILE
bad_arg:
va_end(ap);
return OPUS_BAD_ARG;
# endif
#else
(void)dred_dec;
(void)request;
return OPUS_UNIMPLEMENTED;
#endif
}
#ifdef ENABLE_NEURAL_FEC
static int dred_find_payload(const unsigned char *data, opus_int32 len, const unsigned char **payload)

View file

@ -617,9 +617,6 @@ int main(int argc, char *argv[])
goto failure;
}
}
#ifdef USE_WEIGHTS_FILE
opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len));
#endif
switch(bandwidth)
{
case OPUS_BANDWIDTH_NARROWBAND:
@ -684,6 +681,11 @@ int main(int argc, char *argv[])
}
dred_dec = opus_dred_decoder_create(&err);
dred = opus_dred_alloc(&err);
#ifdef USE_WEIGHTS_FILE
opus_encoder_ctl(enc, OPUS_SET_DNN_BLOB(blob_data, blob_len));
opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len));
opus_dred_decoder_ctl(dred_dec, OPUS_SET_DNN_BLOB(blob_data, blob_len));
#endif
while (!stop)
{
if (delayed_celt)

View file

@ -2847,6 +2847,19 @@ int opus_encoder_ctl(OpusEncoder *st, int request, ...)
}
}
break;
#ifdef USE_WEIGHTS_FILE
case OPUS_SET_DNN_BLOB_REQUEST:
{
const unsigned char *data = va_arg(ap, const unsigned char *);
opus_int32 len = va_arg(ap, opus_int32);
if(len<0 || data == NULL)
{
goto bad_arg;
}
return dred_encoder_load_model(&st->dred_encoder, data, len);
}
break;
#endif
case CELT_GET_MODE_REQUEST:
{
const CELTMode ** value = va_arg(ap, const CELTMode**);