From e1181bcad026b0c1f64dacb631a28f833dc21f42 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Sun, 30 Jan 2022 17:29:33 -0500 Subject: [PATCH] oops, fix band loss --- dnn/training_tf2/train_plc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dnn/training_tf2/train_plc.py b/dnn/training_tf2/train_plc.py index ed075f53..a8bb4027 100644 --- a/dnn/training_tf2/train_plc.py +++ b/dnn/training_tf2/train_plc.py @@ -99,8 +99,8 @@ def plc_loss(alpha=1.0): mask = y_true[:,:,-1:] y_true = y_true[:,:,:-1] e = (y_true - y_pred)*mask - e_bands = tf.signal.idct(e, norm='ortho') - l1_loss = K.mean(K.abs(e) + alpha*K.abs(e_bands)) + e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho') + l1_loss = K.mean(K.abs(e)) + alpha*K.mean(K.abs(e_bands)) return l1_loss return loss @@ -118,7 +118,7 @@ def plc_band_loss(): mask = y_true[:,:,-1:] y_true = y_true[:,:,:-1] e = (y_true - y_pred)*mask - e_bands = tf.signal.idct(e, norm='ortho') + e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho') l1_loss = K.mean(K.abs(e_bands)) return l1_loss return L1_band_loss