diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py index 71c8a656..0e899713 100644 --- a/dnn/torch/rdovae/train_rdovae.py +++ b/dnn/torch/rdovae/train_rdovae.py @@ -163,6 +163,7 @@ if __name__ == '__main__': # training loop + batch = 1 for epoch in range(1, epochs + 1): print(f"training epoch {epoch}...") @@ -203,13 +204,20 @@ if __name__ == '__main__': outputs_soft_quant = model_output['outputs_soft_quant'] statistical_model = model_output['statistical_model'] + if type(args.initial_checkpoint) == type(None): + latent_lambda = (1. - .5/(1.+batch/1000)) + state_lambda = (1. - .9/(1.+batch/6000)) + else: + latent_lambda = 1. + state_lambda = 1. + # rate loss hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False) soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False) states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False) states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False) - soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate)) - hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate)) + soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*soft_rate + .04*state_lambda*states_soft_rate)) + hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*hard_rate + .04*state_lambda*states_hard_rate)) rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss) hard_rate_metric = torch.mean(hard_rate) states_rate_metric = torch.mean(states_hard_rate) @@ -272,6 +280,7 @@ if __name__ == '__main__': rateloss_soft=running_soft_rate_loss / (i + 1) ) previous_total_loss = running_total_loss + batch = batch+1 # save checkpoint checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')