mirror of
https://github.com/xiph/opus.git
synced 2025-05-19 01:48:30 +00:00
Arbitrary 16x1 sparseness
This commit is contained in:
parent
62f330eca3
commit
f13debcf65
3 changed files with 48 additions and 3 deletions
|
@ -5,6 +5,7 @@ from keras.models import Model
|
|||
from keras.layers import Input, LSTM, CuDNNGRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation
|
||||
from keras import backend as K
|
||||
from keras.initializers import Initializer
|
||||
from keras.callbacks import Callback
|
||||
from mdense import MDense
|
||||
import numpy as np
|
||||
import h5py
|
||||
|
@ -17,6 +18,49 @@ embed_size = 128
|
|||
pcm_levels = 2**pcm_bits
|
||||
nb_used_features = 38
|
||||
|
||||
class Sparsify(Callback):
|
||||
def __init__(self, t_start, t_end, interval, density):
|
||||
super(Sparsify, self).__init__()
|
||||
self.batch = 0
|
||||
self.t_start = t_start
|
||||
self.t_end = t_end
|
||||
self.interval = interval
|
||||
self.final_density = density
|
||||
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
#print("batch number", self.batch)
|
||||
self.batch += 1
|
||||
if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end):
|
||||
#print("don't constrain");
|
||||
pass
|
||||
else:
|
||||
#print("constrain");
|
||||
layer = self.model.get_layer('cu_dnngru_1')
|
||||
w = layer.get_weights()
|
||||
p = w[1]
|
||||
nb = p.shape[1]//p.shape[0]
|
||||
N = p.shape[0]
|
||||
#print("nb = ", nb, ", N = ", N);
|
||||
#print(p.shape)
|
||||
density = self.final_density
|
||||
if self.batch < self.t_end:
|
||||
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
|
||||
density = 1 - (1-self.final_density)*(1 - r*r*r)
|
||||
#print ("density = ", density)
|
||||
for k in range(nb):
|
||||
A = p[:, k*N:(k+1)*N]
|
||||
L=np.reshape(A, (N, N//16, 16))
|
||||
S=np.sum(L*L, axis=-1)
|
||||
SS=np.sort(np.reshape(S, (-1,)))
|
||||
thresh = SS[round(N*N//16*(1-density))]
|
||||
mask = (S>=thresh).astype('float32');
|
||||
mask = np.repeat(mask, 16, axis=1)
|
||||
p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask
|
||||
#print(thresh, np.mean(mask))
|
||||
w[1] = p
|
||||
layer.set_weights(w)
|
||||
|
||||
|
||||
class PCMInit(Initializer):
|
||||
def __init__(self, gain=.1, seed=None):
|
||||
self.gain = gain
|
||||
|
|
|
@ -42,7 +42,7 @@ periods = (50*features[:,:,36:37]+100).astype('int16')
|
|||
|
||||
|
||||
|
||||
model.load_weights('wavenet5e3_60.h5')
|
||||
model.load_weights('wavenet5p0_30.h5')
|
||||
|
||||
order = 16
|
||||
|
||||
|
|
|
@ -68,6 +68,7 @@ features = features[:nb_frames*feature_chunk_size*nb_features]
|
|||
|
||||
in_data = np.concatenate([data[0:1], data[:-1]]);
|
||||
noise = np.concatenate([np.zeros((len(data)*1//5)), np.random.randint(-3, 3, len(data)*1//5), np.random.randint(-2, 2, len(data)*1//5), np.random.randint(-1, 1, len(data)*2//5)])
|
||||
#noise = np.round(np.concatenate([np.zeros((len(data)*1//5)), np.random.laplace(0, 1.2, len(data)*1//5), np.random.laplace(0, .77, len(data)*1//5), np.random.laplace(0, .33, len(data)*1//5), np.random.randint(-1, 1, len(data)*1//5)]))
|
||||
in_data = in_data + noise
|
||||
in_data = np.clip(in_data, 0, 255)
|
||||
|
||||
|
@ -118,8 +119,8 @@ periods = (50*features[:,:,36:37]+100).astype('int16')
|
|||
in_data = np.concatenate([in_data, pred], axis=-1)
|
||||
|
||||
# dump models to disk as we go
|
||||
checkpoint = ModelCheckpoint('wavenet5d0_{epoch:02d}.h5')
|
||||
checkpoint = ModelCheckpoint('wavenet5p0_{epoch:02d}.h5')
|
||||
|
||||
#model.load_weights('wavenet4f2_30.h5')
|
||||
model.compile(optimizer=Adam(0.001, amsgrad=True, decay=5e-5), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
|
||||
model.fit([in_data, in_exc, features, periods], out_data, batch_size=batch_size, epochs=60, validation_split=0.2, callbacks=[checkpoint])
|
||||
model.fit([in_data, in_exc, features, periods], out_data, batch_size=batch_size, epochs=60, validation_split=0.2, callbacks=[checkpoint, lpcnet.Sparsify(1000, 20000, 200, 0.25)])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue