diff --git a/dnn/training_tf2/dump_lpcnet.py b/dnn/training_tf2/dump_lpcnet.py index 9dcdba47..a5edefd6 100755 --- a/dnn/training_tf2/dump_lpcnet.py +++ b/dnn/training_tf2/dump_lpcnet.py @@ -229,12 +229,15 @@ def dump_embedding_layer(self, f, hf): return False Embedding.dump_layer = dump_embedding_layer +filename = sys.argv[1] +with h5py.File(filename, "r") as f: + units = min(f['model_weights']['gru_a']['gru_a']['recurrent_kernel:0'].shape) -model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=384) +model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) #model.summary() -model.load_weights(sys.argv[1]) +model.load_weights(filename) if len(sys.argv) > 2: cfile = sys.argv[2]; diff --git a/dnn/training_tf2/train_lpcnet.py b/dnn/training_tf2/train_lpcnet.py index 96ef11fb..89c9d3a8 100755 --- a/dnn/training_tf2/train_lpcnet.py +++ b/dnn/training_tf2/train_lpcnet.py @@ -25,9 +25,35 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ''' -# Train a LPCNet model (note not a Wavenet model) +# Train an LPCNet model + +import argparse + +parser = argparse.ArgumentParser(description='Train an LPCNet model') + +parser.add_argument('features', metavar='', help='binary features file (float32)') +parser.add_argument('data', metavar='