DRED: Add lambda schedule for first epochs

This commit is contained in:
Jean-Marc Valin 2025-04-21 11:23:29 -04:00
parent a41a344a2e
commit cb7cf92a52
No known key found for this signature in database
GPG key ID: 5E5DD9A36F9189C8

View file

@ -163,6 +163,7 @@ if __name__ == '__main__':
# training loop
batch = 1
for epoch in range(1, epochs + 1):
print(f"training epoch {epoch}...")
@ -203,13 +204,20 @@ if __name__ == '__main__':
outputs_soft_quant = model_output['outputs_soft_quant']
statistical_model = model_output['statistical_model']
if type(args.initial_checkpoint) == type(None):
latent_lambda = (1. - .5/(1.+batch/1000))
state_lambda = (1. - .9/(1.+batch/6000))
else:
latent_lambda = 1.
state_lambda = 1.
# rate loss
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate))
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate))
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*soft_rate + .04*state_lambda*states_soft_rate))
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*hard_rate + .04*state_lambda*states_hard_rate))
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
hard_rate_metric = torch.mean(hard_rate)
states_rate_metric = torch.mean(states_hard_rate)
@ -272,6 +280,7 @@ if __name__ == '__main__':
rateloss_soft=running_soft_rate_loss / (i + 1)
)
previous_total_loss = running_total_loss
batch = batch+1
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')