mirror of
https://github.com/xiph/opus.git
synced 2025-05-18 09:28:30 +00:00
89 lines
3 KiB
Python
89 lines
3 KiB
Python
"""
|
|
/* Copyright (c) 2022 Amazon
|
|
Written by Jan Buethe */
|
|
/*
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions
|
|
are met:
|
|
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*/
|
|
"""
|
|
|
|
import os
|
|
import argparse
|
|
import sys
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('checkpoint', type=str, help='rdovae model checkpoint')
|
|
parser.add_argument('output_dir', type=str, help='output folder')
|
|
|
|
args = parser.parse_args()
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
from models import large_if_ccode
|
|
from wexchange.torch import dump_torch_weights
|
|
from wexchange.c_export import CWriter, print_vector
|
|
|
|
def c_export(args, model):
|
|
|
|
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
|
|
|
|
enc_writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='nnpitch')
|
|
enc_writer.header.write(
|
|
f"""
|
|
#include "opus_types.h"
|
|
"""
|
|
)
|
|
|
|
|
|
# encoder
|
|
encoder_dense_layers = [
|
|
('initial' , 'initial', 'TANH'),
|
|
('upsample' , 'upsample', 'TANH')
|
|
]
|
|
|
|
for name, export_name, _ in encoder_dense_layers:
|
|
layer = model.get_submodule(name)
|
|
dump_torch_weights(enc_writer, layer, name=export_name, verbose=True)
|
|
|
|
|
|
encoder_gru_layers = [
|
|
('gru' , 'gru', 'TANH'),
|
|
]
|
|
|
|
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
|
|
for name, export_name, _ in encoder_gru_layers])
|
|
|
|
del enc_writer
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
model = large_if_ccode()
|
|
model.load_state_dict(torch.load(args.checkpoint,map_location='cpu'))
|
|
c_export(args, model)
|