Switching to neural pitch estimator

Remove old pitch estimator and retrain all models
This commit is contained in:
Jean-Marc Valin 2023-10-02 02:23:41 -04:00
parent da7f4c6c99
commit f0ec990dba
No known key found for this signature in database
GPG key ID: 531A52533318F00A
13 changed files with 103 additions and 137 deletions

View file

@ -24,6 +24,7 @@ parser.add_argument('--learning_rate', type=float, help='Learning Rate',default
parser.add_argument('--epochs', type=int, help='Number of training epochs',default = 50,required = False)
parser.add_argument('--choice_cel', type=str, help='Choice of Cross Entropy Loss (default or robust)',choices=['default','robust'],default = 'default',required = False)
parser.add_argument('--prefix', type=str, help="prefix for model export, default: model", default='model')
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
args = parser.parse_args()
@ -55,6 +56,11 @@ elif args.data_format == 'xcorr':
else:
pitch_nn = PitchDNN(3 * args.freq_keep - 2, 224, args.gru_dim, args.output_dim)
if type(args.initial_checkpoint) != type(None):
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
pitch_nn.load_state_dict(checkpoint['state_dict'], strict=False)
dataset_training = PitchDNNDataloader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):