mirror of
https://github.com/xiph/opus.git
synced 2025-05-19 01:48:30 +00:00
changed checkpoint format
This commit is contained in:
parent
733a095ba2
commit
41a4c9515d
4 changed files with 29 additions and 132 deletions
|
@ -3,6 +3,7 @@ Training the neural pitch estimator
|
|||
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
|
@ -22,6 +23,7 @@ parser.add_argument('--output_dim', type=int, help='Output dimension',default =
|
|||
parser.add_argument('--learning_rate', type=float, help='Learning Rate',default = 1.0e-3,required = False)
|
||||
parser.add_argument('--epochs', type=int, help='Number of training epochs',default = 50,required = False)
|
||||
parser.add_argument('--choice_cel', type=str, help='Choice of Cross Entropy Loss (default or robust)',choices=['default','robust'],default = 'default',required = False)
|
||||
parser.add_argument('--prefix', type=str, help="prefix for model export, default: model", default='model')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -163,12 +165,9 @@ choice_cel = args.choice_cel,
|
|||
context = args.context,
|
||||
)
|
||||
|
||||
now = datetime.now()
|
||||
dir_pth_save = args.output_folder
|
||||
dir_network = dir_pth_save + str(now) + '_net_' + args.data_format + '.pth'
|
||||
dir_dictparams = dir_pth_save + str(now) + '_config_' + args.data_format + '.json'
|
||||
# Save Weights
|
||||
torch.save(pitch_nn.state_dict(), dir_network)
|
||||
# Save Config
|
||||
with open(dir_dictparams, 'w') as fp:
|
||||
json.dump(config, fp)
|
||||
model_save_path = os.path.join(args.output, f"{args.prefix}_{args.data_format}.pth")
|
||||
checkpoint = {
|
||||
'state_dict': pitch_nn.state_dict(),
|
||||
'config': config
|
||||
}
|
||||
torch.save(checkpoint, model_save_path)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue