mirror of
https://github.com/xiph/opus.git
synced 2025-05-31 15:47:43 +00:00
Hard quantization for training
Also, using stateful GRU to randomize initialization
This commit is contained in:
parent
3b8d64d746
commit
c5a17a0716
3 changed files with 77 additions and 28 deletions
26
dnn/training_tf2/dataloader.py
Normal file
26
dnn/training_tf2/dataloader.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
import numpy as np
|
||||||
|
from tensorflow.keras.utils import Sequence
|
||||||
|
|
||||||
|
class LPCNetLoader(Sequence):
|
||||||
|
def __init__(self, data, features, periods, batch_size):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.nb_batches = np.minimum(np.minimum(data.shape[0], features.shape[0]), periods.shape[0])//self.batch_size
|
||||||
|
self.data = data[:self.nb_batches*self.batch_size, :]
|
||||||
|
self.features = features[:self.nb_batches*self.batch_size, :]
|
||||||
|
self.periods = periods[:self.nb_batches*self.batch_size, :]
|
||||||
|
self.on_epoch_end()
|
||||||
|
|
||||||
|
def on_epoch_end(self):
|
||||||
|
self.indices = np.arange(self.nb_batches*self.batch_size)
|
||||||
|
np.random.shuffle(self.indices)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
data = self.data[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
|
||||||
|
in_data = data[: , :, :3]
|
||||||
|
out_data = data[: , :, 3:4]
|
||||||
|
features = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
|
||||||
|
periods = self.periods[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
|
||||||
|
return ([in_data, features, periods], out_data)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.nb_batches
|
|
@ -28,7 +28,7 @@
|
||||||
import math
|
import math
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.models import Model
|
from tensorflow.keras.models import Model
|
||||||
from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation
|
from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise
|
||||||
from tensorflow.compat.v1.keras.layers import CuDNNGRU
|
from tensorflow.compat.v1.keras.layers import CuDNNGRU
|
||||||
from tensorflow.keras import backend as K
|
from tensorflow.keras import backend as K
|
||||||
from tensorflow.keras.constraints import Constraint
|
from tensorflow.keras.constraints import Constraint
|
||||||
|
@ -70,21 +70,19 @@ def quant_regularizer(x):
|
||||||
return .01 * tf.reduce_mean(K.sqrt(K.sqrt(1.0001 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))))
|
return .01 * tf.reduce_mean(K.sqrt(K.sqrt(1.0001 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))))
|
||||||
|
|
||||||
class Sparsify(Callback):
|
class Sparsify(Callback):
|
||||||
def __init__(self, t_start, t_end, interval, density):
|
def __init__(self, t_start, t_end, interval, density, quantize=False):
|
||||||
super(Sparsify, self).__init__()
|
super(Sparsify, self).__init__()
|
||||||
self.batch = 0
|
self.batch = 0
|
||||||
self.t_start = t_start
|
self.t_start = t_start
|
||||||
self.t_end = t_end
|
self.t_end = t_end
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.final_density = density
|
self.final_density = density
|
||||||
|
self.quantize = quantize
|
||||||
|
|
||||||
def on_batch_end(self, batch, logs=None):
|
def on_batch_end(self, batch, logs=None):
|
||||||
#print("batch number", self.batch)
|
#print("batch number", self.batch)
|
||||||
self.batch += 1
|
self.batch += 1
|
||||||
if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end):
|
if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
|
||||||
#print("don't constrain");
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
#print("constrain");
|
#print("constrain");
|
||||||
layer = self.model.get_layer('gru_a')
|
layer = self.model.get_layer('gru_a')
|
||||||
w = layer.get_weights()
|
w = layer.get_weights()
|
||||||
|
@ -96,7 +94,7 @@ class Sparsify(Callback):
|
||||||
#print ("density = ", density)
|
#print ("density = ", density)
|
||||||
for k in range(nb):
|
for k in range(nb):
|
||||||
density = self.final_density[k]
|
density = self.final_density[k]
|
||||||
if self.batch < self.t_end:
|
if self.batch < self.t_end and not self.quantize:
|
||||||
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
|
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
|
||||||
density = 1 - (1-self.final_density[k])*(1 - r*r*r)
|
density = 1 - (1-self.final_density[k])*(1 - r*r*r)
|
||||||
A = p[:, k*N:(k+1)*N]
|
A = p[:, k*N:(k+1)*N]
|
||||||
|
@ -108,7 +106,7 @@ class Sparsify(Callback):
|
||||||
S=np.sum(S, axis=1)
|
S=np.sum(S, axis=1)
|
||||||
SS=np.sort(np.reshape(S, (-1,)))
|
SS=np.sort(np.reshape(S, (-1,)))
|
||||||
thresh = SS[round(N*N//32*(1-density))]
|
thresh = SS[round(N*N//32*(1-density))]
|
||||||
mask = (S>=thresh).astype('float32');
|
mask = (S>=thresh).astype('float32')
|
||||||
mask = np.repeat(mask, 4, axis=0)
|
mask = np.repeat(mask, 4, axis=0)
|
||||||
mask = np.repeat(mask, 8, axis=1)
|
mask = np.repeat(mask, 8, axis=1)
|
||||||
mask = np.minimum(1, mask + np.diag(np.ones((N,))))
|
mask = np.minimum(1, mask + np.diag(np.ones((N,))))
|
||||||
|
@ -116,11 +114,21 @@ class Sparsify(Callback):
|
||||||
mask = np.transpose(mask, (1, 0))
|
mask = np.transpose(mask, (1, 0))
|
||||||
p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask
|
p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask
|
||||||
#print(thresh, np.mean(mask))
|
#print(thresh, np.mean(mask))
|
||||||
|
if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
|
||||||
|
if self.batch < self.t_end:
|
||||||
|
threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
|
||||||
|
else:
|
||||||
|
threshold = .5
|
||||||
|
quant = np.round(p*128.)
|
||||||
|
res = p*128.-quant
|
||||||
|
mask = (np.abs(res) <= threshold).astype('float32')
|
||||||
|
p = mask/128.*quant + (1-mask)*p
|
||||||
|
|
||||||
w[1] = p
|
w[1] = p
|
||||||
layer.set_weights(w)
|
layer.set_weights(w)
|
||||||
|
|
||||||
class SparsifyGRUB(Callback):
|
class SparsifyGRUB(Callback):
|
||||||
def __init__(self, t_start, t_end, interval, grua_units, density):
|
def __init__(self, t_start, t_end, interval, grua_units, density, quantize=False):
|
||||||
super(SparsifyGRUB, self).__init__()
|
super(SparsifyGRUB, self).__init__()
|
||||||
self.batch = 0
|
self.batch = 0
|
||||||
self.t_start = t_start
|
self.t_start = t_start
|
||||||
|
@ -128,14 +136,12 @@ class SparsifyGRUB(Callback):
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.final_density = density
|
self.final_density = density
|
||||||
self.grua_units = grua_units
|
self.grua_units = grua_units
|
||||||
|
self.quantize = quantize
|
||||||
|
|
||||||
def on_batch_end(self, batch, logs=None):
|
def on_batch_end(self, batch, logs=None):
|
||||||
#print("batch number", self.batch)
|
#print("batch number", self.batch)
|
||||||
self.batch += 1
|
self.batch += 1
|
||||||
if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end):
|
if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
|
||||||
#print("don't constrain");
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
#print("constrain");
|
#print("constrain");
|
||||||
layer = self.model.get_layer('gru_b')
|
layer = self.model.get_layer('gru_b')
|
||||||
w = layer.get_weights()
|
w = layer.get_weights()
|
||||||
|
@ -144,7 +150,7 @@ class SparsifyGRUB(Callback):
|
||||||
M = p.shape[1]//3
|
M = p.shape[1]//3
|
||||||
for k in range(3):
|
for k in range(3):
|
||||||
density = self.final_density[k]
|
density = self.final_density[k]
|
||||||
if self.batch < self.t_end:
|
if self.batch < self.t_end and not self.quantize:
|
||||||
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
|
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
|
||||||
density = 1 - (1-self.final_density[k])*(1 - r*r*r)
|
density = 1 - (1-self.final_density[k])*(1 - r*r*r)
|
||||||
A = p[:, k*M:(k+1)*M]
|
A = p[:, k*M:(k+1)*M]
|
||||||
|
@ -158,7 +164,7 @@ class SparsifyGRUB(Callback):
|
||||||
S=np.sum(S, axis=1)
|
S=np.sum(S, axis=1)
|
||||||
SS=np.sort(np.reshape(S, (-1,)))
|
SS=np.sort(np.reshape(S, (-1,)))
|
||||||
thresh = SS[round(M*N2//32*(1-density))]
|
thresh = SS[round(M*N2//32*(1-density))]
|
||||||
mask = (S>=thresh).astype('float32');
|
mask = (S>=thresh).astype('float32')
|
||||||
mask = np.repeat(mask, 4, axis=0)
|
mask = np.repeat(mask, 4, axis=0)
|
||||||
mask = np.repeat(mask, 8, axis=1)
|
mask = np.repeat(mask, 8, axis=1)
|
||||||
A = np.concatenate([A2*mask, A[N2:,:]], axis=0)
|
A = np.concatenate([A2*mask, A[N2:,:]], axis=0)
|
||||||
|
@ -167,6 +173,16 @@ class SparsifyGRUB(Callback):
|
||||||
A = np.reshape(A, (N, M))
|
A = np.reshape(A, (N, M))
|
||||||
p[:, k*M:(k+1)*M] = A
|
p[:, k*M:(k+1)*M] = A
|
||||||
#print(thresh, np.mean(mask))
|
#print(thresh, np.mean(mask))
|
||||||
|
if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
|
||||||
|
if self.batch < self.t_end:
|
||||||
|
threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
|
||||||
|
else:
|
||||||
|
threshold = .5
|
||||||
|
quant = np.round(p*128.)
|
||||||
|
res = p*128.-quant
|
||||||
|
mask = (np.abs(res) <= threshold).astype('float32')
|
||||||
|
p = mask/128.*quant + (1-mask)*p
|
||||||
|
|
||||||
w[0] = p
|
w[0] = p
|
||||||
layer.set_weights(w)
|
layer.set_weights(w)
|
||||||
|
|
||||||
|
@ -215,9 +231,9 @@ class WeightClip(Constraint):
|
||||||
constraint = WeightClip(0.992)
|
constraint = WeightClip(0.992)
|
||||||
|
|
||||||
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
|
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
|
||||||
pcm = Input(shape=(None, 3))
|
pcm = Input(shape=(None, 3), batch_size=128)
|
||||||
feat = Input(shape=(None, nb_used_features))
|
feat = Input(shape=(None, nb_used_features), batch_size=128)
|
||||||
pitch = Input(shape=(None, 1))
|
pitch = Input(shape=(None, 1), batch_size=128)
|
||||||
dec_feat = Input(shape=(None, 128))
|
dec_feat = Input(shape=(None, 128))
|
||||||
dec_state1 = Input(shape=(rnn_units1,))
|
dec_state1 = Input(shape=(rnn_units1,))
|
||||||
dec_state2 = Input(shape=(rnn_units2,))
|
dec_state2 = Input(shape=(rnn_units2,))
|
||||||
|
@ -256,19 +272,20 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
|
||||||
quant = quant_regularizer if quantize else None
|
quant = quant_regularizer if quantize else None
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a',
|
rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a', stateful=True,
|
||||||
recurrent_constraint = constraint, recurrent_regularizer=quant)
|
recurrent_constraint = constraint, recurrent_regularizer=quant)
|
||||||
rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b',
|
rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b', stateful=True,
|
||||||
kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
|
kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
|
||||||
else:
|
else:
|
||||||
rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a',
|
rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a', stateful=True,
|
||||||
recurrent_constraint = constraint, recurrent_regularizer=quant)
|
recurrent_constraint = constraint, recurrent_regularizer=quant)
|
||||||
rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b',
|
rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b', stateful=True,
|
||||||
kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
|
kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
|
||||||
|
|
||||||
rnn_in = Concatenate()([cpcm, rep(cfeat)])
|
rnn_in = Concatenate()([cpcm, rep(cfeat)])
|
||||||
md = MDense(pcm_levels, activation='sigmoid', name='dual_fc')
|
md = MDense(pcm_levels, activation='sigmoid', name='dual_fc')
|
||||||
gru_out1, _ = rnn(rnn_in)
|
gru_out1, _ = rnn(rnn_in)
|
||||||
|
gru_out1 = GaussianNoise(.005)(gru_out1)
|
||||||
gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
|
gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
|
||||||
ulaw_prob = Lambda(tree_to_pdf_train)(md(gru_out2))
|
ulaw_prob = Lambda(tree_to_pdf_train)(md(gru_out2))
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
# Train an LPCNet model
|
# Train an LPCNet model
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from dataloader import LPCNetLoader
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Train an LPCNet model')
|
parser = argparse.ArgumentParser(description='Train an LPCNet model')
|
||||||
|
|
||||||
|
@ -148,10 +149,10 @@ data = data[:nb_frames*4*pcm_chunk_size]
|
||||||
|
|
||||||
|
|
||||||
data = np.reshape(data, (nb_frames, pcm_chunk_size, 4))
|
data = np.reshape(data, (nb_frames, pcm_chunk_size, 4))
|
||||||
in_data = data[:,:,:3]
|
#in_data = data[:,:,:3]
|
||||||
out_exc = data[:,:,3:4]
|
#out_exc = data[:,:,3:4]
|
||||||
|
|
||||||
print("ulaw std = ", np.std(out_exc))
|
#print("ulaw std = ", np.std(out_exc))
|
||||||
|
|
||||||
sizeof = features.strides[-1]
|
sizeof = features.strides[-1]
|
||||||
features = np.lib.stride_tricks.as_strided(features, shape=(nb_frames, feature_chunk_size+4, nb_features),
|
features = np.lib.stride_tricks.as_strided(features, shape=(nb_frames, feature_chunk_size+4, nb_features),
|
||||||
|
@ -171,6 +172,10 @@ if args.retrain is not None:
|
||||||
if quantize or retrain:
|
if quantize or retrain:
|
||||||
#Adapting from an existing model
|
#Adapting from an existing model
|
||||||
model.load_weights(input_model)
|
model.load_weights(input_model)
|
||||||
|
if quantize:
|
||||||
|
sparsify = lpcnet.Sparsify(10000, 30000, 100, density, quantize=True)
|
||||||
|
grub_sparsify = lpcnet.SparsifyGRUB(10000, 30000, 100, args.grua_size, grub_density, quantize=True)
|
||||||
|
else:
|
||||||
sparsify = lpcnet.Sparsify(0, 0, 1, density)
|
sparsify = lpcnet.Sparsify(0, 0, 1, density)
|
||||||
grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
|
grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
|
||||||
else:
|
else:
|
||||||
|
@ -180,4 +185,5 @@ else:
|
||||||
|
|
||||||
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
|
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
|
||||||
csv_logger = CSVLogger('training_vals.log')
|
csv_logger = CSVLogger('training_vals.log')
|
||||||
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])
|
loader = LPCNetLoader(data, features, periods, batch_size)
|
||||||
|
model.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue