Adds end-to-end LPC training

Making LPC computation and prediction differentiable
This commit is contained in:
Krishna Subramani 2021-07-29 03:36:13 -04:00 committed by Jean-Marc Valin
parent cba0ecd483
commit c1532559a2
11 changed files with 357 additions and 17 deletions

View file

@ -44,7 +44,7 @@ parser.add_argument('--grua-size', metavar='<units>', default=384, type=int, hel
parser.add_argument('--grub-size', metavar='<units>', default=16, type=int, help='number of units in GRU B (default 16)')
parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
parser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation')
args = parser.parse_args()
@ -66,12 +66,14 @@ lpcnet = importlib.import_module(args.model)
import sys
import numpy as np
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from ulaw import ulaw2lin, lin2ulaw
import tensorflow.keras.backend as K
import h5py
import tensorflow as tf
from tf_funcs import *
from lossfuncs import *
#gpus = tf.config.experimental.list_physical_devices('GPU')
#if gpus:
# try:
@ -93,12 +95,17 @@ else:
lr = 0.001
decay = 2.5e-5
flag_e2e = args.flag_e2e
opt = Adam(lr, decay=decay, beta_2=0.99)
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize)
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize, flag_e2e = flag_e2e)
if not flag_e2e:
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
else:
model.compile(optimizer=opt, loss = interp_mulaw(gamma = 2),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss])
model.summary()
feature_file = args.features
@ -150,4 +157,5 @@ else:
grub_sparsify = lpcnet.SparsifyGRUB(2000, 40000, 400, args.grua_size, grub_density)
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify])
csv_logger = CSVLogger('training_vals.log')
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])