added soft quantization to RDOVAE and FARGAN

This commit is contained in:
Jan Buethe 2025-03-08 13:03:02 -08:00 committed by Jean-Marc Valin
parent ebccedd918
commit 1ca6933ac4
No known key found for this signature in database
GPG key ID: 5E5DD9A36F9189C8
5 changed files with 95 additions and 38 deletions

View file

@ -44,6 +44,7 @@ parser.add_argument('--cuda-visible-devices', type=str, help="comma separates li
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128)
@ -93,7 +94,7 @@ checkpoint['adam_betas'] = adam_betas
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint['model_args'] = ()
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma}
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant}
print(checkpoint['model_kwargs'])
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])