mirror of
https://github.com/xiph/opus.git
synced 2025-05-25 04:39:13 +00:00
Adding command-line options to training script
This commit is contained in:
parent
1edf5d7986
commit
5a51e2eed1
2 changed files with 48 additions and 20 deletions
|
@ -229,12 +229,15 @@ def dump_embedding_layer(self, f, hf):
|
||||||
return False
|
return False
|
||||||
Embedding.dump_layer = dump_embedding_layer
|
Embedding.dump_layer = dump_embedding_layer
|
||||||
|
|
||||||
|
filename = sys.argv[1]
|
||||||
|
with h5py.File(filename, "r") as f:
|
||||||
|
units = min(f['model_weights']['gru_a']['gru_a']['recurrent_kernel:0'].shape)
|
||||||
|
|
||||||
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=384)
|
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units)
|
||||||
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
|
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
|
||||||
#model.summary()
|
#model.summary()
|
||||||
|
|
||||||
model.load_weights(sys.argv[1])
|
model.load_weights(filename)
|
||||||
|
|
||||||
if len(sys.argv) > 2:
|
if len(sys.argv) > 2:
|
||||||
cfile = sys.argv[2];
|
cfile = sys.argv[2];
|
||||||
|
|
|
@ -25,9 +25,35 @@
|
||||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# Train a LPCNet model (note not a Wavenet model)
|
# Train an LPCNet model
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Train an LPCNet model')
|
||||||
|
|
||||||
|
parser.add_argument('features', metavar='<features file>', help='binary features file (float32)')
|
||||||
|
parser.add_argument('data', metavar='<audio data file>', help='binary audio data file (uint8)')
|
||||||
|
parser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
|
||||||
|
parser.add_argument('--model', metavar='<model>', default='lpcnet', help='LPCNet model python definition (without .py)')
|
||||||
|
parser.add_argument('--quantize', metavar='<input weights>', help='quantize model')
|
||||||
|
parser.add_argument('--density', metavar='<global density>', type=float, help='average density of the recurrent weights (default 0.1)')
|
||||||
|
parser.add_argument('--density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each recurrent gate (default 0.05, 0.05, 0.2)')
|
||||||
|
parser.add_argument('--grua-size', metavar='<units>', default=384, type=int, help='number of units in GRU A (default 384)')
|
||||||
|
parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
|
||||||
|
parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
|
||||||
|
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
density = (0.05, 0.05, 0.2)
|
||||||
|
if args.density_split is not None:
|
||||||
|
density = args.density_split
|
||||||
|
elif args.density is not None:
|
||||||
|
density = [0.5*args.density, 0.5*args.density, 2.0*args.density];
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
lpcnet = importlib.import_module(args.model)
|
||||||
|
|
||||||
import lpcnet
|
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.keras.optimizers import Adam
|
from tensorflow.keras.optimizers import Adam
|
||||||
|
@ -44,16 +70,15 @@ import tensorflow as tf
|
||||||
# except RuntimeError as e:
|
# except RuntimeError as e:
|
||||||
# print(e)
|
# print(e)
|
||||||
|
|
||||||
nb_epochs = 120
|
nb_epochs = args.epochs
|
||||||
|
|
||||||
# Try reducing batch_size if you run out of memory on your GPU
|
# Try reducing batch_size if you run out of memory on your GPU
|
||||||
batch_size = 128
|
batch_size = args.batch_size
|
||||||
|
|
||||||
#Set this to True to adapt an existing model (e.g. on new data)
|
quantize = args.quantize is not None
|
||||||
adaptation = False
|
|
||||||
|
|
||||||
if adaptation:
|
if quantize:
|
||||||
lr = 0.0001
|
lr = 0.00003
|
||||||
decay = 0
|
decay = 0
|
||||||
else:
|
else:
|
||||||
lr = 0.001
|
lr = 0.001
|
||||||
|
@ -63,12 +88,12 @@ opt = Adam(lr, decay=decay, beta_2=0.99)
|
||||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
model, _, _ = lpcnet.new_lpcnet_model(training=True)
|
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, training=True, quantize=quantize)
|
||||||
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
|
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
|
||||||
model.summary()
|
model.summary()
|
||||||
|
|
||||||
feature_file = sys.argv[1]
|
feature_file = args.features
|
||||||
pcm_file = sys.argv[2] # 16 bit unsigned short PCM samples
|
pcm_file = args.data # 16 bit unsigned short PCM samples
|
||||||
frame_size = model.frame_size
|
frame_size = model.frame_size
|
||||||
nb_features = 55
|
nb_features = 55
|
||||||
nb_used_features = model.nb_used_features
|
nb_used_features = model.nb_used_features
|
||||||
|
@ -115,15 +140,15 @@ del pred
|
||||||
del in_exc
|
del in_exc
|
||||||
|
|
||||||
# dump models to disk as we go
|
# dump models to disk as we go
|
||||||
checkpoint = ModelCheckpoint('lpcnet33e_384_{epoch:02d}.h5')
|
checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.grua_size, '{epoch:02d}'))
|
||||||
|
|
||||||
if adaptation:
|
if quantize:
|
||||||
#Adapting from an existing model
|
#Adapting from an existing model
|
||||||
model.load_weights('lpcnet33a_384_100.h5')
|
model.load_weights(args.quantize)
|
||||||
sparsify = lpcnet.Sparsify(0, 0, 1, (0.05, 0.05, 0.2))
|
sparsify = lpcnet.Sparsify(0, 0, 1, density)
|
||||||
else:
|
else:
|
||||||
#Training from scratch
|
#Training from scratch
|
||||||
sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2))
|
sparsify = lpcnet.Sparsify(2000, 40000, 400, density)
|
||||||
|
|
||||||
model.save_weights('lpcnet33e_384_00.h5');
|
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
|
||||||
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify])
|
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue