Using 8-bit recurrent weights for GRU B

This commit is contained in:
Jean-Marc Valin 2021-07-21 16:38:35 -04:00
parent 8783ef0088
commit 51ef273e06
4 changed files with 13 additions and 6 deletions

View file

@ -138,7 +138,14 @@ def dump_grub(self, f, hf, gru_a_size):
print("printing layer " + name + " of type " + self.__class__.__name__)
weights = self.get_weights()
qweight = printSparseVector(f, weights[0][:gru_a_size, :], name + '_weights', have_diag=False)
f.write('#ifdef DOT_PROD\n')
qweight = np.clip(np.round(128.*weights[1]).astype('int'), -128, 127)
printVector(f, qweight, name + '_recurrent_weights', dotp=True, dtype='qweight')
f.write('#else /*DOT_PROD*/\n')
printVector(f, weights[1], name + '_recurrent_weights')
f.write('#endif /*DOT_PROD*/\n')
printVector(f, weights[-1], name + '_bias')
subias = weights[-1].copy()
subias[0,:] = subias[0,:] - np.sum(qweight*(1./128.),axis=0)