oops, fix band loss

This commit is contained in:
Jean-Marc Valin 2022-01-30 17:29:33 -05:00
parent c8cbfa7e9b
commit e1181bcad0

View file

@ -99,8 +99,8 @@ def plc_loss(alpha=1.0):
mask = y_true[:,:,-1:]
y_true = y_true[:,:,:-1]
e = (y_true - y_pred)*mask
e_bands = tf.signal.idct(e, norm='ortho')
l1_loss = K.mean(K.abs(e) + alpha*K.abs(e_bands))
e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
l1_loss = K.mean(K.abs(e)) + alpha*K.mean(K.abs(e_bands))
return l1_loss
return loss
@ -118,7 +118,7 @@ def plc_band_loss():
mask = y_true[:,:,-1:]
y_true = y_true[:,:,:-1]
e = (y_true - y_pred)*mask
e_bands = tf.signal.idct(e, norm='ortho')
e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
l1_loss = K.mean(K.abs(e_bands))
return l1_loss
return L1_band_loss