diff --git a/dnn/training_tf2/lpcnet.py b/dnn/training_tf2/lpcnet.py index 8e83b42c..cbd73d83 100644 --- a/dnn/training_tf2/lpcnet.py +++ b/dnn/training_tf2/lpcnet.py @@ -313,7 +313,7 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_s if not flag_e2e: encoder = Model([feat, pitch], cfeat) - dec_rnn_in = Concatenate()([cpcm, dec_feat]) + dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat]) else: encoder = Model([feat, pitch], [cfeat,lpcoeffs]) dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat]) @@ -324,5 +324,5 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_s if flag_e2e: decoder = Model([dpcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2]) else: - decoder = Model([pcm, dec_feat, dec_state1, dec_state2, lpcoeffs], [dec_ulaw_prob, state1, state2]) + decoder = Model([dpcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2]) return model, encoder, decoder diff --git a/dnn/training_tf2/test_lpcnet.py b/dnn/training_tf2/test_lpcnet.py index 88439cf1..e1fdf59d 100755 --- a/dnn/training_tf2/test_lpcnet.py +++ b/dnn/training_tf2/test_lpcnet.py @@ -31,16 +31,22 @@ import numpy as np from ulaw import ulaw2lin, lin2ulaw import h5py -# Flag for synthesizing e2e (differentiable lpc) model -flag_e2e = False +filename = sys.argv[1] +with h5py.File(filename, "r") as f: + units = min(f['model_weights']['gru_a']['gru_a']['recurrent_kernel:0'].shape) + units2 = min(f['model_weights']['gru_b']['gru_b']['recurrent_kernel:0'].shape) + cond_size = min(f['model_weights']['feature_dense1']['feature_dense1']['kernel:0'].shape) + e2e = 'rc2lpc' in f['model_weights'] -model, enc, dec = lpcnet.new_lpcnet_model(training = False, flag_e2e = flag_e2e) + +model, enc, dec = lpcnet.new_lpcnet_model(training = False, rnn_units1=units, rnn_units2=units2, flag_e2e = e2e, cond_size=cond_size, batch_size=1) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) #model.summary() -feature_file = sys.argv[1] -out_file = sys.argv[2] + +feature_file = sys.argv[2] +out_file = sys.argv[3] frame_size = model.frame_size nb_features = 36 nb_used_features = model.nb_used_features @@ -56,7 +62,7 @@ periods = (.1 + 50*features[:,:,18:19]+100).astype('int16') -model.load_weights('lpcnet38Sn_384_02.h5'); +model.load_weights(filename); order = 16 @@ -72,13 +78,13 @@ fout = open(out_file, 'wb') skip = order + 1 for c in range(0, nb_frames): - if not flag_e2e: + if not e2e: cfeat = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]]) else: cfeat,lpcs = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]]) for fr in range(0, feature_chunk_size): f = c*feature_chunk_size + fr - if not flag_e2e: + if not e2e: a = features[c, fr, nb_features-order:] else: a = lpcs[c,fr]