DRED: quantize scale and dead zone to 8 bits

This commit is contained in:
Jean-Marc Valin 2023-11-07 17:46:38 -05:00
parent 4e104555e9
commit 222662dac8
No known key found for this signature in database
GPG key ID: 531A52533318F00A
4 changed files with 23 additions and 17 deletions

View file

@ -59,33 +59,35 @@ def dump_statistical_model(writer, w, name):
p0 = torch.sigmoid(w[:, 4 , :]).numpy()
p0 = 1 - r ** (0.5 + 0.5 * p0)
scales_norm = 255./256./(1e-15+np.max(quant_scales,axis=0))
quant_scales = quant_scales*scales_norm
quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16)
dead_zone_q8 = np.clip(np.round(dead_zone * 2**8), 0, 255).astype(np.uint16)
r_q8 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
p0_q8 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255)
quant_scales_q8 = quant_scales_q8[:, mask]
dead_zone_q10 = dead_zone_q10[:, mask]
dead_zone_q8 = dead_zone_q8[:, mask]
r_q8 = r_q8[:, mask]
p0_q8 = p0_q8[:, mask]
N = r_q8.shape[-1]
print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint16', static=False)
print_vector(writer.source, dead_zone_q10, f'dred_{name}_dead_zone_q10', dtype='opus_uint16', static=False)
print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint8', static=False)
print_vector(writer.source, dead_zone_q8, f'dred_{name}_dead_zone_q8', dtype='opus_uint8', static=False)
print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
writer.header.write(
f"""
extern const opus_uint16 dred_{name}_quant_scales_q8[{levels * N}];
extern const opus_uint16 dred_{name}_dead_zone_q10[{levels * N}];
extern const opus_uint8 dred_{name}_quant_scales_q8[{levels * N}];
extern const opus_uint8 dred_{name}_dead_zone_q8[{levels * N}];
extern const opus_uint8 dred_{name}_r_q8[{levels * N}];
extern const opus_uint8 dred_{name}_p0_q8[{levels * N}];
"""
)
return N, mask
return N, mask, torch.tensor(scales_norm[mask])
def c_export(args, model):
@ -128,14 +130,16 @@ f"""
levels = qembedding.shape[0]
qembedding = torch.reshape(qembedding, (levels, 6, -1))
latent_dim, latent_mask = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
state_dim, state_mask = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')
latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')
padded_latent_dim = (latent_dim+7)//8*8
latent_pad = padded_latent_dim - latent_dim;
w = latent_out.weight[latent_mask,:]
w = w/latent_scale[:, None]
w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0)
b = latent_out.bias[latent_mask]
b = b/latent_scale
b = torch.cat([b, torch.zeros(latent_pad)], dim=0)
latent_out.weight = torch.nn.Parameter(w)
latent_out.bias = torch.nn.Parameter(b)
@ -143,16 +147,18 @@ f"""
padded_state_dim = (state_dim+7)//8*8
state_pad = padded_state_dim - state_dim;
w = state_out.weight[state_mask,:]
w = w/state_scale[:, None]
w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0)
b = state_out.bias[state_mask]
b = b/state_scale
b = torch.cat([b, torch.zeros(state_pad)], dim=0)
state_out.weight = torch.nn.Parameter(w)
state_out.bias = torch.nn.Parameter(b)
latent_in = model.get_submodule('core_decoder.module.dense_1')
state_in = model.get_submodule('core_decoder.module.hidden_init')
latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask])
state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask])
latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask]*latent_scale)
state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask]*state_scale)
# encoder
encoder_dense_layers = [