From 1ca6933ac480c8584e5e47b4fea5ad2d857a336d Mon Sep 17 00:00:00 2001 From: Jan Buethe Date: Sat, 8 Mar 2025 13:03:02 -0800 Subject: [PATCH] added soft quantization to RDOVAE and FARGAN --- dnn/torch/fargan/adv_train_fargan.py | 3 +- dnn/torch/fargan/fargan.py | 50 ++++++++++++++----- dnn/torch/fargan/train_fargan.py | 3 +- dnn/torch/rdovae/rdovae/rdovae.py | 73 +++++++++++++++++++--------- dnn/torch/rdovae/train_rdovae.py | 4 +- 5 files changed, 95 insertions(+), 38 deletions(-) diff --git a/dnn/torch/fargan/adv_train_fargan.py b/dnn/torch/fargan/adv_train_fargan.py index c2977644..37816c38 100644 --- a/dnn/torch/fargan/adv_train_fargan.py +++ b/dnn/torch/fargan/adv_train_fargan.py @@ -44,6 +44,7 @@ parser.add_argument('--cuda-visible-devices', type=str, help="comma separates li model_group = parser.add_argument_group(title="model parameters") model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9) +model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training") training_group = parser.add_argument_group(title="training parameters") training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128) @@ -93,7 +94,7 @@ checkpoint['adam_betas'] = adam_betas device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") checkpoint['model_args'] = () -checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma} +checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant} print(checkpoint['model_kwargs']) model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs']) diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index 8dbb694d..4a5a41dc 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -1,3 +1,6 @@ +import os +import sys + import numpy as np import torch from torch import nn @@ -7,6 +10,11 @@ from torch.nn.utils import weight_norm #from convert_lsp import lpc_to_lsp, lsp_to_lpc from rc import lpc2rc, rc2lpc +source_dir = os.path.split(os.path.abspath(__file__))[0] +sys.path.append(os.path.join(source_dir, "../dnntools")) +from dnntools.quantization import soft_quant + + Fs = 16000 fid_dict = {} @@ -102,13 +110,16 @@ def gen_phase_embedding(periods, frame_size): return torch.cos(embed), torch.sin(embed) class GLU(nn.Module): - def __init__(self, feat_size): + def __init__(self, feat_size, softquant=False): super(GLU, self).__init__() torch.manual_seed(5) self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) + if softquant: + self.gate = soft_quant(self.gate) + self.init_weights() def init_weights(self): @@ -125,7 +136,7 @@ class GLU(nn.Module): return out class FWConv(nn.Module): - def __init__(self, in_size, out_size, kernel_size=2): + def __init__(self, in_size, out_size, kernel_size=2, softquant=False): super(FWConv, self).__init__() torch.manual_seed(5) @@ -133,7 +144,10 @@ class FWConv(nn.Module): self.in_size = in_size self.kernel_size = kernel_size self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False)) - self.glu = GLU(out_size) + self.glu = GLU(out_size, softquant=softquant) + + if softquant: + self.conv = soft_quant(self.conv) self.init_weights() @@ -154,7 +168,7 @@ def n(x): return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.) class FARGANCond(nn.Module): - def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12): + def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12, softquant=False): super(FARGANCond, self).__init__() self.feature_dim = feature_dim @@ -165,6 +179,10 @@ class FARGANCond(nn.Module): self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False) self.fdense2 = nn.Linear(128, 80*4, bias=False) + if softquant: + self.fconv1 = soft_quant(self.fconv1) + self.fdense2 = soft_quant(self.fdense2) + self.apply(init_weights) nb_params = sum(p.numel() for p in self.parameters()) print(f"cond model: {nb_params} weights") @@ -183,7 +201,7 @@ class FARGANCond(nn.Module): return tmp class FARGANSub(nn.Module): - def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256): + def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, softquant=False): super(FARGANSub, self).__init__() self.subframe_size = subframe_size @@ -192,21 +210,27 @@ class FARGANSub(nn.Module): self.cond_gain_dense = nn.Linear(80, 1) #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) - self.fwc0 = FWConv(2*self.subframe_size+80+4, 192) + self.fwc0 = FWConv(2*self.subframe_size+80+4, 192, softquant=softquant) self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False) self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False) self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False) - self.gru1_glu = GLU(160) - self.gru2_glu = GLU(128) - self.gru3_glu = GLU(128) - self.skip_glu = GLU(128) + self.gru1_glu = GLU(160, softquant=softquant) + self.gru2_glu = GLU(128, softquant=softquant) + self.gru3_glu = GLU(128, softquant=softquant) + self.skip_glu = GLU(128, softquant=softquant) #self.ptaps_dense = nn.Linear(4*self.cond_size, 5) self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False) self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False) self.gain_dense_out = nn.Linear(192, 4) + if softquant: + self.gru1 = soft_quant(self.gru1, names=['weight_hh', 'weight_ih']) + self.gru2 = soft_quant(self.gru2, names=['weight_hh', 'weight_ih']) + self.gru3 = soft_quant(self.gru3, names=['weight_hh', 'weight_ih']) + self.skip_dense = soft_quant(self.skip_dense) + self.sig_dense_out = soft_quant(self.sig_dense_out) self.apply(init_weights) nb_params = sum(p.numel() for p in self.parameters()) @@ -271,7 +295,7 @@ class FARGANSub(nn.Module): return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state) class FARGAN(nn.Module): - def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None): + def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None, softquant=False): super(FARGAN, self).__init__() self.subframe_size = subframe_size @@ -280,8 +304,8 @@ class FARGAN(nn.Module): self.feature_dim = feature_dim self.cond_size = cond_size - self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size) - self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size) + self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size, softquant=softquant) + self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, softquant=softquant) def forward(self, features, period, nb_frames, pre=None, states=None): device = features.device diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py index 1b2e2009..4846f995 100644 --- a/dnn/torch/fargan/train_fargan.py +++ b/dnn/torch/fargan/train_fargan.py @@ -25,6 +25,7 @@ parser.add_argument('--cuda-visible-devices', type=str, help="comma separates li model_group = parser.add_argument_group(title="model parameters") model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9) +model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training") training_group = parser.add_argument_group(title="training parameters") training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512) @@ -72,7 +73,7 @@ checkpoint['adam_betas'] = adam_betas device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") checkpoint['model_args'] = () -checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma} +checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant} print(checkpoint['model_kwargs']) model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs']) diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py index cdb07b46..450cc7b4 100644 --- a/dnn/torch/rdovae/rdovae/rdovae.py +++ b/dnn/torch/rdovae/rdovae/rdovae.py @@ -40,6 +40,8 @@ source_dir = os.path.split(os.path.abspath(__file__))[0] sys.path.append(os.path.join(source_dir, "../../lpcnet/")) from utils.sparsification import GRUSparsifier from torch.nn.utils import weight_norm +sys.path.append(os.path.join(source_dir, "../../dnntools")) +from dnntools.quantization import soft_quant # Quantization and rate related utily functions @@ -260,25 +262,32 @@ sparse_params2 = { class MyConv(nn.Module): - def __init__(self, input_dim, output_dim, dilation=1): + def __init__(self, input_dim, output_dim, dilation=1, softquant=False): super(MyConv, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.dilation=dilation self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation) + + if softquant: + self.conv = soft_quant(self.conv) + def forward(self, x, state=None): device = x.device conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1) return torch.tanh(self.conv(conv_in)).permute(0, 2, 1) class GLU(nn.Module): - def __init__(self, feat_size): + def __init__(self, feat_size, softquant=False): super(GLU, self).__init__() torch.manual_seed(5) self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) + if softquant: + self.gate = soft_quant(self.gate) + self.init_weights() def init_weights(self): @@ -299,7 +308,7 @@ class CoreEncoder(nn.Module): FRAMES_PER_STEP = 2 CONV_KERNEL_SIZE = 4 - def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24): + def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24, softquant=False): """ core encoder for RDOVAE Computes latents, initial states, and rate estimates from features and lambda parameter @@ -321,15 +330,15 @@ class CoreEncoder(nn.Module): # layers self.dense_1 = nn.Linear(self.input_dim, 64) self.gru1 = nn.GRU(64, 64, batch_first=True) - self.conv1 = MyConv(128, 96) + self.conv1 = MyConv(128, 96, softquant=True) self.gru2 = nn.GRU(224, 64, batch_first=True) - self.conv2 = MyConv(288, 96, dilation=2) + self.conv2 = MyConv(288, 96, dilation=2, softquant=True) self.gru3 = nn.GRU(384, 64, batch_first=True) - self.conv3 = MyConv(448, 96, dilation=2) + self.conv3 = MyConv(448, 96, dilation=2, softquant=True) self.gru4 = nn.GRU(544, 64, batch_first=True) - self.conv4 = MyConv(608, 96, dilation=2) + self.conv4 = MyConv(608, 96, dilation=2, softquant=True) self.gru5 = nn.GRU(704, 64, batch_first=True) - self.conv5 = MyConv(768, 96, dilation=2) + self.conv5 = MyConv(768, 96, dilation=2, softquant=True) self.z_dense = nn.Linear(864, self.output_dim) @@ -343,6 +352,16 @@ class CoreEncoder(nn.Module): # initialize weights self.apply(init_weights) + if softquant: + self.gru1 = soft_quant(self.gru1, names=['weight_hh_l0', 'weight_ih_l0']) + self.gru2 = soft_quant(self.gru2, names=['weight_hh_l0', 'weight_ih_l0']) + self.gru3 = soft_quant(self.gru3, names=['weight_hh_l0', 'weight_ih_l0']) + self.gru4 = soft_quant(self.gru4, names=['weight_hh_l0', 'weight_ih_l0']) + self.gru5 = soft_quant(self.gru5, names=['weight_hh_l0', 'weight_ih_l0']) + self.z_dense = soft_quant(self.z_dense) + self.state_dense_1 = soft_quant(self.state_dense_1) + self.state_dense_2 = soft_quant(self.state_dense_2) + def forward(self, features): @@ -379,7 +398,7 @@ class CoreDecoder(nn.Module): FRAMES_PER_STEP = 4 - def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24): + def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24, softquant=False): """ core decoder for RDOVAE Computes features from latents, initial state, and quantization index @@ -400,21 +419,21 @@ class CoreDecoder(nn.Module): # layers self.dense_1 = nn.Linear(self.input_size, 96) self.gru1 = nn.GRU(96, 96, batch_first=True) - self.conv1 = MyConv(192, 32) + self.conv1 = MyConv(192, 32, softquant=softquant) self.gru2 = nn.GRU(224, 96, batch_first=True) - self.conv2 = MyConv(320, 32) + self.conv2 = MyConv(320, 32, softquant=softquant) self.gru3 = nn.GRU(352, 96, batch_first=True) - self.conv3 = MyConv(448, 32) + self.conv3 = MyConv(448, 32, softquant=softquant) self.gru4 = nn.GRU(480, 96, batch_first=True) - self.conv4 = MyConv(576, 32) + self.conv4 = MyConv(576, 32, softquant=softquant) self.gru5 = nn.GRU(608, 96, batch_first=True) - self.conv5 = MyConv(704, 32) + self.conv5 = MyConv(704, 32, softquant=softquant) self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim) - self.glu1 = GLU(96) - self.glu2 = GLU(96) - self.glu3 = GLU(96) - self.glu4 = GLU(96) - self.glu5 = GLU(96) + self.glu1 = GLU(96, softquant=softquant) + self.glu2 = GLU(96, softquant=softquant) + self.glu3 = GLU(96, softquant=softquant) + self.glu4 = GLU(96, softquant=softquant) + self.glu5 = GLU(96, softquant=softquant) self.hidden_init = nn.Linear(self.state_size, 128) self.gru_init = nn.Linear(128, 480) @@ -429,6 +448,15 @@ class CoreDecoder(nn.Module): self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) + if softquant: + self.gru1 = soft_quant(self.gru1, names=['weight_hh_l0', 'weight_ih_l0']) + self.gru2 = soft_quant(self.gru2, names=['weight_hh_l0', 'weight_ih_l0']) + self.gru3 = soft_quant(self.gru3, names=['weight_hh_l0', 'weight_ih_l0']) + self.gru4 = soft_quant(self.gru4, names=['weight_hh_l0', 'weight_ih_l0']) + self.gru5 = soft_quant(self.gru5, names=['weight_hh_l0', 'weight_ih_l0']) + self.output = soft_quant(self.output) + self.gru_init = soft_quant(self.gru_init) + def sparsify(self): for sparsifier in self.sparsifier: sparsifier.step() @@ -525,7 +553,8 @@ class RDOVAE(nn.Module): split_mode='split', clip_weights=False, pvq_num_pulses=82, - state_dropout_rate=0): + state_dropout_rate=0, + softquant=False): super(RDOVAE, self).__init__() @@ -541,8 +570,8 @@ class RDOVAE(nn.Module): # submodules encoder and decoder share the statistical model self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim) - self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim)) - self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim)) + self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim, softquant=softquant)) + self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim, softquant=softquant)) self.enc_stride = CoreEncoder.FRAMES_PER_STEP self.dec_stride = CoreDecoder.FRAMES_PER_STEP diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py index d9a43b33..543e326f 100644 --- a/dnn/torch/rdovae/train_rdovae.py +++ b/dnn/torch/rdovae/train_rdovae.py @@ -54,6 +54,7 @@ model_group.add_argument('--lambda-min', type=float, help="minimal value for rat model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104) model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82) model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0) +model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training") training_group = parser.add_argument_group(title="training parameters") training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32) @@ -109,6 +110,7 @@ quant_levels = args.quant_levels lambda_min = args.lambda_min lambda_max = args.lambda_max state_dim = args.state_dim +softquant = args.softquant # not expsed num_features = 20 @@ -118,7 +120,7 @@ feature_file = args.features # model checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2) -checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate} +checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate, 'softquant': softquant} model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) if type(args.initial_checkpoint) != type(None):