mirror of
https://github.com/xiph/opus.git
synced 2025-06-01 08:07:41 +00:00
using features (except pitch gain which has NaNs for now)
This commit is contained in:
parent
b65031ef64
commit
617e462be3
2 changed files with 12 additions and 5 deletions
|
@ -12,12 +12,19 @@ import sys
|
||||||
rnn_units=64
|
rnn_units=64
|
||||||
pcm_bits = 8
|
pcm_bits = 8
|
||||||
pcm_levels = 2**pcm_bits
|
pcm_levels = 2**pcm_bits
|
||||||
|
nb_used_features = 37
|
||||||
|
|
||||||
|
|
||||||
def new_wavernn_model():
|
def new_wavernn_model():
|
||||||
pcm = Input(shape=(None, 1))
|
pcm = Input(shape=(None, 1))
|
||||||
|
feat = Input(shape=(None, nb_used_features))
|
||||||
|
|
||||||
|
rep = Lambda(lambda x: K.repeat_elements(x, 160, 1))
|
||||||
|
|
||||||
rnn = CuDNNGRU(rnn_units, return_sequences=True)
|
rnn = CuDNNGRU(rnn_units, return_sequences=True)
|
||||||
|
rnn_in = Concatenate()([pcm, rep(feat)])
|
||||||
md = MDense(pcm_levels, activation='softmax')
|
md = MDense(pcm_levels, activation='softmax')
|
||||||
ulaw_prob = md(rnn(pcm))
|
ulaw_prob = md(rnn(rnn_in))
|
||||||
|
|
||||||
model = Model(pcm, ulaw_prob)
|
model = Model([pcm, feat], ulaw_prob)
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -24,7 +24,7 @@ model.summary()
|
||||||
pcmfile = sys.argv[1]
|
pcmfile = sys.argv[1]
|
||||||
feature_file = sys.argv[2]
|
feature_file = sys.argv[2]
|
||||||
nb_features = 54
|
nb_features = 54
|
||||||
nb_used_features = 38
|
nb_used_features = lpcnet.nb_used_features
|
||||||
feature_chunk_size = 15
|
feature_chunk_size = 15
|
||||||
pcm_chunk_size = 160*feature_chunk_size
|
pcm_chunk_size = 160*feature_chunk_size
|
||||||
|
|
||||||
|
@ -44,8 +44,8 @@ out_data = (out_data.astype('int16')+128).astype('uint8')
|
||||||
features = np.reshape(features, (nb_frames, feature_chunk_size, nb_features))
|
features = np.reshape(features, (nb_frames, feature_chunk_size, nb_features))
|
||||||
features = features[:, :, :nb_used_features]
|
features = features[:, :, :nb_used_features]
|
||||||
|
|
||||||
checkpoint = ModelCheckpoint('lpcnet1b_{epoch:02d}.h5')
|
checkpoint = ModelCheckpoint('lpcnet1c_{epoch:02d}.h5')
|
||||||
|
|
||||||
#model.load_weights('wavernn1c_01.h5')
|
#model.load_weights('wavernn1c_01.h5')
|
||||||
model.compile(optimizer=Adam(0.002, amsgrad=True, decay=2e-4), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
|
model.compile(optimizer=Adam(0.002, amsgrad=True, decay=2e-4), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
|
||||||
model.fit(in_data, out_data, batch_size=batch_size, epochs=30, validation_split=0.2, callbacks=[checkpoint])
|
model.fit([in_data, features], out_data, batch_size=batch_size, epochs=30, validation_split=0.2, callbacks=[checkpoint])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue