DRED: quantize scale and dead zone to 8 bits
This commit is contained in:
parent
4e104555e9
commit
222662dac8
4 changed files with 23 additions and 17 deletions
|
@ -9,7 +9,7 @@ set -e
|
||||||
srcdir=`dirname $0`
|
srcdir=`dirname $0`
|
||||||
test -n "$srcdir" && cd "$srcdir"
|
test -n "$srcdir" && cd "$srcdir"
|
||||||
|
|
||||||
dnn/download_model.sh 2386a60
|
dnn/download_model.sh b6095cf
|
||||||
|
|
||||||
echo "Updating build configuration files, please wait...."
|
echo "Updating build configuration files, please wait...."
|
||||||
|
|
||||||
|
|
|
@ -59,33 +59,35 @@ def dump_statistical_model(writer, w, name):
|
||||||
p0 = torch.sigmoid(w[:, 4 , :]).numpy()
|
p0 = torch.sigmoid(w[:, 4 , :]).numpy()
|
||||||
p0 = 1 - r ** (0.5 + 0.5 * p0)
|
p0 = 1 - r ** (0.5 + 0.5 * p0)
|
||||||
|
|
||||||
|
scales_norm = 255./256./(1e-15+np.max(quant_scales,axis=0))
|
||||||
|
quant_scales = quant_scales*scales_norm
|
||||||
quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
|
quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
|
||||||
dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16)
|
dead_zone_q8 = np.clip(np.round(dead_zone * 2**8), 0, 255).astype(np.uint16)
|
||||||
r_q8 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
|
r_q8 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
|
||||||
p0_q8 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
|
p0_q8 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
|
||||||
|
|
||||||
mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255)
|
mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255)
|
||||||
quant_scales_q8 = quant_scales_q8[:, mask]
|
quant_scales_q8 = quant_scales_q8[:, mask]
|
||||||
dead_zone_q10 = dead_zone_q10[:, mask]
|
dead_zone_q8 = dead_zone_q8[:, mask]
|
||||||
r_q8 = r_q8[:, mask]
|
r_q8 = r_q8[:, mask]
|
||||||
p0_q8 = p0_q8[:, mask]
|
p0_q8 = p0_q8[:, mask]
|
||||||
N = r_q8.shape[-1]
|
N = r_q8.shape[-1]
|
||||||
|
|
||||||
print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint16', static=False)
|
print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint8', static=False)
|
||||||
print_vector(writer.source, dead_zone_q10, f'dred_{name}_dead_zone_q10', dtype='opus_uint16', static=False)
|
print_vector(writer.source, dead_zone_q8, f'dred_{name}_dead_zone_q8', dtype='opus_uint8', static=False)
|
||||||
print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
|
print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
|
||||||
print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
|
print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
|
||||||
|
|
||||||
writer.header.write(
|
writer.header.write(
|
||||||
f"""
|
f"""
|
||||||
extern const opus_uint16 dred_{name}_quant_scales_q8[{levels * N}];
|
extern const opus_uint8 dred_{name}_quant_scales_q8[{levels * N}];
|
||||||
extern const opus_uint16 dred_{name}_dead_zone_q10[{levels * N}];
|
extern const opus_uint8 dred_{name}_dead_zone_q8[{levels * N}];
|
||||||
extern const opus_uint8 dred_{name}_r_q8[{levels * N}];
|
extern const opus_uint8 dred_{name}_r_q8[{levels * N}];
|
||||||
extern const opus_uint8 dred_{name}_p0_q8[{levels * N}];
|
extern const opus_uint8 dred_{name}_p0_q8[{levels * N}];
|
||||||
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
return N, mask
|
return N, mask, torch.tensor(scales_norm[mask])
|
||||||
|
|
||||||
|
|
||||||
def c_export(args, model):
|
def c_export(args, model):
|
||||||
|
@ -128,14 +130,16 @@ f"""
|
||||||
levels = qembedding.shape[0]
|
levels = qembedding.shape[0]
|
||||||
qembedding = torch.reshape(qembedding, (levels, 6, -1))
|
qembedding = torch.reshape(qembedding, (levels, 6, -1))
|
||||||
|
|
||||||
latent_dim, latent_mask = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
|
latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
|
||||||
state_dim, state_mask = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')
|
state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')
|
||||||
|
|
||||||
padded_latent_dim = (latent_dim+7)//8*8
|
padded_latent_dim = (latent_dim+7)//8*8
|
||||||
latent_pad = padded_latent_dim - latent_dim;
|
latent_pad = padded_latent_dim - latent_dim;
|
||||||
w = latent_out.weight[latent_mask,:]
|
w = latent_out.weight[latent_mask,:]
|
||||||
|
w = w/latent_scale[:, None]
|
||||||
w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0)
|
w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0)
|
||||||
b = latent_out.bias[latent_mask]
|
b = latent_out.bias[latent_mask]
|
||||||
|
b = b/latent_scale
|
||||||
b = torch.cat([b, torch.zeros(latent_pad)], dim=0)
|
b = torch.cat([b, torch.zeros(latent_pad)], dim=0)
|
||||||
latent_out.weight = torch.nn.Parameter(w)
|
latent_out.weight = torch.nn.Parameter(w)
|
||||||
latent_out.bias = torch.nn.Parameter(b)
|
latent_out.bias = torch.nn.Parameter(b)
|
||||||
|
@ -143,16 +147,18 @@ f"""
|
||||||
padded_state_dim = (state_dim+7)//8*8
|
padded_state_dim = (state_dim+7)//8*8
|
||||||
state_pad = padded_state_dim - state_dim;
|
state_pad = padded_state_dim - state_dim;
|
||||||
w = state_out.weight[state_mask,:]
|
w = state_out.weight[state_mask,:]
|
||||||
|
w = w/state_scale[:, None]
|
||||||
w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0)
|
w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0)
|
||||||
b = state_out.bias[state_mask]
|
b = state_out.bias[state_mask]
|
||||||
|
b = b/state_scale
|
||||||
b = torch.cat([b, torch.zeros(state_pad)], dim=0)
|
b = torch.cat([b, torch.zeros(state_pad)], dim=0)
|
||||||
state_out.weight = torch.nn.Parameter(w)
|
state_out.weight = torch.nn.Parameter(w)
|
||||||
state_out.bias = torch.nn.Parameter(b)
|
state_out.bias = torch.nn.Parameter(b)
|
||||||
|
|
||||||
latent_in = model.get_submodule('core_decoder.module.dense_1')
|
latent_in = model.get_submodule('core_decoder.module.dense_1')
|
||||||
state_in = model.get_submodule('core_decoder.module.hidden_init')
|
state_in = model.get_submodule('core_decoder.module.hidden_init')
|
||||||
latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask])
|
latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask]*latent_scale)
|
||||||
state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask])
|
state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask]*state_scale)
|
||||||
|
|
||||||
# encoder
|
# encoder
|
||||||
encoder_dense_layers = [
|
encoder_dense_layers = [
|
||||||
|
|
|
@ -45,7 +45,7 @@ static int sign_extend(int x, int b) {
|
||||||
return (x ^ m) - m;
|
return (x ^ m) - m;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
|
static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint8 *scale, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
|
||||||
int i;
|
int i;
|
||||||
for (i=0;i<dim;i++) {
|
for (i=0;i<dim;i++) {
|
||||||
int q;
|
int q;
|
||||||
|
|
|
@ -223,7 +223,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
|
static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint8 *scale, const opus_uint8 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
|
||||||
int i;
|
int i;
|
||||||
int q[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
|
int q[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
|
||||||
float xq[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
|
float xq[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
|
||||||
|
@ -233,7 +233,7 @@ static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *
|
||||||
/* This is split into multiple loops (with temporary arrays) so that the compiler
|
/* This is split into multiple loops (with temporary arrays) so that the compiler
|
||||||
can vectorize all of it, and so we can call the vector tanh(). */
|
can vectorize all of it, and so we can call the vector tanh(). */
|
||||||
for (i=0;i<dim;i++) {
|
for (i=0;i<dim;i++) {
|
||||||
delta[i] = dzone[i]*(1.f/1024.f);
|
delta[i] = dzone[i]*(1.f/256.f);
|
||||||
xq[i] = x[i]*scale[i]*(1.f/256.f);
|
xq[i] = x[i]*scale[i]*(1.f/256.f);
|
||||||
deadzone[i] = xq[i]/(delta[i]+eps);
|
deadzone[i] = xq[i]/(delta[i]+eps);
|
||||||
}
|
}
|
||||||
|
@ -272,7 +272,7 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
|
||||||
&ec_encoder,
|
&ec_encoder,
|
||||||
enc->initial_state,
|
enc->initial_state,
|
||||||
dred_state_quant_scales_q8 + state_qoffset,
|
dred_state_quant_scales_q8 + state_qoffset,
|
||||||
dred_state_dead_zone_q10 + state_qoffset,
|
dred_state_dead_zone_q8 + state_qoffset,
|
||||||
dred_state_r_q8 + state_qoffset,
|
dred_state_r_q8 + state_qoffset,
|
||||||
dred_state_p0_q8 + state_qoffset,
|
dred_state_p0_q8 + state_qoffset,
|
||||||
DRED_STATE_DIM);
|
DRED_STATE_DIM);
|
||||||
|
@ -291,7 +291,7 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
|
||||||
&ec_encoder,
|
&ec_encoder,
|
||||||
enc->latents_buffer + (i+enc->latent_offset) * DRED_LATENT_DIM,
|
enc->latents_buffer + (i+enc->latent_offset) * DRED_LATENT_DIM,
|
||||||
dred_latent_quant_scales_q8 + offset,
|
dred_latent_quant_scales_q8 + offset,
|
||||||
dred_latent_dead_zone_q10 + offset,
|
dred_latent_dead_zone_q8 + offset,
|
||||||
dred_latent_r_q8 + offset,
|
dred_latent_r_q8 + offset,
|
||||||
dred_latent_p0_q8 + offset,
|
dred_latent_p0_q8 + offset,
|
||||||
DRED_LATENT_DIM
|
DRED_LATENT_DIM
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue