Adds end-to-end LPC training

Making LPC computation and prediction differentiable
This commit is contained in:
Krishna Subramani 2021-07-29 03:36:13 -04:00 committed by Jean-Marc Valin
parent cba0ecd483
commit c1532559a2
11 changed files with 357 additions and 17 deletions

View file

@ -38,6 +38,9 @@ from mdense import MDense
import numpy as np
import h5py
import sys
from tf_funcs import *
from diffembed import diff_Embed
import difflpc
frame_size = 160
pcm_bits = 8
@ -186,7 +189,7 @@ class PCMInit(Initializer):
#a[:,0] = math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows
#a[:,1] = .5*a[:,0]*a[:,0]*a[:,0]
a = a + np.reshape(math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows, (num_rows, 1))
return self.gain * a
return self.gain * a.astype("float32")
def get_config(self):
return {
@ -212,7 +215,7 @@ class WeightClip(Constraint):
constraint = WeightClip(0.992)
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False):
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
pcm = Input(shape=(None, 3))
feat = Input(shape=(None, nb_used_features))
pitch = Input(shape=(None, 1))
@ -224,8 +227,21 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
fconv1 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv1')
fconv2 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv2')
embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig')
cpcm = Reshape((-1, embed_size*3))(embed(pcm))
if not flag_e2e:
embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig')
cpcm = Reshape((-1, embed_size*3))(embed(pcm))
else:
Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1))
error_calc = Lambda(lambda x: tf_l2u(tf_u2l(x[0]) - tf.roll(tf_u2l(x[1]),1,axis = 1)))
feat2lpc = difflpc.difflpc(training = training)
lpcoeffs = feat2lpc(feat)
tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs])
past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds])
embed = diff_Embed(name='embed_sig',initializer = PCMInit())
cpcm = Concatenate()([Input_extractor([pcm,0]),tensor_preds,past_errors])
cpcm = Reshape((-1, embed_size*3))(embed(cpcm))
cpcm_decoder = Concatenate()([Input_extractor([pcm,0]),Input_extractor([pcm,1]),Input_extractor([pcm,2])])
cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder))
pembed = Embedding(256, 64, name='embed_pitch')
cat_feat = Concatenate()([feat, Reshape((-1, 64))(pembed(pitch))])
@ -264,15 +280,22 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
md.trainable=False
embed.Trainable=False
model = Model([pcm, feat, pitch], ulaw_prob)
if not flag_e2e:
model = Model([pcm, feat, pitch], ulaw_prob)
else:
m_out = Concatenate()([tensor_preds,ulaw_prob])
model = Model([pcm, feat, pitch], m_out)
model.rnn_units1 = rnn_units1
model.rnn_units2 = rnn_units2
model.nb_used_features = nb_used_features
model.frame_size = frame_size
encoder = Model([feat, pitch], cfeat)
dec_rnn_in = Concatenate()([cpcm, dec_feat])
if not flag_e2e:
encoder = Model([feat, pitch], cfeat)
dec_rnn_in = Concatenate()([cpcm, dec_feat])
else:
encoder = Model([feat, pitch], [cfeat,lpcoeffs])
dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat])
dec_gru_out1, state1 = rnn(dec_rnn_in, initial_state=dec_state1)
dec_gru_out2, state2 = rnn2(Concatenate()([dec_gru_out1, dec_feat]), initial_state=dec_state2)
dec_ulaw_prob = Lambda(tree_to_pdf_infer)(md(dec_gru_out2))