Hard quantization for training

Also, using stateful GRU to randomize initialization
This commit is contained in:
Jean-Marc Valin 2021-10-04 02:53:46 -04:00
parent 3b8d64d746
commit c5a17a0716
3 changed files with 77 additions and 28 deletions

View file

@ -0,0 +1,26 @@
import numpy as np
from tensorflow.keras.utils import Sequence
class LPCNetLoader(Sequence):
def __init__(self, data, features, periods, batch_size):
self.batch_size = batch_size
self.nb_batches = np.minimum(np.minimum(data.shape[0], features.shape[0]), periods.shape[0])//self.batch_size
self.data = data[:self.nb_batches*self.batch_size, :]
self.features = features[:self.nb_batches*self.batch_size, :]
self.periods = periods[:self.nb_batches*self.batch_size, :]
self.on_epoch_end()
def on_epoch_end(self):
self.indices = np.arange(self.nb_batches*self.batch_size)
np.random.shuffle(self.indices)
def __getitem__(self, index):
data = self.data[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
in_data = data[: , :, :3]
out_data = data[: , :, 3:4]
features = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
periods = self.periods[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
return ([in_data, features, periods], out_data)
def __len__(self):
return self.nb_batches

View file

@ -28,7 +28,7 @@
import math import math
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import Model from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise
from tensorflow.compat.v1.keras.layers import CuDNNGRU from tensorflow.compat.v1.keras.layers import CuDNNGRU
from tensorflow.keras import backend as K from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint from tensorflow.keras.constraints import Constraint
@ -70,21 +70,19 @@ def quant_regularizer(x):
return .01 * tf.reduce_mean(K.sqrt(K.sqrt(1.0001 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x)))))) return .01 * tf.reduce_mean(K.sqrt(K.sqrt(1.0001 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))))
class Sparsify(Callback): class Sparsify(Callback):
def __init__(self, t_start, t_end, interval, density): def __init__(self, t_start, t_end, interval, density, quantize=False):
super(Sparsify, self).__init__() super(Sparsify, self).__init__()
self.batch = 0 self.batch = 0
self.t_start = t_start self.t_start = t_start
self.t_end = t_end self.t_end = t_end
self.interval = interval self.interval = interval
self.final_density = density self.final_density = density
self.quantize = quantize
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
#print("batch number", self.batch) #print("batch number", self.batch)
self.batch += 1 self.batch += 1
if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end): if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
#print("don't constrain");
pass
else:
#print("constrain"); #print("constrain");
layer = self.model.get_layer('gru_a') layer = self.model.get_layer('gru_a')
w = layer.get_weights() w = layer.get_weights()
@ -96,7 +94,7 @@ class Sparsify(Callback):
#print ("density = ", density) #print ("density = ", density)
for k in range(nb): for k in range(nb):
density = self.final_density[k] density = self.final_density[k]
if self.batch < self.t_end: if self.batch < self.t_end and not self.quantize:
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start) r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
density = 1 - (1-self.final_density[k])*(1 - r*r*r) density = 1 - (1-self.final_density[k])*(1 - r*r*r)
A = p[:, k*N:(k+1)*N] A = p[:, k*N:(k+1)*N]
@ -108,7 +106,7 @@ class Sparsify(Callback):
S=np.sum(S, axis=1) S=np.sum(S, axis=1)
SS=np.sort(np.reshape(S, (-1,))) SS=np.sort(np.reshape(S, (-1,)))
thresh = SS[round(N*N//32*(1-density))] thresh = SS[round(N*N//32*(1-density))]
mask = (S>=thresh).astype('float32'); mask = (S>=thresh).astype('float32')
mask = np.repeat(mask, 4, axis=0) mask = np.repeat(mask, 4, axis=0)
mask = np.repeat(mask, 8, axis=1) mask = np.repeat(mask, 8, axis=1)
mask = np.minimum(1, mask + np.diag(np.ones((N,)))) mask = np.minimum(1, mask + np.diag(np.ones((N,))))
@ -116,11 +114,21 @@ class Sparsify(Callback):
mask = np.transpose(mask, (1, 0)) mask = np.transpose(mask, (1, 0))
p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask
#print(thresh, np.mean(mask)) #print(thresh, np.mean(mask))
if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
if self.batch < self.t_end:
threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
else:
threshold = .5
quant = np.round(p*128.)
res = p*128.-quant
mask = (np.abs(res) <= threshold).astype('float32')
p = mask/128.*quant + (1-mask)*p
w[1] = p w[1] = p
layer.set_weights(w) layer.set_weights(w)
class SparsifyGRUB(Callback): class SparsifyGRUB(Callback):
def __init__(self, t_start, t_end, interval, grua_units, density): def __init__(self, t_start, t_end, interval, grua_units, density, quantize=False):
super(SparsifyGRUB, self).__init__() super(SparsifyGRUB, self).__init__()
self.batch = 0 self.batch = 0
self.t_start = t_start self.t_start = t_start
@ -128,14 +136,12 @@ class SparsifyGRUB(Callback):
self.interval = interval self.interval = interval
self.final_density = density self.final_density = density
self.grua_units = grua_units self.grua_units = grua_units
self.quantize = quantize
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
#print("batch number", self.batch) #print("batch number", self.batch)
self.batch += 1 self.batch += 1
if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end): if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
#print("don't constrain");
pass
else:
#print("constrain"); #print("constrain");
layer = self.model.get_layer('gru_b') layer = self.model.get_layer('gru_b')
w = layer.get_weights() w = layer.get_weights()
@ -144,7 +150,7 @@ class SparsifyGRUB(Callback):
M = p.shape[1]//3 M = p.shape[1]//3
for k in range(3): for k in range(3):
density = self.final_density[k] density = self.final_density[k]
if self.batch < self.t_end: if self.batch < self.t_end and not self.quantize:
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start) r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
density = 1 - (1-self.final_density[k])*(1 - r*r*r) density = 1 - (1-self.final_density[k])*(1 - r*r*r)
A = p[:, k*M:(k+1)*M] A = p[:, k*M:(k+1)*M]
@ -158,7 +164,7 @@ class SparsifyGRUB(Callback):
S=np.sum(S, axis=1) S=np.sum(S, axis=1)
SS=np.sort(np.reshape(S, (-1,))) SS=np.sort(np.reshape(S, (-1,)))
thresh = SS[round(M*N2//32*(1-density))] thresh = SS[round(M*N2//32*(1-density))]
mask = (S>=thresh).astype('float32'); mask = (S>=thresh).astype('float32')
mask = np.repeat(mask, 4, axis=0) mask = np.repeat(mask, 4, axis=0)
mask = np.repeat(mask, 8, axis=1) mask = np.repeat(mask, 8, axis=1)
A = np.concatenate([A2*mask, A[N2:,:]], axis=0) A = np.concatenate([A2*mask, A[N2:,:]], axis=0)
@ -167,6 +173,16 @@ class SparsifyGRUB(Callback):
A = np.reshape(A, (N, M)) A = np.reshape(A, (N, M))
p[:, k*M:(k+1)*M] = A p[:, k*M:(k+1)*M] = A
#print(thresh, np.mean(mask)) #print(thresh, np.mean(mask))
if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
if self.batch < self.t_end:
threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
else:
threshold = .5
quant = np.round(p*128.)
res = p*128.-quant
mask = (np.abs(res) <= threshold).astype('float32')
p = mask/128.*quant + (1-mask)*p
w[0] = p w[0] = p
layer.set_weights(w) layer.set_weights(w)
@ -215,9 +231,9 @@ class WeightClip(Constraint):
constraint = WeightClip(0.992) constraint = WeightClip(0.992)
def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False): def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
pcm = Input(shape=(None, 3)) pcm = Input(shape=(None, 3), batch_size=128)
feat = Input(shape=(None, nb_used_features)) feat = Input(shape=(None, nb_used_features), batch_size=128)
pitch = Input(shape=(None, 1)) pitch = Input(shape=(None, 1), batch_size=128)
dec_feat = Input(shape=(None, 128)) dec_feat = Input(shape=(None, 128))
dec_state1 = Input(shape=(rnn_units1,)) dec_state1 = Input(shape=(rnn_units1,))
dec_state2 = Input(shape=(rnn_units2,)) dec_state2 = Input(shape=(rnn_units2,))
@ -256,19 +272,20 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
quant = quant_regularizer if quantize else None quant = quant_regularizer if quantize else None
if training: if training:
rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a', rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a', stateful=True,
recurrent_constraint = constraint, recurrent_regularizer=quant) recurrent_constraint = constraint, recurrent_regularizer=quant)
rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b', rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b', stateful=True,
kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant) kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
else: else:
rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a', rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a', stateful=True,
recurrent_constraint = constraint, recurrent_regularizer=quant) recurrent_constraint = constraint, recurrent_regularizer=quant)
rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b', rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b', stateful=True,
kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant) kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
rnn_in = Concatenate()([cpcm, rep(cfeat)]) rnn_in = Concatenate()([cpcm, rep(cfeat)])
md = MDense(pcm_levels, activation='sigmoid', name='dual_fc') md = MDense(pcm_levels, activation='sigmoid', name='dual_fc')
gru_out1, _ = rnn(rnn_in) gru_out1, _ = rnn(rnn_in)
gru_out1 = GaussianNoise(.005)(gru_out1)
gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)])) gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
ulaw_prob = Lambda(tree_to_pdf_train)(md(gru_out2)) ulaw_prob = Lambda(tree_to_pdf_train)(md(gru_out2))

View file

@ -28,6 +28,7 @@
# Train an LPCNet model # Train an LPCNet model
import argparse import argparse
from dataloader import LPCNetLoader
parser = argparse.ArgumentParser(description='Train an LPCNet model') parser = argparse.ArgumentParser(description='Train an LPCNet model')
@ -148,10 +149,10 @@ data = data[:nb_frames*4*pcm_chunk_size]
data = np.reshape(data, (nb_frames, pcm_chunk_size, 4)) data = np.reshape(data, (nb_frames, pcm_chunk_size, 4))
in_data = data[:,:,:3] #in_data = data[:,:,:3]
out_exc = data[:,:,3:4] #out_exc = data[:,:,3:4]
print("ulaw std = ", np.std(out_exc)) #print("ulaw std = ", np.std(out_exc))
sizeof = features.strides[-1] sizeof = features.strides[-1]
features = np.lib.stride_tricks.as_strided(features, shape=(nb_frames, feature_chunk_size+4, nb_features), features = np.lib.stride_tricks.as_strided(features, shape=(nb_frames, feature_chunk_size+4, nb_features),
@ -171,6 +172,10 @@ if args.retrain is not None:
if quantize or retrain: if quantize or retrain:
#Adapting from an existing model #Adapting from an existing model
model.load_weights(input_model) model.load_weights(input_model)
if quantize:
sparsify = lpcnet.Sparsify(10000, 30000, 100, density, quantize=True)
grub_sparsify = lpcnet.SparsifyGRUB(10000, 30000, 100, args.grua_size, grub_density, quantize=True)
else:
sparsify = lpcnet.Sparsify(0, 0, 1, density) sparsify = lpcnet.Sparsify(0, 0, 1, density)
grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density) grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
else: else:
@ -180,4 +185,5 @@ else:
model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size)) model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
csv_logger = CSVLogger('training_vals.log') csv_logger = CSVLogger('training_vals.log')
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger]) loader = LPCNetLoader(data, features, periods, batch_size)
model.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])