Add validation for weights blob

This commit is contained in:
Jean-Marc Valin 2023-05-20 14:21:58 -04:00
parent 0098fe70ac
commit c7b6935bf2
4 changed files with 62 additions and 25 deletions

View file

@ -340,13 +340,13 @@ if __name__ == "__main__":
W = model.get_layer('gru_a').get_weights()[0][3*embed_size:,:]
#FIXME: dump only half the biases
b = model.get_layer('gru_a').get_weights()[2]
dump_dense_layer_impl('gru_a_dense_feature', W, b, 'LINEAR', f, hf)
dump_dense_layer_impl('gru_a_dense_feature', W, b[:len(b)//2], 'LINEAR', f, hf)
W = model.get_layer('gru_b').get_weights()[0][model.rnn_units1:,:]
b = model.get_layer('gru_b').get_weights()[2]
# Set biases to zero because they'll be included in the GRU input part
# (we need regular and SU biases)
dump_dense_layer_impl('gru_b_dense_feature', W, 0*b, 'LINEAR', f, hf)
dump_dense_layer_impl('gru_b_dense_feature', W, 0*b[:len(b)//2], 'LINEAR', f, hf)
dump_grub(model.get_layer('gru_b'), f, hf, model.rnn_units1)
layer_list = []