mirror of
https://github.com/xiph/opus.git
synced 2025-05-25 20:59:13 +00:00
Adds end-to-end LPC training
Making LPC computation and prediction differentiable
This commit is contained in:
parent
cba0ecd483
commit
c1532559a2
11 changed files with 357 additions and 17 deletions
|
@ -38,6 +38,9 @@ from mdense import MDense
|
|||
import numpy as np
|
||||
import h5py
|
||||
import sys
|
||||
from tf_funcs import *
|
||||
from diffembed import diff_Embed
|
||||
import difflpc
|
||||
|
||||
frame_size = 160
|
||||
pcm_bits = 8
|
||||
|
@ -186,7 +189,7 @@ class PCMInit(Initializer):
|
|||
#a[:,0] = math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows
|
||||
#a[:,1] = .5*a[:,0]*a[:,0]*a[:,0]
|
||||
a = a + np.reshape(math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows, (num_rows, 1))
|
||||
return self.gain * a
|
||||
return self.gain * a.astype("float32")
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
|
@ -212,7 +215,7 @@ class WeightClip(Constraint):
|
|||
|
||||
constraint = WeightClip(0.992)
|
||||
|
||||
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False):
|
||||
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
|
||||
pcm = Input(shape=(None, 3))
|
||||
feat = Input(shape=(None, nb_used_features))
|
||||
pitch = Input(shape=(None, 1))
|
||||
|
@ -224,8 +227,21 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
|
|||
fconv1 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv1')
|
||||
fconv2 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv2')
|
||||
|
||||
embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig')
|
||||
cpcm = Reshape((-1, embed_size*3))(embed(pcm))
|
||||
if not flag_e2e:
|
||||
embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig')
|
||||
cpcm = Reshape((-1, embed_size*3))(embed(pcm))
|
||||
else:
|
||||
Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1))
|
||||
error_calc = Lambda(lambda x: tf_l2u(tf_u2l(x[0]) - tf.roll(tf_u2l(x[1]),1,axis = 1)))
|
||||
feat2lpc = difflpc.difflpc(training = training)
|
||||
lpcoeffs = feat2lpc(feat)
|
||||
tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs])
|
||||
past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds])
|
||||
embed = diff_Embed(name='embed_sig',initializer = PCMInit())
|
||||
cpcm = Concatenate()([Input_extractor([pcm,0]),tensor_preds,past_errors])
|
||||
cpcm = Reshape((-1, embed_size*3))(embed(cpcm))
|
||||
cpcm_decoder = Concatenate()([Input_extractor([pcm,0]),Input_extractor([pcm,1]),Input_extractor([pcm,2])])
|
||||
cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder))
|
||||
|
||||
pembed = Embedding(256, 64, name='embed_pitch')
|
||||
cat_feat = Concatenate()([feat, Reshape((-1, 64))(pembed(pitch))])
|
||||
|
@ -264,15 +280,22 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
|
|||
md.trainable=False
|
||||
embed.Trainable=False
|
||||
|
||||
model = Model([pcm, feat, pitch], ulaw_prob)
|
||||
if not flag_e2e:
|
||||
model = Model([pcm, feat, pitch], ulaw_prob)
|
||||
else:
|
||||
m_out = Concatenate()([tensor_preds,ulaw_prob])
|
||||
model = Model([pcm, feat, pitch], m_out)
|
||||
model.rnn_units1 = rnn_units1
|
||||
model.rnn_units2 = rnn_units2
|
||||
model.nb_used_features = nb_used_features
|
||||
model.frame_size = frame_size
|
||||
|
||||
encoder = Model([feat, pitch], cfeat)
|
||||
|
||||
dec_rnn_in = Concatenate()([cpcm, dec_feat])
|
||||
if not flag_e2e:
|
||||
encoder = Model([feat, pitch], cfeat)
|
||||
dec_rnn_in = Concatenate()([cpcm, dec_feat])
|
||||
else:
|
||||
encoder = Model([feat, pitch], [cfeat,lpcoeffs])
|
||||
dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat])
|
||||
dec_gru_out1, state1 = rnn(dec_rnn_in, initial_state=dec_state1)
|
||||
dec_gru_out2, state2 = rnn2(Concatenate()([dec_gru_out1, dec_feat]), initial_state=dec_state2)
|
||||
dec_ulaw_prob = Lambda(tree_to_pdf_infer)(md(dec_gru_out2))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue