Dumping 16-bit linear training data

This commit is contained in:
Jean-Marc Valin 2021-10-13 23:18:57 -04:00
parent a3ef596822
commit 144b7311bc
7 changed files with 58 additions and 65 deletions

View file

@ -230,8 +230,9 @@ class WeightClip(Constraint):
constraint = WeightClip(0.992)
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_size=128, training=False, adaptation=False, quantize=False, flag_e2e = False, cond_size=128):
pcm = Input(shape=(None, 3), batch_size=batch_size)
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_size=128, training=False, adaptation=False, quantize=False, flag_e2e = False, cond_size=128, lpc_order=16):
pcm = Input(shape=(None, 1), batch_size=batch_size)
dpcm = Input(shape=(None, 3), batch_size=batch_size)
feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
pitch = Input(shape=(None, 1), batch_size=batch_size)
dec_feat = Input(shape=(None, cond_size))
@ -257,20 +258,19 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_s
cfeat = fdense2(fdense1(cfeat))
if not flag_e2e:
embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig')
cpcm = Reshape((-1, embed_size*3))(embed(pcm))
else:
Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1))
error_calc = Lambda(lambda x: tf_l2u(tf_u2l(x[0]) - tf.roll(tf_u2l(x[1]),1,axis = 1)))
Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1))
error_calc = Lambda(lambda x: tf_l2u(x[0] - tf.roll(x[1],1,axis = 1)))
if flag_e2e:
lpcoeffs = diff_rc2lpc(name = "rc2lpc")(cfeat)
tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs])
past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds])
embed = diff_Embed(name='embed_sig',initializer = PCMInit())
cpcm = Concatenate()([Input_extractor([pcm,0]),tensor_preds,past_errors])
cpcm = Reshape((-1, embed_size*3))(embed(cpcm))
cpcm_decoder = Concatenate()([Input_extractor([pcm,0]),Input_extractor([pcm,1]),Input_extractor([pcm,2])])
cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder))
else:
lpcoeffs = Input(shape=(None, lpc_order), batch_size=batch_size)
tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs])
past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds])
embed = diff_Embed(name='embed_sig',initializer = PCMInit())
cpcm = Concatenate()([tf_l2u(Input_extractor([pcm,0])),tf_l2u(tensor_preds),past_errors])
cpcm = Reshape((-1, embed_size*3))(embed(cpcm))
cpcm_decoder = Concatenate()([Input_extractor([dpcm,0]),Input_extractor([dpcm,1]),Input_extractor([dpcm,2])])
cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder))
rep = Lambda(lambda x: K.repeat_elements(x, frame_size, 1))
@ -301,10 +301,10 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_s
md.trainable=False
embed.Trainable=False
m_out = Concatenate(name='pdf')([tensor_preds,ulaw_prob])
if not flag_e2e:
model = Model([pcm, feat, pitch], ulaw_prob)
model = Model([pcm, feat, pitch, lpcoeffs], m_out)
else:
m_out = Concatenate(name='pdf')([tensor_preds,ulaw_prob])
model = Model([pcm, feat, pitch], [m_out, cfeat])
model.rnn_units1 = rnn_units1
model.rnn_units2 = rnn_units2
@ -321,5 +321,8 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_s
dec_gru_out2, state2 = rnn2(Concatenate()([dec_gru_out1, dec_feat]), initial_state=dec_state2)
dec_ulaw_prob = Lambda(tree_to_pdf_infer)(md(dec_gru_out2))
decoder = Model([pcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2])
if flag_e2e:
decoder = Model([dpcm, dec_feat, dec_state1, dec_state2], [dec_ulaw_prob, state1, state2])
else:
decoder = Model([pcm, dec_feat, dec_state1, dec_state2, lpcoeffs], [dec_ulaw_prob, state1, state2])
return model, encoder, decoder