Controlling per-gate sparsity

This commit is contained in:
Jean-Marc Valin 2018-12-10 16:15:50 -05:00
parent b9e0ea23e0
commit 8c271d60c4
2 changed files with 5 additions and 5 deletions

View file

@ -64,12 +64,12 @@ class Sparsify(Callback):
N = p.shape[0]
#print("nb = ", nb, ", N = ", N);
#print(p.shape)
density = self.final_density
if self.batch < self.t_end:
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
density = 1 - (1-self.final_density)*(1 - r*r*r)
#print ("density = ", density)
for k in range(nb):
density = self.final_density[k]
if self.batch < self.t_end:
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
density = 1 - (1-self.final_density[k])*(1 - r*r*r)
A = p[:, k*N:(k+1)*N]
A = A - np.diag(np.diag(A))
A = np.transpose(A, (1, 0))