mirror of
https://github.com/xiph/opus.git
synced 2025-05-30 07:07:42 +00:00
adaptation flag to avoid training the sample rate network
This commit is contained in:
parent
2a7a9fa085
commit
fd1fc693aa
1 changed files with 6 additions and 5 deletions
|
@ -113,7 +113,7 @@ class PCMInit(Initializer):
|
||||||
'seed': self.seed
|
'seed': self.seed
|
||||||
}
|
}
|
||||||
|
|
||||||
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True):
|
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True, adaptation=False):
|
||||||
pcm = Input(shape=(None, 3))
|
pcm = Input(shape=(None, 3))
|
||||||
feat = Input(shape=(None, nb_used_features))
|
feat = Input(shape=(None, nb_used_features))
|
||||||
pitch = Input(shape=(None, 1))
|
pitch = Input(shape=(None, 1))
|
||||||
|
@ -153,10 +153,11 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, train
|
||||||
gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
|
gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
|
||||||
ulaw_prob = md(gru_out2)
|
ulaw_prob = md(gru_out2)
|
||||||
|
|
||||||
rnn.trainable=False
|
if adaptation:
|
||||||
rnn2.trainable=False
|
rnn.trainable=False
|
||||||
md.trainable=False
|
rnn2.trainable=False
|
||||||
embed.Trainable=False
|
md.trainable=False
|
||||||
|
embed.Trainable=False
|
||||||
|
|
||||||
model = Model([pcm, feat, pitch], ulaw_prob)
|
model = Model([pcm, feat, pitch], ulaw_prob)
|
||||||
model.rnn_units1 = rnn_units1
|
model.rnn_units1 = rnn_units1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue