diff --git a/dnn/lpcnet.py b/dnn/lpcnet.py index d0fb01c5..b07292dc 100644 --- a/dnn/lpcnet.py +++ b/dnn/lpcnet.py @@ -44,25 +44,14 @@ class PCMInit(Initializer): def new_wavernn_model(): pcm = Input(shape=(None, 2)) exc = Input(shape=(None, 1)) - pitch = Input(shape=(None, 1)) feat = Input(shape=(None, nb_used_features)) pitch = Input(shape=(None, 1)) dec_feat = Input(shape=(None, 128)) dec_state = Input(shape=(rnn_units,)) - conv1 = Conv1D(16, 7, padding='causal', activation='tanh') - pconv1 = Conv1D(16, 5, padding='same', activation='tanh') - pconv2 = Conv1D(16, 5, padding='same', activation='tanh') fconv1 = Conv1D(128, 3, padding='same', activation='tanh') fconv2 = Conv1D(102, 3, padding='same', activation='tanh') - if False: - cpcm = conv1(pcm) - cpitch = pconv2(pconv1(pitch)) - else: - cpcm = pcm - cpitch = pitch - embed = Embedding(256, embed_size, embeddings_initializer=PCMInit()) cpcm = Reshape((-1, embed_size*2))(embed(pcm)) embed2 = Embedding(256, embed_size, embeddings_initializer=PCMInit()) diff --git a/dnn/train_wavenet_audio.py b/dnn/train_wavenet_audio.py index 5acc6c4d..229f02de 100755 --- a/dnn/train_wavenet_audio.py +++ b/dnn/train_wavenet_audio.py @@ -58,14 +58,11 @@ upred = upred[:nb_frames*pcm_chunk_size] pred_in = ulaw2lin(in_data) for i in range(2, nb_frames*feature_chunk_size): upred[i*frame_size:(i+1)*frame_size] = 0 - #if i % 100000 == 0: - # print(i) for k in range(16): upred[i*frame_size:(i+1)*frame_size] = upred[i*frame_size:(i+1)*frame_size] - \ pred_in[i*frame_size-k:(i+1)*frame_size-k]*features[i, nb_features-16+k] pred = lin2ulaw(upred) -#pred = pred + np.random.randint(-1, 1, len(data)) in_data = np.reshape(in_data, (nb_frames, pcm_chunk_size, 1)) @@ -89,12 +86,6 @@ periods = (50*features[:,:,36:37]+100).astype('int16') in_data = np.concatenate([in_data, pred], axis=-1) -#in_data = np.concatenate([in_data, in_pitch], axis=-1) - -#with h5py.File('in_data.h5', 'w') as f: -# f.create_dataset('data', data=in_data[:50000, :, :]) -# f.create_dataset('feat', data=features[:50000, :, :]) - checkpoint = ModelCheckpoint('wavenet5b_{epoch:02d}.h5') #model.load_weights('wavenet4f2_30.h5')