Properly align LPC with lookahead in data loader

This commit is contained in:
Jean-Marc Valin 2022-09-19 19:03:09 -04:00
parent dd114baf4d
commit f5c251c5d5
2 changed files with 7 additions and 3 deletions

View file

@ -13,13 +13,14 @@ def lpc2rc(lpc):
return rc
class LPCNetLoader(Sequence):
def __init__(self, data, features, periods, batch_size, e2e=False):
def __init__(self, data, features, periods, batch_size, e2e=False, lookahead=2):
self.batch_size = batch_size
self.nb_batches = np.minimum(np.minimum(data.shape[0], features.shape[0]), periods.shape[0])//self.batch_size
self.data = data[:self.nb_batches*self.batch_size, :]
self.features = features[:self.nb_batches*self.batch_size, :]
self.periods = periods[:self.nb_batches*self.batch_size, :]
self.e2e = e2e
self.lookahead = lookahead
self.on_epoch_end()
def on_epoch_end(self):
@ -34,7 +35,10 @@ class LPCNetLoader(Sequence):
periods = self.periods[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
outputs = [out_data]
inputs = [in_data, features, periods]
lpc = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], 2:-2, -16:]
if self.lookahead > 0:
lpc = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], 4-self.lookahead:-self.lookahead, -16:]
else:
lpc = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], 4:, -16:]
if self.e2e:
outputs.append(lpc2rc(lpc))
else:

View file

@ -203,7 +203,7 @@ else:
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
loader = LPCNetLoader(data, features, periods, batch_size, e2e=flag_e2e)
loader = LPCNetLoader(data, features, periods, batch_size, e2e=flag_e2e, lookahead=args.lookahead)
callbacks = [checkpoint, sparsify, grub_sparsify]
if args.logdir is not None: