auto-detect end-to-end models

This commit is contained in:
Jean-Marc Valin 2021-10-13 22:12:39 -04:00
parent d5b6087f48
commit a3ef596822
2 changed files with 26 additions and 10 deletions

View file

@ -250,6 +250,7 @@ with h5py.File(filename, "r") as f:
units = min(f['model_weights']['gru_a']['gru_a']['recurrent_kernel:0'].shape)
units2 = min(f['model_weights']['gru_b']['gru_b']['recurrent_kernel:0'].shape)
cond_size = min(f['model_weights']['feature_dense1']['feature_dense1']['kernel:0'].shape)
e2e = 'rc2lpc' in f['model_weights']
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units, rnn_units2=units2, flag_e2e = flag_e2e, cond_size=cond_size)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
@ -276,6 +277,13 @@ f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "nnet.h"\
hf.write('/*This file is automatically generated from a Keras model*/\n\n')
hf.write('#ifndef RNN_DATA_H\n#define RNN_DATA_H\n\n#include "nnet.h"\n\n')
if e2e:
hf.write('/* This is an end-to-end model */\n')
hf.write('#define END2END\n\n')
else:
hf.write('/* This is *not* an end-to-end model */\n')
hf.write('/* #define END2END */\n\n')
embed_size = lpcnet.embed_size
E = model.get_layer('embed_sig').get_weights()[0]