From d883670bf792ad02750fe481646bcd8b4cf6ad72 Mon Sep 17 00:00:00 2001 From: "Timothy B. Terriberry" Date: Fri, 4 Apr 2008 10:16:19 -0400 Subject: [PATCH] Rework CWRS code. This eliminates an extra O(nm) lookups on decode, and reduces the rate control from O(nm^2) to O(nm), in addition to eliminating O(m) lookups on both encode and decode. Although the interface is slightly more complex, the internal code is also simpler. --- libcelt/cwrs.c | 438 ++++++++++++++++++++------------------------ libcelt/cwrs.h | 31 ++-- libcelt/rate.c | 8 +- tests/cwrs32-test.c | 27 +-- tests/cwrs64-test.c | 29 +-- 5 files changed, 260 insertions(+), 273 deletions(-) diff --git a/libcelt/cwrs.c b/libcelt/cwrs.c index 12835de5..60880c61 100644 --- a/libcelt/cwrs.c +++ b/libcelt/cwrs.c @@ -1,4 +1,4 @@ -/* (C) 2007 Timothy B. Terriberry +/* (C) 2007-2008 Timothy B. Terriberry (C) 2008 Jean-Marc Valin */ /* Redistribution and use in source and binary forms, with or without @@ -29,8 +29,13 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -/* Functions for encoding and decoding pulse vectors. For more details, see: - http://people.xiph.org/~tterribe/notes/cwrs.html +/* Functions for encoding and decoding pulse vectors. + These are based on the function + U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1), + U(n,1) = U(1,m) = 2, + which counts the number of ways of placing m pulses in n dimensions, where + at least one pulse lies in dimension 0. + For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html */ #ifdef HAVE_CONFIG_H @@ -38,133 +43,130 @@ #endif #include +#include #include "cwrs.h" #include "mathops.h" -/* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n, - compute ncwrs() for m+1, for all n. Could also be used when m and n are - swapped just by changing nc */ -static inline void next_ncwrs32(celt_uint32_t *nc, int len, int nc0) -{ - int i; - celt_uint32_t mem; - - mem = nc[0]; - nc[0] = nc0; - for (i=1;i0){ t=p>>1; if(t<=_i||_s[k-1])_i+=t; @@ -172,155 +174,85 @@ void cwrsi(int _n,int _m,celt_uint32_t _i,int * restrict _x,int * restrict _s){ while(p<=_i){ _i-=p; j++; - p=pn; - /*pn=ncwrs(_n-j-1,_m-k-1);*/ - pn=nc[_n-j-1]; - p+=pn; + p=_u[_n-j-1]; } t=p>>1; _s[k]=_i>=t; _x[k]=j; if(_s[k])_i-=t; - if (k<_m-2) - prev_ncwrs32(nc, _n-j+1, 0); - else - prev_ncwrs32(nc, _n-j+1, 1); + uprev32(_u,_n-j,2); + } +} + +void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,celt_uint64_t *_u){ + int j; + int k; + for(k=j=0;k<_m;k++){ + celt_uint64_t p; + celt_uint64_t t; + p=_u[_n-j-1]; + if(k>0){ + t=p>>1; + if(t<=_i||_s[k-1])_i+=t; + } + while(p<=_i){ + _i-=p; + j++; + p=_u[_n-j-1]; + } + t=p>>1; + _s[k]=_i>=t; + _x[k]=j; + if(_s[k])_i-=t; + uprev64(_u,_n-j,2); } - RESTORE_STACK; } /*Returns the index of the given combination of _m elements chosen from a set of size _n with associated sign bits. - _x: The combination with elements sorted in ascending order. - _s: The associated sign bits.*/ -celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound){ + _x: The combination with elements sorted in ascending order. + _s: The associated sign bits. + _u: Temporary storage already initialized to column _m of U(n,m). + Its contents will be overwritten.*/ +celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s, + celt_uint32_t *_u){ celt_uint32_t i; - int j; - int k; - VARDECL(celt_uint32_t, nc); - SAVE_STACK; - ALLOC(nc,_n+1, celt_uint32_t); - for (j=0;j<_n+1;j++) - nc[j] = 1; - for (k=0;k<_m;k++) - next_ncwrs32(nc, _n+1, 0); - if (bound) - *bound = nc[_n]; + int j; + int k; i=0; for(k=j=0;k<_m;k++){ - celt_uint32_t pn; celt_uint32_t p; - if (k<_m-1) - prev_ncwrs32(nc, _n-j+1, 0); - else - prev_ncwrs32(nc, _n-j+1, 1); - /*p=ncwrs(_n-j,_m-k-1); - pn=ncwrs(_n-j-1,_m-k-1);*/ - p=nc[_n-j]; - pn=nc[_n-j-1]; - p+=pn; + p=_u[_n-j-1]; if(k>0)p>>=1; while(j<_x[k]){ i+=p; j++; - p=pn; - /*pn=ncwrs(_n-j-1,_m-k-1);*/ - pn=nc[_n-j-1]; - p+=pn; + p=_u[_n-j-1]; } if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1; + uprev32(_u,_n-j,2); } - RESTORE_STACK; return i; } -/*Returns the _i'th combination of _m elements chosen from a set of size _n - with associated sign bits. - _x: Returns the combination with elements sorted in ascending order. - _s: Returns the associated sign bits.*/ -void cwrsi64(int _n,int _m,celt_uint64_t _i,int * restrict _x,int * restrict _s){ - int j; - int k; - VARDECL(celt_uint64_t, nc); - SAVE_STACK; - ALLOC(nc,_n+1, celt_uint64_t); - for (j=0;j<_n+1;j++) - nc[j] = 1; - for (k=0;k<_m-1;k++) - next_ncwrs64(nc, _n+1, 0); - for(k=j=0;k<_m;k++){ - celt_uint64_t pn, p, t; - /*p=ncwrs64(_n-j,_m-k-1); - pn=ncwrs64(_n-j-1,_m-k-1);*/ - p=nc[_n-j]; - pn=nc[_n-j-1]; - p+=pn; - if(k>0){ - t=p>>1; - if(t<=_i||_s[k-1])_i+=t; - } - while(p<=_i){ - _i-=p; - j++; - p=pn; - /*pn=ncwrs64(_n-j-1,_m-k-1);*/ - pn=nc[_n-j-1]; - p+=pn; - } - t=p>>1; - _s[k]=_i>=t; - _x[k]=j; - if(_s[k])_i-=t; - if (k<_m-2) - prev_ncwrs64(nc, _n-j+1, 0); - else - prev_ncwrs64(nc, _n-j+1, 1); - } - RESTORE_STACK; -} - -/*Returns the index of the given combination of _m elements chosen from a set - of size _n with associated sign bits. - _x: The combination with elements sorted in ascending order. - _s: The associated sign bits.*/ -celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){ +celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, + celt_uint64_t *_u){ celt_uint64_t i; int j; int k; - VARDECL(celt_uint64_t, nc); - SAVE_STACK; - ALLOC(nc,_n+1, celt_uint64_t); - for (j=0;j<_n+1;j++) - nc[j] = 1; - for (k=0;k<_m;k++) - next_ncwrs64(nc, _n+1, 0); - if (bound) - *bound = nc[_n]; i=0; for(k=j=0;k<_m;k++){ - celt_uint64_t pn; celt_uint64_t p; - if (k<_m-1) - prev_ncwrs64(nc, _n-j+1, 0); - else - prev_ncwrs64(nc, _n-j+1, 1); - /*p=ncwrs64(_n-j,_m-k-1); - pn=ncwrs64(_n-j-1,_m-k-1);*/ - p=nc[_n-j]; - pn=nc[_n-j-1]; - p+=pn; + p=_u[_n-j-1]; if(k>0)p>>=1; while(j<_x[k]){ i+=p; j++; - p=pn; - /*pn=ncwrs64(_n-j-1,_m-k-1);*/ - pn=nc[_n-j-1]; - p+=pn; + p=_u[_n-j-1]; } if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1; + uprev64(_u,_n-j,2); } - RESTORE_STACK; return i; } @@ -363,47 +295,83 @@ void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){ } } +static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s, + ec_enc *_enc){ + VARDECL(celt_uint32_t,u); + celt_uint32_t nc; + celt_uint32_t i; + SAVE_STACK; + ALLOC(u,_n,celt_uint32_t); + nc=ncwrs_u32(_n,_m,u); + i=icwrs32(_n,_m,_x,_s,u); + ec_enc_uint(_enc,i,nc); + RESTORE_STACK; +} + +static inline void encode_comb64(int _n,int _m,const int *_x,const int *_s, + ec_enc *_enc){ + VARDECL(celt_uint64_t,u); + celt_uint64_t nc; + celt_uint64_t i; + SAVE_STACK; + ALLOC(u,_n,celt_uint64_t); + nc=ncwrs_u64(_n,_m,u); + i=icwrs64(_n,_m,_x,_s,u); + ec_enc_uint64(_enc,i,nc); + RESTORE_STACK; +} + void encode_pulses(int *_y, int N, int K, ec_enc *enc) { VARDECL(int, comb); VARDECL(int, signs); SAVE_STACK; - + ALLOC(comb, K, int); ALLOC(signs, K, int); - + pulse2comb(N, K, comb, signs, _y); /* Simple heuristic to figure out whether it fits in 32 bits */ if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31) { - celt_uint32_t bound, id; - id = icwrs(N, K, comb, signs, &bound); - ec_enc_uint(enc,id,bound); + encode_comb32(N, K, comb, signs, enc); } else { - celt_uint64_t bound, id; - id = icwrs64(N, K, comb, signs, &bound); - ec_enc_uint64(enc,id,bound); + encode_comb64(N, K, comb, signs, enc); } RESTORE_STACK; } +static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){ + VARDECL(celt_uint32_t,u); + SAVE_STACK; + ALLOC(u,_n,celt_uint32_t); + cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u); + RESTORE_STACK; +} + +static inline void decode_comb64(int _n,int _m,int *_x,int *_s,ec_dec *_dec){ + VARDECL(celt_uint64_t,u); + SAVE_STACK; + ALLOC(u,_n,celt_uint64_t); + cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_x,_s,u); + RESTORE_STACK; +} + void decode_pulses(int *_y, int N, int K, ec_dec *dec) { VARDECL(int, comb); VARDECL(int, signs); SAVE_STACK; - + ALLOC(comb, K, int); ALLOC(signs, K, int); /* Simple heuristic to figure out whether it fits in 32 bits */ if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31) { - cwrsi(N, K, ec_dec_uint(dec, ncwrs(N, K)), comb, signs); - comb2pulse(N, K, _y, comb, signs); + decode_comb32(N, K, comb, signs, dec); } else { - cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs); - comb2pulse(N, K, _y, comb, signs); + decode_comb64(N, K, comb, signs, dec); } + comb2pulse(N, K, _y, comb, signs); RESTORE_STACK; } - diff --git a/libcelt/cwrs.h b/libcelt/cwrs.h index f3f8bdd4..25909d89 100644 --- a/libcelt/cwrs.h +++ b/libcelt/cwrs.h @@ -1,5 +1,4 @@ -/* (C) 2007 Timothy Terriberry -*/ +/* (C) 2007-2008 Timothy Terriberry */ /* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions @@ -37,23 +36,31 @@ #include "entenc.h" #include "entdec.h" -celt_uint32_t ncwrs(int _n,int _m); +/* 32-bit versions */ +celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u); -void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s); +void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s, + celt_uint32_t *_u); + +celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s, + celt_uint32_t *_u); + +/* 64-bit versions */ +celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u); + +celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_u); + +void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s, + celt_uint64_t *_u); + +celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, + celt_uint64_t *_u); -celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound); void comb2pulse(int _n,int _m,int *_y,const int *_x,const int *_s); void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y); -/* 64-bit versions */ -celt_uint64_t ncwrs64(int _n,int _m); - -void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s); - -celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound); - void encode_pulses(int *_y, int N, int K, ec_enc *enc); void decode_pulses(int *_y, int N, int K, ec_dec *dec); diff --git a/libcelt/rate.c b/libcelt/rate.c index 2576159e..912361db 100644 --- a/libcelt/rate.c +++ b/libcelt/rate.c @@ -112,6 +112,9 @@ void compute_alloc_cache(CELTMode *m) { bits[i] = bits[i-1]; } else { + VARDECL(celt_uint64_t, u); + SAVE_STACK; + ALLOC(u, N, celt_uint64_t); int j; /* FIXME: We could save memory here */ bits[i] = celt_alloc(MAX_PULSES*sizeof(celt_int16_t)); @@ -126,7 +129,9 @@ void compute_alloc_cache(CELTMode *m) if (pulses < 0) bits[i][j] = 0; else { - bits[i][j] = log2_frac64(ncwrs64(N, pulses),BITRES); + celt_uint64_t nc; + nc=pulses?ncwrs_unext64(N, u):ncwrs_u64(N, 0, u); + bits[i][j] = log2_frac64(nc,BITRES); /* FIXME: Could there be a better test for the max number of pulses that fit in 64 bits? */ if (bits[i][j] > (60<bits = (const celt_int16_t * const *)bits; diff --git a/tests/cwrs32-test.c b/tests/cwrs32-test.c index 4bbfecf1..0e37c696 100644 --- a/tests/cwrs32-test.c +++ b/tests/cwrs32-test.c @@ -13,27 +13,30 @@ int main(int _argc,char **_argv){ for(n=0;n<=NMAX;n++){ int m; for(m=0;m<=MMAX;m++){ + celt_uint32_t uu[NMAX]; celt_uint32_t inc; celt_uint32_t nc; celt_uint32_t i; - nc=ncwrs(n,m); - inc = nc/10000; - if (inc<1) - inc = 1; + nc=ncwrs_u32(n,m,uu); + inc=nc/10000; + if(inc<1)inc=1; for(i=0;i0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]); } printf(" ->");*/ - if(icwrs(n,m,x,s, NULL)!=i){ + memcpy(u,uu,n*sizeof(*u)); + if(icwrs32(n,m,x,s,u)!=i){ fprintf(stderr,"Combination-index mismatch.\n"); return 1; } diff --git a/tests/cwrs64-test.c b/tests/cwrs64-test.c index 5fa1ee2a..cc76374c 100644 --- a/tests/cwrs64-test.c +++ b/tests/cwrs64-test.c @@ -12,29 +12,32 @@ int main(int _argc,char **_argv){ for(n=0;n<=NMAX;n+=3){ int m; for(m=0;m<=MMAX;m++){ + celt_uint64_t uu[NMAX]; celt_uint64_t inc; celt_uint64_t nc; celt_uint64_t i; - nc=ncwrs64(n,m); - /* Testing all cases just wouldn't work! */ - inc = nc/1000; - if (inc<1) - inc = 1; + nc=ncwrs_u64(n,m,uu); + /*Testing all cases just wouldn't work!*/ + inc=nc/1000; + if(inc<1)inc=1; /*printf("%d/%d: %llu",n,m, nc);*/ for(i=0;i0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]); } printf(" ->");*/ - if(icwrs64(n,m,x,s, NULL)!=i){ + memcpy(u,uu,n*sizeof(*u)); + if(icwrs64(n,m,x,s,u)!=i){ fprintf(stderr,"Combination-index mismatch.\n"); return 1; }