wip 8x4 sparseness

This commit is contained in:
Jean-Marc Valin 2020-12-19 19:25:59 -05:00
parent 8e405b44e0
commit cc28518699
3 changed files with 13 additions and 9 deletions

View file

@ -74,12 +74,14 @@ class Sparsify(Callback):
A = p[:, k*N:(k+1)*N]
A = A - np.diag(np.diag(A))
#A = np.transpose(A, (1, 0))
L=np.reshape(A, (N, N//16, 16))
L=np.reshape(A, (N//4, 4, N//8, 8))
S=np.sum(L*L, axis=-1)
S=np.sum(S, axis=1)
SS=np.sort(np.reshape(S, (-1,)))
thresh = SS[round(N*N//16*(1-density))]
thresh = SS[round(N*N//32*(1-density))]
mask = (S>=thresh).astype('float32');
mask = np.repeat(mask, 16, axis=1)
mask = np.repeat(mask, 4, axis=0)
mask = np.repeat(mask, 8, axis=1)
mask = np.minimum(1, mask + np.diag(np.ones((N,))))
#mask = np.transpose(mask, (1, 0))
p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask