mirror of
https://github.com/xiph/opus.git
synced 2025-06-02 00:27:43 +00:00
Adapting to new data format/model
This commit is contained in:
parent
f38b4a317f
commit
733a095ba2
2 changed files with 42 additions and 22 deletions
|
@ -6,8 +6,7 @@ Training the neural pitch estimator
|
|||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('features_if', type=str, help='.f32 IF Features for training (generated by augmentation script)')
|
||||
parser.add_argument('features_xcorr', type=str, help='.f32 Xcorr Features for training (generated by augmentation script)')
|
||||
parser.add_argument('features', type=str, help='.f32 IF Features for training (generated by augmentation script)')
|
||||
parser.add_argument('features_pitch', type=str, help='.npy Pitch file for training (generated by augmentation script)')
|
||||
parser.add_argument('output_folder', type=str, help='Output directory to store the model weights and config')
|
||||
parser.add_argument('data_format', type=str, help='Choice of Input Data',choices=['if','xcorr','both'])
|
||||
|
@ -45,7 +44,7 @@ from utils import count_parameters
|
|||
import tqdm
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from evaluation import rpa
|
||||
#from evaluation import rpa
|
||||
|
||||
# print(list(range(torch.cuda.device_count())))
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
@ -60,9 +59,9 @@ elif args.data_format == 'xcorr':
|
|||
pitch_nn = model(args.xcorr_dimension,args.gru_dim,args.output_dim)
|
||||
else:
|
||||
from models import large_joint as model
|
||||
pitch_nn = model(args.freq_keep*3,args.xcorr_dimension,args.gru_dim,args.output_dim)
|
||||
pitch_nn = model(88,224,args.gru_dim,args.output_dim)
|
||||
|
||||
dataset_training = loader(args.features_if,args.features_pitch,args.features_xcorr,args.confidence_threshold,args.context,args.data_format)
|
||||
dataset_training = loader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
|
||||
|
||||
def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
||||
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
|
||||
|
@ -71,7 +70,7 @@ def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
|||
if choice == 'default':
|
||||
# Categorical Cross Entropy
|
||||
CE = -torch.sum(torch.log(logits_softmax*labels_one_hot + 1.0e-6)*labels_one_hot,dim=-1)
|
||||
CE = torch.sum(confidence*CE)
|
||||
CE = torch.mean(confidence*CE)
|
||||
|
||||
else:
|
||||
# Robust Cross Entropy
|
||||
|
@ -80,6 +79,14 @@ def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
|||
|
||||
return CE
|
||||
|
||||
def accuracy(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
||||
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
|
||||
pred_pitch = torch.argmax(logits_softmax, 2)
|
||||
#print(pred_pitch.shape, labels.long().shape)
|
||||
accuracy = (pred_pitch != labels.long())*1.
|
||||
#print(accuracy.shape, confidence.shape)
|
||||
return 1.-torch.mean(confidence*accuracy)
|
||||
|
||||
# features = args.features
|
||||
# pitch = args.crepe_pitch
|
||||
# dataset_training = loader(features,pitch,args.confidence_threshold,args.freq_keep,args.context)
|
||||
|
@ -101,20 +108,25 @@ num_epochs = args.epochs
|
|||
|
||||
for epoch in range(num_epochs):
|
||||
losses = []
|
||||
accs = []
|
||||
pitch_nn.train()
|
||||
with tqdm.tqdm(train_dataloader) as train_epoch:
|
||||
for i, (xi, yi, ci) in enumerate(train_epoch):
|
||||
yi, xi, ci = yi.to(device, non_blocking=True), xi.to(device, non_blocking=True), ci.to(device, non_blocking=True)
|
||||
pi = pitch_nn(xi.float())
|
||||
loss = loss_custom(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
|
||||
acc = accuracy(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
|
||||
acc = acc.detach()
|
||||
|
||||
model_opt.zero_grad()
|
||||
loss.backward()
|
||||
model_opt.step()
|
||||
|
||||
losses.append(loss.item())
|
||||
accs.append(acc.item())
|
||||
avg_loss = np.mean(losses)
|
||||
train_epoch.set_postfix({"Train Epoch" : epoch, "Train Loss":avg_loss})
|
||||
avg_acc = np.mean(accs)
|
||||
train_epoch.set_postfix({"Train Epoch" : epoch, "Train Loss":avg_loss, "acc" : avg_acc.item()})
|
||||
|
||||
if epoch % 5 == 0:
|
||||
pitch_nn.eval()
|
||||
|
@ -129,7 +141,7 @@ for epoch in range(num_epochs):
|
|||
test_epoch.set_postfix({"Epoch" : epoch, "Test Loss":avg_loss})
|
||||
|
||||
pitch_nn.eval()
|
||||
rpa(pitch_nn,device,data_format = args.data_format)
|
||||
#rpa(pitch_nn,device,data_format = args.data_format)
|
||||
|
||||
config = dict(
|
||||
data_format = args.data_format,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue