RDOVAE model update

This commit is contained in:
Jean-Marc Valin 2023-09-25 15:29:27 -04:00
parent 574c766c0c
commit c4b83ae62d
No known key found for this signature in database
GPG key ID: 531A52533318F00A
3 changed files with 3 additions and 3 deletions

View file

@ -231,7 +231,7 @@ if __name__ == '__main__':
total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2
if args.enable_first_frame_loss:
total_loss = total_loss + 0.5 * torch.relu(first_frame_loss - distortion_loss_hard_quant)
total_loss = .97*total_loss + 0.03 * first_frame_loss
total_loss.backward()