mirror of
https://github.com/xiph/opus.git
synced 2025-06-04 09:37:44 +00:00
Fix signed-unsigned biases
This commit is contained in:
parent
51ef273e06
commit
e4b4613d05
2 changed files with 8 additions and 2 deletions
|
@ -324,8 +324,13 @@ void compute_gruB(const GRULayer *gru, const float* gru_b_condition, float *stat
|
|||
zrh[i] = gru->bias[i] + gru_b_condition[i];
|
||||
#endif
|
||||
sparse_sgemv_accum8x4(zrh, gru->input_weights, 3*N, M, gru->input_weights_idx, input);
|
||||
#ifdef USE_SU_BIAS
|
||||
for (i=0;i<3*N;i++)
|
||||
recur[i] = gru->subias[3*N + i];
|
||||
#else
|
||||
for (i=0;i<3*N;i++)
|
||||
recur[i] = gru->bias[3*N + i];
|
||||
#endif
|
||||
sgemv_accum8x4(recur, gru->recurrent_weights, 3*N, N, stride, state);
|
||||
for (i=0;i<2*N;i++)
|
||||
zrh[i] += recur[i];
|
||||
|
|
|
@ -140,8 +140,8 @@ def dump_grub(self, f, hf, gru_a_size):
|
|||
qweight = printSparseVector(f, weights[0][:gru_a_size, :], name + '_weights', have_diag=False)
|
||||
|
||||
f.write('#ifdef DOT_PROD\n')
|
||||
qweight = np.clip(np.round(128.*weights[1]).astype('int'), -128, 127)
|
||||
printVector(f, qweight, name + '_recurrent_weights', dotp=True, dtype='qweight')
|
||||
qweight2 = np.clip(np.round(128.*weights[1]).astype('int'), -128, 127)
|
||||
printVector(f, qweight2, name + '_recurrent_weights', dotp=True, dtype='qweight')
|
||||
f.write('#else /*DOT_PROD*/\n')
|
||||
printVector(f, weights[1], name + '_recurrent_weights')
|
||||
f.write('#endif /*DOT_PROD*/\n')
|
||||
|
@ -149,6 +149,7 @@ def dump_grub(self, f, hf, gru_a_size):
|
|||
printVector(f, weights[-1], name + '_bias')
|
||||
subias = weights[-1].copy()
|
||||
subias[0,:] = subias[0,:] - np.sum(qweight*(1./128.),axis=0)
|
||||
subias[1,:] = subias[1,:] - np.sum(qweight2*(1./128.),axis=0)
|
||||
printVector(f, subias, name + '_subias')
|
||||
if hasattr(self, 'activation'):
|
||||
activation = self.activation.__name__.upper()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue