DRED: Add lambda schedule for first epochs
This commit is contained in:
parent
a41a344a2e
commit
cb7cf92a52
1 changed files with 11 additions and 2 deletions
|
@ -163,6 +163,7 @@ if __name__ == '__main__':
|
|||
|
||||
# training loop
|
||||
|
||||
batch = 1
|
||||
for epoch in range(1, epochs + 1):
|
||||
|
||||
print(f"training epoch {epoch}...")
|
||||
|
@ -203,13 +204,20 @@ if __name__ == '__main__':
|
|||
outputs_soft_quant = model_output['outputs_soft_quant']
|
||||
statistical_model = model_output['statistical_model']
|
||||
|
||||
if type(args.initial_checkpoint) == type(None):
|
||||
latent_lambda = (1. - .5/(1.+batch/1000))
|
||||
state_lambda = (1. - .9/(1.+batch/6000))
|
||||
else:
|
||||
latent_lambda = 1.
|
||||
state_lambda = 1.
|
||||
|
||||
# rate loss
|
||||
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
|
||||
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
|
||||
states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
|
||||
states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
|
||||
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate))
|
||||
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate))
|
||||
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*soft_rate + .04*state_lambda*states_soft_rate))
|
||||
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*hard_rate + .04*state_lambda*states_hard_rate))
|
||||
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
|
||||
hard_rate_metric = torch.mean(hard_rate)
|
||||
states_rate_metric = torch.mean(states_hard_rate)
|
||||
|
@ -272,6 +280,7 @@ if __name__ == '__main__':
|
|||
rateloss_soft=running_soft_rate_loss / (i + 1)
|
||||
)
|
||||
previous_total_loss = running_total_loss
|
||||
batch = batch+1
|
||||
|
||||
# save checkpoint
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue