mirror of
https://github.com/xiph/opus.git
synced 2025-06-02 08:37:43 +00:00
added statistical model to dump_nfec_model
This commit is contained in:
parent
50966eecc5
commit
ea4d8f54c3
2 changed files with 53 additions and 3 deletions
|
@ -1,6 +1,7 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('weights', metavar="<weight file>", type=str, help='model weight file in hdf5 format')
|
||||
|
@ -10,7 +11,8 @@ parser.add_argument('--latent-dim', type=int, help="dimension of latent space (d
|
|||
args = parser.parse_args()
|
||||
|
||||
# now import the heavy stuff
|
||||
from keraslayerdump import dump_conv1d_layer, dump_dense_layer, dump_gru_layer
|
||||
import tensorflow as tf
|
||||
from keraslayerdump import dump_conv1d_layer, dump_dense_layer, dump_gru_layer, printVector
|
||||
from rdovae import new_rdovae_model
|
||||
|
||||
def start_header(header_fid, header_name):
|
||||
|
@ -47,13 +49,38 @@ def finish_source(source_fid):
|
|||
pass
|
||||
|
||||
|
||||
def dump_statistical_model(qembedding, f, fh):
|
||||
w = qembedding.weights[0].numpy()
|
||||
levels, dim = w.shape
|
||||
N = dim // 6
|
||||
|
||||
quant_scales = tf.math.softplus(w[:, : N]).numpy()
|
||||
dead_zone_theta = 0.5 + 0.05 * tf.math.softplus(w[:, N : 2 * N]).numpy()
|
||||
r = 0.5 + 0.5 * tf.math.sigmoid(w[:, 4 * N : 5 * N]).numpy()
|
||||
theta = tf.math.sigmoid(w[:, 5 * N : 6 * N]).numpy()
|
||||
|
||||
printVector(f, quant_scales[:], 'nfec_stats_quant_scales')
|
||||
printVector(f, dead_zone_theta[:], 'nfec_stats_dead_zone_theta')
|
||||
printVector(f, r, 'nfec_stats_r')
|
||||
printVector(f, theta, 'nfec_stats_theta')
|
||||
|
||||
fh.write(
|
||||
f"""
|
||||
extern float nfec_stats_quant_scales;
|
||||
extern float nfec_stats_dead_zone_theta;
|
||||
extern float nfec_stats_r;
|
||||
extern float nfec_stats_theta;
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
model, encoder, decoder, qembedding = new_rdovae_model(20, args.latent_dim, cond_size=args.cond_size)
|
||||
model.load_weights(args.weights)
|
||||
|
||||
|
||||
# for the time being only dump encoder
|
||||
# encoder
|
||||
encoder_dense_names = [
|
||||
'enc_dense1',
|
||||
'enc_dense3',
|
||||
|
@ -121,3 +148,26 @@ f"""
|
|||
header_fid.close()
|
||||
source_fid.close()
|
||||
|
||||
# statistical model
|
||||
source_fid = open("nfec_stats_data.c", 'w')
|
||||
header_fid = open("nfec_stats_data.h", 'w')
|
||||
|
||||
start_header(header_fid, "nfec_stats_data.h")
|
||||
start_source(source_fid, "nfec_stats_data.h", os.path.basename(args.weights))
|
||||
|
||||
num_levels = qembedding.weights[0].shape[0]
|
||||
header_fid.write(
|
||||
f"""
|
||||
#define NFEC_STATS_NUM_LEVELS {num_levels}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
dump_statistical_model(qembedding, source_fid, header_fid)
|
||||
|
||||
finish_header(header_fid)
|
||||
finish_source(source_fid)
|
||||
|
||||
header_fid.close()
|
||||
source_fid.close()
|
||||
|
||||
|
|
|
@ -157,4 +157,4 @@ def dump_conv1d_layer(self, f, hf):
|
|||
hf.write('#define {}_STATE_SIZE ({}*{})\n'.format(name.upper(), weights[0].shape[1], (weights[0].shape[0]-1)))
|
||||
hf.write('#define {}_DELAY {}\n'.format(name.upper(), (weights[0].shape[0]-1)//2))
|
||||
hf.write('extern const Conv1DLayer {};\n\n'.format(name));
|
||||
return max_conv_inputs
|
||||
return max_conv_inputs
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue