updated PitchDNN export script

This commit is contained in:
Jan Buethe 2023-09-29 15:34:59 +02:00
parent ce28695844
commit 0459a572f5
No known key found for this signature in database
GPG key ID: 9E32027A35B36314

View file

@ -44,7 +44,7 @@ args = parser.parse_args()
import torch
import numpy as np
from models import large_if_ccode
from models import PitchDNN
from wexchange.torch import dump_torch_weights
from wexchange.c_export import CWriter, print_vector
@ -52,39 +52,51 @@ def c_export(args, model):
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
enc_writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='nnpitch')
enc_writer.header.write(
writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='PitchDNN')
writer.header.write(
f"""
#include "opus_types.h"
"""
)
# encoder
encoder_dense_layers = [
('initial' , 'initial', 'TANH'),
('upsample' , 'upsample', 'TANH')
layers = [
('if_upsample.0', "dense_if_upsampler_1"),
('if_upsample.2', "dense_if_upsampler_2"),
('conv.1', "conv2d_1"),
('conv.4', "conv2d_2"),
('conv.7', "conv2d_3"),
('downsample.0', "dense_downsampler"),
("upsample.0", "dense_final_upsampler")
]
for name, export_name, _ in encoder_dense_layers:
for name, export_name in layers:
layer = model.get_submodule(name)
dump_torch_weights(enc_writer, layer, name=export_name, verbose=True)
dump_torch_weights(writer, layer, name=export_name, verbose=True)
encoder_gru_layers = [
('gru' , 'gru', 'TANH'),
gru_layers = [
("GRU", "gru_1"),
]
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
for name, export_name, _ in encoder_gru_layers])
max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
for name, export_name in gru_layers])
del enc_writer
writer.header.write(
f"""
#define PITCH_DNN_MAX_RNN_UNITS {max_rnn_units}
"""
)
writer.close()
if __name__ == "__main__":
os.makedirs(args.output_dir, exist_ok=True)
model = large_if_ccode()
checkpoint = torch.load(args.checkpoint ,map_location='cpu')
model = PitchDNN()
checkpoint = torch.load(args.checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
c_export(args, model)