Quantizing initial state with rdovae too
More efficient than PVQ
This commit is contained in:
parent
2ec31cc5cc
commit
b88644b9c7
7 changed files with 64 additions and 108 deletions
|
@ -372,7 +372,7 @@ class CoreDecoder(nn.Module):
|
|||
|
||||
|
||||
class StatisticalModel(nn.Module):
|
||||
def __init__(self, quant_levels, latent_dim):
|
||||
def __init__(self, quant_levels, latent_dim, state_dim):
|
||||
""" Statistical model for latent space
|
||||
|
||||
Computes scaling, deadzone, r, and theta
|
||||
|
@ -383,8 +383,10 @@ class StatisticalModel(nn.Module):
|
|||
|
||||
# copy parameters
|
||||
self.latent_dim = latent_dim
|
||||
self.state_dim = state_dim
|
||||
self.total_dim = latent_dim + state_dim
|
||||
self.quant_levels = quant_levels
|
||||
self.embedding_dim = 6 * latent_dim
|
||||
self.embedding_dim = 6 * self.total_dim
|
||||
|
||||
# quantization embedding
|
||||
self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim)
|
||||
|
@ -400,12 +402,12 @@ class StatisticalModel(nn.Module):
|
|||
x = self.quant_embedding(quant_ids)
|
||||
|
||||
# CAVE: theta_soft is not used anymore. Kick it out?
|
||||
quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim])
|
||||
dead_zone = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim])
|
||||
theta_soft = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim])
|
||||
r_soft = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim])
|
||||
theta_hard = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim])
|
||||
r_hard = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim])
|
||||
quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim])
|
||||
dead_zone = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim])
|
||||
theta_soft = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim])
|
||||
r_soft = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim])
|
||||
theta_hard = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim])
|
||||
r_hard = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim])
|
||||
|
||||
|
||||
return {
|
||||
|
@ -445,7 +447,7 @@ class RDOVAE(nn.Module):
|
|||
self.state_dropout_rate = state_dropout_rate
|
||||
|
||||
# submodules encoder and decoder share the statistical model
|
||||
self.statistical_model = StatisticalModel(quant_levels, latent_dim)
|
||||
self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim)
|
||||
self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
|
||||
self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
|
||||
|
||||
|
@ -522,13 +524,18 @@ class RDOVAE(nn.Module):
|
|||
z, states = self.core_encoder(features)
|
||||
|
||||
# scaling, dead-zone and quantization
|
||||
z = z * statistical_model['quant_scale']
|
||||
z = soft_dead_zone(z, statistical_model['dead_zone'])
|
||||
z = z * statistical_model['quant_scale'][:,:,:self.latent_dim]
|
||||
z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim])
|
||||
|
||||
# quantization
|
||||
z_q = hard_quantize(z) / statistical_model['quant_scale']
|
||||
z_n = noise_quantize(z) / statistical_model['quant_scale']
|
||||
states_q = soft_pvq(states, self.pvq_num_pulses)
|
||||
z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
|
||||
z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
|
||||
#states_q = soft_pvq(states, self.pvq_num_pulses)
|
||||
states = states * statistical_model['quant_scale'][:,:,self.latent_dim:]
|
||||
states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:])
|
||||
|
||||
states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
|
||||
states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
|
||||
|
||||
if self.state_dropout_rate > 0:
|
||||
drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
|
||||
|
@ -551,6 +558,7 @@ class RDOVAE(nn.Module):
|
|||
|
||||
# decoder with soft quantized input
|
||||
z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
|
||||
dec_initial_state = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
|
||||
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
|
||||
outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
|
||||
|
||||
|
@ -558,6 +566,7 @@ class RDOVAE(nn.Module):
|
|||
'outputs_hard_quant' : outputs_hq,
|
||||
'outputs_soft_quant' : outputs_sq,
|
||||
'z' : z,
|
||||
'states' : states,
|
||||
'statistical_model' : statistical_model
|
||||
}
|
||||
|
||||
|
@ -586,11 +595,11 @@ class RDOVAE(nn.Module):
|
|||
|
||||
stats = self.statistical_model(q_ids)
|
||||
|
||||
zq = z * stats['quant_scale']
|
||||
zq = soft_dead_zone(zq, stats['dead_zone'])
|
||||
zq = z * stats['quant_scale'][:self.latent_dim]
|
||||
zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim])
|
||||
zq = torch.round(zq)
|
||||
|
||||
sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False)
|
||||
sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False)
|
||||
|
||||
return zq, sizes
|
||||
|
||||
|
@ -599,7 +608,7 @@ class RDOVAE(nn.Module):
|
|||
|
||||
stats = self.statistical_model(q_ids)
|
||||
|
||||
z = zq / stats['quant_scale']
|
||||
z = zq / stats['quant_scale'][:,:,:self.latent_dim]
|
||||
|
||||
return z
|
||||
|
||||
|
|
|
@ -172,6 +172,7 @@ if __name__ == '__main__':
|
|||
running_soft_rate_loss = 0
|
||||
running_total_loss = 0
|
||||
running_rate_metric = 0
|
||||
running_states_rate_metric = 0
|
||||
previous_total_loss = 0
|
||||
running_first_frame_loss = 0
|
||||
|
||||
|
@ -194,17 +195,21 @@ if __name__ == '__main__':
|
|||
|
||||
# collect outputs
|
||||
z = model_output['z']
|
||||
states = model_output['states']
|
||||
outputs_hard_quant = model_output['outputs_hard_quant']
|
||||
outputs_soft_quant = model_output['outputs_soft_quant']
|
||||
statistical_model = model_output['statistical_model']
|
||||
|
||||
# rate loss
|
||||
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False)
|
||||
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False)
|
||||
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate)
|
||||
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate)
|
||||
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
|
||||
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
|
||||
states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
|
||||
states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
|
||||
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate))
|
||||
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate))
|
||||
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
|
||||
hard_rate_metric = torch.mean(hard_rate)
|
||||
states_rate_metric = torch.mean(states_hard_rate)
|
||||
|
||||
## distortion losses
|
||||
|
||||
|
@ -242,6 +247,7 @@ if __name__ == '__main__':
|
|||
running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu())
|
||||
running_rate_loss += float(rate_loss.detach().cpu())
|
||||
running_rate_metric += float(hard_rate_metric.detach().cpu())
|
||||
running_states_rate_metric += float(states_rate_metric.detach().cpu())
|
||||
running_total_loss += float(total_loss.detach().cpu())
|
||||
running_first_frame_loss += float(first_frame_loss.detach().cpu())
|
||||
running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
|
||||
|
@ -256,6 +262,7 @@ if __name__ == '__main__':
|
|||
dist_sq=running_soft_dist_loss / (i + 1),
|
||||
rate_loss=running_rate_loss / (i + 1),
|
||||
rate=running_rate_metric / (i + 1),
|
||||
states_rate=running_states_rate_metric / (i + 1),
|
||||
ffloss=running_first_frame_loss / (i + 1),
|
||||
rateloss_hard=running_hard_rate_loss / (i + 1),
|
||||
rateloss_soft=running_soft_rate_loss / (i + 1)
|
||||
|
|
|
@ -33,16 +33,13 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include "celt/entenc.h"
|
||||
#include "celt/vq.h"
|
||||
#include "celt/cwrs.h"
|
||||
#include "celt/laplace.h"
|
||||
#include "os_support.h"
|
||||
#include "dred_config.h"
|
||||
#include "dred_coding.h"
|
||||
|
||||
#define LATENT_DIM 80
|
||||
#define PVQ_DIM 24
|
||||
#define PVQ_K 82
|
||||
#define STATE_DIM 80
|
||||
|
||||
int compute_quantizer(int q0, int dQ, int i) {
|
||||
int quant;
|
||||
|
@ -53,37 +50,6 @@ int compute_quantizer(int q0, int dQ, int i) {
|
|||
return (int) floor(0.5f + DRED_ENC_Q0 + 1.f * (DRED_ENC_Q1 - DRED_ENC_Q0) * i / (DRED_NUM_REDUNDANCY_FRAMES - 2));
|
||||
}
|
||||
|
||||
static void encode_pvq(const int *iy, int N, int K, ec_enc *enc) {
|
||||
int fits;
|
||||
celt_assert(N==24 || N==12 || N==6);
|
||||
fits = (N==24 && K<=9) || (N==12 && K<=16) || (N==6);
|
||||
/*printf("encode(%d,%d), fits=%d\n", N, K, fits);*/
|
||||
if (fits) {
|
||||
if (K > 0)
|
||||
encode_pulses(iy, N, K, enc);
|
||||
}
|
||||
else {
|
||||
int N2 = N/2;
|
||||
int K0=0;
|
||||
int i;
|
||||
for (i=0;i<N2;i++) K0 += abs(iy[i]);
|
||||
/* FIXME: Don't use uniform probability for K0. */
|
||||
ec_enc_uint(enc, K0, K+1);
|
||||
/*printf("K0 = %d\n", K0);*/
|
||||
encode_pvq(iy, N2, K0, enc);
|
||||
encode_pvq(&iy[N2], N2, K-K0, enc);
|
||||
}
|
||||
}
|
||||
|
||||
void dred_encode_state(ec_enc *enc, const float *x) {
|
||||
int iy[PVQ_DIM];
|
||||
float x0[PVQ_DIM];
|
||||
/* Copy state because the PVQ search will trash it. */
|
||||
OPUS_COPY(x0, x, PVQ_DIM);
|
||||
op_pvq_search_c(x0, iy, PVQ_K, PVQ_DIM, 0);
|
||||
encode_pvq(iy, PVQ_DIM, PVQ_K, enc);
|
||||
}
|
||||
|
||||
void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint16 *r, const opus_uint16 *p0) {
|
||||
int i;
|
||||
float eps = .1f;
|
||||
|
@ -101,47 +67,6 @@ void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale,
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
static void decode_pvq(int *iy, int N, int K, ec_dec *dec) {
|
||||
int fits;
|
||||
celt_assert(N==24 || N==12 || N==6);
|
||||
fits = (N==24 && K<=9) || (N==12 && K<=16) || (N==6);
|
||||
/*printf("encode(%d,%d), fits=%d\n", N, K, fits);*/
|
||||
if (fits) {
|
||||
if (K > 0)
|
||||
decode_pulses(iy, N, K, dec);
|
||||
else
|
||||
OPUS_CLEAR(iy, N);
|
||||
}
|
||||
else {
|
||||
int N2 = N/2;
|
||||
int K0;
|
||||
/* FIXME: Don't use uniform probability for K0. */
|
||||
K0 = ec_dec_uint(dec, K+1);
|
||||
/*printf("K0 = %d\n", K0);*/
|
||||
decode_pvq(iy, N2, K0, dec);
|
||||
decode_pvq(&iy[N2], N2, K-K0, dec);
|
||||
}
|
||||
}
|
||||
|
||||
void dred_decode_state(ec_enc *dec, float *x) {
|
||||
int k;
|
||||
int iy[PVQ_DIM];
|
||||
float norm = 0;
|
||||
decode_pvq(iy, PVQ_DIM, PVQ_K, dec);
|
||||
/*printf("tell: %d\n", ec_tell(dec)-tell1);*/
|
||||
for (k = 0; k < PVQ_DIM; k++)
|
||||
{
|
||||
norm += (float) iy[k] * iy[k];
|
||||
}
|
||||
norm = 1.f / sqrt(norm);
|
||||
for (k = 0; k < PVQ_DIM; k++)
|
||||
{
|
||||
x[k] = iy[k] * norm;
|
||||
}
|
||||
}
|
||||
|
||||
void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint16 *r, const opus_uint16 *p0) {
|
||||
int i;
|
||||
for (i=0;i<LATENT_DIM;i++) {
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
#define DRED_EXTENSION_ID 126
|
||||
|
||||
/* Remove these two completely once DRED gets an extension number assigned. */
|
||||
#define DRED_EXPERIMENTAL_VERSION 1
|
||||
#define DRED_EXPERIMENTAL_VERSION 2
|
||||
#define DRED_EXPERIMENTAL_BYTES 2
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@
|
|||
/* these are inpart duplicates to the values defined in dred_rdovae_constants.h */
|
||||
#define DRED_NUM_FEATURES 20
|
||||
#define DRED_LATENT_DIM 80
|
||||
#define DRED_STATE_DIM 24
|
||||
#define DRED_STATE_DIM 80
|
||||
#define DRED_SILK_ENCODER_DELAY (79+12-80)
|
||||
#define DRED_FRAME_SIZE 160
|
||||
#define DRED_DFRAME_SIZE (2 * (DRED_FRAME_SIZE))
|
||||
|
|
|
@ -54,6 +54,7 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
|
|||
int offset;
|
||||
int q0;
|
||||
int dQ;
|
||||
int state_qoffset;
|
||||
|
||||
|
||||
/* since features are decoded in quadruples, it makes no sense to go with an uneven number of redundancy frames */
|
||||
|
@ -66,7 +67,14 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
|
|||
dQ = ec_dec_uint(&ec, 8);
|
||||
/*printf("%d %d %d\n", dred_offset, q0, dQ);*/
|
||||
|
||||
dred_decode_state(&ec, dec->state);
|
||||
//dred_decode_state(&ec, dec->state);
|
||||
state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_STATE_DIM;
|
||||
dred_decode_latents(
|
||||
&ec,
|
||||
dec->state,
|
||||
quant_scales + state_qoffset,
|
||||
r + state_qoffset,
|
||||
p0 + state_qoffset);
|
||||
|
||||
/* decode newest to oldest and store oldest to newest */
|
||||
for (i = 0; i < IMIN(DRED_NUM_REDUNDANCY_FRAMES, (min_feature_frames+1)/2); i += 2)
|
||||
|
@ -75,7 +83,7 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
|
|||
if (8*num_bytes - ec_tell(&ec) <= 7)
|
||||
break;
|
||||
q_level = compute_quantizer(q0, dQ, i/2);
|
||||
offset = q_level * DRED_LATENT_DIM;
|
||||
offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
|
||||
dred_decode_latents(
|
||||
&ec,
|
||||
&dec->latents[(i/2)*DRED_LATENT_DIM],
|
||||
|
|
|
@ -197,7 +197,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex
|
|||
/* 15 ms (6*2.5 ms) is the ideal offset for DRED because it corresponds to our vocoder look-ahead. */
|
||||
if (enc->dred_offset < 6) {
|
||||
enc->dred_offset += 8;
|
||||
OPUS_COPY(enc->initial_state, enc->state_buffer, 24);
|
||||
OPUS_COPY(enc->initial_state, enc->state_buffer, DRED_STATE_DIM);
|
||||
} else {
|
||||
enc->latent_offset++;
|
||||
}
|
||||
|
@ -221,6 +221,7 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
|
|||
int ec_buffer_fill;
|
||||
int q0;
|
||||
int dQ;
|
||||
int state_qoffset;
|
||||
|
||||
/* entropy coding of state and latents */
|
||||
ec_enc_init(&ec_encoder, buf, max_bytes);
|
||||
|
@ -229,15 +230,21 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
|
|||
ec_enc_uint(&ec_encoder, enc->dred_offset, 32);
|
||||
ec_enc_uint(&ec_encoder, q0, 16);
|
||||
ec_enc_uint(&ec_encoder, dQ, 8);
|
||||
dred_encode_state(&ec_encoder, enc->initial_state);
|
||||
|
||||
state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_STATE_DIM;
|
||||
dred_encode_latents(
|
||||
&ec_encoder,
|
||||
enc->initial_state,
|
||||
quant_scales + state_qoffset,
|
||||
dead_zone + state_qoffset,
|
||||
r + state_qoffset,
|
||||
p0 + state_qoffset);
|
||||
for (i = 0; i < IMIN(2*max_chunks, enc->latents_buffer_fill-enc->latent_offset-1); i += 2)
|
||||
{
|
||||
ec_enc ec_bak;
|
||||
ec_bak = ec_encoder;
|
||||
|
||||
q_level = compute_quantizer(q0, dQ, i/2);
|
||||
offset = q_level * DRED_LATENT_DIM;
|
||||
offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
|
||||
|
||||
dred_encode_latents(
|
||||
&ec_encoder,
|
||||
|
|
|
@ -50,8 +50,8 @@ typedef struct {
|
|||
int latent_offset;
|
||||
float latents_buffer[DRED_MAX_FRAMES * DRED_LATENT_DIM];
|
||||
int latents_buffer_fill;
|
||||
float state_buffer[24];
|
||||
float initial_state[24];
|
||||
float state_buffer[DRED_STATE_DIM];
|
||||
float initial_state[DRED_STATE_DIM];
|
||||
float resample_mem[RESAMPLING_ORDER + 1];
|
||||
LPCNetEncState lpcnet_enc_state;
|
||||
RDOVAEEncState rdovae_enc;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue