mirror of
https://github.com/xiph/opus.git
synced 2025-05-28 14:19:13 +00:00
refactoring and cleanup
This commit is contained in:
parent
4901445490
commit
ce28695844
4 changed files with 75 additions and 173 deletions
|
@ -37,33 +37,25 @@ import time
|
|||
np_seed = int(time.time())
|
||||
torch_seed = int(time.time())
|
||||
|
||||
import json
|
||||
import torch
|
||||
torch.manual_seed(torch_seed)
|
||||
import numpy as np
|
||||
np.random.seed(np_seed)
|
||||
from utils import count_parameters
|
||||
import tqdm
|
||||
import sys
|
||||
from datetime import datetime
|
||||
#from evaluation import rpa
|
||||
from models import PitchDNN, PitchDNNIF, PitchDNNXcorr, PitchDNNDataloader
|
||||
|
||||
# print(list(range(torch.cuda.device_count())))
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# device = 'cpu'
|
||||
|
||||
from models import loader_joint as loader
|
||||
|
||||
if args.data_format == 'if':
|
||||
from models import large_if_ccode as model
|
||||
pitch_nn = model(args.freq_keep*3,args.gru_dim,args.output_dim)
|
||||
pitch_nn = PitchDNNIF(3 * args.freq_keep - 2, args.gru_dim, args.output_dim)
|
||||
elif args.data_format == 'xcorr':
|
||||
from models import large_xcorr as model
|
||||
pitch_nn = model(args.xcorr_dimension,args.gru_dim,args.output_dim)
|
||||
pitch_nn = PitchDNNXcorr(args.xcorr_dimension, args.gru_dim, args.output_dim)
|
||||
else:
|
||||
from models import large_joint as model
|
||||
pitch_nn = model(88,224,args.gru_dim,args.output_dim)
|
||||
pitch_nn = PitchDNN(3 * args.freq_keep - 2, 224, args.gru_dim, args.output_dim)
|
||||
|
||||
dataset_training = loader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
|
||||
dataset_training = PitchDNNDataloader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
|
||||
|
||||
def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
||||
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
|
||||
|
@ -84,23 +76,15 @@ def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
|||
def accuracy(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
||||
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
|
||||
pred_pitch = torch.argmax(logits_softmax, 2)
|
||||
#print(pred_pitch.shape, labels.long().shape)
|
||||
accuracy = (pred_pitch != labels.long())*1.
|
||||
#print(accuracy.shape, confidence.shape)
|
||||
return 1.-torch.mean(confidence*accuracy)
|
||||
|
||||
# features = args.features
|
||||
# pitch = args.crepe_pitch
|
||||
# dataset_training = loader(features,pitch,args.confidence_threshold,args.freq_keep,args.context)
|
||||
# dataset_training = loader(features,pitch,'../../../../testing/testing_features_10pct_xcorr.f32')
|
||||
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(dataset_training, [0.95,0.05],generator=torch.Generator().manual_seed(torch_seed))
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(dataset_training, [0.95,0.05], generator=torch.Generator().manual_seed(torch_seed))
|
||||
|
||||
batch_size = 256
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset = train_dataset,batch_size = batch_size,shuffle = True,num_workers = 0, pin_memory = False)
|
||||
test_dataloader = torch.utils.data.DataLoader(dataset = test_dataset,batch_size = batch_size,shuffle = True,num_workers = 0, pin_memory = False)
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
|
||||
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
|
||||
|
||||
# pitch_nn = model(args.freq_keep*3,args.gru_dim,args.output_dim).to(device)
|
||||
pitch_nn = pitch_nn.to(device)
|
||||
num_params = count_parameters(pitch_nn)
|
||||
learning_rate = args.learning_rate
|
||||
|
@ -143,26 +127,25 @@ for epoch in range(num_epochs):
|
|||
test_epoch.set_postfix({"Epoch" : epoch, "Test Loss":avg_loss})
|
||||
|
||||
pitch_nn.eval()
|
||||
#rpa(pitch_nn,device,data_format = args.data_format)
|
||||
|
||||
config = dict(
|
||||
data_format = args.data_format,
|
||||
epochs = num_epochs,
|
||||
window_size = args.N,
|
||||
hop_factor = args.H,
|
||||
freq_keep = args.freq_keep,
|
||||
batch_size = batch_size,
|
||||
learning_rate = learning_rate,
|
||||
confidence_threshold = args.confidence_threshold,
|
||||
model_parameters = num_params,
|
||||
np_seed = np_seed,
|
||||
torch_seed = torch_seed,
|
||||
xcorr_dim = args.xcorr_dimension,
|
||||
dim_input = 3*args.freq_keep,
|
||||
gru_dim = args.gru_dim,
|
||||
output_dim = args.output_dim,
|
||||
choice_cel = args.choice_cel,
|
||||
context = args.context,
|
||||
data_format=args.data_format,
|
||||
epochs=num_epochs,
|
||||
window_size= args.N,
|
||||
hop_factor= args.H,
|
||||
freq_keep=args.freq_keep,
|
||||
batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
confidence_threshold=args.confidence_threshold,
|
||||
model_parameters=num_params,
|
||||
np_seed=np_seed,
|
||||
torch_seed=torch_seed,
|
||||
xcorr_dim=args.xcorr_dimension,
|
||||
dim_input=3*args.freq_keep - 2,
|
||||
gru_dim=args.gru_dim,
|
||||
output_dim=args.output_dim,
|
||||
choice_cel=args.choice_cel,
|
||||
context=args.context,
|
||||
)
|
||||
|
||||
model_save_path = os.path.join(args.output, f"{args.prefix}_{args.data_format}.pth")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue