mirror of
https://github.com/xiph/opus.git
synced 2025-05-30 07:07:42 +00:00
updated PitchDNN export script
This commit is contained in:
parent
ce28695844
commit
0459a572f5
1 changed files with 29 additions and 17 deletions
|
@ -44,7 +44,7 @@ args = parser.parse_args()
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from models import large_if_ccode
|
from models import PitchDNN
|
||||||
from wexchange.torch import dump_torch_weights
|
from wexchange.torch import dump_torch_weights
|
||||||
from wexchange.c_export import CWriter, print_vector
|
from wexchange.c_export import CWriter, print_vector
|
||||||
|
|
||||||
|
@ -52,39 +52,51 @@ def c_export(args, model):
|
||||||
|
|
||||||
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
|
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
|
||||||
|
|
||||||
enc_writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='nnpitch')
|
writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='PitchDNN')
|
||||||
enc_writer.header.write(
|
writer.header.write(
|
||||||
f"""
|
f"""
|
||||||
#include "opus_types.h"
|
#include "opus_types.h"
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
layers = [
|
||||||
# encoder
|
('if_upsample.0', "dense_if_upsampler_1"),
|
||||||
encoder_dense_layers = [
|
('if_upsample.2', "dense_if_upsampler_2"),
|
||||||
('initial' , 'initial', 'TANH'),
|
('conv.1', "conv2d_1"),
|
||||||
('upsample' , 'upsample', 'TANH')
|
('conv.4', "conv2d_2"),
|
||||||
|
('conv.7', "conv2d_3"),
|
||||||
|
('downsample.0', "dense_downsampler"),
|
||||||
|
("upsample.0", "dense_final_upsampler")
|
||||||
]
|
]
|
||||||
|
|
||||||
for name, export_name, _ in encoder_dense_layers:
|
|
||||||
|
for name, export_name in layers:
|
||||||
layer = model.get_submodule(name)
|
layer = model.get_submodule(name)
|
||||||
dump_torch_weights(enc_writer, layer, name=export_name, verbose=True)
|
dump_torch_weights(writer, layer, name=export_name, verbose=True)
|
||||||
|
|
||||||
|
|
||||||
encoder_gru_layers = [
|
gru_layers = [
|
||||||
('gru' , 'gru', 'TANH'),
|
("GRU", "gru_1"),
|
||||||
]
|
]
|
||||||
|
|
||||||
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
|
max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
|
||||||
for name, export_name, _ in encoder_gru_layers])
|
for name, export_name in gru_layers])
|
||||||
|
|
||||||
del enc_writer
|
writer.header.write(
|
||||||
|
f"""
|
||||||
|
|
||||||
|
#define PITCH_DNN_MAX_RNN_UNITS {max_rnn_units}
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
model = large_if_ccode()
|
model = PitchDNN()
|
||||||
checkpoint = torch.load(args.checkpoint ,map_location='cpu')
|
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
c_export(args, model)
|
c_export(args, model)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue