mirror of
https://github.com/xiph/opus.git
synced 2025-05-15 07:58:29 +00:00

Some checks failed
CMake / CMake/MacOSX/Lib/X64/Release (push) Has been cancelled
CMake / CMake/MacOSX/Framework/X64/Release (push) Has been cancelled
CMake / CMake/Linux/So/X64/Release (push) Has been cancelled
CMake / CMake/MacOSX/So/X64/Release (push) Has been cancelled
CMake / CMake/AssertionsFuzz/MacOSX/Lib/X64/Release (push) Has been cancelled
CMake / CMake/CustomModes/Linux/Lib/X64/Release (push) Has been cancelled
Autotools / AutoMake/Linux/GCC (push) Has been cancelled
Autotools / AutoMake/Linux/GCC/EnableDNN (push) Has been cancelled
Autotools / AutoMake/Linux/GCC/EnableCustomModes (push) Has been cancelled
Autotools / AutoMake/Linux/GCC/EnableAssertions (push) Has been cancelled
CMake / Test build with CMake 3.16.0 (push) Has been cancelled
CMake / CMake MINGW (push) Has been cancelled
CMake / CMake/Linux/Lib/X64/Release (push) Has been cancelled
CMake / CMake/Android/So/ARMv8/Release (push) Has been cancelled
CMake / CMake/Android/Lib/ARMv8/Release (push) Has been cancelled
CMake / CMake/Android/So/X86/Release (push) Has been cancelled
CMake / CMake/Android/Lib/X86/Release (push) Has been cancelled
CMake / CMake/Android/So/X64/Release (push) Has been cancelled
CMake / CMake/Android/Lib/X64/Release (push) Has been cancelled
CMake / CMake/AssertionsFuzz/Linux/Lib/X64/Release (push) Has been cancelled
CMake / CMake/iOS/Framework/arm64/Release (push) Has been cancelled
CMake / CMake/iOS/Dll/arm64/Release (push) Has been cancelled
CMake / CMake/iOS/Lib/arm64/Release (push) Has been cancelled
CMake / CMake/Windows/Dll/ARMv8/Release (push) Has been cancelled
CMake / CMake/Windows/Lib/armv8/Release (push) Has been cancelled
CMake / CMake/Windows/Dll/X64/Release (push) Has been cancelled
CMake / CMake/Windows/Dll/X86/Release (push) Has been cancelled
CMake / CMake/AssertionsFuzz/Windows/Lib/X64/Release (push) Has been cancelled
CMake / CMake/Windows/Lib/X64/Release (push) Has been cancelled
CMake / CMake/Windows/Lib/X86/Release (push) Has been cancelled
DRED / CMake/Android/Lib/ARMv8/Release (push) Has been cancelled
DRED / CMake/Android/Lib/X64/Release (push) Has been cancelled
DRED / CMake/MacOSX/Lib/X64/Release (push) Has been cancelled
DRED / CMake/Linux/Lib/X64/Release (push) Has been cancelled
DRED / CMake/iOS/Lib/arm64/Release (push) Has been cancelled
DRED / CMake/Windows/Lib/armv8/Release (push) Has been cancelled
DRED / CMake/Windows/Lib/X64/Release (push) Has been cancelled
DRED / AutoTools/Linux/Clang (push) Has been cancelled
DRED / AutoTools/Linux/GCC (push) Has been cancelled
Repository / Check trailing white spaces (push) Has been cancelled
290 lines
14 KiB
Python
290 lines
14 KiB
Python
"""
|
|
/* Copyright (c) 2022 Amazon
|
|
Written by Jan Buethe */
|
|
/*
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions
|
|
are met:
|
|
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*/
|
|
"""
|
|
|
|
import os
|
|
import argparse
|
|
|
|
import torch
|
|
import tqdm
|
|
|
|
from rdovae import RDOVAE, RDOVAEDataset, distortion_loss, hard_rate_estimate, soft_rate_estimate
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
|
|
parser.add_argument('output', type=str, help='path to output folder')
|
|
|
|
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="")
|
|
|
|
|
|
model_group = parser.add_argument_group(title="model parameters")
|
|
model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
|
|
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
|
|
model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
|
|
model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
|
|
model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 16", default=16)
|
|
model_group.add_argument('--lambda-min', type=float, help="minimal value for rate lambda, default: 0.0002", default=2e-4)
|
|
model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104)
|
|
model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82)
|
|
model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0)
|
|
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
|
|
|
|
training_group = parser.add_argument_group(title="training parameters")
|
|
training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
|
|
training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4)
|
|
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100)
|
|
training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by chunks_per_offset, default: 400', default=400)
|
|
training_group.add_argument('--chunks-per-offset', type=int, help='chunks per offset', default=4)
|
|
training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5)
|
|
training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split', 'skewed_split'], help='splitting mode for decoder input, default: split', default='split')
|
|
training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames')
|
|
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
|
|
training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# set visible devices
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
|
|
|
|
# checkpoints
|
|
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
|
checkpoint = dict()
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
# training parameters
|
|
batch_size = args.batch_size
|
|
lr = args.lr
|
|
epochs = args.epochs
|
|
sequence_length = args.sequence_length
|
|
lr_decay_factor = args.lr_decay_factor
|
|
split_mode = args.split_mode
|
|
# not exposed
|
|
adam_betas = [0.8, 0.95]
|
|
adam_eps = 1e-8
|
|
|
|
checkpoint['batch_size'] = batch_size
|
|
checkpoint['lr'] = lr
|
|
checkpoint['lr_decay_factor'] = lr_decay_factor
|
|
checkpoint['split_mode'] = split_mode
|
|
checkpoint['epochs'] = epochs
|
|
checkpoint['sequence_length'] = sequence_length
|
|
checkpoint['adam_betas'] = adam_betas
|
|
|
|
# logging
|
|
log_interval = 10
|
|
|
|
# device
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
# model parameters
|
|
cond_size = args.cond_size
|
|
cond_size2 = args.cond_size2
|
|
latent_dim = args.latent_dim
|
|
quant_levels = args.quant_levels
|
|
lambda_min = args.lambda_min
|
|
lambda_max = args.lambda_max
|
|
state_dim = args.state_dim
|
|
softquant = args.softquant
|
|
# not expsed
|
|
num_features = 20
|
|
|
|
|
|
# training data
|
|
feature_file = args.features
|
|
|
|
# model
|
|
checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
|
|
checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate, 'softquant': softquant, 'chunks_per_offset': args.chunks_per_offset}
|
|
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
|
|
|
if type(args.initial_checkpoint) != type(None):
|
|
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
|
|
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
|
|
|
checkpoint['state_dict'] = model.state_dict()
|
|
|
|
if args.train_decoder_only:
|
|
if args.initial_checkpoint is None:
|
|
print("warning: training decoder only without providing initial checkpoint")
|
|
|
|
for p in model.core_encoder.module.parameters():
|
|
p.requires_grad = False
|
|
|
|
for p in model.statistical_model.parameters():
|
|
p.requires_grad = False
|
|
|
|
# dataloader
|
|
checkpoint['dataset_args'] = (feature_file, sequence_length, num_features, 36)
|
|
checkpoint['dataset_kwargs'] = {'lambda_min': lambda_min, 'lambda_max': lambda_max, 'enc_stride': model.enc_stride, 'quant_levels': quant_levels}
|
|
dataset = RDOVAEDataset(*checkpoint['dataset_args'], **checkpoint['dataset_kwargs'])
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
|
|
|
|
|
|
|
|
# optimizer
|
|
params = [p for p in model.parameters() if p.requires_grad]
|
|
optimizer = torch.optim.Adam(params, lr=lr, betas=adam_betas, eps=adam_eps)
|
|
|
|
|
|
# learning rate scheduler
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# push model to device
|
|
model.to(device)
|
|
|
|
# training loop
|
|
|
|
batch = 1
|
|
for epoch in range(1, epochs + 1):
|
|
|
|
print(f"training epoch {epoch}...")
|
|
|
|
# running stats
|
|
running_rate_loss = 0
|
|
running_soft_dist_loss = 0
|
|
running_hard_dist_loss = 0
|
|
running_hard_rate_loss = 0
|
|
running_soft_rate_loss = 0
|
|
running_total_loss = 0
|
|
running_rate_metric = 0
|
|
running_states_rate_metric = 0
|
|
previous_total_loss = 0
|
|
running_first_frame_loss = 0
|
|
|
|
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
|
|
for i, (features, rate_lambda, q_ids) in enumerate(tepoch):
|
|
|
|
# zero out gradients
|
|
optimizer.zero_grad()
|
|
|
|
# push inputs to device
|
|
features = features.to(device)
|
|
q_ids = q_ids.to(device)
|
|
rate_lambda = rate_lambda.to(device)
|
|
|
|
|
|
rate_lambda_upsamp = torch.repeat_interleave(rate_lambda, 2, 1)
|
|
|
|
# run model
|
|
model_output = model(features, q_ids)
|
|
|
|
# collect outputs
|
|
z = model_output['z']
|
|
states = model_output['states']
|
|
outputs_hard_quant = model_output['outputs_hard_quant']
|
|
outputs_soft_quant = model_output['outputs_soft_quant']
|
|
statistical_model = model_output['statistical_model']
|
|
|
|
if type(args.initial_checkpoint) == type(None):
|
|
latent_lambda = (1. - .5/(1.+batch/1000))
|
|
state_lambda = (1. - .9/(1.+batch/6000))
|
|
else:
|
|
latent_lambda = 1.
|
|
state_lambda = 1.
|
|
|
|
# rate loss
|
|
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
|
|
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
|
|
states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
|
|
states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
|
|
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*soft_rate + .04*state_lambda*states_soft_rate))
|
|
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*hard_rate + .04*state_lambda*states_hard_rate))
|
|
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
|
|
hard_rate_metric = torch.mean(hard_rate)
|
|
states_rate_metric = torch.mean(states_hard_rate)
|
|
|
|
## distortion losses
|
|
|
|
# hard quantized decoder input
|
|
distortion_loss_hard_quant = torch.zeros_like(rate_loss)
|
|
for dec_features, start, stop in outputs_hard_quant:
|
|
distortion_loss_hard_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_hard_quant)
|
|
|
|
first_frame_loss = torch.zeros_like(rate_loss)
|
|
for dec_features, start, stop in outputs_hard_quant:
|
|
first_frame_loss += distortion_loss(features[..., stop-4 : stop, :], dec_features[..., -4:, :], rate_lambda_upsamp[..., stop - 4 : stop]) / len(outputs_hard_quant)
|
|
|
|
# soft quantized decoder input
|
|
distortion_loss_soft_quant = torch.zeros_like(rate_loss)
|
|
for dec_features, start, stop in outputs_soft_quant:
|
|
distortion_loss_soft_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_soft_quant)
|
|
|
|
# total loss
|
|
total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2
|
|
|
|
if args.enable_first_frame_loss:
|
|
total_loss = .97*total_loss + 0.03 * first_frame_loss
|
|
|
|
|
|
total_loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
model.clip_weights()
|
|
model.sparsify()
|
|
|
|
scheduler.step()
|
|
|
|
# collect running stats
|
|
running_hard_dist_loss += float(distortion_loss_hard_quant.detach().cpu())
|
|
running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu())
|
|
running_rate_loss += float(rate_loss.detach().cpu())
|
|
running_rate_metric += float(hard_rate_metric.detach().cpu())
|
|
running_states_rate_metric += float(states_rate_metric.detach().cpu())
|
|
running_total_loss += float(total_loss.detach().cpu())
|
|
running_first_frame_loss += float(first_frame_loss.detach().cpu())
|
|
running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
|
|
running_hard_rate_loss += float(hard_rate_loss.detach().cpu())
|
|
|
|
if (i + 1) % log_interval == 0:
|
|
current_loss = (running_total_loss - previous_total_loss) / log_interval
|
|
tepoch.set_postfix(
|
|
current_loss=current_loss,
|
|
total_loss=running_total_loss / (i + 1),
|
|
dist_hq=running_hard_dist_loss / (i + 1),
|
|
dist_sq=running_soft_dist_loss / (i + 1),
|
|
rate_loss=running_rate_loss / (i + 1),
|
|
rate=running_rate_metric / (i + 1),
|
|
states_rate=running_states_rate_metric / (i + 1),
|
|
ffloss=running_first_frame_loss / (i + 1),
|
|
rateloss_hard=running_hard_rate_loss / (i + 1),
|
|
rateloss_soft=running_soft_rate_loss / (i + 1)
|
|
)
|
|
previous_total_loss = running_total_loss
|
|
batch = batch+1
|
|
|
|
# save checkpoint
|
|
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
|
|
checkpoint['state_dict'] = model.state_dict()
|
|
checkpoint['loss'] = running_total_loss / len(dataloader)
|
|
checkpoint['epoch'] = epoch
|
|
torch.save(checkpoint, checkpoint_path)
|