Initial blob loading support

This commit is contained in:
Jean-Marc Valin 2023-05-28 01:53:20 -04:00
parent d98c59fb9a
commit fa7b432eed
8 changed files with 100 additions and 5 deletions

View file

@ -62,7 +62,7 @@ dump_data_SOURCES = common.c dump_data.c burg.c freq.c kiss_fft.c pitch.c lpcnet
dump_data_LDADD = $(LIBM)
dump_data_CFLAGS = $(AM_CFLAGS)
dump_weights_blob_SOURCES = nnet_data.c plc_data.c write_lpcnet_weights.c
dump_weights_blob_SOURCES = write_lpcnet_weights.c
dump_weights_blob_LDADD = $(LIBM)
dump_weights_blob_CFLAGS = $(AM_CFLAGS) -DDUMP_BINARY_WEIGHTS

View file

@ -6,7 +6,7 @@ srcdir=`dirname $0`
test -n "$srcdir" && cd "$srcdir"
#SHA1 of the first commit compatible with the current model
commit=399be7c
commit=859bfae
./download_model.sh $commit
echo "Updating build configuration files for lpcnet, please wait...."

View file

@ -199,4 +199,7 @@ LPCNET_EXPORT void lpcnet_plc_fec_add(LPCNetPLCState *st, const float *features)
LPCNET_EXPORT void lpcnet_plc_fec_clear(LPCNetPLCState *st);
LPCNET_EXPORT int lpcnet_load_model(LPCNetState *st, const unsigned char *data, int len);
LPCNET_EXPORT int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len);
#endif

View file

@ -183,11 +183,25 @@ LPCNET_EXPORT int lpcnet_init(LPCNetState *lpcnet)
lpcnet->sampling_logit_table[i] = -log((1-prob)/prob);
}
kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string));
#ifndef USE_WEIGHTS_FILE
ret = init_lpcnet_model(&lpcnet->model, lpcnet_arrays);
#else
ret = 0;
#endif
celt_assert(ret == 0);
return ret;
}
LPCNET_EXPORT int lpcnet_load_model(LPCNetState *st, const unsigned char *data, int len) {
WeightArray *list;
int ret;
parse_weights(&list, data, len);
ret = init_lpcnet_model(&st->model, list);
free(list);
if (ret == 0) return 0;
else return -1;
}
LPCNET_EXPORT LPCNetState *lpcnet_create()
{

View file

@ -34,6 +34,49 @@
#include "lpcnet.h"
#include "freq.h"
#ifdef USE_WEIGHTS_FILE
# if __unix__
# include <fcntl.h>
# include <sys/mman.h>
# include <unistd.h>
# include <sys/stat.h>
/* When available, mmap() is preferable to reading the file, as it leads to
better resource utilization, especially if multiple processes are using the same
file (mapping will be shared in cache). */
unsigned char *load_blob(const char *filename, int *len) {
int fd;
unsigned char *data;
struct stat st;
stat(filename, &st);
*len = st.st_size;
fd = open(filename, O_RDONLY);
data = mmap(NULL, *len, PROT_READ, MAP_SHARED, fd, 0);
close(fd);
return data;
}
void free_blob(unsigned char *blob, int len) {
munmap(blob, len);
}
# else
unsigned char *load_blob(const char *filename, int *len) {
FILE *file;
unsigned char *data;
file = fopen(filename, "r");
fseek(file, 0L, SEEK_END);
*len = ftell(file);
fseek(file, 0L, SEEK_SET);
if (*len <= 0) return NULL;
data = malloc(*len);
*len = fread(data, 1, *len, file);
return data;
}
void free_blob(unsigned char *blob, int len) {
free(blob);
(void)len;
}
# endif
#endif
#define MODE_ENCODE 0
#define MODE_DECODE 1
#define MODE_FEATURES 2
@ -64,6 +107,11 @@ int main(int argc, char **argv) {
FILE *plc_file = NULL;
const char *plc_options;
int plc_flags=-1;
#ifdef USE_WEIGHTS_FILE
int len;
unsigned char *data;
const char *filename = "weights_blob.bin";
#endif
if (argc < 4) usage();
if (strcmp(argv[1], "-encode") == 0) mode=MODE_ENCODE;
else if (strcmp(argv[1], "-decode") == 0) mode=MODE_DECODE;
@ -109,7 +157,9 @@ int main(int argc, char **argv) {
fprintf(stderr, "Can't open %s\n", argv[3]);
exit(1);
}
#ifdef USE_WEIGHTS_FILE
data = load_blob(filename, &len);
#endif
if (mode == MODE_ENCODE) {
LPCNetEncState *net;
net = lpcnet_encoder_create();
@ -152,6 +202,9 @@ int main(int argc, char **argv) {
} else if (mode == MODE_SYNTHESIS) {
LPCNetState *net;
net = lpcnet_create();
#ifdef USE_WEIGHTS_FILE
lpcnet_load_model(net, data, len);
#endif
while (1) {
float in_features[NB_TOTAL_FEATURES];
float features[NB_FEATURES];
@ -207,5 +260,8 @@ int main(int argc, char **argv) {
}
fclose(fin);
fclose(fout);
#ifdef USE_WEIGHTS_FILE
free_blob(data, len);
#endif
return 0;
}

View file

@ -68,11 +68,25 @@ LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) {
return -1;
}
st->remove_dc = !!(options&LPCNET_PLC_DC_FILTER);
#ifndef USE_WEIGHTS_FILE
ret = init_plc_model(&st->model, lpcnet_plc_arrays);
#else
ret = 0;
#endif
celt_assert(ret == 0);
return ret;
}
LPCNET_EXPORT int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len) {
WeightArray *list;
int ret;
parse_weights(&list, data, len);
ret = init_plc_model(&st->model, list);
free(list);
if (ret == 0) return 0;
else return -1;
}
LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(int options) {
LPCNetPLCState *st;
st = calloc(sizeof(*st), 1);

View file

@ -131,4 +131,6 @@ int lpcnet_compute_single_frame_features(LPCNetEncState *st, const short *pcm, f
void process_single_frame(LPCNetEncState *st, FILE *ffeat);
void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features);
int parse_weights(WeightArray **list, const unsigned char *data, int len);
#endif

View file

@ -31,8 +31,14 @@
#include <stdio.h>
#include "nnet.h"
extern const WeightArray lpcnet_arrays[];
extern const WeightArray lpcnet_plc_arrays[];
/* This is a bit of a hack because we need to build nnet_data.c and plc_data.c without USE_WEIGHTS_FILE,
but USE_WEIGHTS_FILE is defined in config.h. */
#undef HAVE_CONFIG_H
#ifdef USE_WEIGHTS_FILE
#undef USE_WEIGHTS_FILE
#endif
#include "nnet_data.c"
#include "plc_data.c"
void write_weights(const WeightArray *list, FILE *fout)
{