diff --git a/autogen.sh b/autogen.sh index efd3ef3d..d6888f81 100755 --- a/autogen.sh +++ b/autogen.sh @@ -9,7 +9,7 @@ set -e srcdir=`dirname $0` test -n "$srcdir" && cd "$srcdir" -dnn/download_model.sh 98b8be0 +dnn/download_model.sh 2386a60 echo "Updating build configuration files, please wait...." diff --git a/dnn/dred_rdovae.c b/dnn/dred_rdovae.c index b4797b5e..748a463a 100644 --- a/dnn/dred_rdovae.c +++ b/dnn/dred_rdovae.c @@ -77,24 +77,3 @@ void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float { dred_rdovae_decode_qframe(h, model, qframe, z); } - - -const opus_uint8 * DRED_rdovae_get_p0_pointer(void) -{ - return &dred_p0_q8[0]; -} - -const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void) -{ - return &dred_dead_zone_q10[0]; -} - -const opus_uint8 * DRED_rdovae_get_r_pointer(void) -{ - return &dred_r_q8[0]; -} - -const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void) -{ - return &dred_quant_scales_q8[0]; -} diff --git a/dnn/dred_rdovae_enc.c b/dnn/dred_rdovae_enc.c index 98ffba8c..e159e632 100644 --- a/dnn/dred_rdovae_enc.c +++ b/dnn/dred_rdovae_enc.c @@ -34,6 +34,7 @@ #include "dred_rdovae_enc.h" #include "os_support.h" +#include "dred_rdovae_constants.h" static void conv1_cond_init(float *mem, int len, int dilation, int *init) { @@ -52,6 +53,8 @@ void dred_rdovae_encode_dframe( const float *input /* i: double feature frame (concatenated) */ ) { + float padded_latents[DRED_PADDED_LATENT_DIM]; + float padded_state[DRED_PADDED_STATE_DIM]; float buffer[ENC_DENSE1_OUT_SIZE + ENC_GRU1_OUT_SIZE + ENC_GRU2_OUT_SIZE + ENC_GRU3_OUT_SIZE + ENC_GRU4_OUT_SIZE + ENC_GRU5_OUT_SIZE + ENC_CONV1_OUT_SIZE + ENC_CONV2_OUT_SIZE + ENC_CONV3_OUT_SIZE + ENC_CONV4_OUT_SIZE + ENC_CONV5_OUT_SIZE]; float state_hidden[GDENSE1_OUT_SIZE]; @@ -96,9 +99,11 @@ void dred_rdovae_encode_dframe( compute_generic_conv1d_dilation(&model->enc_conv5, &buffer[output_index], enc_state->conv5_state, buffer, output_index, 2, ACTIVATION_TANH); output_index += ENC_CONV5_OUT_SIZE; - compute_generic_dense(&model->enc_zdense, latents, buffer, ACTIVATION_LINEAR); + compute_generic_dense(&model->enc_zdense, padded_latents, buffer, ACTIVATION_LINEAR); + OPUS_COPY(latents, padded_latents, DRED_LATENT_DIM); /* next, calculate initial state */ compute_generic_dense(&model->gdense1, state_hidden, buffer, ACTIVATION_TANH); - compute_generic_dense(&model->gdense2, initial_state, state_hidden, ACTIVATION_LINEAR); + compute_generic_dense(&model->gdense2, padded_state, state_hidden, ACTIVATION_LINEAR); + OPUS_COPY(initial_state, padded_state, DRED_STATE_DIM); } diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py index 3bcc7712..001999c6 100644 --- a/dnn/torch/rdovae/export_rdovae_weights.py +++ b/dnn/torch/rdovae/export_rdovae_weights.py @@ -49,37 +49,43 @@ from wexchange.torch import dump_torch_weights from wexchange.c_export import CWriter, print_vector -def dump_statistical_model(writer, qembedding): - w = qembedding.weight.detach() - levels, dim = w.shape - N = dim // 6 +def dump_statistical_model(writer, w, name): + levels = w.shape[0] print("printing statistical model") - quant_scales = torch.nn.functional.softplus(w[:, : N]).numpy() - dead_zone = 0.05 * torch.nn.functional.softplus(w[:, N : 2 * N]).numpy() - r = torch.sigmoid(w[:, 5 * N : 6 * N]).numpy() - p0 = torch.sigmoid(w[:, 4 * N : 5 * N]).numpy() + quant_scales = torch.nn.functional.softplus(w[:, 0, :]).numpy() + dead_zone = 0.05 * torch.nn.functional.softplus(w[:, 1, :]).numpy() + r = torch.sigmoid(w[:, 5 , :]).numpy() + p0 = torch.sigmoid(w[:, 4 , :]).numpy() p0 = 1 - r ** (0.5 + 0.5 * p0) quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16) dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16) - r_q15 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8) - p0_q15 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16) + r_q8 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8) + p0_q8 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16) - print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False) - print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False) - print_vector(writer.source, r_q15, 'dred_r_q8', dtype='opus_uint8', static=False) - print_vector(writer.source, p0_q15, 'dred_p0_q8', dtype='opus_uint8', static=False) + mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255) + quant_scales_q8 = quant_scales_q8[:, mask] + dead_zone_q10 = dead_zone_q10[:, mask] + r_q8 = r_q8[:, mask] + p0_q8 = p0_q8[:, mask] + N = r_q8.shape[-1] + + print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint16', static=False) + print_vector(writer.source, dead_zone_q10, f'dred_{name}_dead_zone_q10', dtype='opus_uint16', static=False) + print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False) + print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False) writer.header.write( f""" -extern const opus_uint16 dred_quant_scales_q8[{levels * N}]; -extern const opus_uint16 dred_dead_zone_q10[{levels * N}]; -extern const opus_uint8 dred_r_q8[{levels * N}]; -extern const opus_uint8 dred_p0_q8[{levels * N}]; +extern const opus_uint16 dred_{name}_quant_scales_q8[{levels * N}]; +extern const opus_uint16 dred_{name}_dead_zone_q10[{levels * N}]; +extern const opus_uint8 dred_{name}_r_q8[{levels * N}]; +extern const opus_uint8 dred_{name}_p0_q8[{levels * N}]; """ ) + return N, mask def c_export(args, model): @@ -113,6 +119,41 @@ f""" """ ) + latent_out = model.get_submodule('core_encoder.module.z_dense') + state_out = model.get_submodule('core_encoder.module.state_dense_2') + orig_latent_dim = latent_out.weight.shape[0] + orig_state_dim = state_out.weight.shape[0] + # statistical model + qembedding = model.statistical_model.quant_embedding.weight.detach() + levels = qembedding.shape[0] + qembedding = torch.reshape(qembedding, (levels, 6, -1)) + + latent_dim, latent_mask = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent') + state_dim, state_mask = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state') + + padded_latent_dim = (latent_dim+7)//8*8 + latent_pad = padded_latent_dim - latent_dim; + w = latent_out.weight[latent_mask,:] + w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0) + b = latent_out.bias[latent_mask] + b = torch.cat([b, torch.zeros(latent_pad)], dim=0) + latent_out.weight = torch.nn.Parameter(w) + latent_out.bias = torch.nn.Parameter(b) + + padded_state_dim = (state_dim+7)//8*8 + state_pad = padded_state_dim - state_dim; + w = state_out.weight[state_mask,:] + w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0) + b = state_out.bias[state_mask] + b = torch.cat([b, torch.zeros(state_pad)], dim=0) + state_out.weight = torch.nn.Parameter(w) + state_out.bias = torch.nn.Parameter(b) + + latent_in = model.get_submodule('core_decoder.module.dense_1') + state_in = model.get_submodule('core_decoder.module.hidden_init') + latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask]) + state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask]) + # encoder encoder_dense_layers = [ ('core_encoder.module.dense_1' , 'enc_dense1', 'TANH', False,), @@ -187,10 +228,6 @@ f""" del dec_writer - # statistical model - qembedding = model.statistical_model.quant_embedding - dump_statistical_model(stats_writer, qembedding) - del stats_writer # constants @@ -198,9 +235,13 @@ f""" f""" #define DRED_NUM_FEATURES {model.feature_dim} -#define DRED_LATENT_DIM {model.latent_dim} +#define DRED_LATENT_DIM {latent_dim} -#define DRED_STATE_DIME {model.state_dim} +#define DRED_STATE_DIM {state_dim} + +#define DRED_PADDED_LATENT_DIM {padded_latent_dim} + +#define DRED_PADDED_STATE_DIM {padded_state_dim} #define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels} diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py index 98a45c3f..5dd9f138 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/common.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py @@ -124,6 +124,7 @@ def extract_diagonal(A): return diag, B def quantize_weight(weight, scale): + scale = scale + 1e-30 Aq = np.round(weight / scale).astype('int') if Aq.max() > 127 or Aq.min() <= -128: raise ValueError("value out of bounds in quantize_weight") @@ -227,7 +228,7 @@ def print_linear_layer(writer : CWriter, nb_inputs, nb_outputs = weight.shape - if scale is None: + if scale is None and quantize: scale = compute_scaling(weight) @@ -359,4 +360,4 @@ def print_gru_layer(writer : CWriter, writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {N}\n") writer.header.write(f"\n#define {name.upper()}_STATE_SIZE {N}\n") - return N \ No newline at end of file + return N diff --git a/silk/dred_config.h b/silk/dred_config.h index 5e3e74a3..207908fc 100644 --- a/silk/dred_config.h +++ b/silk/dred_config.h @@ -39,9 +39,6 @@ #define DRED_MIN_BYTES 16 /* these are inpart duplicates to the values defined in dred_rdovae_constants.h */ -#define DRED_NUM_FEATURES 20 -#define DRED_LATENT_DIM 80 -#define DRED_STATE_DIM 80 #define DRED_SILK_ENCODER_DELAY (79+12-80) #define DRED_FRAME_SIZE 160 #define DRED_DFRAME_SIZE (2 * (DRED_FRAME_SIZE)) diff --git a/silk/dred_decoder.c b/silk/dred_decoder.c index 68ba8559..c1489f3c 100644 --- a/silk/dred_decoder.c +++ b/silk/dred_decoder.c @@ -36,6 +36,8 @@ #include "dred_coding.h" #include "celt/entdec.h" #include "celt/laplace.h" +#include "dred_rdovae_stats_data.h" +#include "dred_rdovae_constants.h" /* From http://graphics.stanford.edu/~seander/bithacks.html#FixedSignExtend */ static int sign_extend(int x, int b) { @@ -55,9 +57,6 @@ static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int min_feature_frames) { - const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer(); - const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer(); - const opus_uint8 *r = DRED_rdovae_get_r_pointer(); ec_dec ec; int q_level; int i; @@ -78,13 +77,13 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi /*printf("%d %d %d\n", dred_offset, q0, dQ);*/ //dred_decode_state(&ec, dec->state); - state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_LATENT_DIM; + state_qoffset = q0*DRED_STATE_DIM; dred_decode_latents( &ec, dec->state, - quant_scales + state_qoffset, - r + state_qoffset, - p0 + state_qoffset, + dred_state_quant_scales_q8 + state_qoffset, + dred_state_r_q8 + state_qoffset, + dred_state_p0_q8 + state_qoffset, DRED_STATE_DIM); /* decode newest to oldest and store oldest to newest */ @@ -94,13 +93,13 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi if (8*num_bytes - ec_tell(&ec) <= 7) break; q_level = compute_quantizer(q0, dQ, i/2); - offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM); + offset = q_level*DRED_LATENT_DIM; dred_decode_latents( &ec, &dec->latents[(i/2)*DRED_LATENT_DIM], - quant_scales + offset, - r + offset, - p0 + offset, + dred_latent_quant_scales_q8 + offset, + dred_latent_r_q8 + offset, + dred_latent_p0_q8 + offset, DRED_LATENT_DIM ); diff --git a/silk/dred_decoder.h b/silk/dred_decoder.h index c7355ed1..f8d050ff 100644 --- a/silk/dred_decoder.h +++ b/silk/dred_decoder.h @@ -32,6 +32,7 @@ #include "dred_config.h" #include "dred_rdovae.h" #include "entcode.h" +#include "dred_rdovae_constants.h" struct OpusDRED { float fec_features[2*DRED_NUM_REDUNDANCY_FRAMES*DRED_NUM_FEATURES]; diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c index 3f842af0..e4959b30 100644 --- a/silk/dred_encoder.c +++ b/silk/dred_encoder.c @@ -44,6 +44,7 @@ #include "float_cast.h" #include "os_support.h" #include "celt/laplace.h" +#include "dred_rdovae_stats_data.h" int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len) @@ -244,10 +245,6 @@ static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 * } int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunks, int max_bytes) { - const opus_uint16 *dead_zone = DRED_rdovae_get_dead_zone_pointer(); - const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer(); - const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer(); - const opus_uint8 *r = DRED_rdovae_get_r_pointer(); ec_enc ec_encoder; int q_level; @@ -265,14 +262,14 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk ec_enc_uint(&ec_encoder, enc->dred_offset, 32); ec_enc_uint(&ec_encoder, q0, 16); ec_enc_uint(&ec_encoder, dQ, 8); - state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_LATENT_DIM; + state_qoffset = q0*DRED_STATE_DIM; dred_encode_latents( &ec_encoder, enc->initial_state, - quant_scales + state_qoffset, - dead_zone + state_qoffset, - r + state_qoffset, - p0 + state_qoffset, + dred_state_quant_scales_q8 + state_qoffset, + dred_state_dead_zone_q10 + state_qoffset, + dred_state_r_q8 + state_qoffset, + dred_state_p0_q8 + state_qoffset, DRED_STATE_DIM); if (ec_tell(&ec_encoder) > 8*max_bytes) { return 0; @@ -283,15 +280,15 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk ec_bak = ec_encoder; q_level = compute_quantizer(q0, dQ, i/2); - offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM); + offset = q_level * DRED_LATENT_DIM; dred_encode_latents( &ec_encoder, enc->latents_buffer + (i+enc->latent_offset) * DRED_LATENT_DIM, - quant_scales + offset, - dead_zone + offset, - r + offset, - p0 + offset, + dred_latent_quant_scales_q8 + offset, + dred_latent_dead_zone_q10 + offset, + dred_latent_r_q8 + offset, + dred_latent_p0_q8 + offset, DRED_LATENT_DIM ); if (ec_tell(&ec_encoder) > 8*max_bytes) {