mirror of
https://github.com/xiph/opus.git
synced 2025-05-25 04:39:13 +00:00
WIP: 8-bit SIMD for GRU B
This commit is contained in:
parent
e695355ba5
commit
40b309d92b
5 changed files with 58 additions and 7 deletions
|
@ -39,7 +39,10 @@ max_rnn_neurons = 1
|
|||
max_conv_inputs = 1
|
||||
max_mdense_tmp = 1
|
||||
|
||||
def printVector(f, vector, name, dtype='float'):
|
||||
def printVector(f, vector, name, dtype='float', dotp=False):
|
||||
if dotp:
|
||||
vector = vector.reshape((vector.shape[0]//4, 4, vector.shape[1]//8, 8))
|
||||
vector = vector.transpose((2, 0, 3, 1))
|
||||
v = np.reshape(vector, (-1));
|
||||
#print('static const float ', name, '[', len(v), '] = \n', file=f)
|
||||
f.write('static const {} {}[{}] = {{\n '.format(dtype, name, len(v)))
|
||||
|
@ -127,7 +130,12 @@ def dump_gru_layer(self, f, hf):
|
|||
name = self.name
|
||||
print("printing layer " + name + " of type " + self.__class__.__name__)
|
||||
weights = self.get_weights()
|
||||
f.write('#ifdef DOT_PROD\n')
|
||||
qweight = np.clip((128*weights[0]).astype('int'), -128, 127)
|
||||
printVector(f, qweight, name + '_weights', dotp=True, dtype='qweight')
|
||||
f.write('#else /*DOT_PROD*/\n')
|
||||
printVector(f, weights[0], name + '_weights')
|
||||
f.write('#endif /*DOT_PROD*/\n')
|
||||
printVector(f, weights[1], name + '_recurrent_weights')
|
||||
printVector(f, weights[-1], name + '_bias')
|
||||
if hasattr(self, 'activation'):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue