mirror of
https://github.com/xiph/opus.git
synced 2025-06-03 17:17:42 +00:00
Adapting to new data format/model
This commit is contained in:
parent
f38b4a317f
commit
733a095ba2
2 changed files with 42 additions and 22 deletions
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
|
||||
class large_if_ccode(torch.nn.Module):
|
||||
|
||||
def __init__(self,input_dim = 90,gru_dim = 64,output_dim = 192):
|
||||
def __init__(self,input_dim = 88,gru_dim = 64,output_dim = 192):
|
||||
super(large_if_ccode,self).__init__()
|
||||
|
||||
self.activation = torch.nn.Tanh()
|
||||
|
@ -89,11 +89,12 @@ class large_joint(torch.nn.Module):
|
|||
1D CNN on IF, merge with xcorr, 2D CNN on merged + GRU
|
||||
"""
|
||||
|
||||
def __init__(self,input_IF_dim = 90,input_xcorr_dim = 257,gru_dim = 64,output_dim = 192):
|
||||
def __init__(self,input_IF_dim = 88,input_xcorr_dim = 224,gru_dim = 64,output_dim = 192):
|
||||
super(large_joint,self).__init__()
|
||||
|
||||
self.activation = torch.nn.Tanh()
|
||||
|
||||
print("dim=", input_IF_dim)
|
||||
self.if_upsample = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_IF_dim,64),
|
||||
self.activation,
|
||||
|
@ -142,8 +143,8 @@ class large_joint(torch.nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
xcorr_feat = x[:,:,:257]
|
||||
if_feat = x[:,:,257:]
|
||||
xcorr_feat = x[:,:,:224]
|
||||
if_feat = x[:,:,224:]
|
||||
# x = torch.cat([xcorr_feat.unsqueeze(-1),self.if_upsample(if_feat).unsqueeze(-1)],axis = -1)
|
||||
xcorr_feat = self.conv(xcorr_feat.unsqueeze(-1).permute(0,3,2,1)).squeeze(1).permute(0,2,1)
|
||||
if_feat = self.if_upsample(if_feat)
|
||||
|
@ -186,22 +187,29 @@ class loader(torch.utils.data.Dataset):
|
|||
return torch.from_numpy(self.if_feat[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
|
||||
|
||||
class loader_joint(torch.utils.data.Dataset):
|
||||
def __init__(self, features_if, file_pitch, features_xcorr,confidence_threshold = 0.4,context = 100, choice_data = 'both'):
|
||||
self.if_feat = np.memmap(features_if, dtype=np.float32).reshape(-1,90)
|
||||
self.xcorr = np.memmap(features_xcorr, dtype=np.float32).reshape(-1,257)
|
||||
self.cents = np.rint(np.load(file_pitch)[0,:]/20)
|
||||
def __init__(self, features, file_pitch, confidence_threshold = 0.4,context = 100, choice_data = 'both'):
|
||||
self.feat = np.memmap(features, mode='r', dtype=np.int8).reshape(-1,312)
|
||||
#Skip first first two frames for dump_data to sync with CREPE
|
||||
self.feat = self.feat[2:,:]
|
||||
|
||||
self.xcorr = self.feat[:,:224]
|
||||
self.if_feat = self.feat[:,224:]
|
||||
ground_truth = np.memmap(file_pitch, mode='r', dtype=np.float32).reshape(-1,2)
|
||||
self.cents = np.rint(60*np.log2(ground_truth[:,0]/62.5))
|
||||
mask = (self.cents>=0).astype('float32') * (self.cents<=180).astype('float32')
|
||||
self.cents = np.clip(self.cents,0,179)
|
||||
self.confidence = np.load(file_pitch)[1,:]
|
||||
self.confidence = ground_truth[:,1] * mask
|
||||
# Filter confidence for CREPE
|
||||
self.confidence[self.confidence < confidence_threshold] = 0
|
||||
self.context = context
|
||||
print(np.mean(self.confidence), np.mean(self.cents))
|
||||
|
||||
self.choice_data = choice_data
|
||||
|
||||
frame_max = self.if_feat.shape[0]//context
|
||||
self.if_feat = np.reshape(self.if_feat[:frame_max*context,:],(frame_max,context,90))
|
||||
self.if_feat = np.reshape(self.if_feat[:frame_max*context,:],(frame_max,context,88))
|
||||
self.cents = np.reshape(self.cents[:frame_max*context],(frame_max,context))
|
||||
self.xcorr = np.reshape(self.xcorr[:frame_max*context,:],(frame_max,context,257))
|
||||
self.xcorr = np.reshape(self.xcorr[:frame_max*context,:],(frame_max,context,224))
|
||||
# self.cents = np.rint(60*np.log2(256/(self.periods + 1.0e-8))).astype('int')
|
||||
# self.cents = np.clip(self.cents,0,239)
|
||||
self.confidence = np.reshape(self.confidence[:frame_max*context],(frame_max,context))
|
||||
|
@ -211,8 +219,8 @@ class loader_joint(torch.utils.data.Dataset):
|
|||
|
||||
def __getitem__(self, index):
|
||||
if self.choice_data == 'both':
|
||||
return torch.cat([torch.from_numpy(self.xcorr[index,:,:]),torch.from_numpy(self.if_feat[index,:,:])],dim=-1),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
|
||||
return torch.cat([torch.from_numpy((1./127)*self.xcorr[index,:,:]),torch.from_numpy((1./127)*self.if_feat[index,:,:])],dim=-1),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
|
||||
elif self.choice_data == 'if':
|
||||
return torch.from_numpy(self.if_feat[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
|
||||
return torch.from_numpy((1./127)*self.if_feat[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
|
||||
else:
|
||||
return torch.from_numpy(self.xcorr[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
|
||||
return torch.from_numpy((1./127)*self.xcorr[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
|
||||
|
|
|
@ -6,8 +6,7 @@ Training the neural pitch estimator
|
|||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('features_if', type=str, help='.f32 IF Features for training (generated by augmentation script)')
|
||||
parser.add_argument('features_xcorr', type=str, help='.f32 Xcorr Features for training (generated by augmentation script)')
|
||||
parser.add_argument('features', type=str, help='.f32 IF Features for training (generated by augmentation script)')
|
||||
parser.add_argument('features_pitch', type=str, help='.npy Pitch file for training (generated by augmentation script)')
|
||||
parser.add_argument('output_folder', type=str, help='Output directory to store the model weights and config')
|
||||
parser.add_argument('data_format', type=str, help='Choice of Input Data',choices=['if','xcorr','both'])
|
||||
|
@ -45,7 +44,7 @@ from utils import count_parameters
|
|||
import tqdm
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from evaluation import rpa
|
||||
#from evaluation import rpa
|
||||
|
||||
# print(list(range(torch.cuda.device_count())))
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
@ -60,9 +59,9 @@ elif args.data_format == 'xcorr':
|
|||
pitch_nn = model(args.xcorr_dimension,args.gru_dim,args.output_dim)
|
||||
else:
|
||||
from models import large_joint as model
|
||||
pitch_nn = model(args.freq_keep*3,args.xcorr_dimension,args.gru_dim,args.output_dim)
|
||||
pitch_nn = model(88,224,args.gru_dim,args.output_dim)
|
||||
|
||||
dataset_training = loader(args.features_if,args.features_pitch,args.features_xcorr,args.confidence_threshold,args.context,args.data_format)
|
||||
dataset_training = loader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
|
||||
|
||||
def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
||||
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
|
||||
|
@ -71,7 +70,7 @@ def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
|||
if choice == 'default':
|
||||
# Categorical Cross Entropy
|
||||
CE = -torch.sum(torch.log(logits_softmax*labels_one_hot + 1.0e-6)*labels_one_hot,dim=-1)
|
||||
CE = torch.sum(confidence*CE)
|
||||
CE = torch.mean(confidence*CE)
|
||||
|
||||
else:
|
||||
# Robust Cross Entropy
|
||||
|
@ -80,6 +79,14 @@ def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
|||
|
||||
return CE
|
||||
|
||||
def accuracy(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
||||
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
|
||||
pred_pitch = torch.argmax(logits_softmax, 2)
|
||||
#print(pred_pitch.shape, labels.long().shape)
|
||||
accuracy = (pred_pitch != labels.long())*1.
|
||||
#print(accuracy.shape, confidence.shape)
|
||||
return 1.-torch.mean(confidence*accuracy)
|
||||
|
||||
# features = args.features
|
||||
# pitch = args.crepe_pitch
|
||||
# dataset_training = loader(features,pitch,args.confidence_threshold,args.freq_keep,args.context)
|
||||
|
@ -101,20 +108,25 @@ num_epochs = args.epochs
|
|||
|
||||
for epoch in range(num_epochs):
|
||||
losses = []
|
||||
accs = []
|
||||
pitch_nn.train()
|
||||
with tqdm.tqdm(train_dataloader) as train_epoch:
|
||||
for i, (xi, yi, ci) in enumerate(train_epoch):
|
||||
yi, xi, ci = yi.to(device, non_blocking=True), xi.to(device, non_blocking=True), ci.to(device, non_blocking=True)
|
||||
pi = pitch_nn(xi.float())
|
||||
loss = loss_custom(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
|
||||
acc = accuracy(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
|
||||
acc = acc.detach()
|
||||
|
||||
model_opt.zero_grad()
|
||||
loss.backward()
|
||||
model_opt.step()
|
||||
|
||||
losses.append(loss.item())
|
||||
accs.append(acc.item())
|
||||
avg_loss = np.mean(losses)
|
||||
train_epoch.set_postfix({"Train Epoch" : epoch, "Train Loss":avg_loss})
|
||||
avg_acc = np.mean(accs)
|
||||
train_epoch.set_postfix({"Train Epoch" : epoch, "Train Loss":avg_loss, "acc" : avg_acc.item()})
|
||||
|
||||
if epoch % 5 == 0:
|
||||
pitch_nn.eval()
|
||||
|
@ -129,7 +141,7 @@ for epoch in range(num_epochs):
|
|||
test_epoch.set_postfix({"Epoch" : epoch, "Test Loss":avg_loss})
|
||||
|
||||
pitch_nn.eval()
|
||||
rpa(pitch_nn,device,data_format = args.data_format)
|
||||
#rpa(pitch_nn,device,data_format = args.data_format)
|
||||
|
||||
config = dict(
|
||||
data_format = args.data_format,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue