mirror of
https://github.com/xiph/opus.git
synced 2025-06-03 00:57:43 +00:00
Hard quantization for training
Also, using stateful GRU to randomize initialization
This commit is contained in:
parent
3b8d64d746
commit
c5a17a0716
3 changed files with 77 additions and 28 deletions
|
@ -28,6 +28,7 @@
|
|||
# Train an LPCNet model
|
||||
|
||||
import argparse
|
||||
from dataloader import LPCNetLoader
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train an LPCNet model')
|
||||
|
||||
|
@ -148,10 +149,10 @@ data = data[:nb_frames*4*pcm_chunk_size]
|
|||
|
||||
|
||||
data = np.reshape(data, (nb_frames, pcm_chunk_size, 4))
|
||||
in_data = data[:,:,:3]
|
||||
out_exc = data[:,:,3:4]
|
||||
#in_data = data[:,:,:3]
|
||||
#out_exc = data[:,:,3:4]
|
||||
|
||||
print("ulaw std = ", np.std(out_exc))
|
||||
#print("ulaw std = ", np.std(out_exc))
|
||||
|
||||
sizeof = features.strides[-1]
|
||||
features = np.lib.stride_tricks.as_strided(features, shape=(nb_frames, feature_chunk_size+4, nb_features),
|
||||
|
@ -171,8 +172,12 @@ if args.retrain is not None:
|
|||
if quantize or retrain:
|
||||
#Adapting from an existing model
|
||||
model.load_weights(input_model)
|
||||
sparsify = lpcnet.Sparsify(0, 0, 1, density)
|
||||
grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
|
||||
if quantize:
|
||||
sparsify = lpcnet.Sparsify(10000, 30000, 100, density, quantize=True)
|
||||
grub_sparsify = lpcnet.SparsifyGRUB(10000, 30000, 100, args.grua_size, grub_density, quantize=True)
|
||||
else:
|
||||
sparsify = lpcnet.Sparsify(0, 0, 1, density)
|
||||
grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
|
||||
else:
|
||||
#Training from scratch
|
||||
sparsify = lpcnet.Sparsify(2000, 40000, 400, density)
|
||||
|
@ -180,4 +185,5 @@ else:
|
|||
|
||||
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
|
||||
csv_logger = CSVLogger('training_vals.log')
|
||||
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])
|
||||
loader = LPCNetLoader(data, features, periods, batch_size)
|
||||
model.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue