added soft quantization to RDOVAE and FARGAN

This commit is contained in:
Jan Buethe 2025-03-08 13:03:02 -08:00 committed by Jean-Marc Valin
parent ebccedd918
commit 1ca6933ac4
No known key found for this signature in database
GPG key ID: 5E5DD9A36F9189C8
5 changed files with 95 additions and 38 deletions

View file

@ -1,3 +1,6 @@
import os
import sys
import numpy as np
import torch
from torch import nn
@ -7,6 +10,11 @@ from torch.nn.utils import weight_norm
#from convert_lsp import lpc_to_lsp, lsp_to_lpc
from rc import lpc2rc, rc2lpc
source_dir = os.path.split(os.path.abspath(__file__))[0]
sys.path.append(os.path.join(source_dir, "../dnntools"))
from dnntools.quantization import soft_quant
Fs = 16000
fid_dict = {}
@ -102,13 +110,16 @@ def gen_phase_embedding(periods, frame_size):
return torch.cos(embed), torch.sin(embed)
class GLU(nn.Module):
def __init__(self, feat_size):
def __init__(self, feat_size, softquant=False):
super(GLU, self).__init__()
torch.manual_seed(5)
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
if softquant:
self.gate = soft_quant(self.gate)
self.init_weights()
def init_weights(self):
@ -125,7 +136,7 @@ class GLU(nn.Module):
return out
class FWConv(nn.Module):
def __init__(self, in_size, out_size, kernel_size=2):
def __init__(self, in_size, out_size, kernel_size=2, softquant=False):
super(FWConv, self).__init__()
torch.manual_seed(5)
@ -133,7 +144,10 @@ class FWConv(nn.Module):
self.in_size = in_size
self.kernel_size = kernel_size
self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
self.glu = GLU(out_size)
self.glu = GLU(out_size, softquant=softquant)
if softquant:
self.conv = soft_quant(self.conv)
self.init_weights()
@ -154,7 +168,7 @@ def n(x):
return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
class FARGANCond(nn.Module):
def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12):
def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12, softquant=False):
super(FARGANCond, self).__init__()
self.feature_dim = feature_dim
@ -165,6 +179,10 @@ class FARGANCond(nn.Module):
self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
self.fdense2 = nn.Linear(128, 80*4, bias=False)
if softquant:
self.fconv1 = soft_quant(self.fconv1)
self.fdense2 = soft_quant(self.fdense2)
self.apply(init_weights)
nb_params = sum(p.numel() for p in self.parameters())
print(f"cond model: {nb_params} weights")
@ -183,7 +201,7 @@ class FARGANCond(nn.Module):
return tmp
class FARGANSub(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256):
def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, softquant=False):
super(FARGANSub, self).__init__()
self.subframe_size = subframe_size
@ -192,21 +210,27 @@ class FARGANSub(nn.Module):
self.cond_gain_dense = nn.Linear(80, 1)
#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
self.fwc0 = FWConv(2*self.subframe_size+80+4, 192)
self.fwc0 = FWConv(2*self.subframe_size+80+4, 192, softquant=softquant)
self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False)
self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False)
self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)
self.gru1_glu = GLU(160)
self.gru2_glu = GLU(128)
self.gru3_glu = GLU(128)
self.skip_glu = GLU(128)
self.gru1_glu = GLU(160, softquant=softquant)
self.gru2_glu = GLU(128, softquant=softquant)
self.gru3_glu = GLU(128, softquant=softquant)
self.skip_glu = GLU(128, softquant=softquant)
#self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False)
self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False)
self.gain_dense_out = nn.Linear(192, 4)
if softquant:
self.gru1 = soft_quant(self.gru1, names=['weight_hh', 'weight_ih'])
self.gru2 = soft_quant(self.gru2, names=['weight_hh', 'weight_ih'])
self.gru3 = soft_quant(self.gru3, names=['weight_hh', 'weight_ih'])
self.skip_dense = soft_quant(self.skip_dense)
self.sig_dense_out = soft_quant(self.sig_dense_out)
self.apply(init_weights)
nb_params = sum(p.numel() for p in self.parameters())
@ -271,7 +295,7 @@ class FARGANSub(nn.Module):
return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state)
class FARGAN(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None, softquant=False):
super(FARGAN, self).__init__()
self.subframe_size = subframe_size
@ -280,8 +304,8 @@ class FARGAN(nn.Module):
self.feature_dim = feature_dim
self.cond_size = cond_size
self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size)
self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size)
self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size, softquant=softquant)
self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, softquant=softquant)
def forward(self, features, period, nb_frames, pre=None, states=None):
device = features.device