Split stats in two and remove useless dimensions
This commit is contained in:
parent
2386a60ec6
commit
0ab0640d4a
9 changed files with 98 additions and 78 deletions
|
@ -49,37 +49,43 @@ from wexchange.torch import dump_torch_weights
|
|||
from wexchange.c_export import CWriter, print_vector
|
||||
|
||||
|
||||
def dump_statistical_model(writer, qembedding):
|
||||
w = qembedding.weight.detach()
|
||||
levels, dim = w.shape
|
||||
N = dim // 6
|
||||
def dump_statistical_model(writer, w, name):
|
||||
levels = w.shape[0]
|
||||
|
||||
print("printing statistical model")
|
||||
quant_scales = torch.nn.functional.softplus(w[:, : N]).numpy()
|
||||
dead_zone = 0.05 * torch.nn.functional.softplus(w[:, N : 2 * N]).numpy()
|
||||
r = torch.sigmoid(w[:, 5 * N : 6 * N]).numpy()
|
||||
p0 = torch.sigmoid(w[:, 4 * N : 5 * N]).numpy()
|
||||
quant_scales = torch.nn.functional.softplus(w[:, 0, :]).numpy()
|
||||
dead_zone = 0.05 * torch.nn.functional.softplus(w[:, 1, :]).numpy()
|
||||
r = torch.sigmoid(w[:, 5 , :]).numpy()
|
||||
p0 = torch.sigmoid(w[:, 4 , :]).numpy()
|
||||
p0 = 1 - r ** (0.5 + 0.5 * p0)
|
||||
|
||||
quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
|
||||
dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16)
|
||||
r_q15 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
|
||||
p0_q15 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
|
||||
r_q8 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
|
||||
p0_q8 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
|
||||
|
||||
print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
|
||||
print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
|
||||
print_vector(writer.source, r_q15, 'dred_r_q8', dtype='opus_uint8', static=False)
|
||||
print_vector(writer.source, p0_q15, 'dred_p0_q8', dtype='opus_uint8', static=False)
|
||||
mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255)
|
||||
quant_scales_q8 = quant_scales_q8[:, mask]
|
||||
dead_zone_q10 = dead_zone_q10[:, mask]
|
||||
r_q8 = r_q8[:, mask]
|
||||
p0_q8 = p0_q8[:, mask]
|
||||
N = r_q8.shape[-1]
|
||||
|
||||
print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint16', static=False)
|
||||
print_vector(writer.source, dead_zone_q10, f'dred_{name}_dead_zone_q10', dtype='opus_uint16', static=False)
|
||||
print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
|
||||
print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
|
||||
|
||||
writer.header.write(
|
||||
f"""
|
||||
extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
|
||||
extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
|
||||
extern const opus_uint8 dred_r_q8[{levels * N}];
|
||||
extern const opus_uint8 dred_p0_q8[{levels * N}];
|
||||
extern const opus_uint16 dred_{name}_quant_scales_q8[{levels * N}];
|
||||
extern const opus_uint16 dred_{name}_dead_zone_q10[{levels * N}];
|
||||
extern const opus_uint8 dred_{name}_r_q8[{levels * N}];
|
||||
extern const opus_uint8 dred_{name}_p0_q8[{levels * N}];
|
||||
|
||||
"""
|
||||
)
|
||||
return N, mask
|
||||
|
||||
|
||||
def c_export(args, model):
|
||||
|
@ -113,6 +119,41 @@ f"""
|
|||
"""
|
||||
)
|
||||
|
||||
latent_out = model.get_submodule('core_encoder.module.z_dense')
|
||||
state_out = model.get_submodule('core_encoder.module.state_dense_2')
|
||||
orig_latent_dim = latent_out.weight.shape[0]
|
||||
orig_state_dim = state_out.weight.shape[0]
|
||||
# statistical model
|
||||
qembedding = model.statistical_model.quant_embedding.weight.detach()
|
||||
levels = qembedding.shape[0]
|
||||
qembedding = torch.reshape(qembedding, (levels, 6, -1))
|
||||
|
||||
latent_dim, latent_mask = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
|
||||
state_dim, state_mask = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')
|
||||
|
||||
padded_latent_dim = (latent_dim+7)//8*8
|
||||
latent_pad = padded_latent_dim - latent_dim;
|
||||
w = latent_out.weight[latent_mask,:]
|
||||
w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0)
|
||||
b = latent_out.bias[latent_mask]
|
||||
b = torch.cat([b, torch.zeros(latent_pad)], dim=0)
|
||||
latent_out.weight = torch.nn.Parameter(w)
|
||||
latent_out.bias = torch.nn.Parameter(b)
|
||||
|
||||
padded_state_dim = (state_dim+7)//8*8
|
||||
state_pad = padded_state_dim - state_dim;
|
||||
w = state_out.weight[state_mask,:]
|
||||
w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0)
|
||||
b = state_out.bias[state_mask]
|
||||
b = torch.cat([b, torch.zeros(state_pad)], dim=0)
|
||||
state_out.weight = torch.nn.Parameter(w)
|
||||
state_out.bias = torch.nn.Parameter(b)
|
||||
|
||||
latent_in = model.get_submodule('core_decoder.module.dense_1')
|
||||
state_in = model.get_submodule('core_decoder.module.hidden_init')
|
||||
latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask])
|
||||
state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask])
|
||||
|
||||
# encoder
|
||||
encoder_dense_layers = [
|
||||
('core_encoder.module.dense_1' , 'enc_dense1', 'TANH', False,),
|
||||
|
@ -187,10 +228,6 @@ f"""
|
|||
|
||||
del dec_writer
|
||||
|
||||
# statistical model
|
||||
qembedding = model.statistical_model.quant_embedding
|
||||
dump_statistical_model(stats_writer, qembedding)
|
||||
|
||||
del stats_writer
|
||||
|
||||
# constants
|
||||
|
@ -198,9 +235,13 @@ f"""
|
|||
f"""
|
||||
#define DRED_NUM_FEATURES {model.feature_dim}
|
||||
|
||||
#define DRED_LATENT_DIM {model.latent_dim}
|
||||
#define DRED_LATENT_DIM {latent_dim}
|
||||
|
||||
#define DRED_STATE_DIME {model.state_dim}
|
||||
#define DRED_STATE_DIM {state_dim}
|
||||
|
||||
#define DRED_PADDED_LATENT_DIM {padded_latent_dim}
|
||||
|
||||
#define DRED_PADDED_STATE_DIM {padded_state_dim}
|
||||
|
||||
#define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels}
|
||||
|
||||
|
|
|
@ -124,6 +124,7 @@ def extract_diagonal(A):
|
|||
return diag, B
|
||||
|
||||
def quantize_weight(weight, scale):
|
||||
scale = scale + 1e-30
|
||||
Aq = np.round(weight / scale).astype('int')
|
||||
if Aq.max() > 127 or Aq.min() <= -128:
|
||||
raise ValueError("value out of bounds in quantize_weight")
|
||||
|
@ -227,7 +228,7 @@ def print_linear_layer(writer : CWriter,
|
|||
|
||||
nb_inputs, nb_outputs = weight.shape
|
||||
|
||||
if scale is None:
|
||||
if scale is None and quantize:
|
||||
scale = compute_scaling(weight)
|
||||
|
||||
|
||||
|
@ -359,4 +360,4 @@ def print_gru_layer(writer : CWriter,
|
|||
writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {N}\n")
|
||||
writer.header.write(f"\n#define {name.upper()}_STATE_SIZE {N}\n")
|
||||
|
||||
return N
|
||||
return N
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue