mirror of
https://github.com/xiph/opus.git
synced 2025-05-29 22:57:41 +00:00
Properly align LPC with lookahead in data loader
This commit is contained in:
parent
dd114baf4d
commit
f5c251c5d5
2 changed files with 7 additions and 3 deletions
|
@ -13,13 +13,14 @@ def lpc2rc(lpc):
|
||||||
return rc
|
return rc
|
||||||
|
|
||||||
class LPCNetLoader(Sequence):
|
class LPCNetLoader(Sequence):
|
||||||
def __init__(self, data, features, periods, batch_size, e2e=False):
|
def __init__(self, data, features, periods, batch_size, e2e=False, lookahead=2):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.nb_batches = np.minimum(np.minimum(data.shape[0], features.shape[0]), periods.shape[0])//self.batch_size
|
self.nb_batches = np.minimum(np.minimum(data.shape[0], features.shape[0]), periods.shape[0])//self.batch_size
|
||||||
self.data = data[:self.nb_batches*self.batch_size, :]
|
self.data = data[:self.nb_batches*self.batch_size, :]
|
||||||
self.features = features[:self.nb_batches*self.batch_size, :]
|
self.features = features[:self.nb_batches*self.batch_size, :]
|
||||||
self.periods = periods[:self.nb_batches*self.batch_size, :]
|
self.periods = periods[:self.nb_batches*self.batch_size, :]
|
||||||
self.e2e = e2e
|
self.e2e = e2e
|
||||||
|
self.lookahead = lookahead
|
||||||
self.on_epoch_end()
|
self.on_epoch_end()
|
||||||
|
|
||||||
def on_epoch_end(self):
|
def on_epoch_end(self):
|
||||||
|
@ -34,7 +35,10 @@ class LPCNetLoader(Sequence):
|
||||||
periods = self.periods[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
|
periods = self.periods[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
|
||||||
outputs = [out_data]
|
outputs = [out_data]
|
||||||
inputs = [in_data, features, periods]
|
inputs = [in_data, features, periods]
|
||||||
lpc = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], 2:-2, -16:]
|
if self.lookahead > 0:
|
||||||
|
lpc = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], 4-self.lookahead:-self.lookahead, -16:]
|
||||||
|
else:
|
||||||
|
lpc = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], 4:, -16:]
|
||||||
if self.e2e:
|
if self.e2e:
|
||||||
outputs.append(lpc2rc(lpc))
|
outputs.append(lpc2rc(lpc))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -203,7 +203,7 @@ else:
|
||||||
|
|
||||||
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
|
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
|
||||||
|
|
||||||
loader = LPCNetLoader(data, features, periods, batch_size, e2e=flag_e2e)
|
loader = LPCNetLoader(data, features, periods, batch_size, e2e=flag_e2e, lookahead=args.lookahead)
|
||||||
|
|
||||||
callbacks = [checkpoint, sparsify, grub_sparsify]
|
callbacks = [checkpoint, sparsify, grub_sparsify]
|
||||||
if args.logdir is not None:
|
if args.logdir is not None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue