adaptation flag to avoid training the sample rate network

This commit is contained in:
Jean-Marc Valin 2019-04-01 15:22:00 -04:00
parent 2a7a9fa085
commit fd1fc693aa

View file

@ -113,7 +113,7 @@ class PCMInit(Initializer):
'seed': self.seed
}
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True):
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True, adaptation=False):
pcm = Input(shape=(None, 3))
feat = Input(shape=(None, nb_used_features))
pitch = Input(shape=(None, 1))
@ -153,10 +153,11 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, train
gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
ulaw_prob = md(gru_out2)
rnn.trainable=False
rnn2.trainable=False
md.trainable=False
embed.Trainable=False
if adaptation:
rnn.trainable=False
rnn2.trainable=False
md.trainable=False
embed.Trainable=False
model = Model([pcm, feat, pitch], ulaw_prob)
model.rnn_units1 = rnn_units1