mirror of
https://github.com/xiph/opus.git
synced 2025-06-02 08:37:43 +00:00
Adds end-to-end LPC training
Making LPC computation and prediction differentiable
This commit is contained in:
parent
cba0ecd483
commit
c1532559a2
11 changed files with 357 additions and 17 deletions
|
@ -44,7 +44,7 @@ parser.add_argument('--grua-size', metavar='<units>', default=384, type=int, hel
|
|||
parser.add_argument('--grub-size', metavar='<units>', default=16, type=int, help='number of units in GRU B (default 16)')
|
||||
parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
|
||||
parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
|
||||
|
||||
parser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -66,12 +66,14 @@ lpcnet = importlib.import_module(args.model)
|
|||
import sys
|
||||
import numpy as np
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
from tensorflow.keras.callbacks import ModelCheckpoint
|
||||
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
|
||||
from ulaw import ulaw2lin, lin2ulaw
|
||||
import tensorflow.keras.backend as K
|
||||
import h5py
|
||||
|
||||
import tensorflow as tf
|
||||
from tf_funcs import *
|
||||
from lossfuncs import *
|
||||
#gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
#if gpus:
|
||||
# try:
|
||||
|
@ -93,12 +95,17 @@ else:
|
|||
lr = 0.001
|
||||
decay = 2.5e-5
|
||||
|
||||
flag_e2e = args.flag_e2e
|
||||
|
||||
opt = Adam(lr, decay=decay, beta_2=0.99)
|
||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
|
||||
with strategy.scope():
|
||||
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize)
|
||||
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
|
||||
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize, flag_e2e = flag_e2e)
|
||||
if not flag_e2e:
|
||||
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
|
||||
else:
|
||||
model.compile(optimizer=opt, loss = interp_mulaw(gamma = 2),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss])
|
||||
model.summary()
|
||||
|
||||
feature_file = args.features
|
||||
|
@ -150,4 +157,5 @@ else:
|
|||
grub_sparsify = lpcnet.SparsifyGRUB(2000, 40000, 400, args.grua_size, grub_density)
|
||||
|
||||
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
|
||||
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify])
|
||||
csv_logger = CSVLogger('training_vals.log')
|
||||
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue