added FWGAN weight dumping code
This commit is contained in:
parent
9691440a5f
commit
902d763622
5 changed files with 805 additions and 0 deletions
89
dnn/torch/fwgan/dump_model_weights.py
Normal file
89
dnn/torch/fwgan/dump_model_weights.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange'))
|
||||
import wexchange.torch
|
||||
|
||||
from models import model_dict
|
||||
|
||||
unquantized = [
|
||||
'feat_in_conv1.conv',
|
||||
'bfcc_with_corr_upsampler.fc',
|
||||
'cont_net.0',
|
||||
'fwc6.cont_fc.0',
|
||||
'fwc6.fc.0',
|
||||
'fwc6.fc.1.gate',
|
||||
'fwc7.cont_fc.0',
|
||||
'fwc7.fc.0',
|
||||
'fwc7.fc.1.gate'
|
||||
]
|
||||
|
||||
description=f"""
|
||||
This is an unsafe dumping script for FWGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer
|
||||
and will fail to export any other weights.
|
||||
|
||||
Furthermore, the quanitze option relies on the following explicit list of layers to be excluded:
|
||||
{unquantized}.
|
||||
|
||||
Modify this script manually if adjustments are needed.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name')
|
||||
parser.add_argument('weightfile', type=str, help='weight file path')
|
||||
parser.add_argument('export_folder', type=str)
|
||||
parser.add_argument('--export-filename', type=str, default='fwgan_data', help='filename for source and header file (.c and .h will be added), defaults to fwgan_data')
|
||||
parser.add_argument('--struct-name', type=str, default='FWGAN', help='name for C struct, defaults to FWGAN')
|
||||
parser.add_argument('--quantize', action='store_true', help='apply quantization')
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
model = model_dict[args.model]()
|
||||
|
||||
print(f"loading weights from {args.weightfile}...")
|
||||
saved_gen= torch.load(args.weightfile, map_location='cpu')
|
||||
model.load_state_dict(saved_gen)
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
model.apply(_remove_weight_norm)
|
||||
|
||||
|
||||
print("dumping model...")
|
||||
quantize_model=args.quantize
|
||||
|
||||
output_folder = args.export_folder
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
|
||||
if quantize_model:
|
||||
quantize=name not in unquantized
|
||||
scale = None if quantize else 1/128
|
||||
else:
|
||||
quantize=False
|
||||
scale=1/128
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
print(f"dumping linear layer {name}...")
|
||||
wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
|
||||
|
||||
if isinstance(module, nn.Conv1d):
|
||||
print(f"dumping conv1d layer {name}...")
|
||||
wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
|
||||
|
||||
if isinstance(module, nn.GRU):
|
||||
print(f"dumping GRU layer {name}...")
|
||||
wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
|
||||
|
||||
writer.close()
|
141
dnn/torch/fwgan/inference.py
Normal file
141
dnn/torch/fwgan/inference.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
import os
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy import signal as si
|
||||
from scipy.io import wavfile
|
||||
import argparse
|
||||
|
||||
from models import model_dict
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name')
|
||||
parser.add_argument('weightfile', type=str, help='weight file')
|
||||
parser.add_argument('input', type=str, help='input: feature file or folder with feature files')
|
||||
parser.add_argument('output', type=str, help='output: wav file name or folder name, depending on input')
|
||||
|
||||
|
||||
########################### Signal Processing Layers ###########################
|
||||
|
||||
def preemphasis(x, coef= -0.85):
|
||||
|
||||
return si.lfilter(np.array([1.0, coef]), np.array([1.0]), x).astype('float32')
|
||||
|
||||
def deemphasis(x, coef= -0.85):
|
||||
|
||||
return si.lfilter(np.array([1.0]), np.array([1.0, coef]), x).astype('float32')
|
||||
|
||||
gamma = 0.92
|
||||
weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
|
||||
|
||||
|
||||
def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
|
||||
|
||||
out = np.zeros_like(frame)
|
||||
|
||||
filt = np.flip(filt)
|
||||
|
||||
inp = frame[:]
|
||||
|
||||
|
||||
for i in range(0, inp.shape[0]):
|
||||
|
||||
s = inp[i] - np.dot(buffer*weighting_vector, filt)
|
||||
|
||||
buffer[0] = s
|
||||
|
||||
buffer = np.roll(buffer, -1)
|
||||
|
||||
out[i] = s
|
||||
|
||||
return out
|
||||
|
||||
def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
|
||||
|
||||
#inverse perceptual weighting= H_preemph / W(z/gamma)
|
||||
|
||||
pw_signal = preemphasis(pw_signal)
|
||||
|
||||
signal = np.zeros_like(pw_signal)
|
||||
buffer = np.zeros(16)
|
||||
num_frames = pw_signal.shape[0] //160
|
||||
assert num_frames == filters.shape[0]
|
||||
|
||||
for frame_idx in range(0, num_frames):
|
||||
|
||||
in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
|
||||
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
|
||||
signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
|
||||
buffer[:] = out_sig_frame[-16:]
|
||||
|
||||
return signal
|
||||
|
||||
|
||||
def process_item(generator, feature_filename, output_filename, verbose=False):
|
||||
|
||||
feat = np.memmap(feature_filename, dtype='float32', mode='r')
|
||||
|
||||
num_feat_frames = len(feat) // 36
|
||||
feat = np.reshape(feat, (num_feat_frames, 36))
|
||||
|
||||
bfcc = np.copy(feat[:, :18])
|
||||
corr = np.copy(feat[:, 19:20]) + 0.5
|
||||
bfcc_with_corr = torch.from_numpy(np.hstack((bfcc, corr))).type(torch.FloatTensor).unsqueeze(0)#.to(device)
|
||||
|
||||
period = torch.from_numpy((0.1 + 50 * np.copy(feat[:, 18:19]) + 100)\
|
||||
.astype('int32')).type(torch.long).view(1,-1)#.to(device)
|
||||
|
||||
lpc_filters = np.copy(feat[:, -16:])
|
||||
|
||||
start_time = time.time()
|
||||
x1 = generator(period, bfcc_with_corr, torch.zeros(1,320)) #this means the vocoder runs in complete synthesis mode with zero history audio frames
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
x1 = x1.squeeze(1).squeeze(0).detach().cpu().numpy()
|
||||
gen_seconds = len(x1)/16000
|
||||
out = deemphasis(inverse_perceptual_weighting(x1, lpc_filters, weighting_vector))
|
||||
if verbose:
|
||||
print(f"Took {total_time:.3f}s to generate {len(x1)} samples ({gen_seconds}s) -> {gen_seconds/total_time:.2f}x real time")
|
||||
|
||||
out = np.clip(np.round(2**15 * out), -2**15, 2**15 -1).astype(np.int16)
|
||||
wavfile.write(output_filename, 16000, out)
|
||||
|
||||
|
||||
########################### The inference loop over folder containing lpcnet feature files #################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
generator = model_dict[args.model]()
|
||||
|
||||
|
||||
#Load the FWGAN500Hz Checkpoint
|
||||
saved_gen= torch.load(args.weightfile, map_location='cpu')
|
||||
generator.load_state_dict(saved_gen)
|
||||
|
||||
#this is just to remove the weight_norm from the model layers as it's no longer needed
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
generator.apply(_remove_weight_norm)
|
||||
|
||||
#enable inference mode
|
||||
generator = generator.eval()
|
||||
|
||||
print('Successfully loaded the generator model ... start generation:')
|
||||
|
||||
if os.path.isdir(args.input):
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
for fn in os.listdir(args.input):
|
||||
print(f"processing input {fn}...")
|
||||
feature_filename = os.path.join(args.input, fn)
|
||||
output_filename = os.path.join(args.output, os.path.splitext(fn)[0] + f"_{args.model}.wav")
|
||||
process_item(generator, feature_filename, output_filename)
|
||||
else:
|
||||
process_item(generator, args.input, args.output)
|
||||
|
||||
print("Finished!")
|
7
dnn/torch/fwgan/models/__init__.py
Normal file
7
dnn/torch/fwgan/models/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
from .fwgan400 import FWGAN400ContLarge
|
||||
from .fwgan500 import FWGAN500Cont
|
||||
|
||||
model_dict = {
|
||||
'fwgan400': FWGAN400ContLarge,
|
||||
'fwgan500': FWGAN500Cont
|
||||
}
|
308
dnn/torch/fwgan/models/fwgan400.py
Normal file
308
dnn/torch/fwgan/models/fwgan400.py
Normal file
|
@ -0,0 +1,308 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
import numpy as np
|
||||
|
||||
which_norm = weight_norm
|
||||
|
||||
#################### Definition of basic model components ####################
|
||||
|
||||
#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet)
|
||||
class ConvLookahead(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False):
|
||||
super(ConvLookahead, self).__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.padding_left = (kernel_size - 2) * dilation
|
||||
self.padding_right = 1 * dilation
|
||||
|
||||
self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = F.pad(x,(self.padding_left, self.padding_right))
|
||||
conv_out = self.conv(x)
|
||||
return conv_out
|
||||
|
||||
#(modified) GLU Activation layer definition
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, feat_size):
|
||||
super(GLU, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
|
||||
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
out = torch.tanh(x) * torch.sigmoid(self.gate(x))
|
||||
|
||||
return out
|
||||
|
||||
#GRU layer definition
|
||||
class ContForwardGRU(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers=1):
|
||||
super(ContForwardGRU, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.hidden_size, bias=False)),
|
||||
nn.Tanh())
|
||||
|
||||
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\
|
||||
bias=False)
|
||||
|
||||
self.nl = GLU(self.hidden_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x, x0):
|
||||
|
||||
self.gru.flatten_parameters()
|
||||
|
||||
h0 = self.cont_fc(x0).unsqueeze(0)
|
||||
|
||||
output, h0 = self.gru(x, h0)
|
||||
|
||||
return self.nl(output)
|
||||
|
||||
# Framewise convolution layer definition
|
||||
class ContFramewiseConv(torch.nn.Module):
|
||||
|
||||
def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True):
|
||||
|
||||
super(ContFramewiseConv, self).__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.frame_kernel_size = frame_kernel_size
|
||||
self.frame_len = frame_len
|
||||
|
||||
if (causal == True) or (self.frame_kernel_size == 2):
|
||||
|
||||
self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len
|
||||
self.required_pad_right = 0
|
||||
|
||||
self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.required_pad_left, bias=False)),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len
|
||||
self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len
|
||||
|
||||
self.fc_input_dim = self.frame_kernel_size * self.frame_len
|
||||
self.fc_out_dim = out_dim
|
||||
|
||||
if act=='glu':
|
||||
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
|
||||
GLU(self.fc_out_dim)
|
||||
)
|
||||
if act=='tanh':
|
||||
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
|
||||
isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x, x0):
|
||||
|
||||
if self.frame_kernel_size == 1:
|
||||
return self.fc(x)
|
||||
|
||||
x_flat = x.reshape(x.size(0),1,-1)
|
||||
pad = self.cont_fc(x0).view(x0.size(0),1,-1)
|
||||
x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2)
|
||||
|
||||
x_flat_padded_unfolded = F.unfold(x_flat_padded,\
|
||||
kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous()
|
||||
|
||||
out = self.fc(x_flat_padded_unfolded)
|
||||
return out
|
||||
|
||||
# A fully-connected based upsampling layer definition
|
||||
class UpsampleFC(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, upsample_factor):
|
||||
super(UpsampleFC, self).__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.upsample_factor = upsample_factor
|
||||
self.fc = nn.Linear(in_ch, out_ch * upsample_factor, bias=False)
|
||||
self.nl = nn.Tanh()
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or\
|
||||
isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
batch_size = x.size(0)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.nl(self.fc(x))
|
||||
x = x.reshape((batch_size, -1, self.out_ch))
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
########################### The complete model definition #################################
|
||||
|
||||
class FWGAN400ContLarge(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.bfcc_with_corr_upsampler = UpsampleFC(19,80,4)
|
||||
|
||||
self.feat_in_conv1 = ConvLookahead(160,256,kernel_size=5)
|
||||
self.feat_in_nl1 = GLU(256)
|
||||
|
||||
self.cont_net = nn.Sequential(which_norm(nn.Linear(321, 160, bias=False)),
|
||||
nn.Tanh(),
|
||||
which_norm(nn.Linear(160, 160, bias=False)),
|
||||
nn.Tanh(),
|
||||
which_norm(nn.Linear(160, 80, bias=False)),
|
||||
nn.Tanh(),
|
||||
which_norm(nn.Linear(80, 80, bias=False)),
|
||||
nn.Tanh(),
|
||||
which_norm(nn.Linear(80, 64, bias=False)),
|
||||
nn.Tanh(),
|
||||
which_norm(nn.Linear(64, 64, bias=False)),
|
||||
nn.Tanh())
|
||||
|
||||
self.rnn = ContForwardGRU(256,256)
|
||||
|
||||
self.fwc1 = ContFramewiseConv(256, 256)
|
||||
self.fwc2 = ContFramewiseConv(256, 128)
|
||||
self.fwc3 = ContFramewiseConv(128, 128)
|
||||
self.fwc4 = ContFramewiseConv(128, 64)
|
||||
self.fwc5 = ContFramewiseConv(64, 64)
|
||||
self.fwc6 = ContFramewiseConv(64, 40)
|
||||
self.fwc7 = ContFramewiseConv(40, 40)
|
||||
|
||||
self.init_weights()
|
||||
self.count_parameters()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
|
||||
isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def count_parameters(self):
|
||||
num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n")
|
||||
|
||||
def create_phase_signals(self, periods):
|
||||
|
||||
batch_size = periods.size(0)
|
||||
progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
|
||||
progression = torch.repeat_interleave(progression, batch_size, 0)
|
||||
|
||||
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
|
||||
chunks = []
|
||||
for sframe in range(periods.size(1)):
|
||||
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
|
||||
|
||||
chunk_sin = torch.sin(f * progression + phase0)
|
||||
chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,40)
|
||||
|
||||
chunk_cos = torch.cos(f * progression + phase0)
|
||||
chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,40)
|
||||
|
||||
chunk = torch.cat((chunk_sin, chunk_cos), dim = -1)
|
||||
|
||||
phase0 = phase0 + 160 * f
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
phase_signals = torch.cat(chunks, dim=1)
|
||||
|
||||
return phase_signals
|
||||
|
||||
|
||||
def gain_multiply(self, x, c0):
|
||||
|
||||
gain = 10**(0.5*c0/np.sqrt(18.0))
|
||||
gain = torch.repeat_interleave(gain, 160, dim=-1)
|
||||
gain = gain.reshape(gain.size(0),1,-1).squeeze(1)
|
||||
|
||||
return x * gain
|
||||
|
||||
def forward(self, pitch_period, bfcc_with_corr, x0):
|
||||
|
||||
norm_x0 = torch.norm(x0,2, dim=-1, keepdim=True)
|
||||
x0 = x0 / torch.sqrt((1e-8) + norm_x0**2)
|
||||
x0 = torch.cat((torch.log(norm_x0 + 1e-7), x0), dim=-1)
|
||||
|
||||
p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous()
|
||||
|
||||
envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous())
|
||||
|
||||
feat_in = torch.cat((p_embed , envelope), dim=1)
|
||||
|
||||
wav_latent1 = self.feat_in_nl1(self.feat_in_conv1(feat_in).permute(0,2,1).contiguous())
|
||||
|
||||
cont_latent = self.cont_net(x0)
|
||||
|
||||
rnn_out = self.rnn(wav_latent1, cont_latent)
|
||||
|
||||
fwc1_out = self.fwc1(rnn_out, cont_latent)
|
||||
|
||||
fwc2_out = self.fwc2(fwc1_out, cont_latent)
|
||||
|
||||
fwc3_out = self.fwc3(fwc2_out, cont_latent)
|
||||
|
||||
fwc4_out = self.fwc4(fwc3_out, cont_latent)
|
||||
|
||||
fwc5_out = self.fwc5(fwc4_out, cont_latent)
|
||||
|
||||
fwc6_out = self.fwc6(fwc5_out, cont_latent)
|
||||
|
||||
fwc7_out = self.fwc7(fwc6_out, cont_latent)
|
||||
|
||||
waveform = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1)
|
||||
|
||||
waveform = self.gain_multiply(waveform,bfcc_with_corr[:,:,:1])
|
||||
|
||||
return waveform
|
260
dnn/torch/fwgan/models/fwgan500.py
Normal file
260
dnn/torch/fwgan/models/fwgan500.py
Normal file
|
@ -0,0 +1,260 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
import numpy as np
|
||||
|
||||
|
||||
which_norm = weight_norm
|
||||
|
||||
#################### Definition of basic model components ####################
|
||||
|
||||
#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet)
|
||||
class ConvLookahead(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False):
|
||||
super(ConvLookahead, self).__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.padding_left = (kernel_size - 2) * dilation
|
||||
self.padding_right = 1 * dilation
|
||||
|
||||
self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = F.pad(x,(self.padding_left, self.padding_right))
|
||||
conv_out = self.conv(x)
|
||||
return conv_out
|
||||
|
||||
#(modified) GLU Activation layer definition
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, feat_size):
|
||||
super(GLU, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
|
||||
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
out = torch.tanh(x) * torch.sigmoid(self.gate(x))
|
||||
|
||||
return out
|
||||
|
||||
#GRU layer definition
|
||||
class ContForwardGRU(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers=1):
|
||||
super(ContForwardGRU, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
#This is to initialize the layer with history audio samples for continuation.
|
||||
self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.hidden_size, bias=False)),
|
||||
nn.Tanh())
|
||||
|
||||
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\
|
||||
bias=False)
|
||||
|
||||
self.nl = GLU(self.hidden_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x, x0):
|
||||
|
||||
self.gru.flatten_parameters()
|
||||
|
||||
h0 = self.cont_fc(x0).unsqueeze(0)
|
||||
|
||||
output, h0 = self.gru(x, h0)
|
||||
|
||||
return self.nl(output)
|
||||
|
||||
# Framewise convolution layer definition
|
||||
class ContFramewiseConv(torch.nn.Module):
|
||||
|
||||
def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True):
|
||||
|
||||
super(ContFramewiseConv, self).__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.frame_kernel_size = frame_kernel_size
|
||||
self.frame_len = frame_len
|
||||
|
||||
if (causal == True) or (self.frame_kernel_size == 2):
|
||||
|
||||
self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len
|
||||
self.required_pad_right = 0
|
||||
|
||||
#This is to initialize the layer with history audio samples for continuation.
|
||||
self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.required_pad_left, bias=False)),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
else:
|
||||
#This means non-causal frame-wise convolution. We don't use it at the moment
|
||||
self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len
|
||||
self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len
|
||||
|
||||
self.fc_input_dim = self.frame_kernel_size * self.frame_len
|
||||
self.fc_out_dim = out_dim
|
||||
|
||||
if act=='glu':
|
||||
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
|
||||
GLU(self.fc_out_dim)
|
||||
)
|
||||
if act=='tanh':
|
||||
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
|
||||
isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x, x0):
|
||||
|
||||
if self.frame_kernel_size == 1:
|
||||
return self.fc(x)
|
||||
|
||||
x_flat = x.reshape(x.size(0),1,-1)
|
||||
pad = self.cont_fc(x0).view(x0.size(0),1,-1)
|
||||
x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2)
|
||||
|
||||
x_flat_padded_unfolded = F.unfold(x_flat_padded,\
|
||||
kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous()
|
||||
|
||||
out = self.fc(x_flat_padded_unfolded)
|
||||
return out
|
||||
|
||||
########################### The complete model definition #################################
|
||||
|
||||
class FWGAN500Cont(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
#PrecondNet:
|
||||
self.bfcc_with_corr_upsampler = nn.Sequential(nn.ConvTranspose1d(19,64,kernel_size=5,stride=5,padding=0,\
|
||||
bias=False),
|
||||
nn.Tanh())
|
||||
|
||||
self.feat_in_conv = ConvLookahead(128,256,kernel_size=5)
|
||||
self.feat_in_nl = GLU(256)
|
||||
|
||||
#GRU:
|
||||
self.rnn = ContForwardGRU(256,256)
|
||||
|
||||
#Frame-wise convolution stack:
|
||||
self.fwc1 = ContFramewiseConv(256, 256)
|
||||
self.fwc2 = ContFramewiseConv(256, 128)
|
||||
self.fwc3 = ContFramewiseConv(128, 128)
|
||||
self.fwc4 = ContFramewiseConv(128, 64)
|
||||
self.fwc5 = ContFramewiseConv(64, 64)
|
||||
self.fwc6 = ContFramewiseConv(64, 32)
|
||||
self.fwc7 = ContFramewiseConv(32, 32, act='tanh')
|
||||
|
||||
self.init_weights()
|
||||
self.count_parameters()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
|
||||
isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def count_parameters(self):
|
||||
num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n")
|
||||
|
||||
def create_phase_signals(self, periods):
|
||||
|
||||
batch_size = periods.size(0)
|
||||
progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
|
||||
progression = torch.repeat_interleave(progression, batch_size, 0)
|
||||
|
||||
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
|
||||
chunks = []
|
||||
for sframe in range(periods.size(1)):
|
||||
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
|
||||
|
||||
chunk_sin = torch.sin(f * progression + phase0)
|
||||
chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,32)
|
||||
|
||||
chunk_cos = torch.cos(f * progression + phase0)
|
||||
chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,32)
|
||||
|
||||
chunk = torch.cat((chunk_sin, chunk_cos), dim = -1)
|
||||
|
||||
phase0 = phase0 + 160 * f
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
phase_signals = torch.cat(chunks, dim=1)
|
||||
|
||||
return phase_signals
|
||||
|
||||
|
||||
def gain_multiply(self, x, c0):
|
||||
|
||||
gain = 10**(0.5*c0/np.sqrt(18.0))
|
||||
gain = torch.repeat_interleave(gain, 160, dim=-1)
|
||||
gain = gain.reshape(gain.size(0),1,-1).squeeze(1)
|
||||
|
||||
return x * gain
|
||||
|
||||
def forward(self, pitch_period, bfcc_with_corr, x0):
|
||||
|
||||
#This should create a latent representation of shape [Batch_dim, 500 frames, 256 elemets per frame]
|
||||
p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous()
|
||||
envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous())
|
||||
feat_in = torch.cat((p_embed , envelope), dim=1)
|
||||
wav_latent = self.feat_in_nl(self.feat_in_conv(feat_in).permute(0,2,1).contiguous())
|
||||
|
||||
#Generation with continuation using history samples x0 starts from here:
|
||||
|
||||
rnn_out = self.rnn(wav_latent, x0)
|
||||
|
||||
fwc1_out = self.fwc1(rnn_out, x0)
|
||||
fwc2_out = self.fwc2(fwc1_out, x0)
|
||||
fwc3_out = self.fwc3(fwc2_out, x0)
|
||||
fwc4_out = self.fwc4(fwc3_out, x0)
|
||||
fwc5_out = self.fwc5(fwc4_out, x0)
|
||||
fwc6_out = self.fwc6(fwc5_out, x0)
|
||||
fwc7_out = self.fwc7(fwc6_out, x0)
|
||||
|
||||
waveform_unscaled = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1)
|
||||
waveform = self.gain_multiply(waveform_unscaled,bfcc_with_corr[:,:,:1])
|
||||
|
||||
return waveform
|
Loading…
Add table
Add a link
Reference in a new issue