Rework CWRS code.

This eliminates an extra O(nm) lookups on decode, and reduces the rate control
 from O(nm^2) to O(nm), in addition to eliminating O(m) lookups on both encode
 and decode.
Although the interface is slightly more complex, the internal code is also
 simpler.
This commit is contained in:
Timothy B. Terriberry 2008-04-04 10:16:19 -04:00 committed by Jean-Marc Valin
parent ae76e553db
commit d883670bf7
5 changed files with 260 additions and 273 deletions

View file

@ -1,4 +1,4 @@
/* (C) 2007 Timothy B. Terriberry
/* (C) 2007-2008 Timothy B. Terriberry
(C) 2008 Jean-Marc Valin */
/*
Redistribution and use in source and binary forms, with or without
@ -29,8 +29,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/* Functions for encoding and decoding pulse vectors. For more details, see:
http://people.xiph.org/~tterribe/notes/cwrs.html
/* Functions for encoding and decoding pulse vectors.
These are based on the function
U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
U(n,1) = U(1,m) = 2,
which counts the number of ways of placing m pulses in n dimensions, where
at least one pulse lies in dimension 0.
For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
*/
#ifdef HAVE_CONFIG_H
@ -38,133 +43,130 @@
#endif
#include <stdlib.h>
#include <string.h>
#include "cwrs.h"
#include "mathops.h"
/* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
compute ncwrs() for m+1, for all n. Could also be used when m and n are
swapped just by changing nc */
static inline void next_ncwrs32(celt_uint32_t *nc, int len, int nc0)
{
int i;
celt_uint32_t mem;
mem = nc[0];
nc[0] = nc0;
for (i=1;i<len;i++)
{
celt_uint32_t tmp = nc[i]+nc[i-1]+mem;
mem = nc[i];
nc[i] = tmp;
}
/*Computes the next row/column of any recurrence that obeys the relation
u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
_ui0 is the base case for the new row/column.*/
static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
celt_uint32_t ui1;
int j;
for(j=1;j<_len;j++){
ui1=_ui[j]+_ui[j-1]+_ui0;
_ui[j-1]=_ui0;
_ui0=ui1;
}
_ui[j-1]=_ui0;
}
/* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
compute ncwrs() for m-1, for all n. Could also be used when m and n are
swapped just by changing nc */
static inline void prev_ncwrs32(celt_uint32_t *nc, int len, int nc0)
{
int i;
celt_uint32_t mem;
mem = nc[0];
nc[0] = nc0;
for (i=1;i<len;i++)
{
celt_uint32_t tmp = nc[i]-nc[i-1]-mem;
mem = nc[i];
nc[i] = tmp;
}
static inline void unext64(celt_uint64_t *_ui,int _len,celt_uint64_t _ui0){
celt_uint64_t ui1;
int j;
for(j=1;j<_len;j++){
ui1=_ui[j]+_ui[j-1]+_ui0;
_ui[j-1]=_ui0;
_ui0=ui1;
}
_ui[j-1]=_ui0;
}
static inline void next_ncwrs64(celt_uint64_t *nc, int len, int nc0)
{
int i;
celt_uint64_t mem;
mem = nc[0];
nc[0] = nc0;
for (i=1;i<len;i++)
{
celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
mem = nc[i];
nc[i] = tmp;
}
/*Computes the previous row/column of any recurrence that obeys the relation
u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
_ui0 is the base case for the new row/column.*/
static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
celt_uint32_t ui1;
int j;
for(j=1;j<_n;j++){
ui1=_ui[j]-_ui[j-1]-_ui0;
_ui[j-1]=_ui0;
_ui0=ui1;
}
_ui[j-1]=_ui0;
}
static inline void prev_ncwrs64(celt_uint64_t *nc, int len, int nc0)
{
int i;
celt_uint64_t mem;
mem = nc[0];
nc[0] = nc0;
for (i=1;i<len;i++)
{
celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
mem = nc[i];
nc[i] = tmp;
}
static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){
celt_uint64_t ui1;
int j;
for(j=1;j<_n;j++){
ui1=_ui[j]-_ui[j-1]-_ui0;
_ui[j-1]=_ui0;
_ui0=ui1;
}
_ui[j-1]=_ui0;
}
/*Returns the numer of ways of choosing _m elements from a set of size _n with
replacement when a sign bit is needed for each unique element.*/
celt_uint32_t ncwrs(int _n,int _m)
{
int i;
celt_uint32_t ret;
VARDECL(celt_uint32_t, nc);
SAVE_STACK;
ALLOC(nc,_n+1, celt_uint32_t);
for (i=0;i<_n+1;i++)
nc[i] = 1;
for (i=0;i<_m;i++)
next_ncwrs32(nc, _n+1, 0);
ret = nc[_n];
RESTORE_STACK;
return ret;
/*Returns the number of ways of choosing _m elements from a set of size _n with
replacement when a sign bit is needed for each unique element.
On input, _u should be initialized to column (_m-1) of U(n,m).
On exit, _u will be initialized to column _m of U(n,m).*/
celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
celt_uint32_t ret;
celt_uint32_t ui0;
celt_uint32_t ui1;
int j;
ret=ui0=2;
for(j=1;j<_n;j++){
ui1=_ui[j]+_ui[j-1]+ui0;
_ui[j-1]=ui0;
ui0=ui1;
ret+=ui0;
}
_ui[j-1]=ui0;
return ret;
}
/*Returns the numer of ways of choosing _m elements from a set of size _n with
replacement when a sign bit is needed for each unique element.*/
celt_uint64_t ncwrs64(int _n,int _m)
{
int i;
celt_uint64_t ret;
VARDECL(celt_uint64_t, nc);
SAVE_STACK;
ALLOC(nc,_n+1, celt_uint64_t);
for (i=0;i<_n+1;i++)
nc[i] = 1;
for (i=0;i<_m;i++)
next_ncwrs64(nc, _n+1, 0);
ret = nc[_n];
RESTORE_STACK;
return ret;
celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){
celt_uint64_t ret;
celt_uint64_t ui0;
celt_uint64_t ui1;
int j;
ret=ui0=2;
for(j=1;j<_n;j++){
ui1=_ui[j]+_ui[j-1]+ui0;
_ui[j-1]=ui0;
ui0=ui1;
ret+=ui0;
}
_ui[j-1]=ui0;
return ret;
}
/*Returns the number of ways of choosing _m elements from a set of size _n with
replacement when a sign bit is needed for each unique element.
On exit, _u will be initialized to column _m of U(n,m).*/
celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
int k;
memset(_u,0,_n*sizeof(*_u));
if(_m<=0)return 1;
if(_n<=0)return 0;
for(k=1;k<_m;k++)unext32(_u,_n,2);
return ncwrs_unext32(_n,_u);
}
celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){
int k;
memset(_u,0,_n*sizeof(*_u));
if(_m<=0)return 1;
if(_n<=0)return 0;
for(k=1;k<_m;k++)unext64(_u,_n,2);
return ncwrs_unext64(_n,_u);
}
/*Returns the _i'th combination of _m elements chosen from a set of size _n
with associated sign bits.
_x: Returns the combination with elements sorted in ascending order.
_s: Returns the associated sign bits.*/
void cwrsi(int _n,int _m,celt_uint32_t _i,int * restrict _x,int * restrict _s){
_x: Returns the combination with elements sorted in ascending order.
_s: Returns the associated sign bits.
_u: Temporary storage already initialized to column _m of U(n,m).
Its contents will be overwritten.*/
void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
int j;
int k;
VARDECL(celt_uint32_t, nc);
SAVE_STACK;
ALLOC(nc,_n+1, celt_uint32_t);
for (j=0;j<_n+1;j++)
nc[j] = 1;
for (k=0;k<_m-1;k++)
next_ncwrs32(nc, _n+1, 0);
for(k=j=0;k<_m;k++){
celt_uint32_t pn, p, t;
/*p=ncwrs(_n-j,_m-k-1);
pn=ncwrs(_n-j-1,_m-k-1);*/
p=nc[_n-j];
pn=nc[_n-j-1];
p+=pn;
celt_uint32_t p;
celt_uint32_t t;
p=_u[_n-j-1];
if(k>0){
t=p>>1;
if(t<=_i||_s[k-1])_i+=t;
@ -172,155 +174,85 @@ void cwrsi(int _n,int _m,celt_uint32_t _i,int * restrict _x,int * restrict _s){
while(p<=_i){
_i-=p;
j++;
p=pn;
/*pn=ncwrs(_n-j-1,_m-k-1);*/
pn=nc[_n-j-1];
p+=pn;
p=_u[_n-j-1];
}
t=p>>1;
_s[k]=_i>=t;
_x[k]=j;
if(_s[k])_i-=t;
if (k<_m-2)
prev_ncwrs32(nc, _n-j+1, 0);
else
prev_ncwrs32(nc, _n-j+1, 1);
uprev32(_u,_n-j,2);
}
}
void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,celt_uint64_t *_u){
int j;
int k;
for(k=j=0;k<_m;k++){
celt_uint64_t p;
celt_uint64_t t;
p=_u[_n-j-1];
if(k>0){
t=p>>1;
if(t<=_i||_s[k-1])_i+=t;
}
while(p<=_i){
_i-=p;
j++;
p=_u[_n-j-1];
}
t=p>>1;
_s[k]=_i>=t;
_x[k]=j;
if(_s[k])_i-=t;
uprev64(_u,_n-j,2);
}
RESTORE_STACK;
}
/*Returns the index of the given combination of _m elements chosen from a set
of size _n with associated sign bits.
_x: The combination with elements sorted in ascending order.
_s: The associated sign bits.*/
celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound){
_x: The combination with elements sorted in ascending order.
_s: The associated sign bits.
_u: Temporary storage already initialized to column _m of U(n,m).
Its contents will be overwritten.*/
celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
celt_uint32_t *_u){
celt_uint32_t i;
int j;
int k;
VARDECL(celt_uint32_t, nc);
SAVE_STACK;
ALLOC(nc,_n+1, celt_uint32_t);
for (j=0;j<_n+1;j++)
nc[j] = 1;
for (k=0;k<_m;k++)
next_ncwrs32(nc, _n+1, 0);
if (bound)
*bound = nc[_n];
int j;
int k;
i=0;
for(k=j=0;k<_m;k++){
celt_uint32_t pn;
celt_uint32_t p;
if (k<_m-1)
prev_ncwrs32(nc, _n-j+1, 0);
else
prev_ncwrs32(nc, _n-j+1, 1);
/*p=ncwrs(_n-j,_m-k-1);
pn=ncwrs(_n-j-1,_m-k-1);*/
p=nc[_n-j];
pn=nc[_n-j-1];
p+=pn;
p=_u[_n-j-1];
if(k>0)p>>=1;
while(j<_x[k]){
i+=p;
j++;
p=pn;
/*pn=ncwrs(_n-j-1,_m-k-1);*/
pn=nc[_n-j-1];
p+=pn;
p=_u[_n-j-1];
}
if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
uprev32(_u,_n-j,2);
}
RESTORE_STACK;
return i;
}
/*Returns the _i'th combination of _m elements chosen from a set of size _n
with associated sign bits.
_x: Returns the combination with elements sorted in ascending order.
_s: Returns the associated sign bits.*/
void cwrsi64(int _n,int _m,celt_uint64_t _i,int * restrict _x,int * restrict _s){
int j;
int k;
VARDECL(celt_uint64_t, nc);
SAVE_STACK;
ALLOC(nc,_n+1, celt_uint64_t);
for (j=0;j<_n+1;j++)
nc[j] = 1;
for (k=0;k<_m-1;k++)
next_ncwrs64(nc, _n+1, 0);
for(k=j=0;k<_m;k++){
celt_uint64_t pn, p, t;
/*p=ncwrs64(_n-j,_m-k-1);
pn=ncwrs64(_n-j-1,_m-k-1);*/
p=nc[_n-j];
pn=nc[_n-j-1];
p+=pn;
if(k>0){
t=p>>1;
if(t<=_i||_s[k-1])_i+=t;
}
while(p<=_i){
_i-=p;
j++;
p=pn;
/*pn=ncwrs64(_n-j-1,_m-k-1);*/
pn=nc[_n-j-1];
p+=pn;
}
t=p>>1;
_s[k]=_i>=t;
_x[k]=j;
if(_s[k])_i-=t;
if (k<_m-2)
prev_ncwrs64(nc, _n-j+1, 0);
else
prev_ncwrs64(nc, _n-j+1, 1);
}
RESTORE_STACK;
}
/*Returns the index of the given combination of _m elements chosen from a set
of size _n with associated sign bits.
_x: The combination with elements sorted in ascending order.
_s: The associated sign bits.*/
celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){
celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
celt_uint64_t *_u){
celt_uint64_t i;
int j;
int k;
VARDECL(celt_uint64_t, nc);
SAVE_STACK;
ALLOC(nc,_n+1, celt_uint64_t);
for (j=0;j<_n+1;j++)
nc[j] = 1;
for (k=0;k<_m;k++)
next_ncwrs64(nc, _n+1, 0);
if (bound)
*bound = nc[_n];
i=0;
for(k=j=0;k<_m;k++){
celt_uint64_t pn;
celt_uint64_t p;
if (k<_m-1)
prev_ncwrs64(nc, _n-j+1, 0);
else
prev_ncwrs64(nc, _n-j+1, 1);
/*p=ncwrs64(_n-j,_m-k-1);
pn=ncwrs64(_n-j-1,_m-k-1);*/
p=nc[_n-j];
pn=nc[_n-j-1];
p+=pn;
p=_u[_n-j-1];
if(k>0)p>>=1;
while(j<_x[k]){
i+=p;
j++;
p=pn;
/*pn=ncwrs64(_n-j-1,_m-k-1);*/
pn=nc[_n-j-1];
p+=pn;
p=_u[_n-j-1];
}
if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
uprev64(_u,_n-j,2);
}
RESTORE_STACK;
return i;
}
@ -363,47 +295,83 @@ void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
}
}
static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
ec_enc *_enc){
VARDECL(celt_uint32_t,u);
celt_uint32_t nc;
celt_uint32_t i;
SAVE_STACK;
ALLOC(u,_n,celt_uint32_t);
nc=ncwrs_u32(_n,_m,u);
i=icwrs32(_n,_m,_x,_s,u);
ec_enc_uint(_enc,i,nc);
RESTORE_STACK;
}
static inline void encode_comb64(int _n,int _m,const int *_x,const int *_s,
ec_enc *_enc){
VARDECL(celt_uint64_t,u);
celt_uint64_t nc;
celt_uint64_t i;
SAVE_STACK;
ALLOC(u,_n,celt_uint64_t);
nc=ncwrs_u64(_n,_m,u);
i=icwrs64(_n,_m,_x,_s,u);
ec_enc_uint64(_enc,i,nc);
RESTORE_STACK;
}
void encode_pulses(int *_y, int N, int K, ec_enc *enc)
{
VARDECL(int, comb);
VARDECL(int, signs);
SAVE_STACK;
ALLOC(comb, K, int);
ALLOC(signs, K, int);
pulse2comb(N, K, comb, signs, _y);
/* Simple heuristic to figure out whether it fits in 32 bits */
if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
{
celt_uint32_t bound, id;
id = icwrs(N, K, comb, signs, &bound);
ec_enc_uint(enc,id,bound);
encode_comb32(N, K, comb, signs, enc);
} else {
celt_uint64_t bound, id;
id = icwrs64(N, K, comb, signs, &bound);
ec_enc_uint64(enc,id,bound);
encode_comb64(N, K, comb, signs, enc);
}
RESTORE_STACK;
}
static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
VARDECL(celt_uint32_t,u);
SAVE_STACK;
ALLOC(u,_n,celt_uint32_t);
cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
RESTORE_STACK;
}
static inline void decode_comb64(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
VARDECL(celt_uint64_t,u);
SAVE_STACK;
ALLOC(u,_n,celt_uint64_t);
cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_x,_s,u);
RESTORE_STACK;
}
void decode_pulses(int *_y, int N, int K, ec_dec *dec)
{
VARDECL(int, comb);
VARDECL(int, signs);
SAVE_STACK;
ALLOC(comb, K, int);
ALLOC(signs, K, int);
/* Simple heuristic to figure out whether it fits in 32 bits */
if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
{
cwrsi(N, K, ec_dec_uint(dec, ncwrs(N, K)), comb, signs);
comb2pulse(N, K, _y, comb, signs);
decode_comb32(N, K, comb, signs, dec);
} else {
cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs);
comb2pulse(N, K, _y, comb, signs);
decode_comb64(N, K, comb, signs, dec);
}
comb2pulse(N, K, _y, comb, signs);
RESTORE_STACK;
}