mirror of
https://github.com/xiph/opus.git
synced 2025-06-06 07:21:03 +00:00
oops, fix band loss
This commit is contained in:
parent
c8cbfa7e9b
commit
e1181bcad0
1 changed files with 3 additions and 3 deletions
|
@ -99,8 +99,8 @@ def plc_loss(alpha=1.0):
|
|||
mask = y_true[:,:,-1:]
|
||||
y_true = y_true[:,:,:-1]
|
||||
e = (y_true - y_pred)*mask
|
||||
e_bands = tf.signal.idct(e, norm='ortho')
|
||||
l1_loss = K.mean(K.abs(e) + alpha*K.abs(e_bands))
|
||||
e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
|
||||
l1_loss = K.mean(K.abs(e)) + alpha*K.mean(K.abs(e_bands))
|
||||
return l1_loss
|
||||
return loss
|
||||
|
||||
|
@ -118,7 +118,7 @@ def plc_band_loss():
|
|||
mask = y_true[:,:,-1:]
|
||||
y_true = y_true[:,:,:-1]
|
||||
e = (y_true - y_pred)*mask
|
||||
e_bands = tf.signal.idct(e, norm='ortho')
|
||||
e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
|
||||
l1_loss = K.mean(K.abs(e_bands))
|
||||
return l1_loss
|
||||
return L1_band_loss
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue