From 617e462be36bba4f3a1a0fba4a212cb9dded3236 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Tue, 26 Jun 2018 01:31:44 -0400 Subject: [PATCH] using features (except pitch gain which has NaNs for now) --- dnn/lpcnet.py | 11 +++++++++-- dnn/train_lpcnet.py | 6 +++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/dnn/lpcnet.py b/dnn/lpcnet.py index 1aa6b7cd..9eca5061 100644 --- a/dnn/lpcnet.py +++ b/dnn/lpcnet.py @@ -12,12 +12,19 @@ import sys rnn_units=64 pcm_bits = 8 pcm_levels = 2**pcm_bits +nb_used_features = 37 + def new_wavernn_model(): pcm = Input(shape=(None, 1)) + feat = Input(shape=(None, nb_used_features)) + + rep = Lambda(lambda x: K.repeat_elements(x, 160, 1)) + rnn = CuDNNGRU(rnn_units, return_sequences=True) + rnn_in = Concatenate()([pcm, rep(feat)]) md = MDense(pcm_levels, activation='softmax') - ulaw_prob = md(rnn(pcm)) + ulaw_prob = md(rnn(rnn_in)) - model = Model(pcm, ulaw_prob) + model = Model([pcm, feat], ulaw_prob) return model diff --git a/dnn/train_lpcnet.py b/dnn/train_lpcnet.py index ba8163fb..4f49ad15 100755 --- a/dnn/train_lpcnet.py +++ b/dnn/train_lpcnet.py @@ -24,7 +24,7 @@ model.summary() pcmfile = sys.argv[1] feature_file = sys.argv[2] nb_features = 54 -nb_used_features = 38 +nb_used_features = lpcnet.nb_used_features feature_chunk_size = 15 pcm_chunk_size = 160*feature_chunk_size @@ -44,8 +44,8 @@ out_data = (out_data.astype('int16')+128).astype('uint8') features = np.reshape(features, (nb_frames, feature_chunk_size, nb_features)) features = features[:, :, :nb_used_features] -checkpoint = ModelCheckpoint('lpcnet1b_{epoch:02d}.h5') +checkpoint = ModelCheckpoint('lpcnet1c_{epoch:02d}.h5') #model.load_weights('wavernn1c_01.h5') model.compile(optimizer=Adam(0.002, amsgrad=True, decay=2e-4), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) -model.fit(in_data, out_data, batch_size=batch_size, epochs=30, validation_split=0.2, callbacks=[checkpoint]) +model.fit([in_data, features], out_data, batch_size=batch_size, epochs=30, validation_split=0.2, callbacks=[checkpoint])