Again, same conversion as 3206cec, for NEON

This commit is contained in:
Jean-Marc Valin 2021-07-08 13:20:15 -04:00
parent 7d8b00f11d
commit a1079c2ce3

View file

@ -291,15 +291,15 @@ static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int c
{ {
int i, j; int i, j;
signed char x[MAX_INPUTS]; signed char x[MAX_INPUTS];
int out[MAX_OUTPUTS]; const float32x4_t scale = vdupq_n_f32(SCALE);
const float32x4_t scale_1 = vdupq_n_f32(SCALE_1);
(void)col_stride; (void)col_stride;
for (i=0;i<rows;i++) out[i] = (int)floor(.5+SCALE*_out[i]);
for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]); for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]);
for (i=0;i<rows;i+=8) for (i=0;i<rows;i+=8)
{ {
int32x4_t acc0, acc1; int32x4_t acc0, acc1;
acc0 = vld1q_s32(&out[i]); acc0 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i])));
acc1 = vld1q_s32(&out[i+4]); acc1 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i+4])));
for (j=0;j<cols;j+=4) for (j=0;j<cols;j+=4)
{ {
int8x16_t vw0, vw1, vx; int8x16_t vw0, vw1, vx;
@ -310,25 +310,24 @@ static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int c
acc1 = vdotprod(acc1, vw1, vx); acc1 = vdotprod(acc1, vw1, vx);
w += 32; w += 32;
} }
vst1q_s32(&out[i], acc0); vst1q_f32(&_out[i], vmulq_f32(scale_1, vcvtq_f32_s32(acc0)));
vst1q_s32(&out[i+4], acc1); vst1q_f32(&_out[i+4], vmulq_f32(scale_1, vcvtq_f32_s32(acc1)));
} }
for (i=0;i<rows;i++) _out[i] = SCALE_1*out[i];
} }
static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, const int *idx, const float *_x) static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
{ {
int i, j; int i, j;
signed char x[MAX_INPUTS]; signed char x[MAX_INPUTS];
int out[MAX_OUTPUTS]; const float32x4_t scale = vdupq_n_f32(SCALE);
for (i=0;i<rows;i++) out[i] = (int)floor(.5+SCALE*_out[i]); const float32x4_t scale_1 = vdupq_n_f32(SCALE_1);
for (i=0;i<cols;i++) x[i] = floor(.5+127*_x[i]); for (i=0;i<cols;i++) x[i] = floor(.5+127*_x[i]);
for (i=0;i<rows;i+=8) for (i=0;i<rows;i+=8)
{ {
int colblocks; int colblocks;
int32x4_t acc0, acc1; int32x4_t acc0, acc1;
acc0 = vld1q_s32(&out[i]); acc0 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i])));
acc1 = vld1q_s32(&out[i+4]); acc1 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i+4])));
colblocks = *idx++; colblocks = *idx++;
for (j=0;j<colblocks;j++) for (j=0;j<colblocks;j++)
{ {
@ -342,8 +341,7 @@ static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows
acc1 = vdotprod(acc1, vw1, vx); acc1 = vdotprod(acc1, vw1, vx);
w += 32; w += 32;
} }
vst1q_s32(&out[i], acc0); vst1q_f32(&_out[i], vmulq_f32(scale_1, vcvtq_f32_s32(acc0)));
vst1q_s32(&out[i+4], acc1); vst1q_f32(&_out[i+4], vmulq_f32(scale_1, vcvtq_f32_s32(acc1)));
} }
for (i=0;i<rows;i++) _out[i] = SCALE_1*out[i];
} }