diff --git a/libcelt/celt.c b/libcelt/celt.c index 4d585f6a..c02f5e9d 100644 --- a/libcelt/celt.c +++ b/libcelt/celt.c @@ -400,74 +400,6 @@ static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X } } -#define FLAG_NONE 0 -#define FLAG_INTRA (1U<<13) -#define FLAG_PITCH (1U<<12) -#define FLAG_SHORT (1U<<11) -#define FLAG_FOLD (1U<<10) -#define FLAG_MASK (FLAG_INTRA|FLAG_PITCH|FLAG_SHORT|FLAG_FOLD) - -static const int flaglist[8] = { - 0 /*00 */ | FLAG_FOLD, - 1 /*01 */ | FLAG_PITCH|FLAG_FOLD, - 8 /*1000*/ | FLAG_NONE, - 9 /*1001*/ | FLAG_SHORT|FLAG_FOLD, - 10 /*1010*/ | FLAG_PITCH, - 11 /*1011*/ | FLAG_INTRA, - 6 /*110 */ | FLAG_INTRA|FLAG_FOLD, - 7 /*111 */ | FLAG_INTRA|FLAG_SHORT|FLAG_FOLD -}; - -static void encode_flags(ec_enc *enc, int intra_ener, int shortBlocks, int has_fold) -{ - int i; - int flags=FLAG_NONE; - int flag_bits; - flags |= intra_ener ? FLAG_INTRA : 0; - flags |= shortBlocks ? FLAG_SHORT : 0; - flags |= has_fold ? FLAG_FOLD : 0; - for (i=0;i<8;i++) - { - if (flags == (flaglist[i]&FLAG_MASK)) - { - flag_bits = flaglist[i]&0xf; - break; - } - } - celt_assert(i<8); - /*printf ("enc %d: %d %d %d %d\n", flag_bits, intra_ener, has_pitch, shortBlocks, has_fold);*/ - if (i<2) - ec_enc_uint(enc, flag_bits, 4); - else if (i<6) - { - ec_enc_uint(enc, flag_bits>>2, 4); - ec_enc_uint(enc, flag_bits&0x3, 4); - } else { - ec_enc_uint(enc, flag_bits>>1, 4); - ec_enc_uint(enc, flag_bits&0x1, 2); - } -} - -static void decode_flags(ec_dec *dec, int *intra_ener, int *shortBlocks, int *has_fold) -{ - int i; - int flag_bits; - flag_bits = ec_dec_uint(dec, 4); - /*printf ("(%d) ", flag_bits);*/ - if (flag_bits==2) - flag_bits = (flag_bits<<2) | ec_dec_uint(dec, 4); - else if (flag_bits==3) - flag_bits = (flag_bits<<1) | ec_dec_uint(dec, 2); - for (i=0;i<8;i++) - if (flag_bits == (flaglist[i]&0xf)) - break; - celt_assert(i<8); - *intra_ener = (flaglist[i]&FLAG_INTRA) != 0; - *shortBlocks = (flaglist[i]&FLAG_SHORT) != 0; - *has_fold = (flaglist[i]&FLAG_FOLD ) != 0; - /*printf ("dec %d: %d %d %d %d\n", flag_bits, *intra_ener, *has_pitch, *shortBlocks, *has_fold);*/ -} - static void deemphasis(celt_sig *in, celt_word16 *pcm, int N, int _C, const celt_word16 *coef, celt_sig *mem) { const int C = CHANNELS(_C); @@ -877,8 +809,12 @@ int celt_encode_with_ec_float(CELTEncoder * restrict st, const celt_sig * pcm, c mdct_shape(st->mode, X, mdct_weight_pos+1, M, N, mdct_weight_shift, effEnd, C, 0, M); } + /* Encode the global flags using a simple probability model + (first symbols in the stream) */ + ec_enc_bit_prob(enc, intra_ener, 8192); + ec_enc_bit_prob(enc, shortBlocks!=0, 8192); + ec_enc_bit_prob(enc, has_fold, 57344); - encode_flags(enc, intra_ener, shortBlocks, has_fold); if (shortBlocks) { if (transient_shift) @@ -1676,7 +1612,11 @@ int celt_decode_with_ec_float(CELTDecoder * restrict st, const unsigned char *da } nbAvailableBytes = len-nbFilledBytes; - decode_flags(dec, &intra_ener, &isTransient, &has_fold); + /* Decode the global flags (first symbols in the stream) */ + intra_ener = ec_dec_bit_prob(dec, 8192); + isTransient = ec_dec_bit_prob(dec, 8192); + has_fold = ec_dec_bit_prob(dec, 57344); + if (isTransient) shortBlocks = M; else