mirror of
https://github.com/xiph/opus.git
synced 2025-06-05 15:03:39 +00:00
finalized quantization option in export_rdovae_weights.py
This commit is contained in:
parent
88c8b30785
commit
1accd2472e
1 changed files with 10 additions and 10 deletions
|
@ -121,9 +121,9 @@ f"""
|
||||||
('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH', True)
|
('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH', True)
|
||||||
]
|
]
|
||||||
|
|
||||||
for name, export_name, _, _ in encoder_dense_layers:
|
for name, export_name, _, quantize in encoder_dense_layers:
|
||||||
layer = model.get_submodule(name)
|
layer = model.get_submodule(name)
|
||||||
dump_torch_weights(enc_writer, layer, name=export_name, verbose=True)
|
dump_torch_weights(enc_writer, layer, name=export_name, verbose=True, quantize=quantize, scale=None)
|
||||||
|
|
||||||
|
|
||||||
encoder_gru_layers = [
|
encoder_gru_layers = [
|
||||||
|
@ -134,8 +134,8 @@ f"""
|
||||||
('core_encoder.module.gru5' , 'enc_gru5', 'TANH', True),
|
('core_encoder.module.gru5' , 'enc_gru5', 'TANH', True),
|
||||||
]
|
]
|
||||||
|
|
||||||
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True)
|
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=quantize, scale=None, recurrent_scale=None)
|
||||||
for name, export_name, _, _ in encoder_gru_layers])
|
for name, export_name, _, quantize in encoder_gru_layers])
|
||||||
|
|
||||||
|
|
||||||
encoder_conv_layers = [
|
encoder_conv_layers = [
|
||||||
|
@ -146,7 +146,7 @@ f"""
|
||||||
('core_encoder.module.conv5.conv' , 'enc_conv5', 'TANH', True),
|
('core_encoder.module.conv5.conv' , 'enc_conv5', 'TANH', True),
|
||||||
]
|
]
|
||||||
|
|
||||||
enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, quantize=False) for name, export_name, _, _ in encoder_conv_layers])
|
enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, quantize=quantize, scale=None) for name, export_name, _, quantize in encoder_conv_layers])
|
||||||
|
|
||||||
|
|
||||||
del enc_writer
|
del enc_writer
|
||||||
|
@ -159,9 +159,9 @@ f"""
|
||||||
('core_decoder.module.gru_init' , 'dec_gru_init', 'TANH', True),
|
('core_decoder.module.gru_init' , 'dec_gru_init', 'TANH', True),
|
||||||
]
|
]
|
||||||
|
|
||||||
for name, export_name, _, _ in decoder_dense_layers:
|
for name, export_name, _, quantize in decoder_dense_layers:
|
||||||
layer = model.get_submodule(name)
|
layer = model.get_submodule(name)
|
||||||
dump_torch_weights(dec_writer, layer, name=export_name, verbose=True)
|
dump_torch_weights(dec_writer, layer, name=export_name, verbose=True, quantize=quantize, scale=None)
|
||||||
|
|
||||||
|
|
||||||
decoder_gru_layers = [
|
decoder_gru_layers = [
|
||||||
|
@ -172,8 +172,8 @@ f"""
|
||||||
('core_decoder.module.gru5' , 'dec_gru5', 'TANH', True),
|
('core_decoder.module.gru5' , 'dec_gru5', 'TANH', True),
|
||||||
]
|
]
|
||||||
|
|
||||||
dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True)
|
dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=quantize, scale=None, recurrent_scale=None)
|
||||||
for name, export_name, _, _ in decoder_gru_layers])
|
for name, export_name, _, quantize in decoder_gru_layers])
|
||||||
|
|
||||||
decoder_conv_layers = [
|
decoder_conv_layers = [
|
||||||
('core_decoder.module.conv1.conv' , 'dec_conv1', 'TANH', True),
|
('core_decoder.module.conv1.conv' , 'dec_conv1', 'TANH', True),
|
||||||
|
@ -183,7 +183,7 @@ f"""
|
||||||
('core_decoder.module.conv5.conv' , 'dec_conv5', 'TANH', True),
|
('core_decoder.module.conv5.conv' , 'dec_conv5', 'TANH', True),
|
||||||
]
|
]
|
||||||
|
|
||||||
dec_max_conv_inputs = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, quantize=False) for name, export_name, _, _ in decoder_conv_layers])
|
dec_max_conv_inputs = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, quantize=quantize, scale=None) for name, export_name, _, quantize in decoder_conv_layers])
|
||||||
|
|
||||||
del dec_writer
|
del dec_writer
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue