mirror of
https://github.com/xiph/opus.git
synced 2025-05-25 04:39:13 +00:00
wip 8x4 sparseness
This commit is contained in:
parent
8e405b44e0
commit
cc28518699
3 changed files with 13 additions and 9 deletions
|
@ -66,15 +66,17 @@ def printSparseVector(f, A, name):
|
|||
A[:,2*N:] = A[:,2*N:] - np.diag(np.diag(A[:,2*N:]))
|
||||
printVector(f, diag, name + '_diag')
|
||||
idx = np.zeros((0,), dtype='int')
|
||||
for i in range(3*N//16):
|
||||
for i in range(3*N//8):
|
||||
pos = idx.shape[0]
|
||||
idx = np.append(idx, -1)
|
||||
nb_nonzero = 0
|
||||
for j in range(N):
|
||||
if np.sum(np.abs(A[j, i*16:(i+1)*16])) > 1e-10:
|
||||
for j in range(N//4):
|
||||
block = A[j*4:(j+1)*4, i*8:(i+1)*8]
|
||||
if np.sum(np.abs(block)) > 1e-10:
|
||||
nb_nonzero = nb_nonzero + 1
|
||||
idx = np.append(idx, j)
|
||||
W = np.concatenate([W, A[j, i*16:(i+1)*16]])
|
||||
vblock = block.transpose((1,0)).reshape((-1,))
|
||||
W = np.concatenate([W, vblock])
|
||||
idx[pos] = nb_nonzero
|
||||
printVector(f, W, name)
|
||||
#idx = np.tile(np.concatenate([np.array([N]), np.arange(N)]), 3*N//16)
|
||||
|
|
|
@ -74,12 +74,14 @@ class Sparsify(Callback):
|
|||
A = p[:, k*N:(k+1)*N]
|
||||
A = A - np.diag(np.diag(A))
|
||||
#A = np.transpose(A, (1, 0))
|
||||
L=np.reshape(A, (N, N//16, 16))
|
||||
L=np.reshape(A, (N//4, 4, N//8, 8))
|
||||
S=np.sum(L*L, axis=-1)
|
||||
S=np.sum(S, axis=1)
|
||||
SS=np.sort(np.reshape(S, (-1,)))
|
||||
thresh = SS[round(N*N//16*(1-density))]
|
||||
thresh = SS[round(N*N//32*(1-density))]
|
||||
mask = (S>=thresh).astype('float32');
|
||||
mask = np.repeat(mask, 16, axis=1)
|
||||
mask = np.repeat(mask, 4, axis=0)
|
||||
mask = np.repeat(mask, 8, axis=1)
|
||||
mask = np.minimum(1, mask + np.diag(np.ones((N,))))
|
||||
#mask = np.transpose(mask, (1, 0))
|
||||
p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask
|
||||
|
|
|
@ -102,7 +102,7 @@ del pred
|
|||
del in_exc
|
||||
|
||||
# dump models to disk as we go
|
||||
checkpoint = ModelCheckpoint('lpcnet32c_384_10_G16_{epoch:02d}.h5')
|
||||
checkpoint = ModelCheckpoint('lpcnet32v_384_10_G16_{epoch:02d}.h5')
|
||||
|
||||
#Set this to True to adapt an existing model (e.g. on new data)
|
||||
adaptation = False
|
||||
|
@ -120,5 +120,5 @@ else:
|
|||
decay = 5e-5
|
||||
|
||||
model.compile(optimizer=Adam(lr, decay=decay, beta_2=0.99), loss='sparse_categorical_crossentropy')
|
||||
model.save_weights('lpcnet32c_384_10_G16_00.h5');
|
||||
model.save_weights('lpcnet32v_384_10_G16_00.h5');
|
||||
model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue