Update blob loading code

This commit is contained in:
Jean-Marc Valin 2023-10-29 16:33:57 -04:00
parent 0b75501270
commit d53531d0bd
No known key found for this signature in database
GPG key ID: 531A52533318F00A
8 changed files with 28 additions and 1 deletions

View file

@ -174,7 +174,11 @@ void fargan_init(FARGANState *st)
{
int ret;
OPUS_CLEAR(st, 1);
#ifndef USE_WEIGHTS_FILE
ret = init_fargan(&st->model, fargan_arrays);
#else
ret = 0;
#endif
celt_assert(ret == 0);
/* FIXME: perform arch detection. */
}

View file

@ -94,6 +94,8 @@ int lpcnet_encoder_get_size(void);
*/
int lpcnet_encoder_init(LPCNetEncState *st);
int lpcnet_encoder_load_model(LPCNetEncState *st, const unsigned char *data, int len);
/** Allocates and initializes an encoder state.
* @returns The newly created state
*/

View file

@ -148,7 +148,7 @@ int main(int argc, char **argv) {
float zeros[320] = {0};
fargan_init(&fargan);
#ifdef USE_WEIGHTS_FILE
fargan_load_model(fwgan, data, len);
fargan_load_model(&fargan, data, len);
#endif
/* uncomment the following to align with Python code */
/*ret = fread(&in_features[0], sizeof(in_features[0]), NB_TOTAL_FEATURES, fin);*/

View file

@ -56,6 +56,10 @@ int lpcnet_encoder_init(LPCNetEncState *st) {
return 0;
}
int lpcnet_encoder_load_model(LPCNetEncState *st, const unsigned char *data, int len) {
return pitchdnn_load_model(&st->pitchdnn, data, len);
}
LPCNetEncState *lpcnet_encoder_create(void) {
LPCNetEncState *st;
st = malloc(lpcnet_encoder_get_size());

View file

@ -73,6 +73,9 @@ int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len
parse_weights(&list, data, len);
ret = init_plc_model(&st->model, list);
free(list);
if (ret == 0) {
ret = lpcnet_encoder_load_model(&st->enc, data, len);
} else return -1;
if (ret == 0) {
return fargan_load_model(&st->fargan, data, len);
}

View file

@ -67,3 +67,13 @@ void pitchdnn_init(PitchDNNState *st)
celt_assert(ret == 0);
/* FIXME: perform arch detection. */
}
int pitchdnn_load_model(PitchDNNState *st, const unsigned char *data, int len) {
WeightArray *list;
int ret;
parse_weights(&list, data, len);
ret = init_pitchdnn(&st->model, list);
free(list);
if (ret == 0) return 0;
else return -1;
}

View file

@ -22,6 +22,7 @@ typedef struct {
void pitchdnn_init(PitchDNNState *st);
int pitchdnn_load_model(PitchDNNState *st, const unsigned char *data, int len);
float compute_pitchdnn(
PitchDNNState *st,

View file

@ -54,6 +54,9 @@ int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len)
parse_weights(&list, data, len);
ret = init_rdovaeenc(&enc->model, list);
free(list);
if (ret == 0) {
ret = lpcnet_encoder_load_model(&enc->lpcnet_enc_state, data, len);
}
return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
}