Initial blob loading support
This commit is contained in:
parent
d98c59fb9a
commit
fa7b432eed
8 changed files with 100 additions and 5 deletions
|
@ -62,7 +62,7 @@ dump_data_SOURCES = common.c dump_data.c burg.c freq.c kiss_fft.c pitch.c lpcnet
|
||||||
dump_data_LDADD = $(LIBM)
|
dump_data_LDADD = $(LIBM)
|
||||||
dump_data_CFLAGS = $(AM_CFLAGS)
|
dump_data_CFLAGS = $(AM_CFLAGS)
|
||||||
|
|
||||||
dump_weights_blob_SOURCES = nnet_data.c plc_data.c write_lpcnet_weights.c
|
dump_weights_blob_SOURCES = write_lpcnet_weights.c
|
||||||
dump_weights_blob_LDADD = $(LIBM)
|
dump_weights_blob_LDADD = $(LIBM)
|
||||||
dump_weights_blob_CFLAGS = $(AM_CFLAGS) -DDUMP_BINARY_WEIGHTS
|
dump_weights_blob_CFLAGS = $(AM_CFLAGS) -DDUMP_BINARY_WEIGHTS
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ srcdir=`dirname $0`
|
||||||
test -n "$srcdir" && cd "$srcdir"
|
test -n "$srcdir" && cd "$srcdir"
|
||||||
|
|
||||||
#SHA1 of the first commit compatible with the current model
|
#SHA1 of the first commit compatible with the current model
|
||||||
commit=399be7c
|
commit=859bfae
|
||||||
./download_model.sh $commit
|
./download_model.sh $commit
|
||||||
|
|
||||||
echo "Updating build configuration files for lpcnet, please wait...."
|
echo "Updating build configuration files for lpcnet, please wait...."
|
||||||
|
|
|
@ -199,4 +199,7 @@ LPCNET_EXPORT void lpcnet_plc_fec_add(LPCNetPLCState *st, const float *features)
|
||||||
|
|
||||||
LPCNET_EXPORT void lpcnet_plc_fec_clear(LPCNetPLCState *st);
|
LPCNET_EXPORT void lpcnet_plc_fec_clear(LPCNetPLCState *st);
|
||||||
|
|
||||||
|
LPCNET_EXPORT int lpcnet_load_model(LPCNetState *st, const unsigned char *data, int len);
|
||||||
|
LPCNET_EXPORT int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
14
dnn/lpcnet.c
14
dnn/lpcnet.c
|
@ -183,11 +183,25 @@ LPCNET_EXPORT int lpcnet_init(LPCNetState *lpcnet)
|
||||||
lpcnet->sampling_logit_table[i] = -log((1-prob)/prob);
|
lpcnet->sampling_logit_table[i] = -log((1-prob)/prob);
|
||||||
}
|
}
|
||||||
kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string));
|
kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string));
|
||||||
|
#ifndef USE_WEIGHTS_FILE
|
||||||
ret = init_lpcnet_model(&lpcnet->model, lpcnet_arrays);
|
ret = init_lpcnet_model(&lpcnet->model, lpcnet_arrays);
|
||||||
|
#else
|
||||||
|
ret = 0;
|
||||||
|
#endif
|
||||||
celt_assert(ret == 0);
|
celt_assert(ret == 0);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LPCNET_EXPORT int lpcnet_load_model(LPCNetState *st, const unsigned char *data, int len) {
|
||||||
|
WeightArray *list;
|
||||||
|
int ret;
|
||||||
|
parse_weights(&list, data, len);
|
||||||
|
ret = init_lpcnet_model(&st->model, list);
|
||||||
|
free(list);
|
||||||
|
if (ret == 0) return 0;
|
||||||
|
else return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
LPCNET_EXPORT LPCNetState *lpcnet_create()
|
LPCNET_EXPORT LPCNetState *lpcnet_create()
|
||||||
{
|
{
|
||||||
|
|
|
@ -34,6 +34,49 @@
|
||||||
#include "lpcnet.h"
|
#include "lpcnet.h"
|
||||||
#include "freq.h"
|
#include "freq.h"
|
||||||
|
|
||||||
|
#ifdef USE_WEIGHTS_FILE
|
||||||
|
# if __unix__
|
||||||
|
# include <fcntl.h>
|
||||||
|
# include <sys/mman.h>
|
||||||
|
# include <unistd.h>
|
||||||
|
# include <sys/stat.h>
|
||||||
|
/* When available, mmap() is preferable to reading the file, as it leads to
|
||||||
|
better resource utilization, especially if multiple processes are using the same
|
||||||
|
file (mapping will be shared in cache). */
|
||||||
|
unsigned char *load_blob(const char *filename, int *len) {
|
||||||
|
int fd;
|
||||||
|
unsigned char *data;
|
||||||
|
struct stat st;
|
||||||
|
stat(filename, &st);
|
||||||
|
*len = st.st_size;
|
||||||
|
fd = open(filename, O_RDONLY);
|
||||||
|
data = mmap(NULL, *len, PROT_READ, MAP_SHARED, fd, 0);
|
||||||
|
close(fd);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
void free_blob(unsigned char *blob, int len) {
|
||||||
|
munmap(blob, len);
|
||||||
|
}
|
||||||
|
# else
|
||||||
|
unsigned char *load_blob(const char *filename, int *len) {
|
||||||
|
FILE *file;
|
||||||
|
unsigned char *data;
|
||||||
|
file = fopen(filename, "r");
|
||||||
|
fseek(file, 0L, SEEK_END);
|
||||||
|
*len = ftell(file);
|
||||||
|
fseek(file, 0L, SEEK_SET);
|
||||||
|
if (*len <= 0) return NULL;
|
||||||
|
data = malloc(*len);
|
||||||
|
*len = fread(data, 1, *len, file);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
void free_blob(unsigned char *blob, int len) {
|
||||||
|
free(blob);
|
||||||
|
(void)len;
|
||||||
|
}
|
||||||
|
# endif
|
||||||
|
#endif
|
||||||
|
|
||||||
#define MODE_ENCODE 0
|
#define MODE_ENCODE 0
|
||||||
#define MODE_DECODE 1
|
#define MODE_DECODE 1
|
||||||
#define MODE_FEATURES 2
|
#define MODE_FEATURES 2
|
||||||
|
@ -64,6 +107,11 @@ int main(int argc, char **argv) {
|
||||||
FILE *plc_file = NULL;
|
FILE *plc_file = NULL;
|
||||||
const char *plc_options;
|
const char *plc_options;
|
||||||
int plc_flags=-1;
|
int plc_flags=-1;
|
||||||
|
#ifdef USE_WEIGHTS_FILE
|
||||||
|
int len;
|
||||||
|
unsigned char *data;
|
||||||
|
const char *filename = "weights_blob.bin";
|
||||||
|
#endif
|
||||||
if (argc < 4) usage();
|
if (argc < 4) usage();
|
||||||
if (strcmp(argv[1], "-encode") == 0) mode=MODE_ENCODE;
|
if (strcmp(argv[1], "-encode") == 0) mode=MODE_ENCODE;
|
||||||
else if (strcmp(argv[1], "-decode") == 0) mode=MODE_DECODE;
|
else if (strcmp(argv[1], "-decode") == 0) mode=MODE_DECODE;
|
||||||
|
@ -109,7 +157,9 @@ int main(int argc, char **argv) {
|
||||||
fprintf(stderr, "Can't open %s\n", argv[3]);
|
fprintf(stderr, "Can't open %s\n", argv[3]);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
#ifdef USE_WEIGHTS_FILE
|
||||||
|
data = load_blob(filename, &len);
|
||||||
|
#endif
|
||||||
if (mode == MODE_ENCODE) {
|
if (mode == MODE_ENCODE) {
|
||||||
LPCNetEncState *net;
|
LPCNetEncState *net;
|
||||||
net = lpcnet_encoder_create();
|
net = lpcnet_encoder_create();
|
||||||
|
@ -152,6 +202,9 @@ int main(int argc, char **argv) {
|
||||||
} else if (mode == MODE_SYNTHESIS) {
|
} else if (mode == MODE_SYNTHESIS) {
|
||||||
LPCNetState *net;
|
LPCNetState *net;
|
||||||
net = lpcnet_create();
|
net = lpcnet_create();
|
||||||
|
#ifdef USE_WEIGHTS_FILE
|
||||||
|
lpcnet_load_model(net, data, len);
|
||||||
|
#endif
|
||||||
while (1) {
|
while (1) {
|
||||||
float in_features[NB_TOTAL_FEATURES];
|
float in_features[NB_TOTAL_FEATURES];
|
||||||
float features[NB_FEATURES];
|
float features[NB_FEATURES];
|
||||||
|
@ -207,5 +260,8 @@ int main(int argc, char **argv) {
|
||||||
}
|
}
|
||||||
fclose(fin);
|
fclose(fin);
|
||||||
fclose(fout);
|
fclose(fout);
|
||||||
|
#ifdef USE_WEIGHTS_FILE
|
||||||
|
free_blob(data, len);
|
||||||
|
#endif
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,11 +68,25 @@ LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
st->remove_dc = !!(options&LPCNET_PLC_DC_FILTER);
|
st->remove_dc = !!(options&LPCNET_PLC_DC_FILTER);
|
||||||
|
#ifndef USE_WEIGHTS_FILE
|
||||||
ret = init_plc_model(&st->model, lpcnet_plc_arrays);
|
ret = init_plc_model(&st->model, lpcnet_plc_arrays);
|
||||||
|
#else
|
||||||
|
ret = 0;
|
||||||
|
#endif
|
||||||
celt_assert(ret == 0);
|
celt_assert(ret == 0);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LPCNET_EXPORT int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len) {
|
||||||
|
WeightArray *list;
|
||||||
|
int ret;
|
||||||
|
parse_weights(&list, data, len);
|
||||||
|
ret = init_plc_model(&st->model, list);
|
||||||
|
free(list);
|
||||||
|
if (ret == 0) return 0;
|
||||||
|
else return -1;
|
||||||
|
}
|
||||||
|
|
||||||
LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(int options) {
|
LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(int options) {
|
||||||
LPCNetPLCState *st;
|
LPCNetPLCState *st;
|
||||||
st = calloc(sizeof(*st), 1);
|
st = calloc(sizeof(*st), 1);
|
||||||
|
|
|
@ -131,4 +131,6 @@ int lpcnet_compute_single_frame_features(LPCNetEncState *st, const short *pcm, f
|
||||||
void process_single_frame(LPCNetEncState *st, FILE *ffeat);
|
void process_single_frame(LPCNetEncState *st, FILE *ffeat);
|
||||||
|
|
||||||
void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features);
|
void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features);
|
||||||
|
|
||||||
|
int parse_weights(WeightArray **list, const unsigned char *data, int len);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -31,8 +31,14 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include "nnet.h"
|
#include "nnet.h"
|
||||||
|
|
||||||
extern const WeightArray lpcnet_arrays[];
|
/* This is a bit of a hack because we need to build nnet_data.c and plc_data.c without USE_WEIGHTS_FILE,
|
||||||
extern const WeightArray lpcnet_plc_arrays[];
|
but USE_WEIGHTS_FILE is defined in config.h. */
|
||||||
|
#undef HAVE_CONFIG_H
|
||||||
|
#ifdef USE_WEIGHTS_FILE
|
||||||
|
#undef USE_WEIGHTS_FILE
|
||||||
|
#endif
|
||||||
|
#include "nnet_data.c"
|
||||||
|
#include "plc_data.c"
|
||||||
|
|
||||||
void write_weights(const WeightArray *list, FILE *fout)
|
void write_weights(const WeightArray *list, FILE *fout)
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue