Hard quantization for training

Also, using stateful GRU to randomize initialization
This commit is contained in:
Jean-Marc Valin 2021-10-04 02:53:46 -04:00
parent 3b8d64d746
commit c5a17a0716
3 changed files with 77 additions and 28 deletions

View file

@ -0,0 +1,26 @@
import numpy as np
from tensorflow.keras.utils import Sequence
class LPCNetLoader(Sequence):
def __init__(self, data, features, periods, batch_size):
self.batch_size = batch_size
self.nb_batches = np.minimum(np.minimum(data.shape[0], features.shape[0]), periods.shape[0])//self.batch_size
self.data = data[:self.nb_batches*self.batch_size, :]
self.features = features[:self.nb_batches*self.batch_size, :]
self.periods = periods[:self.nb_batches*self.batch_size, :]
self.on_epoch_end()
def on_epoch_end(self):
self.indices = np.arange(self.nb_batches*self.batch_size)
np.random.shuffle(self.indices)
def __getitem__(self, index):
data = self.data[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
in_data = data[: , :, :3]
out_data = data[: , :, 3:4]
features = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
periods = self.periods[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
return ([in_data, features, periods], out_data)
def __len__(self):
return self.nb_batches

View file

@ -28,7 +28,7 @@
import math
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation
from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise
from tensorflow.compat.v1.keras.layers import CuDNNGRU
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
@ -70,21 +70,19 @@ def quant_regularizer(x):
return .01 * tf.reduce_mean(K.sqrt(K.sqrt(1.0001 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))))
class Sparsify(Callback):
def __init__(self, t_start, t_end, interval, density):
def __init__(self, t_start, t_end, interval, density, quantize=False):
super(Sparsify, self).__init__()
self.batch = 0
self.t_start = t_start
self.t_end = t_end
self.interval = interval
self.final_density = density
self.quantize = quantize
def on_batch_end(self, batch, logs=None):
#print("batch number", self.batch)
self.batch += 1
if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end):
#print("don't constrain");
pass
else:
if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
#print("constrain");
layer = self.model.get_layer('gru_a')
w = layer.get_weights()
@ -96,7 +94,7 @@ class Sparsify(Callback):
#print ("density = ", density)
for k in range(nb):
density = self.final_density[k]
if self.batch < self.t_end:
if self.batch < self.t_end and not self.quantize:
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
density = 1 - (1-self.final_density[k])*(1 - r*r*r)
A = p[:, k*N:(k+1)*N]
@ -108,7 +106,7 @@ class Sparsify(Callback):
S=np.sum(S, axis=1)
SS=np.sort(np.reshape(S, (-1,)))
thresh = SS[round(N*N//32*(1-density))]
mask = (S>=thresh).astype('float32');
mask = (S>=thresh).astype('float32')
mask = np.repeat(mask, 4, axis=0)
mask = np.repeat(mask, 8, axis=1)
mask = np.minimum(1, mask + np.diag(np.ones((N,))))
@ -116,11 +114,21 @@ class Sparsify(Callback):
mask = np.transpose(mask, (1, 0))
p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask
#print(thresh, np.mean(mask))
if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
if self.batch < self.t_end:
threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
else:
threshold = .5
quant = np.round(p*128.)
res = p*128.-quant
mask = (np.abs(res) <= threshold).astype('float32')
p = mask/128.*quant + (1-mask)*p
w[1] = p
layer.set_weights(w)
class SparsifyGRUB(Callback):
def __init__(self, t_start, t_end, interval, grua_units, density):
def __init__(self, t_start, t_end, interval, grua_units, density, quantize=False):
super(SparsifyGRUB, self).__init__()
self.batch = 0
self.t_start = t_start
@ -128,14 +136,12 @@ class SparsifyGRUB(Callback):
self.interval = interval
self.final_density = density
self.grua_units = grua_units
self.quantize = quantize
def on_batch_end(self, batch, logs=None):
#print("batch number", self.batch)
self.batch += 1
if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end):
#print("don't constrain");
pass
else:
if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
#print("constrain");
layer = self.model.get_layer('gru_b')
w = layer.get_weights()
@ -144,7 +150,7 @@ class SparsifyGRUB(Callback):
M = p.shape[1]//3
for k in range(3):
density = self.final_density[k]
if self.batch < self.t_end:
if self.batch < self.t_end and not self.quantize:
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
density = 1 - (1-self.final_density[k])*(1 - r*r*r)
A = p[:, k*M:(k+1)*M]
@ -158,7 +164,7 @@ class SparsifyGRUB(Callback):
S=np.sum(S, axis=1)
SS=np.sort(np.reshape(S, (-1,)))
thresh = SS[round(M*N2//32*(1-density))]
mask = (S>=thresh).astype('float32');
mask = (S>=thresh).astype('float32')
mask = np.repeat(mask, 4, axis=0)
mask = np.repeat(mask, 8, axis=1)
A = np.concatenate([A2*mask, A[N2:,:]], axis=0)
@ -167,6 +173,16 @@ class SparsifyGRUB(Callback):
A = np.reshape(A, (N, M))
p[:, k*M:(k+1)*M] = A
#print(thresh, np.mean(mask))
if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
if self.batch < self.t_end:
threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
else:
threshold = .5
quant = np.round(p*128.)
res = p*128.-quant
mask = (np.abs(res) <= threshold).astype('float32')
p = mask/128.*quant + (1-mask)*p
w[0] = p
layer.set_weights(w)
@ -215,9 +231,9 @@ class WeightClip(Constraint):
constraint = WeightClip(0.992)
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
pcm = Input(shape=(None, 3))
feat = Input(shape=(None, nb_used_features))
pitch = Input(shape=(None, 1))
pcm = Input(shape=(None, 3), batch_size=128)
feat = Input(shape=(None, nb_used_features), batch_size=128)
pitch = Input(shape=(None, 1), batch_size=128)
dec_feat = Input(shape=(None, 128))
dec_state1 = Input(shape=(rnn_units1,))
dec_state2 = Input(shape=(rnn_units2,))
@ -256,19 +272,20 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
quant = quant_regularizer if quantize else None
if training:
rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a',
rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a', stateful=True,
recurrent_constraint = constraint, recurrent_regularizer=quant)
rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b',
rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b', stateful=True,
kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
else:
rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a',
rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a', stateful=True,
recurrent_constraint = constraint, recurrent_regularizer=quant)
rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b',
rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b', stateful=True,
kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
rnn_in = Concatenate()([cpcm, rep(cfeat)])
md = MDense(pcm_levels, activation='sigmoid', name='dual_fc')
gru_out1, _ = rnn(rnn_in)
gru_out1 = GaussianNoise(.005)(gru_out1)
gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
ulaw_prob = Lambda(tree_to_pdf_train)(md(gru_out2))

View file

@ -28,6 +28,7 @@
# Train an LPCNet model
import argparse
from dataloader import LPCNetLoader
parser = argparse.ArgumentParser(description='Train an LPCNet model')
@ -148,10 +149,10 @@ data = data[:nb_frames*4*pcm_chunk_size]
data = np.reshape(data, (nb_frames, pcm_chunk_size, 4))
in_data = data[:,:,:3]
out_exc = data[:,:,3:4]
#in_data = data[:,:,:3]
#out_exc = data[:,:,3:4]
print("ulaw std = ", np.std(out_exc))
#print("ulaw std = ", np.std(out_exc))
sizeof = features.strides[-1]
features = np.lib.stride_tricks.as_strided(features, shape=(nb_frames, feature_chunk_size+4, nb_features),
@ -171,6 +172,10 @@ if args.retrain is not None:
if quantize or retrain:
#Adapting from an existing model
model.load_weights(input_model)
if quantize:
sparsify = lpcnet.Sparsify(10000, 30000, 100, density, quantize=True)
grub_sparsify = lpcnet.SparsifyGRUB(10000, 30000, 100, args.grua_size, grub_density, quantize=True)
else:
sparsify = lpcnet.Sparsify(0, 0, 1, density)
grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
else:
@ -180,4 +185,5 @@ else:
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
csv_logger = CSVLogger('training_vals.log')
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])
loader = LPCNetLoader(data, features, periods, batch_size)
model.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])