added soft quantization to RDOVAE and FARGAN

This commit is contained in:
Jan Buethe 2025-03-08 13:03:02 -08:00 committed by Jean-Marc Valin
parent ebccedd918
commit 1ca6933ac4
No known key found for this signature in database
GPG key ID: 5E5DD9A36F9189C8
5 changed files with 95 additions and 38 deletions

View file

@ -44,6 +44,7 @@ parser.add_argument('--cuda-visible-devices', type=str, help="comma separates li
model_group = parser.add_argument_group(title="model parameters") model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9) model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
training_group = parser.add_argument_group(title="training parameters") training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128) training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128)
@ -93,7 +94,7 @@ checkpoint['adam_betas'] = adam_betas
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint['model_args'] = () checkpoint['model_args'] = ()
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma} checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant}
print(checkpoint['model_kwargs']) print(checkpoint['model_kwargs'])
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs']) model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])

View file

@ -1,3 +1,6 @@
import os
import sys
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
@ -7,6 +10,11 @@ from torch.nn.utils import weight_norm
#from convert_lsp import lpc_to_lsp, lsp_to_lpc #from convert_lsp import lpc_to_lsp, lsp_to_lpc
from rc import lpc2rc, rc2lpc from rc import lpc2rc, rc2lpc
source_dir = os.path.split(os.path.abspath(__file__))[0]
sys.path.append(os.path.join(source_dir, "../dnntools"))
from dnntools.quantization import soft_quant
Fs = 16000 Fs = 16000
fid_dict = {} fid_dict = {}
@ -102,13 +110,16 @@ def gen_phase_embedding(periods, frame_size):
return torch.cos(embed), torch.sin(embed) return torch.cos(embed), torch.sin(embed)
class GLU(nn.Module): class GLU(nn.Module):
def __init__(self, feat_size): def __init__(self, feat_size, softquant=False):
super(GLU, self).__init__() super(GLU, self).__init__()
torch.manual_seed(5) torch.manual_seed(5)
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
if softquant:
self.gate = soft_quant(self.gate)
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self):
@ -125,7 +136,7 @@ class GLU(nn.Module):
return out return out
class FWConv(nn.Module): class FWConv(nn.Module):
def __init__(self, in_size, out_size, kernel_size=2): def __init__(self, in_size, out_size, kernel_size=2, softquant=False):
super(FWConv, self).__init__() super(FWConv, self).__init__()
torch.manual_seed(5) torch.manual_seed(5)
@ -133,7 +144,10 @@ class FWConv(nn.Module):
self.in_size = in_size self.in_size = in_size
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False)) self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
self.glu = GLU(out_size) self.glu = GLU(out_size, softquant=softquant)
if softquant:
self.conv = soft_quant(self.conv)
self.init_weights() self.init_weights()
@ -154,7 +168,7 @@ def n(x):
return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.) return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
class FARGANCond(nn.Module): class FARGANCond(nn.Module):
def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12): def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12, softquant=False):
super(FARGANCond, self).__init__() super(FARGANCond, self).__init__()
self.feature_dim = feature_dim self.feature_dim = feature_dim
@ -165,6 +179,10 @@ class FARGANCond(nn.Module):
self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False) self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
self.fdense2 = nn.Linear(128, 80*4, bias=False) self.fdense2 = nn.Linear(128, 80*4, bias=False)
if softquant:
self.fconv1 = soft_quant(self.fconv1)
self.fdense2 = soft_quant(self.fdense2)
self.apply(init_weights) self.apply(init_weights)
nb_params = sum(p.numel() for p in self.parameters()) nb_params = sum(p.numel() for p in self.parameters())
print(f"cond model: {nb_params} weights") print(f"cond model: {nb_params} weights")
@ -183,7 +201,7 @@ class FARGANCond(nn.Module):
return tmp return tmp
class FARGANSub(nn.Module): class FARGANSub(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256): def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, softquant=False):
super(FARGANSub, self).__init__() super(FARGANSub, self).__init__()
self.subframe_size = subframe_size self.subframe_size = subframe_size
@ -192,21 +210,27 @@ class FARGANSub(nn.Module):
self.cond_gain_dense = nn.Linear(80, 1) self.cond_gain_dense = nn.Linear(80, 1)
#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
self.fwc0 = FWConv(2*self.subframe_size+80+4, 192) self.fwc0 = FWConv(2*self.subframe_size+80+4, 192, softquant=softquant)
self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False) self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False)
self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False) self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False)
self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False) self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)
self.gru1_glu = GLU(160) self.gru1_glu = GLU(160, softquant=softquant)
self.gru2_glu = GLU(128) self.gru2_glu = GLU(128, softquant=softquant)
self.gru3_glu = GLU(128) self.gru3_glu = GLU(128, softquant=softquant)
self.skip_glu = GLU(128) self.skip_glu = GLU(128, softquant=softquant)
#self.ptaps_dense = nn.Linear(4*self.cond_size, 5) #self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False) self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False)
self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False) self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False)
self.gain_dense_out = nn.Linear(192, 4) self.gain_dense_out = nn.Linear(192, 4)
if softquant:
self.gru1 = soft_quant(self.gru1, names=['weight_hh', 'weight_ih'])
self.gru2 = soft_quant(self.gru2, names=['weight_hh', 'weight_ih'])
self.gru3 = soft_quant(self.gru3, names=['weight_hh', 'weight_ih'])
self.skip_dense = soft_quant(self.skip_dense)
self.sig_dense_out = soft_quant(self.sig_dense_out)
self.apply(init_weights) self.apply(init_weights)
nb_params = sum(p.numel() for p in self.parameters()) nb_params = sum(p.numel() for p in self.parameters())
@ -271,7 +295,7 @@ class FARGANSub(nn.Module):
return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state) return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state)
class FARGAN(nn.Module): class FARGAN(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None): def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None, softquant=False):
super(FARGAN, self).__init__() super(FARGAN, self).__init__()
self.subframe_size = subframe_size self.subframe_size = subframe_size
@ -280,8 +304,8 @@ class FARGAN(nn.Module):
self.feature_dim = feature_dim self.feature_dim = feature_dim
self.cond_size = cond_size self.cond_size = cond_size
self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size) self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size, softquant=softquant)
self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size) self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, softquant=softquant)
def forward(self, features, period, nb_frames, pre=None, states=None): def forward(self, features, period, nb_frames, pre=None, states=None):
device = features.device device = features.device

View file

@ -25,6 +25,7 @@ parser.add_argument('--cuda-visible-devices', type=str, help="comma separates li
model_group = parser.add_argument_group(title="model parameters") model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9) model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
training_group = parser.add_argument_group(title="training parameters") training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512) training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
@ -72,7 +73,7 @@ checkpoint['adam_betas'] = adam_betas
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint['model_args'] = () checkpoint['model_args'] = ()
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma} checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant}
print(checkpoint['model_kwargs']) print(checkpoint['model_kwargs'])
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs']) model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])

View file

@ -40,6 +40,8 @@ source_dir = os.path.split(os.path.abspath(__file__))[0]
sys.path.append(os.path.join(source_dir, "../../lpcnet/")) sys.path.append(os.path.join(source_dir, "../../lpcnet/"))
from utils.sparsification import GRUSparsifier from utils.sparsification import GRUSparsifier
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm
sys.path.append(os.path.join(source_dir, "../../dnntools"))
from dnntools.quantization import soft_quant
# Quantization and rate related utily functions # Quantization and rate related utily functions
@ -260,25 +262,32 @@ sparse_params2 = {
class MyConv(nn.Module): class MyConv(nn.Module):
def __init__(self, input_dim, output_dim, dilation=1): def __init__(self, input_dim, output_dim, dilation=1, softquant=False):
super(MyConv, self).__init__() super(MyConv, self).__init__()
self.input_dim = input_dim self.input_dim = input_dim
self.output_dim = output_dim self.output_dim = output_dim
self.dilation=dilation self.dilation=dilation
self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation) self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation)
if softquant:
self.conv = soft_quant(self.conv)
def forward(self, x, state=None): def forward(self, x, state=None):
device = x.device device = x.device
conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1) conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1)
return torch.tanh(self.conv(conv_in)).permute(0, 2, 1) return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
class GLU(nn.Module): class GLU(nn.Module):
def __init__(self, feat_size): def __init__(self, feat_size, softquant=False):
super(GLU, self).__init__() super(GLU, self).__init__()
torch.manual_seed(5) torch.manual_seed(5)
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
if softquant:
self.gate = soft_quant(self.gate)
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self):
@ -299,7 +308,7 @@ class CoreEncoder(nn.Module):
FRAMES_PER_STEP = 2 FRAMES_PER_STEP = 2
CONV_KERNEL_SIZE = 4 CONV_KERNEL_SIZE = 4
def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24): def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24, softquant=False):
""" core encoder for RDOVAE """ core encoder for RDOVAE
Computes latents, initial states, and rate estimates from features and lambda parameter Computes latents, initial states, and rate estimates from features and lambda parameter
@ -321,15 +330,15 @@ class CoreEncoder(nn.Module):
# layers # layers
self.dense_1 = nn.Linear(self.input_dim, 64) self.dense_1 = nn.Linear(self.input_dim, 64)
self.gru1 = nn.GRU(64, 64, batch_first=True) self.gru1 = nn.GRU(64, 64, batch_first=True)
self.conv1 = MyConv(128, 96) self.conv1 = MyConv(128, 96, softquant=True)
self.gru2 = nn.GRU(224, 64, batch_first=True) self.gru2 = nn.GRU(224, 64, batch_first=True)
self.conv2 = MyConv(288, 96, dilation=2) self.conv2 = MyConv(288, 96, dilation=2, softquant=True)
self.gru3 = nn.GRU(384, 64, batch_first=True) self.gru3 = nn.GRU(384, 64, batch_first=True)
self.conv3 = MyConv(448, 96, dilation=2) self.conv3 = MyConv(448, 96, dilation=2, softquant=True)
self.gru4 = nn.GRU(544, 64, batch_first=True) self.gru4 = nn.GRU(544, 64, batch_first=True)
self.conv4 = MyConv(608, 96, dilation=2) self.conv4 = MyConv(608, 96, dilation=2, softquant=True)
self.gru5 = nn.GRU(704, 64, batch_first=True) self.gru5 = nn.GRU(704, 64, batch_first=True)
self.conv5 = MyConv(768, 96, dilation=2) self.conv5 = MyConv(768, 96, dilation=2, softquant=True)
self.z_dense = nn.Linear(864, self.output_dim) self.z_dense = nn.Linear(864, self.output_dim)
@ -343,6 +352,16 @@ class CoreEncoder(nn.Module):
# initialize weights # initialize weights
self.apply(init_weights) self.apply(init_weights)
if softquant:
self.gru1 = soft_quant(self.gru1, names=['weight_hh_l0', 'weight_ih_l0'])
self.gru2 = soft_quant(self.gru2, names=['weight_hh_l0', 'weight_ih_l0'])
self.gru3 = soft_quant(self.gru3, names=['weight_hh_l0', 'weight_ih_l0'])
self.gru4 = soft_quant(self.gru4, names=['weight_hh_l0', 'weight_ih_l0'])
self.gru5 = soft_quant(self.gru5, names=['weight_hh_l0', 'weight_ih_l0'])
self.z_dense = soft_quant(self.z_dense)
self.state_dense_1 = soft_quant(self.state_dense_1)
self.state_dense_2 = soft_quant(self.state_dense_2)
def forward(self, features): def forward(self, features):
@ -379,7 +398,7 @@ class CoreDecoder(nn.Module):
FRAMES_PER_STEP = 4 FRAMES_PER_STEP = 4
def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24): def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24, softquant=False):
""" core decoder for RDOVAE """ core decoder for RDOVAE
Computes features from latents, initial state, and quantization index Computes features from latents, initial state, and quantization index
@ -400,21 +419,21 @@ class CoreDecoder(nn.Module):
# layers # layers
self.dense_1 = nn.Linear(self.input_size, 96) self.dense_1 = nn.Linear(self.input_size, 96)
self.gru1 = nn.GRU(96, 96, batch_first=True) self.gru1 = nn.GRU(96, 96, batch_first=True)
self.conv1 = MyConv(192, 32) self.conv1 = MyConv(192, 32, softquant=softquant)
self.gru2 = nn.GRU(224, 96, batch_first=True) self.gru2 = nn.GRU(224, 96, batch_first=True)
self.conv2 = MyConv(320, 32) self.conv2 = MyConv(320, 32, softquant=softquant)
self.gru3 = nn.GRU(352, 96, batch_first=True) self.gru3 = nn.GRU(352, 96, batch_first=True)
self.conv3 = MyConv(448, 32) self.conv3 = MyConv(448, 32, softquant=softquant)
self.gru4 = nn.GRU(480, 96, batch_first=True) self.gru4 = nn.GRU(480, 96, batch_first=True)
self.conv4 = MyConv(576, 32) self.conv4 = MyConv(576, 32, softquant=softquant)
self.gru5 = nn.GRU(608, 96, batch_first=True) self.gru5 = nn.GRU(608, 96, batch_first=True)
self.conv5 = MyConv(704, 32) self.conv5 = MyConv(704, 32, softquant=softquant)
self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim) self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim)
self.glu1 = GLU(96) self.glu1 = GLU(96, softquant=softquant)
self.glu2 = GLU(96) self.glu2 = GLU(96, softquant=softquant)
self.glu3 = GLU(96) self.glu3 = GLU(96, softquant=softquant)
self.glu4 = GLU(96) self.glu4 = GLU(96, softquant=softquant)
self.glu5 = GLU(96) self.glu5 = GLU(96, softquant=softquant)
self.hidden_init = nn.Linear(self.state_size, 128) self.hidden_init = nn.Linear(self.state_size, 128)
self.gru_init = nn.Linear(128, 480) self.gru_init = nn.Linear(128, 480)
@ -429,6 +448,15 @@ class CoreDecoder(nn.Module):
self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
if softquant:
self.gru1 = soft_quant(self.gru1, names=['weight_hh_l0', 'weight_ih_l0'])
self.gru2 = soft_quant(self.gru2, names=['weight_hh_l0', 'weight_ih_l0'])
self.gru3 = soft_quant(self.gru3, names=['weight_hh_l0', 'weight_ih_l0'])
self.gru4 = soft_quant(self.gru4, names=['weight_hh_l0', 'weight_ih_l0'])
self.gru5 = soft_quant(self.gru5, names=['weight_hh_l0', 'weight_ih_l0'])
self.output = soft_quant(self.output)
self.gru_init = soft_quant(self.gru_init)
def sparsify(self): def sparsify(self):
for sparsifier in self.sparsifier: for sparsifier in self.sparsifier:
sparsifier.step() sparsifier.step()
@ -525,7 +553,8 @@ class RDOVAE(nn.Module):
split_mode='split', split_mode='split',
clip_weights=False, clip_weights=False,
pvq_num_pulses=82, pvq_num_pulses=82,
state_dropout_rate=0): state_dropout_rate=0,
softquant=False):
super(RDOVAE, self).__init__() super(RDOVAE, self).__init__()
@ -541,8 +570,8 @@ class RDOVAE(nn.Module):
# submodules encoder and decoder share the statistical model # submodules encoder and decoder share the statistical model
self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim) self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim)
self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim)) self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim, softquant=softquant))
self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim)) self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim, softquant=softquant))
self.enc_stride = CoreEncoder.FRAMES_PER_STEP self.enc_stride = CoreEncoder.FRAMES_PER_STEP
self.dec_stride = CoreDecoder.FRAMES_PER_STEP self.dec_stride = CoreDecoder.FRAMES_PER_STEP

View file

@ -54,6 +54,7 @@ model_group.add_argument('--lambda-min', type=float, help="minimal value for rat
model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104) model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104)
model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82) model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82)
model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0) model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0)
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
training_group = parser.add_argument_group(title="training parameters") training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32) training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
@ -109,6 +110,7 @@ quant_levels = args.quant_levels
lambda_min = args.lambda_min lambda_min = args.lambda_min
lambda_max = args.lambda_max lambda_max = args.lambda_max
state_dim = args.state_dim state_dim = args.state_dim
softquant = args.softquant
# not expsed # not expsed
num_features = 20 num_features = 20
@ -118,7 +120,7 @@ feature_file = args.features
# model # model
checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2) checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate} checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate, 'softquant': softquant}
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
if type(args.initial_checkpoint) != type(None): if type(args.initial_checkpoint) != type(None):