adaptation flag to avoid training the sample rate network

This commit is contained in:
Jean-Marc Valin 2019-04-01 15:22:00 -04:00
parent 2a7a9fa085
commit fd1fc693aa

View file

@ -113,7 +113,7 @@ class PCMInit(Initializer):
'seed': self.seed 'seed': self.seed
} }
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True): def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True, adaptation=False):
pcm = Input(shape=(None, 3)) pcm = Input(shape=(None, 3))
feat = Input(shape=(None, nb_used_features)) feat = Input(shape=(None, nb_used_features))
pitch = Input(shape=(None, 1)) pitch = Input(shape=(None, 1))
@ -153,10 +153,11 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, train
gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)])) gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
ulaw_prob = md(gru_out2) ulaw_prob = md(gru_out2)
rnn.trainable=False if adaptation:
rnn2.trainable=False rnn.trainable=False
md.trainable=False rnn2.trainable=False
embed.Trainable=False md.trainable=False
embed.Trainable=False
model = Model([pcm, feat, pitch], ulaw_prob) model = Model([pcm, feat, pitch], ulaw_prob)
model.rnn_units1 = rnn_units1 model.rnn_units1 = rnn_units1