Support for biased loss

This commit is contained in:
Jean-Marc Valin 2022-02-01 02:57:50 -05:00
parent 186fa61680
commit 1db1946f77

View file

@ -46,6 +46,8 @@ parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=in
parser.add_argument('--seq-length', metavar='<sequence length>', default=1000, type=int, help='sequence length to use (default 1000)')
parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
parser.add_argument('--band-loss', metavar='<weight>', default=1.0, type=float, help='weight of band loss (default 1.0)')
parser.add_argument('--loss-bias', metavar='<bias>', default=0.0, type=float, help='loss bias towards low energy (default 0.0)')
parser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files')
@ -94,13 +96,13 @@ if args.decay is not None:
if retrain:
input_model = args.retrain
def plc_loss(alpha=1.0):
def plc_loss(alpha=1.0, bias=0.):
def loss(y_true,y_pred):
mask = y_true[:,:,-1:]
y_true = y_true[:,:,:-1]
e = (y_true - y_pred)*mask
e = (y_pred - y_true)*mask
e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
l1_loss = K.mean(K.abs(e)) + alpha*K.mean(K.abs(e_bands))
l1_loss = K.mean(K.abs(e)) + bias*K.mean(K.maximum(e[:,:,:1], 0.)) + alpha*K.mean(K.abs(e_bands) + bias*K.maximum(e_bands, 0.))
return l1_loss
return loss
@ -108,7 +110,7 @@ def plc_l1_loss():
def L1_loss(y_true,y_pred):
mask = y_true[:,:,-1:]
y_true = y_true[:,:,:-1]
e = (y_true - y_pred)*mask
e = (y_pred - y_true)*mask
l1_loss = K.mean(K.abs(e))
return l1_loss
return L1_loss
@ -117,7 +119,7 @@ def plc_band_loss():
def L1_band_loss(y_true,y_pred):
mask = y_true[:,:,-1:]
y_true = y_true[:,:,:-1]
e = (y_true - y_pred)*mask
e = (y_pred - y_true)*mask
e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
l1_loss = K.mean(K.abs(e_bands))
return l1_loss
@ -128,7 +130,7 @@ strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
model = lpcnet.new_lpcnet_plc_model(rnn_units=args.gru_size, batch_size=batch_size, training=True, quantize=quantize, cond_size=args.cond_size)
model.compile(optimizer=opt, loss=plc_loss(alpha=1.), metrics=[plc_l1_loss(), plc_band_loss()])
model.compile(optimizer=opt, loss=plc_loss(alpha=args.band_loss, bias=args.loss_bias), metrics=[plc_l1_loss(), plc_band_loss()])
model.summary()
lpc_order = 16