Adds end-to-end LPC training

Making LPC computation and prediction differentiable
This commit is contained in:
Krishna Subramani 2021-07-29 03:36:13 -04:00 committed by Jean-Marc Valin
parent cba0ecd483
commit c1532559a2
11 changed files with 357 additions and 17 deletions

View file

@ -92,7 +92,11 @@ void write_audio(LPCNetEncState *st, const short *pcm, const int *noise, FILE *f
/* Excitation in. */
data[4*i+2] = st->exc_mem;
/* Excitation out. */
#ifdef END2END
data[4*i+3] = lin2ulaw(pcm[k*FRAME_SIZE+i]);
#else
data[4*i+3] = e;
#endif
/* Simulate error on excitation. */
e += noise[k*FRAME_SIZE+i];
e = IMIN(255, IMAX(0, e));

View file

@ -131,6 +131,51 @@ LPCNET_EXPORT void lpcnet_destroy(LPCNetState *lpcnet)
free(lpcnet);
}
#ifdef END2END
void rc2lpc(float *lpc, const float *rc)
{
float tmp[LPC_ORDER];
float ntmp[LPC_ORDER] = {0.0};
RNN_COPY(tmp, rc, LPC_ORDER);
for(int i = 0; i < LPC_ORDER ; i++)
{
for(int j = 0; j <= i-1; j++)
{
ntmp[j] = tmp[j] + tmp[i]*tmp[i - j - 1];
}
for(int k = 0; k <= i-1; k++)
{
tmp[k] = ntmp[k];
}
}
for(int i = 0; i < LPC_ORDER ; i++)
{
lpc[i] = tmp[i];
}
}
void lpc_from_features(LPCNetState *lpcnet,const float *features)
{
NNetState *net;
float in[NB_FEATURES];
float conv1_out[F2RC_CONV1_OUT_SIZE];
float conv2_out[F2RC_CONV2_OUT_SIZE];
float dense1_out[F2RC_DENSE3_OUT_SIZE];
float rc[LPC_ORDER];
net = &lpcnet->nnet;
RNN_COPY(in, features, NB_FEATURES);
compute_conv1d(&f2rc_conv1, conv1_out, net->f2rc_conv1_state, in);
if (lpcnet->frame_count < F2RC_CONV1_DELAY + 1) RNN_CLEAR(conv1_out, F2RC_CONV1_OUT_SIZE);
compute_conv1d(&f2rc_conv2, conv2_out, net->f2rc_conv2_state, conv1_out);
if (lpcnet->frame_count < (FEATURES_DELAY_F2RC + 1)) RNN_CLEAR(conv2_out, F2RC_CONV2_OUT_SIZE);
memmove(lpcnet->old_input_f2rc[1], lpcnet->old_input_f2rc[0], (FEATURES_DELAY_F2RC-1)*NB_FEATURES*sizeof(in[0]));
memcpy(lpcnet->old_input_f2rc[0], in, NB_FEATURES*sizeof(in[0]));
compute_dense(&f2rc_dense3, dense1_out, conv2_out);
compute_dense(&f2rc_dense4_outp_rc, rc, dense1_out);
rc2lpc(lpcnet->old_lpc[0], rc);
}
#endif
LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features, short *output, int N)
{
int i;
@ -144,9 +189,15 @@ LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features,
memmove(&lpcnet->old_gain[1], &lpcnet->old_gain[0], (FEATURES_DELAY-1)*sizeof(lpcnet->old_gain[0]));
lpcnet->old_gain[0] = features[PITCH_GAIN_FEATURE];
run_frame_network(lpcnet, gru_a_condition, gru_b_condition, features, pitch);
#ifdef END2END
lpc_from_features(lpcnet,features);
memcpy(lpc, lpcnet->old_lpc[0], LPC_ORDER*sizeof(lpc[0]));
#else
memcpy(lpc, lpcnet->old_lpc[FEATURES_DELAY-1], LPC_ORDER*sizeof(lpc[0]));
memmove(lpcnet->old_lpc[1], lpcnet->old_lpc[0], (FEATURES_DELAY-1)*LPC_ORDER*sizeof(lpc[0]));
lpc_from_cepstrum(lpcnet->old_lpc[0], features);
#endif
if (lpcnet->frame_count <= FEATURES_DELAY)
{
RNN_CLEAR(output, N);

View file

@ -22,11 +22,18 @@
#define FEATURES_DELAY (FEATURE_CONV1_DELAY + FEATURE_CONV2_DELAY)
#ifdef END2END
#define FEATURES_DELAY_F2RC (F2RC_CONV1_DELAY + F2RC_CONV2_DELAY)
#endif
struct LPCNetState {
NNetState nnet;
int last_exc;
float last_sig[LPC_ORDER];
float old_input[FEATURES_DELAY][FEATURE_CONV2_OUT_SIZE];
#ifdef END2END
float old_input_f2rc[FEATURES_DELAY_F2RC][F2RC_CONV2_OUT_SIZE];
#endif
float old_lpc[FEATURES_DELAY][LPC_ORDER];
float old_gain[FEATURES_DELAY];
float sampling_logit_table[256];

View file

@ -0,0 +1,49 @@
"""
Modification of Tensorflow's Embedding Layer:
1. Not restricted to be the first layer of a model
2. Differentiable (allows non-integer lookups)
- For non integer lookup, this layer linearly interpolates between the adjacent embeddings in the following way to preserver gradient flow
- E = (1 - frac(x))*embed(floor(x)) + frac(x)*embed(ceil(x))
"""
import tensorflow as tf
from tensorflow.keras.layers import Layer
class diff_Embed(Layer):
"""
Parameters:
- units: int
Dimension of the Embedding
- dict_size: int
Number of Embeddings to lookup
- pcm_init: boolean
Initialized for the embedding matrix
"""
def __init__(self, units=128, dict_size = 256, pcm_init = True, initializer = None, **kwargs):
super(diff_Embed, self).__init__(**kwargs)
self.units = units
self.dict_size = dict_size
self.pcm_init = pcm_init
self.initializer = initializer
def build(self, input_shape):
w_init = tf.random_normal_initializer()
if self.pcm_init:
w_init = self.initializer
self.w = tf.Variable(initial_value=w_init(shape=(self.dict_size, self.units),dtype='float32'),trainable=True)
def call(self, inputs):
alpha = inputs - tf.math.floor(inputs)
alpha = tf.expand_dims(alpha,axis = -1)
alpha = tf.tile(alpha,[1,1,1,self.units])
inputs = tf.cast(inputs,'int32')
M = (1 - alpha)*tf.gather(self.w,inputs) + alpha*tf.gather(self.w,tf.clip_by_value(inputs + 1, 0, 255))
return M
def get_config(self):
config = super(diff_Embed, self).get_config()
config.update({"units": self.units})
config.update({"dict_size" : self.dict_size})
config.update({"pcm_init" : self.pcm_init})
config.update({"initializer" : self.initializer})
return config

View file

@ -0,0 +1,27 @@
"""
Tensorflow model (differentiable lpc) to learn the lpcs from the features
"""
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Concatenate, Lambda, Conv1D, Multiply, Layer, LeakyReLU
from tensorflow.keras import backend as K
from tf_funcs import diff_rc2lpc
frame_size = 160
lpcoeffs_N = 16
def difflpc(nb_used_features = 20, training=False):
feat = Input(shape=(None, nb_used_features)) # BFCC
padding = 'valid' if training else 'same'
L1 = Conv1D(100, 3, padding=padding, activation='tanh', name='f2rc_conv1')
L2 = Conv1D(75, 3, padding=padding, activation='tanh', name='f2rc_conv2')
L3 = Dense(50, activation='tanh',name = 'f2rc_dense3')
L4 = Dense(lpcoeffs_N, activation='tanh',name = "f2rc_dense4_outp_rc")
rc = L4(L3(L2(L1(feat))))
# Differentiable RC 2 LPC
lpcoeffs = diff_rc2lpc(name = "rc2lpc")(rc)
model = Model(feat,lpcoeffs,name = 'f2lpc')
model.nb_used_features = nb_used_features
model.frame_size = frame_size
return model

View file

@ -35,6 +35,9 @@ from mdense import MDense
import h5py
import re
# Flag for dumping e2e (differentiable lpc) network weights
flag_e2e = False
max_rnn_neurons = 1
max_conv_inputs = 1
max_mdense_tmp = 1
@ -237,7 +240,7 @@ with h5py.File(filename, "r") as f:
units = min(f['model_weights']['gru_a']['gru_a']['recurrent_kernel:0'].shape)
units2 = min(f['model_weights']['gru_b']['gru_b']['recurrent_kernel:0'].shape)
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units, rnn_units2=units2)
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units, rnn_units2=units2, flag_e2e = flag_e2e)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
#model.summary()
@ -288,6 +291,12 @@ for i, layer in enumerate(model.layers):
if layer.dump_layer(f, hf):
layer_list.append(layer.name)
if flag_e2e:
print("-- Weight Dumping for the Differentiable LPC Block --")
for i, layer in enumerate(model.get_layer("f2lpc").layers):
if layer.dump_layer(f, hf):
layer_list.append(layer.name)
dump_sparse_gru(model.get_layer('gru_a'), f, hf)
hf.write('#define MAX_RNN_NEURONS {}\n\n'.format(max_rnn_neurons))

View file

@ -0,0 +1,85 @@
"""
Custom Loss functions and metrics for training/analysis
"""
from tf_funcs import *
import tensorflow as tf
# The following loss functions all expect the lpcnet model to output the lpc prediction
# Computing the excitation by subtracting the lpc prediction from the target, followed by minimizing the cross entropy
def res_from_sigloss():
def loss(y_true,y_pred):
p = y_pred[:,:,0:1]
model_out = y_pred[:,:,1:]
e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
e_gt = tf.round(e_gt)
e_gt = tf.cast(e_gt,'int32')
sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out)
return sparse_cel
return loss
# Interpolated and Compensated Loss (In case of end to end lpcnet)
# Interpolates between adjacent embeddings based on the fractional value of the excitation computed (similar to the embedding interpolation)
# Also adds a probability compensation (to account for matching cross entropy in the linear domain), weighted by gamma
def interp_mulaw(gamma = 1):
def loss(y_true,y_pred):
p = y_pred[:,:,0:1]
model_out = y_pred[:,:,1:]
e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0))
alpha = e_gt - tf.math.floor(e_gt)
alpha = tf.tile(alpha,[1,1,256])
e_gt = tf.cast(e_gt,'int32')
e_gt = tf.clip_by_value(e_gt,0,254)
interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
loss_mod = sparse_cel + gamma*prob_compensation
return loss_mod
return loss
# Same as above, except a metric
def metric_oginterploss(y_true,y_pred):
p = y_pred[:,:,0:1]
model_out = y_pred[:,:,1:]
e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0))
alpha = e_gt - tf.math.floor(e_gt)
alpha = tf.tile(alpha,[1,1,256])
e_gt = tf.cast(e_gt,'int32')
e_gt = tf.clip_by_value(e_gt,0,254)
interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
loss_mod = sparse_cel + prob_compensation
return loss_mod
# Interpolated cross entropy loss metric
def metric_icel(y_true, y_pred):
p = y_pred[:,:,0:1]
model_out = y_pred[:,:,1:]
e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
alpha = e_gt - tf.math.floor(e_gt)
alpha = tf.tile(alpha,[1,1,256])
e_gt = tf.cast(e_gt,'int32')
e_gt = tf.clip_by_value(e_gt,0,254) #Check direction
interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
return sparse_cel
# Non-interpolated (rounded) cross entropy loss metric
def metric_cel(y_true, y_pred):
p = y_pred[:,:,0:1]
model_out = y_pred[:,:,1:]
e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
e_gt = tf.round(e_gt)
e_gt = tf.cast(e_gt,'int32')
e_gt = tf.clip_by_value(e_gt,0,255)
sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out)
return sparse_cel
# Variance metric of the output excitation
def metric_exc_sd(y_true,y_pred):
p = y_pred[:,:,0:1]
e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
sd_egt = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(e_gt,128)
return sd_egt

View file

@ -38,6 +38,9 @@ from mdense import MDense
import numpy as np
import h5py
import sys
from tf_funcs import *
from diffembed import diff_Embed
import difflpc
frame_size = 160
pcm_bits = 8
@ -186,7 +189,7 @@ class PCMInit(Initializer):
#a[:,0] = math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows
#a[:,1] = .5*a[:,0]*a[:,0]*a[:,0]
a = a + np.reshape(math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows, (num_rows, 1))
return self.gain * a
return self.gain * a.astype("float32")
def get_config(self):
return {
@ -212,7 +215,7 @@ class WeightClip(Constraint):
constraint = WeightClip(0.992)
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False):
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
pcm = Input(shape=(None, 3))
feat = Input(shape=(None, nb_used_features))
pitch = Input(shape=(None, 1))
@ -224,8 +227,21 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
fconv1 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv1')
fconv2 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv2')
if not flag_e2e:
embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig')
cpcm = Reshape((-1, embed_size*3))(embed(pcm))
else:
Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1))
error_calc = Lambda(lambda x: tf_l2u(tf_u2l(x[0]) - tf.roll(tf_u2l(x[1]),1,axis = 1)))
feat2lpc = difflpc.difflpc(training = training)
lpcoeffs = feat2lpc(feat)
tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs])
past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds])
embed = diff_Embed(name='embed_sig',initializer = PCMInit())
cpcm = Concatenate()([Input_extractor([pcm,0]),tensor_preds,past_errors])
cpcm = Reshape((-1, embed_size*3))(embed(cpcm))
cpcm_decoder = Concatenate()([Input_extractor([pcm,0]),Input_extractor([pcm,1]),Input_extractor([pcm,2])])
cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder))
pembed = Embedding(256, 64, name='embed_pitch')
cat_feat = Concatenate()([feat, Reshape((-1, 64))(pembed(pitch))])
@ -264,15 +280,22 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
md.trainable=False
embed.Trainable=False
if not flag_e2e:
model = Model([pcm, feat, pitch], ulaw_prob)
else:
m_out = Concatenate()([tensor_preds,ulaw_prob])
model = Model([pcm, feat, pitch], m_out)
model.rnn_units1 = rnn_units1
model.rnn_units2 = rnn_units2
model.nb_used_features = nb_used_features
model.frame_size = frame_size
if not flag_e2e:
encoder = Model([feat, pitch], cfeat)
dec_rnn_in = Concatenate()([cpcm, dec_feat])
else:
encoder = Model([feat, pitch], [cfeat,lpcoeffs])
dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat])
dec_gru_out1, state1 = rnn(dec_rnn_in, initial_state=dec_state1)
dec_gru_out2, state2 = rnn2(Concatenate()([dec_gru_out1, dec_feat]), initial_state=dec_state2)
dec_ulaw_prob = Lambda(tree_to_pdf_infer)(md(dec_gru_out2))

View file

@ -31,8 +31,10 @@ import numpy as np
from ulaw import ulaw2lin, lin2ulaw
import h5py
# Flag for synthesizing e2e (differentiable lpc) model
flag_e2e = False
model, enc, dec = lpcnet.new_lpcnet_model()
model, enc, dec = lpcnet.new_lpcnet_model(training = False, flag_e2e = flag_e2e)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
#model.summary()
@ -70,10 +72,16 @@ fout = open(out_file, 'wb')
skip = order + 1
for c in range(0, nb_frames):
if not flag_e2e:
cfeat = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]])
else:
cfeat,lpcs = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]])
for fr in range(0, feature_chunk_size):
f = c*feature_chunk_size + fr
if not flag_e2e:
a = features[c, fr, nb_features-order:]
else:
a = lpcs[c,fr]
for i in range(skip, frame_size):
pred = -sum(a*pcm[f*frame_size + i - 1:f*frame_size + i - order-1:-1])
fexc[0, 0, 1] = lin2ulaw(pred)

View file

@ -0,0 +1,69 @@
"""
Tensorflow/Keras helper functions to do the following:
1. \mu law <-> Linear domain conversion
2. Differentiable prediction from the input signal and LP coefficients
3. Differentiable transformations Reflection Coefficients (RCs) <-> LP Coefficients
"""
from tensorflow.keras.layers import Lambda, Multiply, Layer, Concatenate
from tensorflow.keras import backend as K
import tensorflow as tf
# \mu law <-> Linear conversion functions
scale = 255.0/32768.0
scale_1 = 32768.0/255.0
def tf_l2u(x):
s = K.sign(x)
x = K.abs(x)
u = (s*(128*K.log(1+scale*x)/K.log(256.0)))
u = K.clip(128 + u, 0, 255)
return u
def tf_u2l(u):
u = tf.cast(u,"float32")
u = u - 128.0
s = K.sign(u)
u = K.abs(u)
return s*scale_1*(K.exp(u/128.*K.log(256.0))-1)
# Differentiable Prediction Layer
# Computes the LP prediction from the input lag signal and the LP coefficients
# The inputs xt and lpc conform with the shapes in lpcnet.py (the '2400' is coded keeping this in mind)
class diff_pred(Layer):
def call(self, inputs, lpcoeffs_N = 16, frame_size = 160):
xt = tf_u2l(inputs[0])
lpc = inputs[1]
rept = Lambda(lambda x: K.repeat_elements(x , frame_size, 1))
zpX = Lambda(lambda x: K.concatenate([0*x[:,0:lpcoeffs_N,:], x],axis = 1))
cX = Lambda(lambda x: K.concatenate([x[:,(lpcoeffs_N - i):(lpcoeffs_N - i + 2400),:] for i in range(lpcoeffs_N)],axis = 2))
pred = -Multiply()([rept(lpc),cX(zpX(xt))])
return tf_l2u(K.sum(pred,axis = 2,keepdims = True))
# Differentiable Transformations (RC <-> LPC) computed using the Levinson Durbin Recursion
class diff_rc2lpc(Layer):
def call(self, inputs, lpcoeffs_N = 16):
def pred_lpc_recursive(input):
temp = (input[0] + K.repeat_elements(input[1],input[0].shape[2],2)*K.reverse(input[0],axes = 2))
temp = Concatenate(axis = 2)([temp,input[1]])
return temp
Llpc = Lambda(pred_lpc_recursive)
lpc_init = inputs
for i in range(1,lpcoeffs_N):
lpc_init = Llpc([lpc_init[:,:,:i],K.expand_dims(inputs[:,:,i],axis = -1)])
return lpc_init
class diff_lpc2rc(Layer):
def call(self, inputs, lpcoeffs_N = 16):
def pred_rc_recursive(input):
ki = K.repeat_elements(K.expand_dims(input[1][:,:,0],axis = -1),input[0].shape[2],2)
temp = (input[0] - ki*K.reverse(input[0],axes = 2))/(1 - ki*ki)
temp = Concatenate(axis = 2)([temp,input[1]])
return temp
Lrc = Lambda(pred_rc_recursive)
rc_init = inputs
for i in range(1,lpcoeffs_N):
j = (lpcoeffs_N - i + 1)
rc_init = Lrc([rc_init[:,:,:(j - 1)],rc_init[:,:,(j - 1):]])
return rc_init

View file

@ -44,7 +44,7 @@ parser.add_argument('--grua-size', metavar='<units>', default=384, type=int, hel
parser.add_argument('--grub-size', metavar='<units>', default=16, type=int, help='number of units in GRU B (default 16)')
parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
parser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation')
args = parser.parse_args()
@ -66,12 +66,14 @@ lpcnet = importlib.import_module(args.model)
import sys
import numpy as np
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from ulaw import ulaw2lin, lin2ulaw
import tensorflow.keras.backend as K
import h5py
import tensorflow as tf
from tf_funcs import *
from lossfuncs import *
#gpus = tf.config.experimental.list_physical_devices('GPU')
#if gpus:
# try:
@ -93,12 +95,17 @@ else:
lr = 0.001
decay = 2.5e-5
flag_e2e = args.flag_e2e
opt = Adam(lr, decay=decay, beta_2=0.99)
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize)
model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize, flag_e2e = flag_e2e)
if not flag_e2e:
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
else:
model.compile(optimizer=opt, loss = interp_mulaw(gamma = 2),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss])
model.summary()
feature_file = args.features
@ -150,4 +157,5 @@ else:
grub_sparsify = lpcnet.SparsifyGRUB(2000, 40000, 400, args.grua_size, grub_density)
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify])
csv_logger = CSVLogger('training_vals.log')
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])