added statistical model to dump_nfec_model
This commit is contained in:
parent
50966eecc5
commit
ea4d8f54c3
2 changed files with 53 additions and 3 deletions
|
@ -1,6 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument('weights', metavar="<weight file>", type=str, help='model weight file in hdf5 format')
|
parser.add_argument('weights', metavar="<weight file>", type=str, help='model weight file in hdf5 format')
|
||||||
|
@ -10,7 +11,8 @@ parser.add_argument('--latent-dim', type=int, help="dimension of latent space (d
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# now import the heavy stuff
|
# now import the heavy stuff
|
||||||
from keraslayerdump import dump_conv1d_layer, dump_dense_layer, dump_gru_layer
|
import tensorflow as tf
|
||||||
|
from keraslayerdump import dump_conv1d_layer, dump_dense_layer, dump_gru_layer, printVector
|
||||||
from rdovae import new_rdovae_model
|
from rdovae import new_rdovae_model
|
||||||
|
|
||||||
def start_header(header_fid, header_name):
|
def start_header(header_fid, header_name):
|
||||||
|
@ -47,13 +49,38 @@ def finish_source(source_fid):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def dump_statistical_model(qembedding, f, fh):
|
||||||
|
w = qembedding.weights[0].numpy()
|
||||||
|
levels, dim = w.shape
|
||||||
|
N = dim // 6
|
||||||
|
|
||||||
|
quant_scales = tf.math.softplus(w[:, : N]).numpy()
|
||||||
|
dead_zone_theta = 0.5 + 0.05 * tf.math.softplus(w[:, N : 2 * N]).numpy()
|
||||||
|
r = 0.5 + 0.5 * tf.math.sigmoid(w[:, 4 * N : 5 * N]).numpy()
|
||||||
|
theta = tf.math.sigmoid(w[:, 5 * N : 6 * N]).numpy()
|
||||||
|
|
||||||
|
printVector(f, quant_scales[:], 'nfec_stats_quant_scales')
|
||||||
|
printVector(f, dead_zone_theta[:], 'nfec_stats_dead_zone_theta')
|
||||||
|
printVector(f, r, 'nfec_stats_r')
|
||||||
|
printVector(f, theta, 'nfec_stats_theta')
|
||||||
|
|
||||||
|
fh.write(
|
||||||
|
f"""
|
||||||
|
extern float nfec_stats_quant_scales;
|
||||||
|
extern float nfec_stats_dead_zone_theta;
|
||||||
|
extern float nfec_stats_r;
|
||||||
|
extern float nfec_stats_theta;
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
model, encoder, decoder, qembedding = new_rdovae_model(20, args.latent_dim, cond_size=args.cond_size)
|
model, encoder, decoder, qembedding = new_rdovae_model(20, args.latent_dim, cond_size=args.cond_size)
|
||||||
model.load_weights(args.weights)
|
model.load_weights(args.weights)
|
||||||
|
|
||||||
|
|
||||||
# for the time being only dump encoder
|
# encoder
|
||||||
encoder_dense_names = [
|
encoder_dense_names = [
|
||||||
'enc_dense1',
|
'enc_dense1',
|
||||||
'enc_dense3',
|
'enc_dense3',
|
||||||
|
@ -121,3 +148,26 @@ f"""
|
||||||
header_fid.close()
|
header_fid.close()
|
||||||
source_fid.close()
|
source_fid.close()
|
||||||
|
|
||||||
|
# statistical model
|
||||||
|
source_fid = open("nfec_stats_data.c", 'w')
|
||||||
|
header_fid = open("nfec_stats_data.h", 'w')
|
||||||
|
|
||||||
|
start_header(header_fid, "nfec_stats_data.h")
|
||||||
|
start_source(source_fid, "nfec_stats_data.h", os.path.basename(args.weights))
|
||||||
|
|
||||||
|
num_levels = qembedding.weights[0].shape[0]
|
||||||
|
header_fid.write(
|
||||||
|
f"""
|
||||||
|
#define NFEC_STATS_NUM_LEVELS {num_levels}
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
dump_statistical_model(qembedding, source_fid, header_fid)
|
||||||
|
|
||||||
|
finish_header(header_fid)
|
||||||
|
finish_source(source_fid)
|
||||||
|
|
||||||
|
header_fid.close()
|
||||||
|
source_fid.close()
|
||||||
|
|
||||||
|
|
|
@ -157,4 +157,4 @@ def dump_conv1d_layer(self, f, hf):
|
||||||
hf.write('#define {}_STATE_SIZE ({}*{})\n'.format(name.upper(), weights[0].shape[1], (weights[0].shape[0]-1)))
|
hf.write('#define {}_STATE_SIZE ({}*{})\n'.format(name.upper(), weights[0].shape[1], (weights[0].shape[0]-1)))
|
||||||
hf.write('#define {}_DELAY {}\n'.format(name.upper(), (weights[0].shape[0]-1)//2))
|
hf.write('#define {}_DELAY {}\n'.format(name.upper(), (weights[0].shape[0]-1)//2))
|
||||||
hf.write('extern const Conv1DLayer {};\n\n'.format(name));
|
hf.write('extern const Conv1DLayer {};\n\n'.format(name));
|
||||||
return max_conv_inputs
|
return max_conv_inputs
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue