Switching to neural pitch estimator

Remove old pitch estimator and retrain all models
This commit is contained in:
Jean-Marc Valin 2023-10-02 02:23:41 -04:00
parent da7f4c6c99
commit f0ec990dba
No known key found for this signature in database
GPG key ID: 531A52533318F00A
13 changed files with 103 additions and 137 deletions

View file

@ -24,6 +24,7 @@ parser.add_argument('--learning_rate', type=float, help='Learning Rate',default
parser.add_argument('--epochs', type=int, help='Number of training epochs',default = 50,required = False)
parser.add_argument('--choice_cel', type=str, help='Choice of Cross Entropy Loss (default or robust)',choices=['default','robust'],default = 'default',required = False)
parser.add_argument('--prefix', type=str, help="prefix for model export, default: model", default='model')
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
args = parser.parse_args()
@ -55,6 +56,11 @@ elif args.data_format == 'xcorr':
else:
pitch_nn = PitchDNN(3 * args.freq_keep - 2, 224, args.gru_dim, args.output_dim)
if type(args.initial_checkpoint) != type(None):
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
pitch_nn.load_state_dict(checkpoint['state_dict'], strict=False)
dataset_training = PitchDNNDataloader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):

View file

@ -159,7 +159,7 @@ def distortion_loss(y_true, y_pred, rate_lambda=None):
raise ValueError('distortion loss is designed to work with 20 features')
ceps_error = y_pred[..., :18] - y_true[..., :18]
pitch_error = 2 * (y_pred[..., 18:19] - y_true[..., 18:19]) / (2 + y_true[..., 18:19])
pitch_error = 2*(y_pred[..., 18:19] - y_true[..., 18:19])
corr_error = y_pred[..., 19:] - y_true[..., 19:]
pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2