PyTorch code for training the PLC model
Should match the TF2 code, but mostly untested
This commit is contained in:
parent
6ad03ae03e
commit
26ddfd7135
3 changed files with 345 additions and 0 deletions
144
dnn/torch/plc/plc.py
Normal file
144
dnn/torch/plc/plc.py
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
import math
|
||||||
|
|
||||||
|
fid_dict = {}
|
||||||
|
def dump_signal(x, filename):
|
||||||
|
return
|
||||||
|
if filename in fid_dict:
|
||||||
|
fid = fid_dict[filename]
|
||||||
|
else:
|
||||||
|
fid = open(filename, "w")
|
||||||
|
fid_dict[filename] = fid
|
||||||
|
x = x.detach().numpy().astype('float32')
|
||||||
|
x.tofile(fid)
|
||||||
|
|
||||||
|
|
||||||
|
class IDCT(nn.Module):
|
||||||
|
def __init__(self, N, device=None):
|
||||||
|
super(IDCT, self).__init__()
|
||||||
|
|
||||||
|
self.N = N
|
||||||
|
n = torch.arange(N, device=device)
|
||||||
|
k = torch.arange(N, device=device)
|
||||||
|
self.table = torch.cos(torch.pi/N * (n[:,None]+.5) * k[None,:])
|
||||||
|
self.table[:,0] = self.table[:,0] * math.sqrt(.5)
|
||||||
|
self.table = self.table / math.sqrt(N/2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.linear(x, self.table, None)
|
||||||
|
|
||||||
|
def plc_loss(N, device=None, alpha=1.0, bias=1.):
|
||||||
|
idct = IDCT(18, device=device)
|
||||||
|
def loss(y_true,y_pred):
|
||||||
|
mask = y_true[:,:,-1:]
|
||||||
|
y_true = y_true[:,:,:-1]
|
||||||
|
e = (y_pred - y_true)*mask
|
||||||
|
e_bands = idct(e[:,:,:-2])
|
||||||
|
bias_mask = torch.clamp(4*y_true[:,:,-1:], min=0., max=1.)
|
||||||
|
l1_loss = torch.mean(torch.abs(e))
|
||||||
|
ceps_loss = torch.mean(torch.abs(e[:,:,:-2]))
|
||||||
|
band_loss = torch.mean(torch.abs(e_bands))
|
||||||
|
biased_loss = torch.mean(bias_mask*torch.clamp(e_bands, min=0.))
|
||||||
|
pitch_loss1 = torch.mean(torch.clamp(torch.abs(e[:,:,18:19]),max=1.))
|
||||||
|
pitch_loss = torch.mean(torch.clamp(torch.abs(e[:,:,18:19]),max=.4))
|
||||||
|
voice_bias = torch.mean(torch.clamp(-e[:,:,-1:], min=0.))
|
||||||
|
tot = l1_loss + 0.1*voice_bias + alpha*(band_loss + bias*biased_loss) + pitch_loss1 + 8*pitch_loss
|
||||||
|
return tot, l1_loss, ceps_loss, band_loss, pitch_loss
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# weight initialization and clipping
|
||||||
|
def init_weights(module):
|
||||||
|
if isinstance(module, nn.GRU):
|
||||||
|
for p in module.named_parameters():
|
||||||
|
if p[0].startswith('weight_hh_'):
|
||||||
|
nn.init.orthogonal_(p[1])
|
||||||
|
|
||||||
|
|
||||||
|
class GLU(nn.Module):
|
||||||
|
def __init__(self, feat_size):
|
||||||
|
super(GLU, self).__init__()
|
||||||
|
|
||||||
|
torch.manual_seed(5)
|
||||||
|
|
||||||
|
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
|
||||||
|
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||||
|
nn.init.orthogonal_(m.weight.data)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
out = x * torch.sigmoid(self.gate(x))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class FWConv(nn.Module):
|
||||||
|
def __init__(self, in_size, out_size, kernel_size=2):
|
||||||
|
super(FWConv, self).__init__()
|
||||||
|
|
||||||
|
torch.manual_seed(5)
|
||||||
|
|
||||||
|
self.in_size = in_size
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
|
||||||
|
self.glu = GLU(out_size)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
|
||||||
|
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||||
|
nn.init.orthogonal_(m.weight.data)
|
||||||
|
|
||||||
|
def forward(self, x, state):
|
||||||
|
xcat = torch.cat((state, x), -1)
|
||||||
|
out = self.glu(torch.tanh(self.conv(xcat)))
|
||||||
|
return out, xcat[:,self.in_size:]
|
||||||
|
|
||||||
|
def n(x):
|
||||||
|
return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
|
||||||
|
|
||||||
|
class PLC(nn.Module):
|
||||||
|
def __init__(self, features_in=57, features_out=20, cond_size=128, gru_size=128):
|
||||||
|
super(PLC, self).__init__()
|
||||||
|
|
||||||
|
self.features_in = features_in
|
||||||
|
self.features_out = features_out
|
||||||
|
self.cond_size = cond_size
|
||||||
|
self.gru_size = gru_size
|
||||||
|
|
||||||
|
self.dense_in = nn.Linear(self.features_in, self.cond_size)
|
||||||
|
self.gru1 = nn.GRU(self.cond_size, self.gru_size, batch_first=True)
|
||||||
|
self.gru2 = nn.GRU(self.gru_size, self.gru_size, batch_first=True)
|
||||||
|
self.dense_out = nn.Linear(self.gru_size, features_out)
|
||||||
|
|
||||||
|
self.apply(init_weights)
|
||||||
|
nb_params = sum(p.numel() for p in self.parameters())
|
||||||
|
print(f"plc model: {nb_params} weights")
|
||||||
|
|
||||||
|
def forward(self, features, lost, states=None):
|
||||||
|
device = features.device
|
||||||
|
batch_size = features.size(0)
|
||||||
|
if states is None:
|
||||||
|
gru1_state = torch.zeros((1, batch_size, self.gru_size), device=device)
|
||||||
|
gru2_state = torch.zeros((1, batch_size, self.gru_size), device=device)
|
||||||
|
else:
|
||||||
|
gru1_state = states[0]
|
||||||
|
gru2_state = states[1]
|
||||||
|
x = torch.cat([features, lost], dim=-1)
|
||||||
|
x = torch.tanh(self.dense_in(x))
|
||||||
|
gru1_out, gru1_state = self.gru1(x, gru1_state)
|
||||||
|
gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
|
||||||
|
return self.dense_out(gru2_out), [gru1_state, gru2_state]
|
56
dnn/torch/plc/plc_dataset.py
Normal file
56
dnn/torch/plc/plc_dataset.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class PLCDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
feature_file,
|
||||||
|
loss_file,
|
||||||
|
sequence_length=1000,
|
||||||
|
nb_features=20,
|
||||||
|
nb_burg_features=36,
|
||||||
|
lpc_order=16):
|
||||||
|
|
||||||
|
self.features_in = nb_features + nb_burg_features
|
||||||
|
self.nb_burg_features = nb_burg_features
|
||||||
|
total_features = self.features_in + lpc_order
|
||||||
|
self.sequence_length = sequence_length
|
||||||
|
self.nb_features = nb_features
|
||||||
|
|
||||||
|
self.features = np.memmap(feature_file, dtype='float32', mode='r')
|
||||||
|
self.lost = np.memmap(loss_file, dtype='int8', mode='r')
|
||||||
|
self.lost = self.lost.astype('float32')
|
||||||
|
|
||||||
|
self.nb_sequences = self.features.shape[0]//self.sequence_length//total_features
|
||||||
|
|
||||||
|
self.features = self.features[:self.nb_sequences*self.sequence_length*total_features]
|
||||||
|
self.features = self.features.reshape((self.nb_sequences, self.sequence_length, total_features))
|
||||||
|
self.features = self.features[:,:,:self.features_in]
|
||||||
|
|
||||||
|
#self.lost = self.lost[:(len(self.lost)//features.shape[1]-1)*features.shape[1]]
|
||||||
|
#self.lost = self.lost.reshape((-1, self.sequence_length))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.nb_sequences
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
features = self.features[index, :, :]
|
||||||
|
burg_lost = (np.random.rand(features.shape[0]) > .1).astype('float32')
|
||||||
|
burg_lost = np.reshape(burg_lost, (features.shape[0], 1))
|
||||||
|
burg_mask = np.tile(burg_lost, (1,self.nb_burg_features))
|
||||||
|
|
||||||
|
lost_offset = np.random.randint(0, high=self.lost.shape[0]-self.sequence_length)
|
||||||
|
lost = self.lost[lost_offset:lost_offset+self.sequence_length]
|
||||||
|
lost = np.reshape(lost, (features.shape[0], 1))
|
||||||
|
lost_mask = np.tile(lost, (1,features.shape[-1]))
|
||||||
|
in_features = features*lost_mask
|
||||||
|
in_features[:,:self.nb_burg_features] = in_features[:,:self.nb_burg_features]*burg_mask
|
||||||
|
|
||||||
|
#For the first frame after a loss, we don't have valid features, but the Burg estimate is valid.
|
||||||
|
#in_features[:,1:,self.nb_burg_features:] = in_features[:,1:,self.nb_burg_features:]*lost_mask[:,:-1,self.nb_burg_features:]
|
||||||
|
out_lost = np.copy(lost)
|
||||||
|
#out_lost[:,1:,:] = out_lost[:,1:,:]*out_lost[:,:-1,:]
|
||||||
|
|
||||||
|
out_features = np.concatenate([features[:,self.nb_burg_features:], 1.-out_lost], axis=-1)
|
||||||
|
burg_sign = 2*burg_lost - 1
|
||||||
|
# last dim is 1 for received packet, 0 for lost packet, and -1 when just the Burg info is missing
|
||||||
|
return in_features*lost_mask, lost*burg_sign, out_features
|
145
dnn/torch/plc/train_plc.py
Normal file
145
dnn/torch/plc/train_plc.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
import plc
|
||||||
|
from plc_dataset import PLCDataset
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
|
||||||
|
parser.add_argument('loss', type=str, help='path to signal file in .s8 format')
|
||||||
|
parser.add_argument('output', type=str, help='path to output folder')
|
||||||
|
|
||||||
|
parser.add_argument('--suffix', type=str, help="model name suffix", default="")
|
||||||
|
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
model_group = parser.add_argument_group(title="model parameters")
|
||||||
|
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 128", default=128)
|
||||||
|
model_group.add_argument('--gru-size', type=int, help="GRU size, default: 128", default=128)
|
||||||
|
|
||||||
|
training_group = parser.add_argument_group(title="training parameters")
|
||||||
|
training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
|
||||||
|
training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3)
|
||||||
|
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20)
|
||||||
|
training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15)
|
||||||
|
training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4)
|
||||||
|
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cuda_visible_devices != None:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
|
||||||
|
|
||||||
|
# checkpoints
|
||||||
|
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||||
|
checkpoint = dict()
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
# training parameters
|
||||||
|
batch_size = args.batch_size
|
||||||
|
lr = args.lr
|
||||||
|
epochs = args.epochs
|
||||||
|
sequence_length = args.sequence_length
|
||||||
|
lr_decay = args.lr_decay
|
||||||
|
|
||||||
|
adam_betas = [0.8, 0.95]
|
||||||
|
adam_eps = 1e-8
|
||||||
|
features_file = args.features
|
||||||
|
loss_file = args.loss
|
||||||
|
|
||||||
|
# model parameters
|
||||||
|
cond_size = args.cond_size
|
||||||
|
|
||||||
|
|
||||||
|
checkpoint['batch_size'] = batch_size
|
||||||
|
checkpoint['lr'] = lr
|
||||||
|
checkpoint['lr_decay'] = lr_decay
|
||||||
|
checkpoint['epochs'] = epochs
|
||||||
|
checkpoint['sequence_length'] = sequence_length
|
||||||
|
checkpoint['adam_betas'] = adam_betas
|
||||||
|
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
|
||||||
|
checkpoint['model_args'] = ()
|
||||||
|
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gru_size': args.gru_size}
|
||||||
|
print(checkpoint['model_kwargs'])
|
||||||
|
model = plc.PLC(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||||
|
|
||||||
|
if type(args.initial_checkpoint) != type(None):
|
||||||
|
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||||
|
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||||
|
|
||||||
|
checkpoint['state_dict'] = model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
dataset = PLCDataset(features_file, loss_file, sequence_length=sequence_length)
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
|
||||||
|
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
|
||||||
|
|
||||||
|
|
||||||
|
# learning rate scheduler
|
||||||
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
|
||||||
|
|
||||||
|
states = None
|
||||||
|
|
||||||
|
plc_loss = plc.plc_loss(18, device=device)
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
for epoch in range(1, epochs + 1):
|
||||||
|
|
||||||
|
running_loss = 0
|
||||||
|
running_l1_loss = 0
|
||||||
|
running_ceps_loss = 0
|
||||||
|
running_band_loss = 0
|
||||||
|
running_pitch_loss = 0
|
||||||
|
|
||||||
|
print(f"training epoch {epoch}...")
|
||||||
|
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
|
||||||
|
for i, (features, lost, target) in enumerate(tepoch):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
features = features.to(device)
|
||||||
|
lost = lost.to(device)
|
||||||
|
target = target.to(device)
|
||||||
|
|
||||||
|
out, states = model(features, lost)
|
||||||
|
|
||||||
|
loss, l1_loss, ceps_loss, band_loss, pitch_loss = plc_loss(target, out)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
#model.clip_weights()
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
running_loss += loss.detach().cpu().item()
|
||||||
|
running_l1_loss += l1_loss.detach().cpu().item()
|
||||||
|
running_ceps_loss += ceps_loss.detach().cpu().item()
|
||||||
|
running_band_loss += band_loss.detach().cpu().item()
|
||||||
|
running_pitch_loss += pitch_loss.detach().cpu().item()
|
||||||
|
tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
|
||||||
|
l1_loss=f"{running_l1_loss/(i+1):8.5f}",
|
||||||
|
ceps_loss=f"{running_ceps_loss/(i+1):8.5f}",
|
||||||
|
band_loss=f"{running_band_loss/(i+1):8.5f}",
|
||||||
|
pitch_loss=f"{running_pitch_loss/(i+1):8.5f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# save checkpoint
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth')
|
||||||
|
checkpoint['state_dict'] = model.state_dict()
|
||||||
|
checkpoint['loss'] = running_loss / len(dataloader)
|
||||||
|
checkpoint['epoch'] = epoch
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
Loading…
Add table
Add a link
Reference in a new issue