Using sparse GRUs in DRED decoder
Saves ~270 kB of weights in the decoder
This commit is contained in:
parent
58923f61c2
commit
b0620c0bf9
7 changed files with 105 additions and 19 deletions
|
@ -9,7 +9,7 @@ set -e
|
|||
srcdir=`dirname $0`
|
||||
test -n "$srcdir" && cd "$srcdir"
|
||||
|
||||
dnn/download_model.sh b6095cf
|
||||
dnn/download_model.sh 58923f6
|
||||
|
||||
echo "Updating build configuration files, please wait...."
|
||||
|
||||
|
|
|
@ -98,35 +98,35 @@ void dred_rdovae_decode_qframe(
|
|||
output_index += DEC_DENSE1_OUT_SIZE;
|
||||
|
||||
compute_generic_gru(&model->dec_gru1_input, &model->dec_gru1_recurrent, dec_state->gru1_state, buffer);
|
||||
OPUS_COPY(&buffer[output_index], dec_state->gru1_state, DEC_GRU1_OUT_SIZE);
|
||||
compute_glu(&model->dec_glu1, &buffer[output_index], dec_state->gru1_state);
|
||||
output_index += DEC_GRU1_OUT_SIZE;
|
||||
conv1_cond_init(dec_state->conv1_state, output_index, 1, &dec_state->initialized);
|
||||
compute_generic_conv1d(&model->dec_conv1, &buffer[output_index], dec_state->conv1_state, buffer, output_index, ACTIVATION_TANH);
|
||||
output_index += DEC_CONV1_OUT_SIZE;
|
||||
|
||||
compute_generic_gru(&model->dec_gru2_input, &model->dec_gru2_recurrent, dec_state->gru2_state, buffer);
|
||||
OPUS_COPY(&buffer[output_index], dec_state->gru2_state, DEC_GRU2_OUT_SIZE);
|
||||
compute_glu(&model->dec_glu2, &buffer[output_index], dec_state->gru2_state);
|
||||
output_index += DEC_GRU2_OUT_SIZE;
|
||||
conv1_cond_init(dec_state->conv2_state, output_index, 1, &dec_state->initialized);
|
||||
compute_generic_conv1d(&model->dec_conv2, &buffer[output_index], dec_state->conv2_state, buffer, output_index, ACTIVATION_TANH);
|
||||
output_index += DEC_CONV2_OUT_SIZE;
|
||||
|
||||
compute_generic_gru(&model->dec_gru3_input, &model->dec_gru3_recurrent, dec_state->gru3_state, buffer);
|
||||
OPUS_COPY(&buffer[output_index], dec_state->gru3_state, DEC_GRU3_OUT_SIZE);
|
||||
compute_glu(&model->dec_glu3, &buffer[output_index], dec_state->gru3_state);
|
||||
output_index += DEC_GRU3_OUT_SIZE;
|
||||
conv1_cond_init(dec_state->conv3_state, output_index, 1, &dec_state->initialized);
|
||||
compute_generic_conv1d(&model->dec_conv3, &buffer[output_index], dec_state->conv3_state, buffer, output_index, ACTIVATION_TANH);
|
||||
output_index += DEC_CONV3_OUT_SIZE;
|
||||
|
||||
compute_generic_gru(&model->dec_gru4_input, &model->dec_gru4_recurrent, dec_state->gru4_state, buffer);
|
||||
OPUS_COPY(&buffer[output_index], dec_state->gru4_state, DEC_GRU4_OUT_SIZE);
|
||||
compute_glu(&model->dec_glu4, &buffer[output_index], dec_state->gru4_state);
|
||||
output_index += DEC_GRU4_OUT_SIZE;
|
||||
conv1_cond_init(dec_state->conv4_state, output_index, 1, &dec_state->initialized);
|
||||
compute_generic_conv1d(&model->dec_conv4, &buffer[output_index], dec_state->conv4_state, buffer, output_index, ACTIVATION_TANH);
|
||||
output_index += DEC_CONV4_OUT_SIZE;
|
||||
|
||||
compute_generic_gru(&model->dec_gru5_input, &model->dec_gru5_recurrent, dec_state->gru5_state, buffer);
|
||||
OPUS_COPY(&buffer[output_index], dec_state->gru5_state, DEC_GRU5_OUT_SIZE);
|
||||
compute_glu(&model->dec_glu5, &buffer[output_index], dec_state->gru5_state);
|
||||
output_index += DEC_GRU5_OUT_SIZE;
|
||||
conv1_cond_init(dec_state->conv5_state, output_index, 1, &dec_state->initialized);
|
||||
compute_generic_conv1d(&model->dec_conv5, &buffer[output_index], dec_state->conv5_state, buffer, output_index, ACTIVATION_TANH);
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
|
||||
import torch
|
||||
|
||||
def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False):
|
||||
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
|
||||
""" sparsifies matrix with specified block size
|
||||
|
||||
Parameters:
|
||||
|
|
|
@ -226,6 +226,11 @@ f"""
|
|||
# decoder
|
||||
decoder_dense_layers = [
|
||||
('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False),
|
||||
('core_decoder.module.glu1.gate' , 'dec_glu1', 'TANH', True),
|
||||
('core_decoder.module.glu2.gate' , 'dec_glu2', 'TANH', True),
|
||||
('core_decoder.module.glu3.gate' , 'dec_glu3', 'TANH', True),
|
||||
('core_decoder.module.glu4.gate' , 'dec_glu4', 'TANH', True),
|
||||
('core_decoder.module.glu5.gate' , 'dec_glu5', 'TANH', True),
|
||||
('core_decoder.module.output' , 'dec_output', 'LINEAR', True),
|
||||
('core_decoder.module.hidden_init' , 'dec_hidden_init', 'TANH', False),
|
||||
('core_decoder.module.gru_init' , 'dec_gru_init','TANH', True),
|
||||
|
@ -338,6 +343,13 @@ if __name__ == "__main__":
|
|||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
model.apply(_remove_weight_norm)
|
||||
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(f"error: missing keys in state dict")
|
||||
|
|
|
@ -34,6 +34,12 @@ import math as m
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import sys
|
||||
import os
|
||||
source_dir = os.path.split(os.path.abspath(__file__))[0]
|
||||
sys.path.append(os.path.join(source_dir, "../../lpcnet/"))
|
||||
from utils.sparsification import GRUSparsifier
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
# Quantization and rate related utily functions
|
||||
|
||||
|
@ -227,6 +233,32 @@ def n(x):
|
|||
|
||||
# RDOVAE module and submodules
|
||||
|
||||
sparsify_start = 12000
|
||||
sparsify_stop = 24000
|
||||
sparsify_interval = 100
|
||||
sparsify_exponent = 3
|
||||
#sparsify_start = 0
|
||||
#sparsify_stop = 0
|
||||
|
||||
sparse_params1 = {
|
||||
# 'W_hr' : (1.0, [8, 4], True),
|
||||
# 'W_hz' : (1.0, [8, 4], True),
|
||||
# 'W_hn' : (1.0, [8, 4], True),
|
||||
'W_ir' : (0.6, [8, 4], False),
|
||||
'W_iz' : (0.4, [8, 4], False),
|
||||
'W_in' : (0.8, [8, 4], False)
|
||||
}
|
||||
|
||||
sparse_params2 = {
|
||||
# 'W_hr' : (1.0, [8, 4], True),
|
||||
# 'W_hz' : (1.0, [8, 4], True),
|
||||
# 'W_hn' : (1.0, [8, 4], True),
|
||||
'W_ir' : (0.3, [8, 4], False),
|
||||
'W_iz' : (0.2, [8, 4], False),
|
||||
'W_in' : (0.4, [8, 4], False)
|
||||
}
|
||||
|
||||
|
||||
class MyConv(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dilation=1):
|
||||
super(MyConv, self).__init__()
|
||||
|
@ -239,6 +271,29 @@ class MyConv(nn.Module):
|
|||
conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1)
|
||||
return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, feat_size):
|
||||
super(GLU, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
|
||||
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
out = x * torch.sigmoid(self.gate(x))
|
||||
|
||||
return out
|
||||
|
||||
class CoreEncoder(nn.Module):
|
||||
STATE_HIDDEN = 128
|
||||
FRAMES_PER_STEP = 2
|
||||
|
@ -355,7 +410,11 @@ class CoreDecoder(nn.Module):
|
|||
self.gru5 = nn.GRU(608, 96, batch_first=True)
|
||||
self.conv5 = MyConv(704, 32)
|
||||
self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim)
|
||||
|
||||
self.glu1 = GLU(96)
|
||||
self.glu2 = GLU(96)
|
||||
self.glu3 = GLU(96)
|
||||
self.glu4 = GLU(96)
|
||||
self.glu5 = GLU(96)
|
||||
self.hidden_init = nn.Linear(self.state_size, 128)
|
||||
self.gru_init = nn.Linear(128, 480)
|
||||
|
||||
|
@ -363,6 +422,16 @@ class CoreDecoder(nn.Module):
|
|||
print(f"decoder: {nb_params} weights")
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
self.sparsifier = []
|
||||
self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
|
||||
self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
|
||||
self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
|
||||
self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
|
||||
self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
|
||||
|
||||
def sparsify(self):
|
||||
for sparsifier in self.sparsifier:
|
||||
sparsifier.step()
|
||||
|
||||
def forward(self, z, initial_state):
|
||||
|
||||
|
@ -377,15 +446,15 @@ class CoreDecoder(nn.Module):
|
|||
# run decoding layer stack
|
||||
x = n(torch.tanh(self.dense_1(z)))
|
||||
|
||||
x = torch.cat([x, n(self.gru1(x, h1_state)[0])], -1)
|
||||
x = torch.cat([x, n(self.glu1(n(self.gru1(x, h1_state)[0])))], -1)
|
||||
x = torch.cat([x, n(self.conv1(x))], -1)
|
||||
x = torch.cat([x, n(self.gru2(x, h2_state)[0])], -1)
|
||||
x = torch.cat([x, n(self.glu2(n(self.gru2(x, h2_state)[0])))], -1)
|
||||
x = torch.cat([x, n(self.conv2(x))], -1)
|
||||
x = torch.cat([x, n(self.gru3(x, h3_state)[0])], -1)
|
||||
x = torch.cat([x, n(self.glu3(n(self.gru3(x, h3_state)[0])))], -1)
|
||||
x = torch.cat([x, n(self.conv3(x))], -1)
|
||||
x = torch.cat([x, n(self.gru4(x, h4_state)[0])], -1)
|
||||
x = torch.cat([x, n(self.glu4(n(self.gru4(x, h4_state)[0])))], -1)
|
||||
x = torch.cat([x, n(self.conv4(x))], -1)
|
||||
x = torch.cat([x, n(self.gru5(x, h5_state)[0])], -1)
|
||||
x = torch.cat([x, n(self.glu5(n(self.gru5(x, h5_state)[0])))], -1)
|
||||
x = torch.cat([x, n(self.conv5(x))], -1)
|
||||
|
||||
# output layer and reshaping
|
||||
|
@ -490,6 +559,10 @@ class RDOVAE(nn.Module):
|
|||
if not type(self.weight_clip_fn) == type(None):
|
||||
self.apply(self.weight_clip_fn)
|
||||
|
||||
def sparsify(self):
|
||||
#self.core_encoder.module.sparsify()
|
||||
self.core_decoder.module.sparsify()
|
||||
|
||||
def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
|
||||
|
||||
enc_stride = self.enc_stride
|
||||
|
|
|
@ -84,7 +84,7 @@ sequence_length = args.sequence_length
|
|||
lr_decay_factor = args.lr_decay_factor
|
||||
split_mode = args.split_mode
|
||||
# not exposed
|
||||
adam_betas = [0.9, 0.99]
|
||||
adam_betas = [0.8, 0.95]
|
||||
adam_eps = 1e-8
|
||||
|
||||
checkpoint['batch_size'] = batch_size
|
||||
|
@ -239,6 +239,7 @@ if __name__ == '__main__':
|
|||
optimizer.step()
|
||||
|
||||
model.clip_weights()
|
||||
model.sparsify()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
#define DRED_EXTENSION_ID 126
|
||||
|
||||
/* Remove these two completely once DRED gets an extension number assigned. */
|
||||
#define DRED_EXPERIMENTAL_VERSION 7
|
||||
#define DRED_EXPERIMENTAL_VERSION 8
|
||||
#define DRED_EXPERIMENTAL_BYTES 2
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue