Sparse matrix indexing optimization

The 4* is now stored in the table to avoid computing it in the loop
This commit is contained in:
Jean-Marc Valin 2021-07-06 17:05:07 -04:00
parent 2681822c18
commit 54abdb6f5d
4 changed files with 10 additions and 10 deletions

View file

@ -80,7 +80,7 @@ def printSparseVector(f, A, name):
qblock = AQ[j*4:(j+1)*4, i*8:(i+1)*8] qblock = AQ[j*4:(j+1)*4, i*8:(i+1)*8]
if np.sum(np.abs(block)) > 1e-10: if np.sum(np.abs(block)) > 1e-10:
nb_nonzero = nb_nonzero + 1 nb_nonzero = nb_nonzero + 1
idx = np.append(idx, j) idx = np.append(idx, j*4)
vblock = qblock.transpose((1,0)).reshape((-1,)) vblock = qblock.transpose((1,0)).reshape((-1,))
W0 = np.concatenate([W0, block.reshape((-1,))]) W0 = np.concatenate([W0, block.reshape((-1,))])
W = np.concatenate([W, vblock]) W = np.concatenate([W, vblock])

View file

@ -250,7 +250,7 @@ static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows,
int pos; int pos;
float * restrict y; float * restrict y;
int xj0, xj1, xj2, xj3; int xj0, xj1, xj2, xj3;
pos = 4 * (*idx++); pos = (*idx++);
xj0 = x[pos+0]; xj0 = x[pos+0];
xj1 = x[pos+1]; xj1 = x[pos+1];
xj2 = x[pos+2]; xj2 = x[pos+2];
@ -318,7 +318,7 @@ static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows,
int pos; int pos;
float * restrict y; float * restrict y;
int xj0, xj1, xj2, xj3; int xj0, xj1, xj2, xj3;
pos = 4 * (*idx++); pos = (*idx++);
xj0 = x[pos+0]; xj0 = x[pos+0];
xj1 = x[pos+1]; xj1 = x[pos+1];
xj2 = x[pos+2]; xj2 = x[pos+2];
@ -357,7 +357,7 @@ static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows,
int pos; int pos;
float * restrict y; float * restrict y;
float xj0, xj1, xj2, xj3; float xj0, xj1, xj2, xj3;
pos = 4 * (*idx++); pos = (*idx++);
xj0 = x[pos+0]; xj0 = x[pos+0];
xj1 = x[pos+1]; xj1 = x[pos+1];
xj2 = x[pos+2]; xj2 = x[pos+2];

View file

@ -508,7 +508,7 @@ static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows
__m256i vxj; __m256i vxj;
__m256i vw; __m256i vw;
int pos; int pos;
pos = 4 * (*idx++); pos = (*idx++);
vxj = _mm256_set1_epi32(*(int*)&x[pos]); vxj = _mm256_set1_epi32(*(int*)&x[pos]);
vw = _mm256_loadu_si256((const __m256i *)w); //_mm256_lddqu_si256? vw = _mm256_loadu_si256((const __m256i *)w); //_mm256_lddqu_si256?
tmp = _mm256_maddubs_epi16(vxj, vw); //swap? tmp = _mm256_maddubs_epi16(vxj, vw); //swap?
@ -544,19 +544,19 @@ static inline void sparse_sgemv_accum8x4(float *out, const qweight *weights, int
__m256 vxj; __m256 vxj;
__m256 vw; __m256 vw;
id = *idx++; id = *idx++;
vxj = _mm256_broadcast_ss(&x[4*id]); vxj = _mm256_broadcast_ss(&x[id]);
vw = _mm256_loadu_ps(&weights[0]); vw = _mm256_loadu_ps(&weights[0]);
vy0 = _mm256_fmadd_ps(vw, vxj, vy0); vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
vxj = _mm256_broadcast_ss(&x[4*id+1]); vxj = _mm256_broadcast_ss(&x[id+1]);
vw = _mm256_loadu_ps(&weights[8]); vw = _mm256_loadu_ps(&weights[8]);
vy0 = _mm256_fmadd_ps(vw, vxj, vy0); vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
vxj = _mm256_broadcast_ss(&x[4*id+2]); vxj = _mm256_broadcast_ss(&x[id+2]);
vw = _mm256_loadu_ps(&weights[16]); vw = _mm256_loadu_ps(&weights[16]);
vy0 = _mm256_fmadd_ps(vw, vxj, vy0); vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
vxj = _mm256_broadcast_ss(&x[4*id+3]); vxj = _mm256_broadcast_ss(&x[id+3]);
vw = _mm256_loadu_ps(&weights[24]); vw = _mm256_loadu_ps(&weights[24]);
vy0 = _mm256_fmadd_ps(vw, vxj, vy0); vy0 = _mm256_fmadd_ps(vw, vxj, vy0);

View file

@ -333,7 +333,7 @@ static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows
for (j=0;j<colblocks;j++) for (j=0;j<colblocks;j++)
{ {
int pos; int pos;
pos = 4 * (*idx++); pos = (*idx++);
int8x16_t vw0, vw1, vx; int8x16_t vw0, vw1, vx;
vx = (int8x16_t)vld1q_dup_s32((int*)&x[pos]); vx = (int8x16_t)vld1q_dup_s32((int*)&x[pos]);
vw0 = vld1q_s8(w); vw0 = vld1q_s8(w);