Using sparse GRUs in DRED decoder

Saves ~270 kB of weights in the decoder
This commit is contained in:
Jean-Marc Valin 2023-11-15 04:08:50 -05:00
parent 58923f61c2
commit b0620c0bf9
No known key found for this signature in database
GPG key ID: 531A52533318F00A
7 changed files with 105 additions and 19 deletions

View file

@ -9,7 +9,7 @@ set -e
srcdir=`dirname $0`
test -n "$srcdir" && cd "$srcdir"
dnn/download_model.sh b6095cf
dnn/download_model.sh 58923f6
echo "Updating build configuration files, please wait...."

View file

@ -98,35 +98,35 @@ void dred_rdovae_decode_qframe(
output_index += DEC_DENSE1_OUT_SIZE;
compute_generic_gru(&model->dec_gru1_input, &model->dec_gru1_recurrent, dec_state->gru1_state, buffer);
OPUS_COPY(&buffer[output_index], dec_state->gru1_state, DEC_GRU1_OUT_SIZE);
compute_glu(&model->dec_glu1, &buffer[output_index], dec_state->gru1_state);
output_index += DEC_GRU1_OUT_SIZE;
conv1_cond_init(dec_state->conv1_state, output_index, 1, &dec_state->initialized);
compute_generic_conv1d(&model->dec_conv1, &buffer[output_index], dec_state->conv1_state, buffer, output_index, ACTIVATION_TANH);
output_index += DEC_CONV1_OUT_SIZE;
compute_generic_gru(&model->dec_gru2_input, &model->dec_gru2_recurrent, dec_state->gru2_state, buffer);
OPUS_COPY(&buffer[output_index], dec_state->gru2_state, DEC_GRU2_OUT_SIZE);
compute_glu(&model->dec_glu2, &buffer[output_index], dec_state->gru2_state);
output_index += DEC_GRU2_OUT_SIZE;
conv1_cond_init(dec_state->conv2_state, output_index, 1, &dec_state->initialized);
compute_generic_conv1d(&model->dec_conv2, &buffer[output_index], dec_state->conv2_state, buffer, output_index, ACTIVATION_TANH);
output_index += DEC_CONV2_OUT_SIZE;
compute_generic_gru(&model->dec_gru3_input, &model->dec_gru3_recurrent, dec_state->gru3_state, buffer);
OPUS_COPY(&buffer[output_index], dec_state->gru3_state, DEC_GRU3_OUT_SIZE);
compute_glu(&model->dec_glu3, &buffer[output_index], dec_state->gru3_state);
output_index += DEC_GRU3_OUT_SIZE;
conv1_cond_init(dec_state->conv3_state, output_index, 1, &dec_state->initialized);
compute_generic_conv1d(&model->dec_conv3, &buffer[output_index], dec_state->conv3_state, buffer, output_index, ACTIVATION_TANH);
output_index += DEC_CONV3_OUT_SIZE;
compute_generic_gru(&model->dec_gru4_input, &model->dec_gru4_recurrent, dec_state->gru4_state, buffer);
OPUS_COPY(&buffer[output_index], dec_state->gru4_state, DEC_GRU4_OUT_SIZE);
compute_glu(&model->dec_glu4, &buffer[output_index], dec_state->gru4_state);
output_index += DEC_GRU4_OUT_SIZE;
conv1_cond_init(dec_state->conv4_state, output_index, 1, &dec_state->initialized);
compute_generic_conv1d(&model->dec_conv4, &buffer[output_index], dec_state->conv4_state, buffer, output_index, ACTIVATION_TANH);
output_index += DEC_CONV4_OUT_SIZE;
compute_generic_gru(&model->dec_gru5_input, &model->dec_gru5_recurrent, dec_state->gru5_state, buffer);
OPUS_COPY(&buffer[output_index], dec_state->gru5_state, DEC_GRU5_OUT_SIZE);
compute_glu(&model->dec_glu5, &buffer[output_index], dec_state->gru5_state);
output_index += DEC_GRU5_OUT_SIZE;
conv1_cond_init(dec_state->conv5_state, output_index, 1, &dec_state->initialized);
compute_generic_conv1d(&model->dec_conv5, &buffer[output_index], dec_state->conv5_state, buffer, output_index, ACTIVATION_TANH);

View file

@ -29,7 +29,7 @@
import torch
def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False):
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
""" sparsifies matrix with specified block size
Parameters:

View file

@ -226,6 +226,11 @@ f"""
# decoder
decoder_dense_layers = [
('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False),
('core_decoder.module.glu1.gate' , 'dec_glu1', 'TANH', True),
('core_decoder.module.glu2.gate' , 'dec_glu2', 'TANH', True),
('core_decoder.module.glu3.gate' , 'dec_glu3', 'TANH', True),
('core_decoder.module.glu4.gate' , 'dec_glu4', 'TANH', True),
('core_decoder.module.glu5.gate' , 'dec_glu5', 'TANH', True),
('core_decoder.module.output' , 'dec_output', 'LINEAR', True),
('core_decoder.module.hidden_init' , 'dec_hidden_init', 'TANH', False),
('core_decoder.module.gru_init' , 'dec_gru_init','TANH', True),
@ -338,6 +343,13 @@ if __name__ == "__main__":
checkpoint = torch.load(args.checkpoint, map_location='cpu')
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
model.apply(_remove_weight_norm)
if len(missing_keys) > 0:
raise ValueError(f"error: missing keys in state dict")

View file

@ -34,6 +34,12 @@ import math as m
import torch
from torch import nn
import torch.nn.functional as F
import sys
import os
source_dir = os.path.split(os.path.abspath(__file__))[0]
sys.path.append(os.path.join(source_dir, "../../lpcnet/"))
from utils.sparsification import GRUSparsifier
from torch.nn.utils import weight_norm
# Quantization and rate related utily functions
@ -227,6 +233,32 @@ def n(x):
# RDOVAE module and submodules
sparsify_start = 12000
sparsify_stop = 24000
sparsify_interval = 100
sparsify_exponent = 3
#sparsify_start = 0
#sparsify_stop = 0
sparse_params1 = {
# 'W_hr' : (1.0, [8, 4], True),
# 'W_hz' : (1.0, [8, 4], True),
# 'W_hn' : (1.0, [8, 4], True),
'W_ir' : (0.6, [8, 4], False),
'W_iz' : (0.4, [8, 4], False),
'W_in' : (0.8, [8, 4], False)
}
sparse_params2 = {
# 'W_hr' : (1.0, [8, 4], True),
# 'W_hz' : (1.0, [8, 4], True),
# 'W_hn' : (1.0, [8, 4], True),
'W_ir' : (0.3, [8, 4], False),
'W_iz' : (0.2, [8, 4], False),
'W_in' : (0.4, [8, 4], False)
}
class MyConv(nn.Module):
def __init__(self, input_dim, output_dim, dilation=1):
super(MyConv, self).__init__()
@ -239,6 +271,29 @@ class MyConv(nn.Module):
conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1)
return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
class GLU(nn.Module):
def __init__(self, feat_size):
super(GLU, self).__init__()
torch.manual_seed(5)
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
out = x * torch.sigmoid(self.gate(x))
return out
class CoreEncoder(nn.Module):
STATE_HIDDEN = 128
FRAMES_PER_STEP = 2
@ -355,7 +410,11 @@ class CoreDecoder(nn.Module):
self.gru5 = nn.GRU(608, 96, batch_first=True)
self.conv5 = MyConv(704, 32)
self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim)
self.glu1 = GLU(96)
self.glu2 = GLU(96)
self.glu3 = GLU(96)
self.glu4 = GLU(96)
self.glu5 = GLU(96)
self.hidden_init = nn.Linear(self.state_size, 128)
self.gru_init = nn.Linear(128, 480)
@ -363,6 +422,16 @@ class CoreDecoder(nn.Module):
print(f"decoder: {nb_params} weights")
# initialize weights
self.apply(init_weights)
self.sparsifier = []
self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
def sparsify(self):
for sparsifier in self.sparsifier:
sparsifier.step()
def forward(self, z, initial_state):
@ -377,15 +446,15 @@ class CoreDecoder(nn.Module):
# run decoding layer stack
x = n(torch.tanh(self.dense_1(z)))
x = torch.cat([x, n(self.gru1(x, h1_state)[0])], -1)
x = torch.cat([x, n(self.glu1(n(self.gru1(x, h1_state)[0])))], -1)
x = torch.cat([x, n(self.conv1(x))], -1)
x = torch.cat([x, n(self.gru2(x, h2_state)[0])], -1)
x = torch.cat([x, n(self.glu2(n(self.gru2(x, h2_state)[0])))], -1)
x = torch.cat([x, n(self.conv2(x))], -1)
x = torch.cat([x, n(self.gru3(x, h3_state)[0])], -1)
x = torch.cat([x, n(self.glu3(n(self.gru3(x, h3_state)[0])))], -1)
x = torch.cat([x, n(self.conv3(x))], -1)
x = torch.cat([x, n(self.gru4(x, h4_state)[0])], -1)
x = torch.cat([x, n(self.glu4(n(self.gru4(x, h4_state)[0])))], -1)
x = torch.cat([x, n(self.conv4(x))], -1)
x = torch.cat([x, n(self.gru5(x, h5_state)[0])], -1)
x = torch.cat([x, n(self.glu5(n(self.gru5(x, h5_state)[0])))], -1)
x = torch.cat([x, n(self.conv5(x))], -1)
# output layer and reshaping
@ -490,6 +559,10 @@ class RDOVAE(nn.Module):
if not type(self.weight_clip_fn) == type(None):
self.apply(self.weight_clip_fn)
def sparsify(self):
#self.core_encoder.module.sparsify()
self.core_decoder.module.sparsify()
def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
enc_stride = self.enc_stride

View file

@ -84,7 +84,7 @@ sequence_length = args.sequence_length
lr_decay_factor = args.lr_decay_factor
split_mode = args.split_mode
# not exposed
adam_betas = [0.9, 0.99]
adam_betas = [0.8, 0.95]
adam_eps = 1e-8
checkpoint['batch_size'] = batch_size
@ -239,6 +239,7 @@ if __name__ == '__main__':
optimizer.step()
model.clip_weights()
model.sparsify()
scheduler.step()

View file

@ -32,7 +32,7 @@
#define DRED_EXTENSION_ID 126
/* Remove these two completely once DRED gets an extension number assigned. */
#define DRED_EXPERIMENTAL_VERSION 7
#define DRED_EXPERIMENTAL_VERSION 8
#define DRED_EXPERIMENTAL_BYTES 2