opus/dnn/torch/rdovae/train_rdovae.py
Jean-Marc Valin 6a45b767e2
Some checks failed
CMake / CMake/MacOSX/Lib/X64/Release (push) Has been cancelled
CMake / CMake/MacOSX/Framework/X64/Release (push) Has been cancelled
CMake / CMake/Linux/So/X64/Release (push) Has been cancelled
CMake / CMake/MacOSX/So/X64/Release (push) Has been cancelled
CMake / CMake/AssertionsFuzz/MacOSX/Lib/X64/Release (push) Has been cancelled
CMake / CMake/CustomModes/Linux/Lib/X64/Release (push) Has been cancelled
Autotools / AutoMake/Linux/GCC (push) Has been cancelled
Autotools / AutoMake/Linux/GCC/EnableDNN (push) Has been cancelled
Autotools / AutoMake/Linux/GCC/EnableCustomModes (push) Has been cancelled
Autotools / AutoMake/Linux/GCC/EnableAssertions (push) Has been cancelled
CMake / Test build with CMake 3.16.0 (push) Has been cancelled
CMake / CMake MINGW (push) Has been cancelled
CMake / CMake/Linux/Lib/X64/Release (push) Has been cancelled
CMake / CMake/Android/So/ARMv8/Release (push) Has been cancelled
CMake / CMake/Android/Lib/ARMv8/Release (push) Has been cancelled
CMake / CMake/Android/So/X86/Release (push) Has been cancelled
CMake / CMake/Android/Lib/X86/Release (push) Has been cancelled
CMake / CMake/Android/So/X64/Release (push) Has been cancelled
CMake / CMake/Android/Lib/X64/Release (push) Has been cancelled
CMake / CMake/AssertionsFuzz/Linux/Lib/X64/Release (push) Has been cancelled
CMake / CMake/iOS/Framework/arm64/Release (push) Has been cancelled
CMake / CMake/iOS/Dll/arm64/Release (push) Has been cancelled
CMake / CMake/iOS/Lib/arm64/Release (push) Has been cancelled
CMake / CMake/Windows/Dll/ARMv8/Release (push) Has been cancelled
CMake / CMake/Windows/Lib/armv8/Release (push) Has been cancelled
CMake / CMake/Windows/Dll/X64/Release (push) Has been cancelled
CMake / CMake/Windows/Dll/X86/Release (push) Has been cancelled
CMake / CMake/AssertionsFuzz/Windows/Lib/X64/Release (push) Has been cancelled
CMake / CMake/Windows/Lib/X64/Release (push) Has been cancelled
CMake / CMake/Windows/Lib/X86/Release (push) Has been cancelled
DRED / CMake/Android/Lib/ARMv8/Release (push) Has been cancelled
DRED / CMake/Android/Lib/X64/Release (push) Has been cancelled
DRED / CMake/MacOSX/Lib/X64/Release (push) Has been cancelled
DRED / CMake/Linux/Lib/X64/Release (push) Has been cancelled
DRED / CMake/iOS/Lib/arm64/Release (push) Has been cancelled
DRED / CMake/Windows/Lib/armv8/Release (push) Has been cancelled
DRED / CMake/Windows/Lib/X64/Release (push) Has been cancelled
DRED / AutoTools/Linux/Clang (push) Has been cancelled
DRED / AutoTools/Linux/GCC (push) Has been cancelled
Repository / Check trailing white spaces (push) Has been cancelled
Add skewed split for fine-tuning decoder
2025-04-21 11:24:08 -04:00

290 lines
14 KiB
Python

"""
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import torch
import tqdm
from rdovae import RDOVAE, RDOVAEDataset, distortion_loss, hard_rate_estimate, soft_rate_estimate
parser = argparse.ArgumentParser()
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
parser.add_argument('output', type=str, help='path to output folder')
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="")
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 16", default=16)
model_group.add_argument('--lambda-min', type=float, help="minimal value for rate lambda, default: 0.0002", default=2e-4)
model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104)
model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82)
model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0)
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4)
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100)
training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by chunks_per_offset, default: 400', default=400)
training_group.add_argument('--chunks-per-offset', type=int, help='chunks per offset', default=4)
training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5)
training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split', 'skewed_split'], help='splitting mode for decoder input, default: split', default='split')
training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames')
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only')
args = parser.parse_args()
# set visible devices
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
# checkpoints
checkpoint_dir = os.path.join(args.output, 'checkpoints')
checkpoint = dict()
os.makedirs(checkpoint_dir, exist_ok=True)
# training parameters
batch_size = args.batch_size
lr = args.lr
epochs = args.epochs
sequence_length = args.sequence_length
lr_decay_factor = args.lr_decay_factor
split_mode = args.split_mode
# not exposed
adam_betas = [0.8, 0.95]
adam_eps = 1e-8
checkpoint['batch_size'] = batch_size
checkpoint['lr'] = lr
checkpoint['lr_decay_factor'] = lr_decay_factor
checkpoint['split_mode'] = split_mode
checkpoint['epochs'] = epochs
checkpoint['sequence_length'] = sequence_length
checkpoint['adam_betas'] = adam_betas
# logging
log_interval = 10
# device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# model parameters
cond_size = args.cond_size
cond_size2 = args.cond_size2
latent_dim = args.latent_dim
quant_levels = args.quant_levels
lambda_min = args.lambda_min
lambda_max = args.lambda_max
state_dim = args.state_dim
softquant = args.softquant
# not expsed
num_features = 20
# training data
feature_file = args.features
# model
checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate, 'softquant': softquant, 'chunks_per_offset': args.chunks_per_offset}
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
if type(args.initial_checkpoint) != type(None):
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
checkpoint['state_dict'] = model.state_dict()
if args.train_decoder_only:
if args.initial_checkpoint is None:
print("warning: training decoder only without providing initial checkpoint")
for p in model.core_encoder.module.parameters():
p.requires_grad = False
for p in model.statistical_model.parameters():
p.requires_grad = False
# dataloader
checkpoint['dataset_args'] = (feature_file, sequence_length, num_features, 36)
checkpoint['dataset_kwargs'] = {'lambda_min': lambda_min, 'lambda_max': lambda_max, 'enc_stride': model.enc_stride, 'quant_levels': quant_levels}
dataset = RDOVAEDataset(*checkpoint['dataset_args'], **checkpoint['dataset_kwargs'])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
# optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=lr, betas=adam_betas, eps=adam_eps)
# learning rate scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
if __name__ == '__main__':
# push model to device
model.to(device)
# training loop
batch = 1
for epoch in range(1, epochs + 1):
print(f"training epoch {epoch}...")
# running stats
running_rate_loss = 0
running_soft_dist_loss = 0
running_hard_dist_loss = 0
running_hard_rate_loss = 0
running_soft_rate_loss = 0
running_total_loss = 0
running_rate_metric = 0
running_states_rate_metric = 0
previous_total_loss = 0
running_first_frame_loss = 0
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (features, rate_lambda, q_ids) in enumerate(tepoch):
# zero out gradients
optimizer.zero_grad()
# push inputs to device
features = features.to(device)
q_ids = q_ids.to(device)
rate_lambda = rate_lambda.to(device)
rate_lambda_upsamp = torch.repeat_interleave(rate_lambda, 2, 1)
# run model
model_output = model(features, q_ids)
# collect outputs
z = model_output['z']
states = model_output['states']
outputs_hard_quant = model_output['outputs_hard_quant']
outputs_soft_quant = model_output['outputs_soft_quant']
statistical_model = model_output['statistical_model']
if type(args.initial_checkpoint) == type(None):
latent_lambda = (1. - .5/(1.+batch/1000))
state_lambda = (1. - .9/(1.+batch/6000))
else:
latent_lambda = 1.
state_lambda = 1.
# rate loss
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*soft_rate + .04*state_lambda*states_soft_rate))
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*hard_rate + .04*state_lambda*states_hard_rate))
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
hard_rate_metric = torch.mean(hard_rate)
states_rate_metric = torch.mean(states_hard_rate)
## distortion losses
# hard quantized decoder input
distortion_loss_hard_quant = torch.zeros_like(rate_loss)
for dec_features, start, stop in outputs_hard_quant:
distortion_loss_hard_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_hard_quant)
first_frame_loss = torch.zeros_like(rate_loss)
for dec_features, start, stop in outputs_hard_quant:
first_frame_loss += distortion_loss(features[..., stop-4 : stop, :], dec_features[..., -4:, :], rate_lambda_upsamp[..., stop - 4 : stop]) / len(outputs_hard_quant)
# soft quantized decoder input
distortion_loss_soft_quant = torch.zeros_like(rate_loss)
for dec_features, start, stop in outputs_soft_quant:
distortion_loss_soft_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_soft_quant)
# total loss
total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2
if args.enable_first_frame_loss:
total_loss = .97*total_loss + 0.03 * first_frame_loss
total_loss.backward()
optimizer.step()
model.clip_weights()
model.sparsify()
scheduler.step()
# collect running stats
running_hard_dist_loss += float(distortion_loss_hard_quant.detach().cpu())
running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu())
running_rate_loss += float(rate_loss.detach().cpu())
running_rate_metric += float(hard_rate_metric.detach().cpu())
running_states_rate_metric += float(states_rate_metric.detach().cpu())
running_total_loss += float(total_loss.detach().cpu())
running_first_frame_loss += float(first_frame_loss.detach().cpu())
running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
running_hard_rate_loss += float(hard_rate_loss.detach().cpu())
if (i + 1) % log_interval == 0:
current_loss = (running_total_loss - previous_total_loss) / log_interval
tepoch.set_postfix(
current_loss=current_loss,
total_loss=running_total_loss / (i + 1),
dist_hq=running_hard_dist_loss / (i + 1),
dist_sq=running_soft_dist_loss / (i + 1),
rate_loss=running_rate_loss / (i + 1),
rate=running_rate_metric / (i + 1),
states_rate=running_states_rate_metric / (i + 1),
ffloss=running_first_frame_loss / (i + 1),
rateloss_hard=running_hard_rate_loss / (i + 1),
rateloss_soft=running_soft_rate_loss / (i + 1)
)
previous_total_loss = running_total_loss
batch = batch+1
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = running_total_loss / len(dataloader)
checkpoint['epoch'] = epoch
torch.save(checkpoint, checkpoint_path)