mirror of
https://github.com/xiph/opus.git
synced 2025-05-29 14:49:14 +00:00
updated PitchDNN export script
This commit is contained in:
parent
ce28695844
commit
0459a572f5
1 changed files with 29 additions and 17 deletions
|
@ -44,7 +44,7 @@ args = parser.parse_args()
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
from models import large_if_ccode
|
||||
from models import PitchDNN
|
||||
from wexchange.torch import dump_torch_weights
|
||||
from wexchange.c_export import CWriter, print_vector
|
||||
|
||||
|
@ -52,39 +52,51 @@ def c_export(args, model):
|
|||
|
||||
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
|
||||
|
||||
enc_writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='nnpitch')
|
||||
enc_writer.header.write(
|
||||
writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='PitchDNN')
|
||||
writer.header.write(
|
||||
f"""
|
||||
#include "opus_types.h"
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# encoder
|
||||
encoder_dense_layers = [
|
||||
('initial' , 'initial', 'TANH'),
|
||||
('upsample' , 'upsample', 'TANH')
|
||||
layers = [
|
||||
('if_upsample.0', "dense_if_upsampler_1"),
|
||||
('if_upsample.2', "dense_if_upsampler_2"),
|
||||
('conv.1', "conv2d_1"),
|
||||
('conv.4', "conv2d_2"),
|
||||
('conv.7', "conv2d_3"),
|
||||
('downsample.0', "dense_downsampler"),
|
||||
("upsample.0", "dense_final_upsampler")
|
||||
]
|
||||
|
||||
for name, export_name, _ in encoder_dense_layers:
|
||||
|
||||
for name, export_name in layers:
|
||||
layer = model.get_submodule(name)
|
||||
dump_torch_weights(enc_writer, layer, name=export_name, verbose=True)
|
||||
dump_torch_weights(writer, layer, name=export_name, verbose=True)
|
||||
|
||||
|
||||
encoder_gru_layers = [
|
||||
('gru' , 'gru', 'TANH'),
|
||||
gru_layers = [
|
||||
("GRU", "gru_1"),
|
||||
]
|
||||
|
||||
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
|
||||
for name, export_name, _ in encoder_gru_layers])
|
||||
max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
|
||||
for name, export_name in gru_layers])
|
||||
|
||||
del enc_writer
|
||||
writer.header.write(
|
||||
f"""
|
||||
|
||||
#define PITCH_DNN_MAX_RNN_UNITS {max_rnn_units}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
model = large_if_ccode()
|
||||
checkpoint = torch.load(args.checkpoint ,map_location='cpu')
|
||||
model = PitchDNN()
|
||||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
c_export(args, model)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue