Quantizing initial state with rdovae too

More efficient than PVQ
This commit is contained in:
Jean-Marc Valin 2023-09-15 17:27:44 -04:00
parent 2ec31cc5cc
commit b88644b9c7
No known key found for this signature in database
GPG key ID: 531A52533318F00A
7 changed files with 64 additions and 108 deletions

View file

@ -372,7 +372,7 @@ class CoreDecoder(nn.Module):
class StatisticalModel(nn.Module):
def __init__(self, quant_levels, latent_dim):
def __init__(self, quant_levels, latent_dim, state_dim):
""" Statistical model for latent space
Computes scaling, deadzone, r, and theta
@ -383,8 +383,10 @@ class StatisticalModel(nn.Module):
# copy parameters
self.latent_dim = latent_dim
self.state_dim = state_dim
self.total_dim = latent_dim + state_dim
self.quant_levels = quant_levels
self.embedding_dim = 6 * latent_dim
self.embedding_dim = 6 * self.total_dim
# quantization embedding
self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim)
@ -400,12 +402,12 @@ class StatisticalModel(nn.Module):
x = self.quant_embedding(quant_ids)
# CAVE: theta_soft is not used anymore. Kick it out?
quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim])
dead_zone = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim])
theta_soft = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim])
r_soft = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim])
theta_hard = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim])
r_hard = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim])
quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim])
dead_zone = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim])
theta_soft = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim])
r_soft = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim])
theta_hard = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim])
r_hard = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim])
return {
@ -445,7 +447,7 @@ class RDOVAE(nn.Module):
self.state_dropout_rate = state_dropout_rate
# submodules encoder and decoder share the statistical model
self.statistical_model = StatisticalModel(quant_levels, latent_dim)
self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim)
self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
@ -522,13 +524,18 @@ class RDOVAE(nn.Module):
z, states = self.core_encoder(features)
# scaling, dead-zone and quantization
z = z * statistical_model['quant_scale']
z = soft_dead_zone(z, statistical_model['dead_zone'])
z = z * statistical_model['quant_scale'][:,:,:self.latent_dim]
z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim])
# quantization
z_q = hard_quantize(z) / statistical_model['quant_scale']
z_n = noise_quantize(z) / statistical_model['quant_scale']
states_q = soft_pvq(states, self.pvq_num_pulses)
z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
#states_q = soft_pvq(states, self.pvq_num_pulses)
states = states * statistical_model['quant_scale'][:,:,self.latent_dim:]
states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:])
states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
if self.state_dropout_rate > 0:
drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
@ -551,6 +558,7 @@ class RDOVAE(nn.Module):
# decoder with soft quantized input
z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
dec_initial_state = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
@ -558,6 +566,7 @@ class RDOVAE(nn.Module):
'outputs_hard_quant' : outputs_hq,
'outputs_soft_quant' : outputs_sq,
'z' : z,
'states' : states,
'statistical_model' : statistical_model
}
@ -586,11 +595,11 @@ class RDOVAE(nn.Module):
stats = self.statistical_model(q_ids)
zq = z * stats['quant_scale']
zq = soft_dead_zone(zq, stats['dead_zone'])
zq = z * stats['quant_scale'][:self.latent_dim]
zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim])
zq = torch.round(zq)
sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False)
sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False)
return zq, sizes
@ -599,7 +608,7 @@ class RDOVAE(nn.Module):
stats = self.statistical_model(q_ids)
z = zq / stats['quant_scale']
z = zq / stats['quant_scale'][:,:,:self.latent_dim]
return z

View file

@ -172,6 +172,7 @@ if __name__ == '__main__':
running_soft_rate_loss = 0
running_total_loss = 0
running_rate_metric = 0
running_states_rate_metric = 0
previous_total_loss = 0
running_first_frame_loss = 0
@ -194,17 +195,21 @@ if __name__ == '__main__':
# collect outputs
z = model_output['z']
states = model_output['states']
outputs_hard_quant = model_output['outputs_hard_quant']
outputs_soft_quant = model_output['outputs_soft_quant']
statistical_model = model_output['statistical_model']
# rate loss
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False)
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False)
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate)
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate)
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate))
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate))
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
hard_rate_metric = torch.mean(hard_rate)
states_rate_metric = torch.mean(states_hard_rate)
## distortion losses
@ -242,6 +247,7 @@ if __name__ == '__main__':
running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu())
running_rate_loss += float(rate_loss.detach().cpu())
running_rate_metric += float(hard_rate_metric.detach().cpu())
running_states_rate_metric += float(states_rate_metric.detach().cpu())
running_total_loss += float(total_loss.detach().cpu())
running_first_frame_loss += float(first_frame_loss.detach().cpu())
running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
@ -256,6 +262,7 @@ if __name__ == '__main__':
dist_sq=running_soft_dist_loss / (i + 1),
rate_loss=running_rate_loss / (i + 1),
rate=running_rate_metric / (i + 1),
states_rate=running_states_rate_metric / (i + 1),
ffloss=running_first_frame_loss / (i + 1),
rateloss_hard=running_hard_rate_loss / (i + 1),
rateloss_soft=running_soft_rate_loss / (i + 1)

View file

@ -33,16 +33,13 @@
#include <stdio.h>
#include "celt/entenc.h"
#include "celt/vq.h"
#include "celt/cwrs.h"
#include "celt/laplace.h"
#include "os_support.h"
#include "dred_config.h"
#include "dred_coding.h"
#define LATENT_DIM 80
#define PVQ_DIM 24
#define PVQ_K 82
#define STATE_DIM 80
int compute_quantizer(int q0, int dQ, int i) {
int quant;
@ -53,37 +50,6 @@ int compute_quantizer(int q0, int dQ, int i) {
return (int) floor(0.5f + DRED_ENC_Q0 + 1.f * (DRED_ENC_Q1 - DRED_ENC_Q0) * i / (DRED_NUM_REDUNDANCY_FRAMES - 2));
}
static void encode_pvq(const int *iy, int N, int K, ec_enc *enc) {
int fits;
celt_assert(N==24 || N==12 || N==6);
fits = (N==24 && K<=9) || (N==12 && K<=16) || (N==6);
/*printf("encode(%d,%d), fits=%d\n", N, K, fits);*/
if (fits) {
if (K > 0)
encode_pulses(iy, N, K, enc);
}
else {
int N2 = N/2;
int K0=0;
int i;
for (i=0;i<N2;i++) K0 += abs(iy[i]);
/* FIXME: Don't use uniform probability for K0. */
ec_enc_uint(enc, K0, K+1);
/*printf("K0 = %d\n", K0);*/
encode_pvq(iy, N2, K0, enc);
encode_pvq(&iy[N2], N2, K-K0, enc);
}
}
void dred_encode_state(ec_enc *enc, const float *x) {
int iy[PVQ_DIM];
float x0[PVQ_DIM];
/* Copy state because the PVQ search will trash it. */
OPUS_COPY(x0, x, PVQ_DIM);
op_pvq_search_c(x0, iy, PVQ_K, PVQ_DIM, 0);
encode_pvq(iy, PVQ_DIM, PVQ_K, enc);
}
void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint16 *r, const opus_uint16 *p0) {
int i;
float eps = .1f;
@ -101,47 +67,6 @@ void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale,
}
}
static void decode_pvq(int *iy, int N, int K, ec_dec *dec) {
int fits;
celt_assert(N==24 || N==12 || N==6);
fits = (N==24 && K<=9) || (N==12 && K<=16) || (N==6);
/*printf("encode(%d,%d), fits=%d\n", N, K, fits);*/
if (fits) {
if (K > 0)
decode_pulses(iy, N, K, dec);
else
OPUS_CLEAR(iy, N);
}
else {
int N2 = N/2;
int K0;
/* FIXME: Don't use uniform probability for K0. */
K0 = ec_dec_uint(dec, K+1);
/*printf("K0 = %d\n", K0);*/
decode_pvq(iy, N2, K0, dec);
decode_pvq(&iy[N2], N2, K-K0, dec);
}
}
void dred_decode_state(ec_enc *dec, float *x) {
int k;
int iy[PVQ_DIM];
float norm = 0;
decode_pvq(iy, PVQ_DIM, PVQ_K, dec);
/*printf("tell: %d\n", ec_tell(dec)-tell1);*/
for (k = 0; k < PVQ_DIM; k++)
{
norm += (float) iy[k] * iy[k];
}
norm = 1.f / sqrt(norm);
for (k = 0; k < PVQ_DIM; k++)
{
x[k] = iy[k] * norm;
}
}
void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint16 *r, const opus_uint16 *p0) {
int i;
for (i=0;i<LATENT_DIM;i++) {

View file

@ -32,7 +32,7 @@
#define DRED_EXTENSION_ID 126
/* Remove these two completely once DRED gets an extension number assigned. */
#define DRED_EXPERIMENTAL_VERSION 1
#define DRED_EXPERIMENTAL_VERSION 2
#define DRED_EXPERIMENTAL_BYTES 2
@ -41,7 +41,7 @@
/* these are inpart duplicates to the values defined in dred_rdovae_constants.h */
#define DRED_NUM_FEATURES 20
#define DRED_LATENT_DIM 80
#define DRED_STATE_DIM 24
#define DRED_STATE_DIM 80
#define DRED_SILK_ENCODER_DELAY (79+12-80)
#define DRED_FRAME_SIZE 160
#define DRED_DFRAME_SIZE (2 * (DRED_FRAME_SIZE))

View file

@ -54,6 +54,7 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
int offset;
int q0;
int dQ;
int state_qoffset;
/* since features are decoded in quadruples, it makes no sense to go with an uneven number of redundancy frames */
@ -66,7 +67,14 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
dQ = ec_dec_uint(&ec, 8);
/*printf("%d %d %d\n", dred_offset, q0, dQ);*/
dred_decode_state(&ec, dec->state);
//dred_decode_state(&ec, dec->state);
state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_STATE_DIM;
dred_decode_latents(
&ec,
dec->state,
quant_scales + state_qoffset,
r + state_qoffset,
p0 + state_qoffset);
/* decode newest to oldest and store oldest to newest */
for (i = 0; i < IMIN(DRED_NUM_REDUNDANCY_FRAMES, (min_feature_frames+1)/2); i += 2)
@ -75,7 +83,7 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
if (8*num_bytes - ec_tell(&ec) <= 7)
break;
q_level = compute_quantizer(q0, dQ, i/2);
offset = q_level * DRED_LATENT_DIM;
offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
dred_decode_latents(
&ec,
&dec->latents[(i/2)*DRED_LATENT_DIM],

View file

@ -197,7 +197,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex
/* 15 ms (6*2.5 ms) is the ideal offset for DRED because it corresponds to our vocoder look-ahead. */
if (enc->dred_offset < 6) {
enc->dred_offset += 8;
OPUS_COPY(enc->initial_state, enc->state_buffer, 24);
OPUS_COPY(enc->initial_state, enc->state_buffer, DRED_STATE_DIM);
} else {
enc->latent_offset++;
}
@ -221,6 +221,7 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
int ec_buffer_fill;
int q0;
int dQ;
int state_qoffset;
/* entropy coding of state and latents */
ec_enc_init(&ec_encoder, buf, max_bytes);
@ -229,15 +230,21 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
ec_enc_uint(&ec_encoder, enc->dred_offset, 32);
ec_enc_uint(&ec_encoder, q0, 16);
ec_enc_uint(&ec_encoder, dQ, 8);
dred_encode_state(&ec_encoder, enc->initial_state);
state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_STATE_DIM;
dred_encode_latents(
&ec_encoder,
enc->initial_state,
quant_scales + state_qoffset,
dead_zone + state_qoffset,
r + state_qoffset,
p0 + state_qoffset);
for (i = 0; i < IMIN(2*max_chunks, enc->latents_buffer_fill-enc->latent_offset-1); i += 2)
{
ec_enc ec_bak;
ec_bak = ec_encoder;
q_level = compute_quantizer(q0, dQ, i/2);
offset = q_level * DRED_LATENT_DIM;
offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
dred_encode_latents(
&ec_encoder,

View file

@ -50,8 +50,8 @@ typedef struct {
int latent_offset;
float latents_buffer[DRED_MAX_FRAMES * DRED_LATENT_DIM];
int latents_buffer_fill;
float state_buffer[24];
float initial_state[24];
float state_buffer[DRED_STATE_DIM];
float initial_state[DRED_STATE_DIM];
float resample_mem[RESAMPLING_ORDER + 1];
LPCNetEncState lpcnet_enc_state;
RDOVAEEncState rdovae_enc;