From 1db1946f77bed48cdaf6fb1c00611b27275e96ce Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Tue, 1 Feb 2022 02:57:50 -0500 Subject: [PATCH] Support for biased loss --- dnn/training_tf2/train_plc.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/dnn/training_tf2/train_plc.py b/dnn/training_tf2/train_plc.py index a8bb4027..0e489c29 100644 --- a/dnn/training_tf2/train_plc.py +++ b/dnn/training_tf2/train_plc.py @@ -46,6 +46,8 @@ parser.add_argument('--batch-size', metavar='', default=128, type=in parser.add_argument('--seq-length', metavar='', default=1000, type=int, help='sequence length to use (default 1000)') parser.add_argument('--lr', metavar='', type=float, help='learning rate') parser.add_argument('--decay', metavar='', type=float, help='learning rate decay') +parser.add_argument('--band-loss', metavar='', default=1.0, type=float, help='weight of band loss (default 1.0)') +parser.add_argument('--loss-bias', metavar='', default=0.0, type=float, help='loss bias towards low energy (default 0.0)') parser.add_argument('--logdir', metavar='', help='directory for tensorboard log files') @@ -94,13 +96,13 @@ if args.decay is not None: if retrain: input_model = args.retrain -def plc_loss(alpha=1.0): +def plc_loss(alpha=1.0, bias=0.): def loss(y_true,y_pred): mask = y_true[:,:,-1:] y_true = y_true[:,:,:-1] - e = (y_true - y_pred)*mask + e = (y_pred - y_true)*mask e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho') - l1_loss = K.mean(K.abs(e)) + alpha*K.mean(K.abs(e_bands)) + l1_loss = K.mean(K.abs(e)) + bias*K.mean(K.maximum(e[:,:,:1], 0.)) + alpha*K.mean(K.abs(e_bands) + bias*K.maximum(e_bands, 0.)) return l1_loss return loss @@ -108,7 +110,7 @@ def plc_l1_loss(): def L1_loss(y_true,y_pred): mask = y_true[:,:,-1:] y_true = y_true[:,:,:-1] - e = (y_true - y_pred)*mask + e = (y_pred - y_true)*mask l1_loss = K.mean(K.abs(e)) return l1_loss return L1_loss @@ -117,7 +119,7 @@ def plc_band_loss(): def L1_band_loss(y_true,y_pred): mask = y_true[:,:,-1:] y_true = y_true[:,:,:-1] - e = (y_true - y_pred)*mask + e = (y_pred - y_true)*mask e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho') l1_loss = K.mean(K.abs(e_bands)) return l1_loss @@ -128,7 +130,7 @@ strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() with strategy.scope(): model = lpcnet.new_lpcnet_plc_model(rnn_units=args.gru_size, batch_size=batch_size, training=True, quantize=quantize, cond_size=args.cond_size) - model.compile(optimizer=opt, loss=plc_loss(alpha=1.), metrics=[plc_l1_loss(), plc_band_loss()]) + model.compile(optimizer=opt, loss=plc_loss(alpha=args.band_loss, bias=args.loss_bias), metrics=[plc_l1_loss(), plc_band_loss()]) model.summary() lpc_order = 16