diff --git a/libcelt/bands.c b/libcelt/bands.c index a78fe1ec..75ede945 100644 --- a/libcelt/bands.c +++ b/libcelt/bands.c @@ -448,15 +448,17 @@ int folding_decision(const CELTMode *m, celt_norm *X, celt_word16 *average, int } /* Quantisation of the residual */ -void quant_bands(const CELTMode *m, int start, celt_norm * restrict X, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int resynth, int total_bits, int encode, void *enc_dec, int M) +void quant_bands(const CELTMode *m, int start, celt_norm * restrict X, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int resynth, int total_bits, int encode, void *enc_dec, int LM) { int i, j, remaining_bits, balance; const celt_int16 * restrict eBands = m->eBands; celt_norm * restrict norm; VARDECL(celt_norm, _norm); int B; + int M; SAVE_STACK; + M = 1<nbEBands+1], celt_norm); norm = _norm; @@ -472,7 +474,7 @@ void quant_bands(const CELTMode *m, int start, celt_norm * restrict X, const cel int curr_balance, curr_bits; N = M*eBands[i+1]-M*eBands[i]; - BPbits = m->bits[FULL_FRAME(m)]; + BPbits = m->bits[LM]; if (encode) tell = ec_enc_tell(enc_dec, BITRES); @@ -522,7 +524,7 @@ void quant_bands(const CELTMode *m, int start, celt_norm * restrict X, const cel #ifndef DISABLE_STEREO -void quant_bands_stereo(const CELTMode *m, int start, celt_norm *_X, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int resynth, int total_bits, ec_enc *enc, int M) +void quant_bands_stereo(const CELTMode *m, int start, celt_norm *_X, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int resynth, int total_bits, ec_enc *enc, int LM) { int i, j, remaining_bits, balance; const celt_int16 * restrict eBands = m->eBands; @@ -530,8 +532,10 @@ void quant_bands_stereo(const CELTMode *m, int start, celt_norm *_X, const celt_ VARDECL(celt_norm, _norm); int B; celt_word16 mid, side; + int M; SAVE_STACK; + M = 1<nbEBands+1], celt_norm); norm = _norm; @@ -552,7 +556,7 @@ void quant_bands_stereo(const CELTMode *m, int start, celt_norm *_X, const celt_ X = _X+M*eBands[i]; Y = X+M*eBands[m->nbEBands+1]; - BPbits = m->bits[FULL_FRAME(m)]; + BPbits = m->bits[LM]; N = M*eBands[i+1]-M*eBands[i]; tell = ec_enc_tell(enc, BITRES); @@ -758,7 +762,7 @@ void quant_bands_stereo(const CELTMode *m, int start, celt_norm *_X, const celt_ #ifndef DISABLE_STEREO -void unquant_bands_stereo(const CELTMode *m, int start, celt_norm *_X, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int total_bits, ec_dec *dec, int M) +void unquant_bands_stereo(const CELTMode *m, int start, celt_norm *_X, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int total_bits, ec_dec *dec, int LM) { int i, j, remaining_bits, balance; const celt_int16 * restrict eBands = m->eBands; @@ -766,8 +770,10 @@ void unquant_bands_stereo(const CELTMode *m, int start, celt_norm *_X, const cel VARDECL(celt_norm, _norm); int B; celt_word16 mid, side; + int M; SAVE_STACK; + M = 1<nbEBands+1], celt_norm); norm = _norm; @@ -789,7 +795,7 @@ void unquant_bands_stereo(const CELTMode *m, int start, celt_norm *_X, const cel X = _X+M*eBands[i]; Y = X+M*eBands[m->nbEBands+1]; - BPbits = m->bits[FULL_FRAME(m)]; + BPbits = m->bits[LM]; N = M*eBands[i+1]-M*eBands[i]; tell = ec_dec_tell(dec, BITRES); diff --git a/libcelt/celt.c b/libcelt/celt.c index 8a9aad74..e5a0312f 100644 --- a/libcelt/celt.c +++ b/libcelt/celt.c @@ -323,16 +323,16 @@ static int transient_analysis(const celt_word32 * restrict in, int len, int C, /** Apply window and compute the MDCT for all sub-frames and all channels in a frame */ -static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * restrict in, celt_sig * restrict out, int _C) +static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * restrict in, celt_sig * restrict out, int _C, int LM) { const int C = CHANNELS(_C); if (C==1 && !shortBlocks) { - const mdct_lookup *lookup = &mode->mdct[FULL_FRAME(mode)]; + const mdct_lookup *lookup = &mode->mdct[LM]; const int overlap = OVERLAP(mode); clt_mdct_forward(lookup, in, out, mode->window, overlap); } else { - const mdct_lookup *lookup = &mode->mdct[FULL_FRAME(mode)]; + const mdct_lookup *lookup = &mode->mdct[LM]; const int overlap = OVERLAP(mode); int N = FRAMESIZE(mode); int B = 1; @@ -367,7 +367,7 @@ static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig * rest /** Compute the IMDCT and apply window for all sub-frames and all channels in a frame */ -static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X, int transient_time, int transient_shift, celt_sig * restrict out_mem, int _C) +static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X, int transient_time, int transient_shift, celt_sig * restrict out_mem, int _C, int LM) { int c, N4; const int C = CHANNELS(_C); @@ -378,7 +378,7 @@ static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X { int j; if (transient_shift==0 && C==1 && !shortBlocks) { - const mdct_lookup *lookup = &mode->mdct[FULL_FRAME(mode)]; + const mdct_lookup *lookup = &mode->mdct[LM]; clt_mdct_backward(lookup, X, out_mem+C*(MAX_PERIOD-N-N4), mode->window, overlap); } else { VARDECL(celt_word32, x); @@ -387,7 +387,7 @@ static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X int N2 = N; int B = 1; int n4offset=0; - const mdct_lookup *lookup = &mode->mdct[FULL_FRAME(mode)]; + const mdct_lookup *lookup = &mode->mdct[LM]; SAVE_STACK; ALLOC(x, 2*N, celt_word32); @@ -544,10 +544,10 @@ static void mdct_shape(const CELTMode *mode, celt_norm *X, int start, #ifdef FIXED_POINT -int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_synthesis, unsigned char *compressed, int nbCompressedBytes) +int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * optional_synthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes) { #else -int celt_encode_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig * optional_synthesis, unsigned char *compressed, int nbCompressedBytes) +int celt_encode_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig * optional_synthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes) { #endif int i, c, N, NN, N4; @@ -581,7 +581,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig int gain_id=0; int norm_rate; int start=0; - const int M=st->mode->nbShortMdcts; + int LM, M; SAVE_STACK; if (check_encoder(st) != CELT_OK) @@ -591,8 +591,14 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig return CELT_INVALID_MODE; if (nbCompressedBytes<0 || pcm==NULL) - return CELT_BAD_ARG; + return CELT_BAD_ARG; + for (LM=0;LM<4;LM++) + if (st->mode->shortMdctSize<=MAX_CONFIG_SIZES) + return CELT_BAD_ARG; + M=1<mode->nbEBands*C, celt_ener); ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16); /* Compute MDCTs */ - compute_mdcts(st->mode, shortBlocks, in, freq, C); + compute_mdcts(st->mode, shortBlocks, in, freq, C, LM); norm_rate = (nbCompressedBytes-5)*8*(celt_uint32)st->mode->Fs/(C*N)>>10; @@ -692,7 +698,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig ALLOC(pitch_freq, C*N, celt_sig); /**< Interleaved signal MDCTs */ if (has_pitch) { - compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, pitch_freq, C); + compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, pitch_freq, C, LM); has_pitch = compute_pitch_gain(st->mode, freq, pitch_freq, norm_rate, &gain_id, C, &st->gain_prod); } @@ -873,10 +879,10 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig /* Residual quantisation */ if (C==1) - quant_bands(st->mode, start, X, bandE, pulses, shortBlocks, has_fold, resynth, nbCompressedBytes*8, 1, &enc, M); + quant_bands(st->mode, start, X, bandE, pulses, shortBlocks, has_fold, resynth, nbCompressedBytes*8, 1, &enc, LM); #ifndef DISABLE_STEREO else - quant_bands_stereo(st->mode, start, X, bandE, pulses, shortBlocks, has_fold, resynth, nbCompressedBytes*8, &enc, M); + quant_bands_stereo(st->mode, start, X, bandE, pulses, shortBlocks, has_fold, resynth, nbCompressedBytes*8, &enc, LM); #endif quant_energy_finalise(st->mode, start, bandE, st->oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(&enc, 0), &enc, C); @@ -900,7 +906,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig if (has_pitch) apply_pitch(st->mode, freq, pitch_freq, gain_id, 0, C); - compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem, C); + compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem, C, LM); /* De-emphasis and put everything back at the right place in the synthesis history */ @@ -918,9 +924,9 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig * pcm, celt_sig #ifdef FIXED_POINT #ifndef DISABLE_FLOAT_API -int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, unsigned char *compressed, int nbCompressedBytes) +int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * optional_synthesis, int frame_size, unsigned char *compressed, int nbCompressedBytes) { - int j, ret, C, N; + int j, ret, C, N, LM, M; VARDECL(celt_int16, in); SAVE_STACK; @@ -933,19 +939,26 @@ int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * opti if (pcm==NULL) return CELT_BAD_ARG; + for (LM=0;LM<4;LM++) + if (st->mode->shortMdctSize<=MAX_CONFIG_SIZES) + return CELT_BAD_ARG; + M=1<channels); - N = st->mode->nbShortMdcts*st->mode->shortMdctSize; + N = M*st->mode->shortMdctSize; ALLOC(in, C*N, celt_int16); for (j=0;jmode->nbShortMdcts; SAVE_STACK; if (check_encoder(st) != CELT_OK) @@ -969,6 +981,13 @@ int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * if (pcm==NULL) return CELT_BAD_ARG; + for (LM=0;LM<4;LM++) + if (st->mode->shortMdctSize<=MAX_CONFIG_SIZES) + return CELT_BAD_ARG; + M=1<channels); N=M*st->mode->shortMdctSize; ALLOC(in, C*N, celt_sig); @@ -977,11 +996,11 @@ int celt_encode(CELTEncoder * restrict st, const celt_int16 * pcm, celt_int16 * } if (optional_synthesis != NULL) { - ret = celt_encode_float(st,in,in,compressed,nbCompressedBytes); + ret = celt_encode_float(st,in,in,frame_size,compressed,nbCompressedBytes); for (j=0;j= MAX_PERIOD) offset -= pitch_index; - compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq, C); + compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq, C, LM); for (i=0;iout_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->mode->overlap-N)); /* Compute inverse MDCTs */ - compute_inv_mdcts(st->mode, 0, freq, -1, 0, st->out_mem, C); + compute_inv_mdcts(st->mode, 0, freq, -1, 0, st->out_mem, C, LM); #else for (c=0;cmode->nbShortMdcts; + int LM, M; SAVE_STACK; if (check_decoder(st) != CELT_OK) @@ -1471,6 +1490,13 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int if (pcm==NULL) return CELT_BAD_ARG; + for (LM=0;LM<4;LM++) + if (st->mode->shortMdctSize<=MAX_CONFIG_SIZES) + return CELT_BAD_ARG; + M=1<mode->shortMdctSize; N4 = (N-st->overlap)>>1; @@ -1480,7 +1506,7 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int if (data == NULL) { - celt_decode_lost(st, pcm, N); + celt_decode_lost(st, pcm, N, LM); RESTORE_STACK; return 0; } @@ -1546,15 +1572,15 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int if (has_pitch) { /* Pitch MDCT */ - compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, pitch_freq, C); + compute_mdcts(st->mode, 0, st->out_mem+pitch_index*C, pitch_freq, C, LM); } /* Decode fixed codebook and merge with pitch */ if (C==1) - quant_bands(st->mode, start, X, bandE, pulses, isTransient, has_fold, 1, len*8, 0, &dec, M); + quant_bands(st->mode, start, X, bandE, pulses, isTransient, has_fold, 1, len*8, 0, &dec, LM); #ifndef DISABLE_STEREO else - unquant_bands_stereo(st->mode, start, X, bandE, pulses, isTransient, has_fold, len*8, &dec, M); + unquant_bands_stereo(st->mode, start, X, bandE, pulses, isTransient, has_fold, len*8, &dec, LM); #endif unquant_energy_finalise(st->mode, start, bandE, st->oldBandE, fine_quant, fine_priority, len*8-ec_dec_tell(&dec, 0), &dec, C); @@ -1576,7 +1602,7 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int freq[i] = 0; /* Compute inverse MDCTs */ - compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem, C); + compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem, C, LM); deemphasis(st->out_mem, pcm, N, C, preemph, st->preemph_memD); st->loss_count = 0; @@ -1586,9 +1612,9 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int #ifdef FIXED_POINT #ifndef DISABLE_FLOAT_API -int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm) +int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm, int frame_size) { - int j, ret, C, N; + int j, ret, C, N, LM, M; VARDECL(celt_int16, out); SAVE_STACK; @@ -1601,11 +1627,18 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int if (pcm==NULL) return CELT_BAD_ARG; + for (LM=0;LM<4;LM++) + if (st->mode->shortMdctSize<=MAX_CONFIG_SIZES) + return CELT_BAD_ARG; + M=1<channels); - N = st->mode->nbShortMdcts*st->mode->shortMdctSize; + N = M*st->mode->shortMdctSize; ALLOC(out, C*N, celt_int16); - ret=celt_decode(st, data, len, out); + ret=celt_decode(st, data, len, out, frame_size); for (j=0;jmode->shortMdctSize<=MAX_CONFIG_SIZES) + return CELT_BAD_ARG; + M=1<channels); - N = st->mode->nbShortMdcts*st->mode->shortMdctSize; + N = M*st->mode->shortMdctSize; ALLOC(out, C*N, celt_sig); - ret=celt_decode_float(st, data, len, out); + ret=celt_decode_float(st, data, len, out, frame_size); for (j=0;jnbShortMdcts) - i++; - return i; -} - int check_mode(const CELTMode *mode); #endif diff --git a/libcelt/testcelt.c b/libcelt/testcelt.c index a4ada410..ddd851bd 100644 --- a/libcelt/testcelt.c +++ b/libcelt/testcelt.c @@ -134,7 +134,7 @@ int main(int argc, char *argv[]) err = fread(in, sizeof(short), frame_size*channels, fin); if (feof(fin)) break; - len = celt_encode(enc, in, in, data, bytes_per_packet); + len = celt_encode(enc, in, in, frame_size, data, bytes_per_packet); if (len <= 0) { fprintf (stderr, "celt_encode() returned %d\n", len); @@ -166,9 +166,9 @@ int main(int argc, char *argv[]) /* This is to simulate packet loss */ if (argc==9 && rand()%1000