mirror of
https://github.com/xiph/opus.git
synced 2025-05-24 12:19:15 +00:00
Add bidirectional quantizer
This commit is contained in:
parent
543ee94037
commit
90d74bbbe9
1 changed files with 163 additions and 6 deletions
|
@ -6,9 +6,12 @@
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
|
||||||
#define MIN(a,b) ((a)<(b)?(a):(b))
|
#define MIN(a,b) ((a)<(b)?(a):(b))
|
||||||
#define COEF 0.75f
|
#define COEF 0.0f
|
||||||
#define MAX_ENTRIES 16384
|
#define MAX_ENTRIES 16384
|
||||||
|
|
||||||
|
#define MULTI 4
|
||||||
|
#define MULTI_MASK (MULTI-1)
|
||||||
|
|
||||||
void compute_weights(const float *x, float *w, int ndim)
|
void compute_weights(const float *x, float *w, int ndim)
|
||||||
{
|
{
|
||||||
int i;
|
int i;
|
||||||
|
@ -48,6 +51,31 @@ int find_nearest(const float *codebook, int nb_entries, const float *x, int ndim
|
||||||
return nearest;
|
return nearest;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist)
|
||||||
|
{
|
||||||
|
int i, j;
|
||||||
|
float min_dist = 1e15;
|
||||||
|
int nearest = 0;
|
||||||
|
|
||||||
|
for (i=0;i<nb_entries;i++)
|
||||||
|
{
|
||||||
|
int offset;
|
||||||
|
float dist=0;
|
||||||
|
offset = (i&MULTI_MASK)*ndim;
|
||||||
|
for (j=0;j<ndim;j++)
|
||||||
|
dist += (x[offset+j]-codebook[i*ndim+j])*(x[offset+j]-codebook[i*ndim+j]);
|
||||||
|
if (dist<min_dist)
|
||||||
|
{
|
||||||
|
min_dist = dist;
|
||||||
|
nearest = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (dist)
|
||||||
|
*dist = min_dist;
|
||||||
|
return nearest;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
int find_nearest_weighted(const float *codebook, int nb_entries, float *x, const float *w, int ndim)
|
int find_nearest_weighted(const float *codebook, int nb_entries, float *x, const float *w, int ndim)
|
||||||
{
|
{
|
||||||
int i, j;
|
int i, j;
|
||||||
|
@ -203,6 +231,45 @@ void update(float *data, int nb_vectors, float *codebook, int nb_entries, int nd
|
||||||
//fprintf(stderr, "%f / %d\n", 1./w2, nb_entries);
|
//fprintf(stderr, "%f / %d\n", 1./w2, nb_entries);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void update_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim)
|
||||||
|
{
|
||||||
|
int i,j;
|
||||||
|
int count[nb_entries];
|
||||||
|
int nearest[nb_vectors];
|
||||||
|
double err=0;
|
||||||
|
|
||||||
|
for (i=0;i<nb_entries;i++)
|
||||||
|
count[i] = 0;
|
||||||
|
|
||||||
|
for (i=0;i<nb_vectors;i++)
|
||||||
|
{
|
||||||
|
float dist;
|
||||||
|
nearest[i] = find_nearest_multi(codebook, nb_entries, data+MULTI*i*ndim, ndim, &dist);
|
||||||
|
err += dist;
|
||||||
|
}
|
||||||
|
printf("RMS error = %f\n", sqrt(err/nb_vectors/ndim));
|
||||||
|
for (i=0;i<nb_entries*ndim;i++)
|
||||||
|
codebook[i] = 0;
|
||||||
|
|
||||||
|
for (i=0;i<nb_vectors;i++)
|
||||||
|
{
|
||||||
|
int n = nearest[i];
|
||||||
|
count[n]++;
|
||||||
|
for (j=0;j<ndim;j++)
|
||||||
|
codebook[n*ndim+j] += data[(MULTI*i + (n&MULTI_MASK))*ndim+j];
|
||||||
|
}
|
||||||
|
|
||||||
|
float w2=0;
|
||||||
|
for (i=0;i<nb_entries;i++)
|
||||||
|
{
|
||||||
|
for (j=0;j<ndim;j++)
|
||||||
|
codebook[i*ndim+j] *= (1./count[i]);
|
||||||
|
w2 += (count[i]/(float)nb_vectors)*(count[i]/(float)nb_vectors);
|
||||||
|
}
|
||||||
|
//fprintf(stderr, "%f / %d\n", 1./w2, nb_entries);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void update_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
|
void update_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
|
||||||
{
|
{
|
||||||
int i,j;
|
int i,j;
|
||||||
|
@ -271,6 +338,38 @@ void vq_train(float *data, int nb_vectors, float *codebook, int nb_entries, int
|
||||||
update(data, nb_vectors, codebook, e, ndim);
|
update(data, nb_vectors, codebook, e, ndim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void vq_train_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim)
|
||||||
|
{
|
||||||
|
int i, j, e;
|
||||||
|
for (e=0;e<MULTI;e++) {
|
||||||
|
for (j=0;j<ndim;j++)
|
||||||
|
codebook[e*ndim+j] = 0;
|
||||||
|
for (i=0;i<nb_vectors;i++)
|
||||||
|
for (j=0;j<ndim;j++)
|
||||||
|
codebook[e*ndim+j] += data[(MULTI*i+e)*ndim+j];
|
||||||
|
for (j=0;j<ndim;j++) {
|
||||||
|
float delta = .01*(rand()/(float)RAND_MAX-.5);
|
||||||
|
codebook[e*ndim+j] *= (1./nb_vectors);
|
||||||
|
codebook[e*ndim+j] += delta;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e = MULTI;
|
||||||
|
for (j=0;j<10;j++)
|
||||||
|
update_multi(data, nb_vectors, codebook, e, ndim);
|
||||||
|
|
||||||
|
while (e < nb_entries)
|
||||||
|
{
|
||||||
|
split(codebook, e, ndim);
|
||||||
|
e<<=1;
|
||||||
|
fprintf(stderr, "%d\n", e);
|
||||||
|
for (j=0;j<4;j++)
|
||||||
|
update_multi(data, nb_vectors, codebook, e, ndim);
|
||||||
|
}
|
||||||
|
for (j=0;j<ndim*2;j++)
|
||||||
|
update_multi(data, nb_vectors, codebook, e, ndim);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void vq_train_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
|
void vq_train_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
|
||||||
{
|
{
|
||||||
int i, j, e;
|
int i, j, e;
|
||||||
|
@ -303,8 +402,9 @@ void vq_train_weighted(float *data, float *weight, int nb_vectors, float *codebo
|
||||||
int main(int argc, char **argv)
|
int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
int i,j;
|
int i,j;
|
||||||
int nb_vectors, nb_entries, ndim, ndim0, total_dim;
|
int nb_vectors, nb_entries, nb_entries2, ndim, ndim0, total_dim;
|
||||||
float *data, *pred, *codebook, *codebook2;
|
float *data, *pred, *multi_data, *multi_data2, *qdata;
|
||||||
|
float *codebook, *codebook2, *codebook_diff2, *codebook_diff4;
|
||||||
float *delta;
|
float *delta;
|
||||||
double err;
|
double err;
|
||||||
FILE *fout;
|
FILE *fout;
|
||||||
|
@ -314,11 +414,17 @@ int main(int argc, char **argv)
|
||||||
total_dim = atoi(argv[2]);
|
total_dim = atoi(argv[2]);
|
||||||
nb_vectors = atoi(argv[3]);
|
nb_vectors = atoi(argv[3]);
|
||||||
nb_entries = 1<<atoi(argv[4]);
|
nb_entries = 1<<atoi(argv[4]);
|
||||||
|
nb_entries2 = 64;
|
||||||
|
|
||||||
data = malloc((nb_vectors*ndim+total_dim)*sizeof(*data));
|
data = malloc((nb_vectors*ndim+total_dim)*sizeof(*data));
|
||||||
|
qdata = malloc((nb_vectors*ndim+total_dim)*sizeof(*qdata));
|
||||||
pred = malloc(nb_vectors*ndim0*sizeof(*pred));
|
pred = malloc(nb_vectors*ndim0*sizeof(*pred));
|
||||||
|
multi_data = malloc(MULTI*nb_vectors*ndim*sizeof(*multi_data));
|
||||||
|
multi_data2 = malloc(MULTI*nb_vectors*ndim*sizeof(*multi_data));
|
||||||
codebook = malloc(nb_entries*ndim0*sizeof(*codebook));
|
codebook = malloc(nb_entries*ndim0*sizeof(*codebook));
|
||||||
codebook2 = malloc(nb_entries*ndim0*sizeof(*codebook2));
|
codebook2 = malloc(nb_entries*ndim0*sizeof(*codebook2));
|
||||||
|
codebook_diff4 = malloc(nb_entries*ndim*sizeof(*codebook_diff4));
|
||||||
|
codebook_diff2 = malloc(nb_entries2*ndim*sizeof(*codebook_diff2));
|
||||||
|
|
||||||
for (i=0;i<nb_vectors;i++)
|
for (i=0;i<nb_vectors;i++)
|
||||||
{
|
{
|
||||||
|
@ -348,8 +454,10 @@ int main(int argc, char **argv)
|
||||||
for (i=0;i<nb_vectors;i++)
|
for (i=0;i<nb_vectors;i++)
|
||||||
{
|
{
|
||||||
int nearest = find_nearest(codebook, nb_entries, &pred[i*ndim0], ndim0, NULL);
|
int nearest = find_nearest(codebook, nb_entries, &pred[i*ndim0], ndim0, NULL);
|
||||||
|
qdata[i*ndim+j] = data[i*ndim+j];
|
||||||
for (j=0;j<ndim0;j++)
|
for (j=0;j<ndim0;j++)
|
||||||
{
|
{
|
||||||
|
qdata[i*ndim+j+1] = codebook[nearest*ndim0+j];
|
||||||
delta[i*ndim0+j] = pred[i*ndim0+j] - codebook[nearest*ndim0+j];
|
delta[i*ndim0+j] = pred[i*ndim0+j] - codebook[nearest*ndim0+j];
|
||||||
err += delta[i*ndim0+j]*delta[i*ndim0+j];
|
err += delta[i*ndim0+j]*delta[i*ndim0+j];
|
||||||
}
|
}
|
||||||
|
@ -366,12 +474,44 @@ int main(int argc, char **argv)
|
||||||
n1 = find_nearest(codebook2, nb_entries, &delta[i*ndim0], ndim0, NULL);
|
n1 = find_nearest(codebook2, nb_entries, &delta[i*ndim0], ndim0, NULL);
|
||||||
for (j=0;j<ndim0;j++)
|
for (j=0;j<ndim0;j++)
|
||||||
{
|
{
|
||||||
delta[i*ndim0+j] = delta[i*ndim0+j] - codebook2[n1*ndim0+j];
|
qdata[i*ndim+j+1] += codebook2[n1*ndim0+j];
|
||||||
|
//delta[i*ndim0+j] = delta[i*ndim0+j] - codebook2[n1*ndim0+j];
|
||||||
|
delta[i*ndim0+j] = qdata[i*ndim+j+1] - data[i*ndim+j+1];
|
||||||
err += delta[i*ndim0+j]*delta[i*ndim0+j];
|
err += delta[i*ndim0+j]*delta[i*ndim0+j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fprintf(stderr, "Cepstrum RMS error after stage 2: %f)\n", sqrt(err/nb_vectors/ndim));
|
fprintf(stderr, "Cepstrum RMS error after stage 2: %f)\n", sqrt(err/nb_vectors/ndim));
|
||||||
|
|
||||||
|
for (i=0;i<nb_vectors-4;i++)
|
||||||
|
{
|
||||||
|
for (j=0;j<ndim0;j++)
|
||||||
|
multi_data[MULTI*i*ndim+j] = data[(i+1)*ndim+j+1] - .5*(qdata[i*ndim+j+1]+qdata[(i+2)*ndim+j+1]);
|
||||||
|
for (j=0;j<ndim0;j++)
|
||||||
|
multi_data[(MULTI*i+1)*ndim+j] = data[(i+1)*ndim+j+1] - .5*(qdata[i*ndim+j+1]+qdata[(i+2)*ndim+j+1]);
|
||||||
|
for (j=0;j<ndim0;j++)
|
||||||
|
multi_data[(MULTI*i+2)*ndim+j] = data[(i+1)*ndim+j+1] - qdata[i*ndim+j+1];
|
||||||
|
for (j=0;j<ndim0;j++)
|
||||||
|
multi_data[(MULTI*i+3)*ndim+j] = data[(i+1)*ndim+j+1] - qdata[(i+2)*ndim+j+1];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i=0;i<nb_vectors-4;i++)
|
||||||
|
{
|
||||||
|
for (j=0;j<ndim0;j++)
|
||||||
|
multi_data2[MULTI*i*ndim+j] = data[(i+2)*ndim+j+1] - .5*(qdata[i*ndim+j+1]+qdata[(i+4)*ndim+j+1]);
|
||||||
|
for (j=0;j<ndim0;j++)
|
||||||
|
multi_data2[(MULTI*i+1)*ndim+j] = data[(i+2)*ndim+j+1] - .5*(qdata[i*ndim+j+1]+qdata[(i+4)*ndim+j+1]);
|
||||||
|
for (j=0;j<ndim0;j++)
|
||||||
|
multi_data2[(MULTI*i+2)*ndim+j] = data[(i+2)*ndim+j+1] - qdata[i*ndim+j+1];
|
||||||
|
for (j=0;j<ndim0;j++)
|
||||||
|
multi_data2[(MULTI*i+3)*ndim+j] = data[(i+2)*ndim+j+1] - qdata[(i+4)*ndim+j+1];
|
||||||
|
}
|
||||||
|
|
||||||
|
vq_train_multi(multi_data2, nb_vectors-4, codebook_diff4, nb_entries, ndim);
|
||||||
|
|
||||||
|
printf("done\n");
|
||||||
|
vq_train_multi(multi_data, nb_vectors-4, codebook_diff2, 64, ndim);
|
||||||
|
|
||||||
|
|
||||||
fout = fopen("ceps_codebooks.c", "w");
|
fout = fopen("ceps_codebooks.c", "w");
|
||||||
fprintf(fout, "/* This file is automatically generated */\n\n");
|
fprintf(fout, "/* This file is automatically generated */\n\n");
|
||||||
fprintf(fout, "float ceps_codebook1[%d*%d] = {\n",nb_entries, ndim0);
|
fprintf(fout, "float ceps_codebook1[%d*%d] = {\n",nb_entries, ndim0);
|
||||||
|
@ -385,7 +525,6 @@ int main(int argc, char **argv)
|
||||||
fprintf(fout, "};\n\n");
|
fprintf(fout, "};\n\n");
|
||||||
|
|
||||||
fprintf(fout, "float ceps_codebook2[%d*%d] = {\n",nb_entries, ndim0);
|
fprintf(fout, "float ceps_codebook2[%d*%d] = {\n",nb_entries, ndim0);
|
||||||
|
|
||||||
for (i=0;i<nb_entries;i++)
|
for (i=0;i<nb_entries;i++)
|
||||||
{
|
{
|
||||||
for (j=0;j<ndim0;j++)
|
for (j=0;j<ndim0;j++)
|
||||||
|
@ -394,6 +533,24 @@ int main(int argc, char **argv)
|
||||||
}
|
}
|
||||||
fprintf(fout, "};\n\n");
|
fprintf(fout, "};\n\n");
|
||||||
|
|
||||||
|
fprintf(fout, "float ceps_codebook_diff4[%d*%d] = {\n",nb_entries, ndim);
|
||||||
|
for (i=0;i<nb_entries;i++)
|
||||||
|
{
|
||||||
|
for (j=0;j<ndim;j++)
|
||||||
|
fprintf(fout, "%f, ", codebook_diff4[i*ndim+j]);
|
||||||
|
fprintf(fout, "\n");
|
||||||
|
}
|
||||||
|
fprintf(fout, "};\n\n");
|
||||||
|
|
||||||
|
fprintf(fout, "float ceps_codebook_diff2[%d*%d] = {\n",nb_entries2, ndim);
|
||||||
|
for (i=0;i<nb_entries2;i++)
|
||||||
|
{
|
||||||
|
for (j=0;j<ndim;j++)
|
||||||
|
fprintf(fout, "%f, ", codebook_diff2[i*ndim+j]);
|
||||||
|
fprintf(fout, "\n");
|
||||||
|
}
|
||||||
|
fprintf(fout, "};\n\n");
|
||||||
|
|
||||||
fclose(fout);
|
fclose(fout);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue