From e4b4613d05065f128e33f7b1704134ed85f6ea23 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Wed, 21 Jul 2021 22:35:02 -0400 Subject: [PATCH] Fix signed-unsigned biases --- dnn/nnet.c | 5 +++++ dnn/training_tf2/dump_lpcnet.py | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/dnn/nnet.c b/dnn/nnet.c index ea64e3cf..7f4914c4 100644 --- a/dnn/nnet.c +++ b/dnn/nnet.c @@ -324,8 +324,13 @@ void compute_gruB(const GRULayer *gru, const float* gru_b_condition, float *stat zrh[i] = gru->bias[i] + gru_b_condition[i]; #endif sparse_sgemv_accum8x4(zrh, gru->input_weights, 3*N, M, gru->input_weights_idx, input); +#ifdef USE_SU_BIAS + for (i=0;i<3*N;i++) + recur[i] = gru->subias[3*N + i]; +#else for (i=0;i<3*N;i++) recur[i] = gru->bias[3*N + i]; +#endif sgemv_accum8x4(recur, gru->recurrent_weights, 3*N, N, stride, state); for (i=0;i<2*N;i++) zrh[i] += recur[i]; diff --git a/dnn/training_tf2/dump_lpcnet.py b/dnn/training_tf2/dump_lpcnet.py index 8eac0db7..26108dbd 100755 --- a/dnn/training_tf2/dump_lpcnet.py +++ b/dnn/training_tf2/dump_lpcnet.py @@ -140,8 +140,8 @@ def dump_grub(self, f, hf, gru_a_size): qweight = printSparseVector(f, weights[0][:gru_a_size, :], name + '_weights', have_diag=False) f.write('#ifdef DOT_PROD\n') - qweight = np.clip(np.round(128.*weights[1]).astype('int'), -128, 127) - printVector(f, qweight, name + '_recurrent_weights', dotp=True, dtype='qweight') + qweight2 = np.clip(np.round(128.*weights[1]).astype('int'), -128, 127) + printVector(f, qweight2, name + '_recurrent_weights', dotp=True, dtype='qweight') f.write('#else /*DOT_PROD*/\n') printVector(f, weights[1], name + '_recurrent_weights') f.write('#endif /*DOT_PROD*/\n') @@ -149,6 +149,7 @@ def dump_grub(self, f, hf, gru_a_size): printVector(f, weights[-1], name + '_bias') subias = weights[-1].copy() subias[0,:] = subias[0,:] - np.sum(qweight*(1./128.),axis=0) + subias[1,:] = subias[1,:] - np.sum(qweight2*(1./128.),axis=0) printVector(f, subias, name + '_subias') if hasattr(self, 'activation'): activation = self.activation.__name__.upper()