Sharing conditioning network with LPC

This commit is contained in:
Jean-Marc Valin 2021-07-24 18:09:20 -04:00
parent c1532559a2
commit ab9a09266f
6 changed files with 17 additions and 77 deletions

View file

@ -40,7 +40,6 @@ import h5py
import sys
from tf_funcs import *
from diffembed import diff_Embed
import difflpc
frame_size = 160
pcm_bits = 8
@ -226,6 +225,15 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
padding = 'valid' if training else 'same'
fconv1 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv1')
fconv2 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv2')
pembed = Embedding(256, 64, name='embed_pitch')
cat_feat = Concatenate()([feat, Reshape((-1, 64))(pembed(pitch))])
cfeat = fconv2(fconv1(cat_feat))
fdense1 = Dense(128, activation='tanh', name='feature_dense1')
fdense2 = Dense(128, activation='tanh', name='feature_dense2')
cfeat = fdense2(fdense1(cfeat))
if not flag_e2e:
embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig')
@ -233,8 +241,7 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
else:
Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1))
error_calc = Lambda(lambda x: tf_l2u(tf_u2l(x[0]) - tf.roll(tf_u2l(x[1]),1,axis = 1)))
feat2lpc = difflpc.difflpc(training = training)
lpcoeffs = feat2lpc(feat)
lpcoeffs = diff_rc2lpc(name = "rc2lpc")(cfeat)
tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs])
past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds])
embed = diff_Embed(name='embed_sig',initializer = PCMInit())
@ -243,15 +250,6 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
cpcm_decoder = Concatenate()([Input_extractor([pcm,0]),Input_extractor([pcm,1]),Input_extractor([pcm,2])])
cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder))
pembed = Embedding(256, 64, name='embed_pitch')
cat_feat = Concatenate()([feat, Reshape((-1, 64))(pembed(pitch))])
cfeat = fconv2(fconv1(cat_feat))
fdense1 = Dense(128, activation='tanh', name='feature_dense1')
fdense2 = Dense(128, activation='tanh', name='feature_dense2')
cfeat = fdense2(fdense1(cfeat))
rep = Lambda(lambda x: K.repeat_elements(x, frame_size, 1))