changed checkpoint format

This commit is contained in:
Jan Buethe 2023-09-26 14:35:36 +02:00 committed by Jean-Marc Valin
parent 733a095ba2
commit 41a4c9515d
No known key found for this signature in database
GPG key ID: 531A52533318F00A
4 changed files with 29 additions and 132 deletions

View file

@ -3,6 +3,7 @@ Training the neural pitch estimator
"""
import os
import argparse
parser = argparse.ArgumentParser()
@ -22,6 +23,7 @@ parser.add_argument('--output_dim', type=int, help='Output dimension',default =
parser.add_argument('--learning_rate', type=float, help='Learning Rate',default = 1.0e-3,required = False)
parser.add_argument('--epochs', type=int, help='Number of training epochs',default = 50,required = False)
parser.add_argument('--choice_cel', type=str, help='Choice of Cross Entropy Loss (default or robust)',choices=['default','robust'],default = 'default',required = False)
parser.add_argument('--prefix', type=str, help="prefix for model export, default: model", default='model')
args = parser.parse_args()
@ -163,12 +165,9 @@ choice_cel = args.choice_cel,
context = args.context,
)
now = datetime.now()
dir_pth_save = args.output_folder
dir_network = dir_pth_save + str(now) + '_net_' + args.data_format + '.pth'
dir_dictparams = dir_pth_save + str(now) + '_config_' + args.data_format + '.json'
# Save Weights
torch.save(pitch_nn.state_dict(), dir_network)
# Save Config
with open(dir_dictparams, 'w') as fp:
json.dump(config, fp)
model_save_path = os.path.join(args.output, f"{args.prefix}_{args.data_format}.pth")
checkpoint = {
'state_dict': pitch_nn.state_dict(),
'config': config
}
torch.save(checkpoint, model_save_path)