mirror of
https://github.com/xiph/opus.git
synced 2025-06-05 15:03:39 +00:00
Dumping 16-bit linear training data
This commit is contained in:
parent
a3ef596822
commit
144b7311bc
7 changed files with 58 additions and 65 deletions
|
@ -230,8 +230,9 @@ class WeightClip(Constraint):
|
|||
|
||||
constraint = WeightClip(0.992)
|
||||
|
||||
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_size=128, training=False, adaptation=False, quantize=False, flag_e2e = False, cond_size=128):
|
||||
pcm = Input(shape=(None, 3), batch_size=batch_size)
|
||||
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_size=128, training=False, adaptation=False, quantize=False, flag_e2e = False, cond_size=128, lpc_order=16):
|
||||
pcm = Input(shape=(None, 1), batch_size=batch_size)
|
||||
dpcm = Input(shape=(None, 3), batch_size=batch_size)
|
||||
feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
|
||||
pitch = Input(shape=(None, 1), batch_size=batch_size)
|
||||
dec_feat = Input(shape=(None, cond_size))
|
||||
|
@ -257,20 +258,19 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_s
|
|||
|
||||
cfeat = fdense2(fdense1(cfeat))
|
||||
|
||||
if not flag_e2e:
|
||||
embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig')
|
||||
cpcm = Reshape((-1, embed_size*3))(embed(pcm))
|
||||
else:
|
||||
Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1))
|
||||
error_calc = Lambda(lambda x: tf_l2u(tf_u2l(x[0]) - tf.roll(tf_u2l(x[1]),1,axis = 1)))
|
||||
Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1))
|
||||
error_calc = Lambda(lambda x: tf_l2u(x[0] - tf.roll(x[1],1,axis = 1)))
|
||||
if flag_e2e:
|
||||
lpcoeffs = diff_rc2lpc(name = "rc2lpc")(cfeat)
|
||||
tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs])
|
||||
past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds])
|
||||
embed = diff_Embed(name='embed_sig',initializer = PCMInit())
|
||||
cpcm = Concatenate()([Input_extractor([pcm,0]),tensor_preds,past_errors])
|
||||
cpcm = Reshape((-1, embed_size*3))(embed(cpcm))
|
||||
cpcm_decoder = Concatenate()([Input_extractor([pcm,0]),Input_extractor([pcm,1]),Input_extractor([pcm,2])])
|
||||
cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder))
|
||||
else:
|
||||
lpcoeffs = Input(shape=(None, lpc_order), batch_size=batch_size)
|
||||
tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs])
|
||||
past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds])
|
||||
embed = diff_Embed(name='embed_sig',initializer = PCMInit())
|
||||
cpcm = Concatenate()([tf_l2u(Input_extractor([pcm,0])),tf_l2u(tensor_preds),past_errors])
|
||||
cpcm = Reshape((-1, embed_size*3))(embed(cpcm))
|
||||
cpcm_decoder = Concatenate()([Input_extractor([dpcm,0]),Input_extractor([dpcm,1]),Input_extractor([dpcm,2])])
|
||||
cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder))
|
||||
|
||||
|
||||
rep = Lambda(lambda x: K.repeat_elements(x, frame_size, 1))
|
||||
|
@ -301,10 +301,10 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_s
|
|||
md.trainable=False
|
||||
embed.Trainable=False
|
||||
|
||||
m_out = Concatenate(name='pdf')([tensor_preds,ulaw_prob])
|
||||
if not flag_e2e:
|
||||
model = Model([pcm, feat, pitch], ulaw_prob)
|
||||
model = Model([pcm, feat, pitch, lpcoeffs], m_out)
|
||||
else:
|
||||
m_out = Concatenate(name='pdf')([tensor_preds,ulaw_prob])
|
||||
model = Model([pcm, feat, pitch], [m_out, cfeat])
|
||||
model.rnn_units1 = rnn_units1
|
||||
model.rnn_units2 = rnn_units2
|
||||
|
@ -321,5 +321,8 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_s
|
|||
dec_gru_out2, state2 = rnn2(Concatenate()([dec_gru_out1, dec_feat]), initial_state=dec_state2)
|
||||
dec_ulaw_prob = Lambda(tree_to_pdf_infer)(md(dec_gru_out2))
|
||||
|
||||
decoder = Model([pcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2])
|
||||
if flag_e2e:
|
||||
decoder = Model([dpcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2])
|
||||
else:
|
||||
decoder = Model([pcm, dec_feat, dec_state1, dec_state2, lpcoeffs], [dec_ulaw_prob, state1, state2])
|
||||
return model, encoder, decoder
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue