added soft quantization to RDOVAE and FARGAN

This commit is contained in:
Jan Buethe 2025-03-08 13:03:02 -08:00 committed by Jean-Marc Valin
parent ebccedd918
commit 1ca6933ac4
No known key found for this signature in database
GPG key ID: 5E5DD9A36F9189C8
5 changed files with 95 additions and 38 deletions

View file

@ -54,6 +54,7 @@ model_group.add_argument('--lambda-min', type=float, help="minimal value for rat
model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104)
model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82)
model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0)
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
@ -109,6 +110,7 @@ quant_levels = args.quant_levels
lambda_min = args.lambda_min
lambda_max = args.lambda_max
state_dim = args.state_dim
softquant = args.softquant
# not expsed
num_features = 20
@ -118,7 +120,7 @@ feature_file = args.features
# model
checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate}
checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate, 'softquant': softquant}
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
if type(args.initial_checkpoint) != type(None):