Add blob loading support to decoder

This commit is contained in:
Jean-Marc Valin 2023-06-01 13:44:22 -04:00
parent d43eb241e3
commit 6fb930956f
No known key found for this signature in database
GPG key ID: 531A52533318F00A
3 changed files with 76 additions and 2 deletions

View file

@ -171,6 +171,8 @@ extern "C" {
#define OPUS_GET_IN_DTX_REQUEST 4049
#define OPUS_SET_DRED_DURATION_REQUEST 4050
#define OPUS_GET_DRED_DURATION_REQUEST 4051
#define OPUS_SET_DNN_BLOB_REQUEST 4052
/*#define OPUS_GET_DNN_BLOB_REQUEST 4053 */
/** Defines for the presence of extended APIs. */
#define OPUS_HAVE_OPUS_PROJECTION_H
@ -179,6 +181,7 @@ extern "C" {
#define __opus_check_int(x) (((void)((x) == (opus_int32)0)), (opus_int32)(x))
#define __opus_check_int_ptr(ptr) ((ptr) + ((ptr) - (opus_int32*)(ptr)))
#define __opus_check_uint_ptr(ptr) ((ptr) + ((ptr) - (opus_uint32*)(ptr)))
#define __opus_check_uint8_ptr(ptr) ((ptr) + ((ptr) - (opus_uint8*)(ptr)))
#define __opus_check_val16_ptr(ptr) ((ptr) + ((ptr) - (opus_val16*)(ptr)))
/** @endcond */
@ -629,6 +632,10 @@ extern "C" {
* @hideinitializer */
#define OPUS_GET_DRED_DURATION(x) OPUS_GET_DRED_DURATION_REQUEST, __opus_check_int_ptr(x)
/** Provide external DNN weights from binary object (only when explicitly built without the weights)
* @hideinitializer */
#define OPUS_SET_DNN_BLOB(data, len) OPUS_SET_DNN_BLOB_REQUEST, __opus_check_uint8_ptr(data), __opus_check_int(len)
/**@}*/

View file

@ -995,6 +995,19 @@ int opus_decoder_ctl(OpusDecoder *st, int request, ...)
ret = celt_decoder_ctl(celt_dec, OPUS_GET_PHASE_INVERSION_DISABLED(value));
}
break;
#ifdef USE_WEIGHTS_FILE
case OPUS_SET_DNN_BLOB_REQUEST:
{
const unsigned char *data = va_arg(ap, const unsigned char *);
opus_int32 len = va_arg(ap, opus_int32);
if(len<0 || data == NULL)
{
goto bad_arg;
}
return lpcnet_plc_load_model(&st->lpcnet, data, len);
}
break;
#endif
default:
/*fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);*/
ret = OPUS_UNIMPLEMENTED;

View file

@ -42,6 +42,50 @@
#define MAX_PACKET 1500
#ifdef USE_WEIGHTS_FILE
# if __unix__
# include <fcntl.h>
# include <sys/mman.h>
# include <unistd.h>
# include <sys/stat.h>
/* When available, mmap() is preferable to reading the file, as it leads to
better resource utilization, especially if multiple processes are using the same
file (mapping will be shared in cache). */
unsigned char *load_blob(const char *filename, int *len) {
int fd;
unsigned char *data;
struct stat st;
stat(filename, &st);
*len = st.st_size;
fd = open(filename, O_RDONLY);
data = mmap(NULL, *len, PROT_READ, MAP_SHARED, fd, 0);
close(fd);
return data;
}
void free_blob(unsigned char *blob, int len) {
munmap(blob, len);
}
# else
unsigned char *load_blob(const char *filename, int *len) {
FILE *file;
unsigned char *data;
file = fopen(filename, "r");
fseek(file, 0L, SEEK_END);
*len = ftell(file);
fseek(file, 0L, SEEK_SET);
if (*len <= 0) return NULL;
data = malloc(*len);
*len = fread(data, 1, *len, file);
return data;
}
void free_blob(unsigned char *blob, int len) {
free(blob);
(void)len;
}
# endif
#endif
void print_usage( char* argv[] )
{
fprintf(stderr, "Usage: %s [-e] <application> <sampling rate (Hz)> <channels (1/2)> "
@ -270,6 +314,12 @@ int main(int argc, char *argv[])
int lost_count=0;
FILE *packet_loss_file=NULL;
int dred_duration=0;
#ifdef USE_WEIGHTS_FILE
int blob_len;
unsigned char *blob_data;
const char *filename = "weights_blob.bin";
blob_data = load_blob(filename, &blob_len);
#endif
if (argc < 5 )
{
@ -567,8 +617,9 @@ int main(int argc, char *argv[])
goto failure;
}
}
#ifdef USE_WEIGHTS_FILE
opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len));
#endif
switch(bandwidth)
{
case OPUS_BANDWIDTH_NARROWBAND:
@ -928,5 +979,8 @@ failure:
free(in);
free(out);
free(fbytes);
#ifdef USE_WEIGHTS_FILE
free_blob(blob_data, blob_len);
#endif
return ret;
}