Don't hardcode the number of bands

This commit is contained in:
Jean-Marc Valin 2021-10-20 17:20:32 -04:00
parent b5b1d5013e
commit a9bf6cee8a
2 changed files with 5 additions and 3 deletions

View file

@ -89,7 +89,7 @@ void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b
int pitch;
float rc[LPC_ORDER];
/* Matches the Python code -- the 0.1 avoids rounding issues. */
pitch = (int)floor(.1 + 50*features[18]+100);
pitch = (int)floor(.1 + 50*features[NB_BANDS]+100);
pitch = IMIN(255, IMAX(33, pitch));
net = &lpcnet->nnet;
RNN_COPY(in, features, NB_FEATURES);

View file

@ -100,6 +100,8 @@ batch_size = args.batch_size
quantize = args.quantize is not None
retrain = args.retrain is not None
lpc_order = 16
if quantize:
lr = 0.00003
decay = 0
@ -133,7 +135,7 @@ with strategy.scope():
feature_file = args.features
pcm_file = args.data # 16 bit unsigned short PCM samples
frame_size = model.frame_size
nb_features = 36
nb_features = model.nb_used_features + lpc_order
nb_used_features = model.nb_used_features
feature_chunk_size = 15
pcm_chunk_size = frame_size*feature_chunk_size
@ -160,7 +162,7 @@ features = np.lib.stride_tricks.as_strided(features, shape=(nb_frames, feature_c
#features = features[:, :, :nb_used_features]
periods = (.1 + 50*features[:,:,18:19]+100).astype('int16')
periods = (.1 + 50*features[:,:,nb_used_features-2:nb_used_features-1]+100).astype('int16')
#periods = np.minimum(periods, 255)
# dump models to disk as we go