mirror of
https://github.com/xiph/opus.git
synced 2025-06-03 09:07:42 +00:00
Using 8-bit recurrent weights for GRU B
This commit is contained in:
parent
8783ef0088
commit
51ef273e06
4 changed files with 13 additions and 6 deletions
|
@ -138,7 +138,14 @@ def dump_grub(self, f, hf, gru_a_size):
|
|||
print("printing layer " + name + " of type " + self.__class__.__name__)
|
||||
weights = self.get_weights()
|
||||
qweight = printSparseVector(f, weights[0][:gru_a_size, :], name + '_weights', have_diag=False)
|
||||
|
||||
f.write('#ifdef DOT_PROD\n')
|
||||
qweight = np.clip(np.round(128.*weights[1]).astype('int'), -128, 127)
|
||||
printVector(f, qweight, name + '_recurrent_weights', dotp=True, dtype='qweight')
|
||||
f.write('#else /*DOT_PROD*/\n')
|
||||
printVector(f, weights[1], name + '_recurrent_weights')
|
||||
f.write('#endif /*DOT_PROD*/\n')
|
||||
|
||||
printVector(f, weights[-1], name + '_bias')
|
||||
subias = weights[-1].copy()
|
||||
subias[0,:] = subias[0,:] - np.sum(qweight*(1./128.),axis=0)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue