From f13debcf65282023e1b708283ddd95db3df00105 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Sun, 14 Oct 2018 12:21:09 -0400 Subject: [PATCH] Arbitrary 16x1 sparseness --- dnn/lpcnet.py | 44 ++++++++++++++++++++++++++++++++++++++ dnn/test_wavenet_audio.py | 2 +- dnn/train_wavenet_audio.py | 5 +++-- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/dnn/lpcnet.py b/dnn/lpcnet.py index e19449db..08c9b32c 100644 --- a/dnn/lpcnet.py +++ b/dnn/lpcnet.py @@ -5,6 +5,7 @@ from keras.models import Model from keras.layers import Input, LSTM, CuDNNGRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation from keras import backend as K from keras.initializers import Initializer +from keras.callbacks import Callback from mdense import MDense import numpy as np import h5py @@ -17,6 +18,49 @@ embed_size = 128 pcm_levels = 2**pcm_bits nb_used_features = 38 +class Sparsify(Callback): + def __init__(self, t_start, t_end, interval, density): + super(Sparsify, self).__init__() + self.batch = 0 + self.t_start = t_start + self.t_end = t_end + self.interval = interval + self.final_density = density + + def on_batch_end(self, batch, logs=None): + #print("batch number", self.batch) + self.batch += 1 + if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end): + #print("don't constrain"); + pass + else: + #print("constrain"); + layer = self.model.get_layer('cu_dnngru_1') + w = layer.get_weights() + p = w[1] + nb = p.shape[1]//p.shape[0] + N = p.shape[0] + #print("nb = ", nb, ", N = ", N); + #print(p.shape) + density = self.final_density + if self.batch < self.t_end: + r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start) + density = 1 - (1-self.final_density)*(1 - r*r*r) + #print ("density = ", density) + for k in range(nb): + A = p[:, k*N:(k+1)*N] + L=np.reshape(A, (N, N//16, 16)) + S=np.sum(L*L, axis=-1) + SS=np.sort(np.reshape(S, (-1,))) + thresh = SS[round(N*N//16*(1-density))] + mask = (S>=thresh).astype('float32'); + mask = np.repeat(mask, 16, axis=1) + p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask + #print(thresh, np.mean(mask)) + w[1] = p + layer.set_weights(w) + + class PCMInit(Initializer): def __init__(self, gain=.1, seed=None): self.gain = gain diff --git a/dnn/test_wavenet_audio.py b/dnn/test_wavenet_audio.py index 257bc387..96348eb1 100755 --- a/dnn/test_wavenet_audio.py +++ b/dnn/test_wavenet_audio.py @@ -42,7 +42,7 @@ periods = (50*features[:,:,36:37]+100).astype('int16') -model.load_weights('wavenet5e3_60.h5') +model.load_weights('wavenet5p0_30.h5') order = 16 diff --git a/dnn/train_wavenet_audio.py b/dnn/train_wavenet_audio.py index c7dce4f5..63cac35a 100755 --- a/dnn/train_wavenet_audio.py +++ b/dnn/train_wavenet_audio.py @@ -68,6 +68,7 @@ features = features[:nb_frames*feature_chunk_size*nb_features] in_data = np.concatenate([data[0:1], data[:-1]]); noise = np.concatenate([np.zeros((len(data)*1//5)), np.random.randint(-3, 3, len(data)*1//5), np.random.randint(-2, 2, len(data)*1//5), np.random.randint(-1, 1, len(data)*2//5)]) +#noise = np.round(np.concatenate([np.zeros((len(data)*1//5)), np.random.laplace(0, 1.2, len(data)*1//5), np.random.laplace(0, .77, len(data)*1//5), np.random.laplace(0, .33, len(data)*1//5), np.random.randint(-1, 1, len(data)*1//5)])) in_data = in_data + noise in_data = np.clip(in_data, 0, 255) @@ -118,8 +119,8 @@ periods = (50*features[:,:,36:37]+100).astype('int16') in_data = np.concatenate([in_data, pred], axis=-1) # dump models to disk as we go -checkpoint = ModelCheckpoint('wavenet5d0_{epoch:02d}.h5') +checkpoint = ModelCheckpoint('wavenet5p0_{epoch:02d}.h5') #model.load_weights('wavenet4f2_30.h5') model.compile(optimizer=Adam(0.001, amsgrad=True, decay=5e-5), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) -model.fit([in_data, in_exc, features, periods], out_data, batch_size=batch_size, epochs=60, validation_split=0.2, callbacks=[checkpoint]) +model.fit([in_data, in_exc, features, periods], out_data, batch_size=batch_size, epochs=60, validation_split=0.2, callbacks=[checkpoint, lpcnet.Sparsify(1000, 20000, 200, 0.25)])