mirror of
https://github.com/xiph/opus.git
synced 2025-05-21 02:48:29 +00:00
Controlling per-gate sparsity
This commit is contained in:
parent
b9e0ea23e0
commit
8c271d60c4
2 changed files with 5 additions and 5 deletions
|
@ -64,12 +64,12 @@ class Sparsify(Callback):
|
|||
N = p.shape[0]
|
||||
#print("nb = ", nb, ", N = ", N);
|
||||
#print(p.shape)
|
||||
density = self.final_density
|
||||
if self.batch < self.t_end:
|
||||
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
|
||||
density = 1 - (1-self.final_density)*(1 - r*r*r)
|
||||
#print ("density = ", density)
|
||||
for k in range(nb):
|
||||
density = self.final_density[k]
|
||||
if self.batch < self.t_end:
|
||||
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
|
||||
density = 1 - (1-self.final_density[k])*(1 - r*r*r)
|
||||
A = p[:, k*N:(k+1)*N]
|
||||
A = A - np.diag(np.diag(A))
|
||||
A = np.transpose(A, (1, 0))
|
||||
|
|
|
@ -147,4 +147,4 @@ checkpoint = ModelCheckpoint('lpcnet15_384_10_G16_{epoch:02d}.h5')
|
|||
|
||||
#model.load_weights('lpcnet9b_384_10_G16_01.h5')
|
||||
model.compile(optimizer=Adam(0.001, amsgrad=True, decay=5e-5), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
|
||||
model.fit([in_data, in_exc, features, periods], out_data, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, lpcnet.Sparsify(2000, 40000, 400, 0.1)])
|
||||
model.fit([in_data, in_exc, features, periods], out_data, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, lpcnet.Sparsify(2000, 40000, 400, (0.1, 0.1, 0.1))])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue