diff --git a/dnn/nnet.c b/dnn/nnet.c index 97ac74f3..179e7063 100644 --- a/dnn/nnet.c +++ b/dnn/nnet.c @@ -394,6 +394,34 @@ void conv2d_float(float *out, const float *weights, int in_channels, int out_cha } } +void conv2d_3x3_float(float *out, const float *weights, int in_channels, int out_channels, const float *in, int height, int hstride) +{ + int i; + int in_stride; + int kheight, ktime; + kheight = ktime = 3; + in_stride = height+kheight-1; + for (i=0;iktime-1)*time_stride], in, time_stride); OPUS_COPY(mem, &in_buf[time_stride], (conv->ktime-1)*time_stride); bias = conv->bias; - conv2d_float(out, conv->float_weights, conv->in_channels, conv->out_channels, conv->ktime, conv->kheight, in_buf, height, hstride); + if (conv->kheight == 3 && conv->ktime == 3) + conv2d_3x3_float(out, conv->float_weights, conv->in_channels, conv->out_channels, in_buf, height, hstride); + else + conv2d_float(out, conv->float_weights, conv->in_channels, conv->out_channels, conv->ktime, conv->kheight, in_buf, height, hstride); if (bias != NULL) { for (i=0;iout_channels;i++) { int j;