From 5a51e2eed1166b7435a07f324b10a823721e6752 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Tue, 13 Jul 2021 03:09:04 -0400 Subject: [PATCH] Adding command-line options to training script --- dnn/training_tf2/dump_lpcnet.py | 7 ++-- dnn/training_tf2/train_lpcnet.py | 61 ++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/dnn/training_tf2/dump_lpcnet.py b/dnn/training_tf2/dump_lpcnet.py index 9dcdba47..a5edefd6 100755 --- a/dnn/training_tf2/dump_lpcnet.py +++ b/dnn/training_tf2/dump_lpcnet.py @@ -229,12 +229,15 @@ def dump_embedding_layer(self, f, hf): return False Embedding.dump_layer = dump_embedding_layer +filename = sys.argv[1] +with h5py.File(filename, "r") as f: + units = min(f['model_weights']['gru_a']['gru_a']['recurrent_kernel:0'].shape) -model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=384) +model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) #model.summary() -model.load_weights(sys.argv[1]) +model.load_weights(filename) if len(sys.argv) > 2: cfile = sys.argv[2]; diff --git a/dnn/training_tf2/train_lpcnet.py b/dnn/training_tf2/train_lpcnet.py index 96ef11fb..89c9d3a8 100755 --- a/dnn/training_tf2/train_lpcnet.py +++ b/dnn/training_tf2/train_lpcnet.py @@ -25,9 +25,35 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ''' -# Train a LPCNet model (note not a Wavenet model) +# Train an LPCNet model + +import argparse + +parser = argparse.ArgumentParser(description='Train an LPCNet model') + +parser.add_argument('features', metavar='', help='binary features file (float32)') +parser.add_argument('data', metavar='