added statistical model to dump_nfec_model

This commit is contained in:
Jan Buethe 2022-10-19 17:18:25 +00:00
parent 50966eecc5
commit ea4d8f54c3
2 changed files with 53 additions and 3 deletions

View file

@ -1,6 +1,7 @@
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument('weights', metavar="<weight file>", type=str, help='model weight file in hdf5 format')
@ -10,7 +11,8 @@ parser.add_argument('--latent-dim', type=int, help="dimension of latent space (d
args = parser.parse_args()
# now import the heavy stuff
from keraslayerdump import dump_conv1d_layer, dump_dense_layer, dump_gru_layer
import tensorflow as tf
from keraslayerdump import dump_conv1d_layer, dump_dense_layer, dump_gru_layer, printVector
from rdovae import new_rdovae_model
def start_header(header_fid, header_name):
@ -47,13 +49,38 @@ def finish_source(source_fid):
pass
def dump_statistical_model(qembedding, f, fh):
w = qembedding.weights[0].numpy()
levels, dim = w.shape
N = dim // 6
quant_scales = tf.math.softplus(w[:, : N]).numpy()
dead_zone_theta = 0.5 + 0.05 * tf.math.softplus(w[:, N : 2 * N]).numpy()
r = 0.5 + 0.5 * tf.math.sigmoid(w[:, 4 * N : 5 * N]).numpy()
theta = tf.math.sigmoid(w[:, 5 * N : 6 * N]).numpy()
printVector(f, quant_scales[:], 'nfec_stats_quant_scales')
printVector(f, dead_zone_theta[:], 'nfec_stats_dead_zone_theta')
printVector(f, r, 'nfec_stats_r')
printVector(f, theta, 'nfec_stats_theta')
fh.write(
f"""
extern float nfec_stats_quant_scales;
extern float nfec_stats_dead_zone_theta;
extern float nfec_stats_r;
extern float nfec_stats_theta;
"""
)
if __name__ == "__main__":
model, encoder, decoder, qembedding = new_rdovae_model(20, args.latent_dim, cond_size=args.cond_size)
model.load_weights(args.weights)
# for the time being only dump encoder
# encoder
encoder_dense_names = [
'enc_dense1',
'enc_dense3',
@ -121,3 +148,26 @@ f"""
header_fid.close()
source_fid.close()
# statistical model
source_fid = open("nfec_stats_data.c", 'w')
header_fid = open("nfec_stats_data.h", 'w')
start_header(header_fid, "nfec_stats_data.h")
start_source(source_fid, "nfec_stats_data.h", os.path.basename(args.weights))
num_levels = qembedding.weights[0].shape[0]
header_fid.write(
f"""
#define NFEC_STATS_NUM_LEVELS {num_levels}
"""
)
dump_statistical_model(qembedding, source_fid, header_fid)
finish_header(header_fid)
finish_source(source_fid)
header_fid.close()
source_fid.close()