mirror of
https://github.com/xiph/opus.git
synced 2025-05-25 12:49:12 +00:00
Support for biased loss
This commit is contained in:
parent
186fa61680
commit
1db1946f77
1 changed files with 8 additions and 6 deletions
|
@ -46,6 +46,8 @@ parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=in
|
|||
parser.add_argument('--seq-length', metavar='<sequence length>', default=1000, type=int, help='sequence length to use (default 1000)')
|
||||
parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
|
||||
parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
|
||||
parser.add_argument('--band-loss', metavar='<weight>', default=1.0, type=float, help='weight of band loss (default 1.0)')
|
||||
parser.add_argument('--loss-bias', metavar='<bias>', default=0.0, type=float, help='loss bias towards low energy (default 0.0)')
|
||||
parser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files')
|
||||
|
||||
|
||||
|
@ -94,13 +96,13 @@ if args.decay is not None:
|
|||
if retrain:
|
||||
input_model = args.retrain
|
||||
|
||||
def plc_loss(alpha=1.0):
|
||||
def plc_loss(alpha=1.0, bias=0.):
|
||||
def loss(y_true,y_pred):
|
||||
mask = y_true[:,:,-1:]
|
||||
y_true = y_true[:,:,:-1]
|
||||
e = (y_true - y_pred)*mask
|
||||
e = (y_pred - y_true)*mask
|
||||
e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
|
||||
l1_loss = K.mean(K.abs(e)) + alpha*K.mean(K.abs(e_bands))
|
||||
l1_loss = K.mean(K.abs(e)) + bias*K.mean(K.maximum(e[:,:,:1], 0.)) + alpha*K.mean(K.abs(e_bands) + bias*K.maximum(e_bands, 0.))
|
||||
return l1_loss
|
||||
return loss
|
||||
|
||||
|
@ -108,7 +110,7 @@ def plc_l1_loss():
|
|||
def L1_loss(y_true,y_pred):
|
||||
mask = y_true[:,:,-1:]
|
||||
y_true = y_true[:,:,:-1]
|
||||
e = (y_true - y_pred)*mask
|
||||
e = (y_pred - y_true)*mask
|
||||
l1_loss = K.mean(K.abs(e))
|
||||
return l1_loss
|
||||
return L1_loss
|
||||
|
@ -117,7 +119,7 @@ def plc_band_loss():
|
|||
def L1_band_loss(y_true,y_pred):
|
||||
mask = y_true[:,:,-1:]
|
||||
y_true = y_true[:,:,:-1]
|
||||
e = (y_true - y_pred)*mask
|
||||
e = (y_pred - y_true)*mask
|
||||
e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
|
||||
l1_loss = K.mean(K.abs(e_bands))
|
||||
return l1_loss
|
||||
|
@ -128,7 +130,7 @@ strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
|||
|
||||
with strategy.scope():
|
||||
model = lpcnet.new_lpcnet_plc_model(rnn_units=args.gru_size, batch_size=batch_size, training=True, quantize=quantize, cond_size=args.cond_size)
|
||||
model.compile(optimizer=opt, loss=plc_loss(alpha=1.), metrics=[plc_l1_loss(), plc_band_loss()])
|
||||
model.compile(optimizer=opt, loss=plc_loss(alpha=args.band_loss, bias=args.loss_bias), metrics=[plc_l1_loss(), plc_band_loss()])
|
||||
model.summary()
|
||||
|
||||
lpc_order = 16
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue