added FWGAN weight dumping code

This commit is contained in:
Jan Buethe 2023-08-01 18:18:28 +02:00
parent 9691440a5f
commit 902d763622
No known key found for this signature in database
GPG key ID: 9E32027A35B36314
5 changed files with 805 additions and 0 deletions

View file

@ -0,0 +1,89 @@
import os
import sys
import argparse
import torch
from torch import nn
sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange'))
import wexchange.torch
from models import model_dict
unquantized = [
'feat_in_conv1.conv',
'bfcc_with_corr_upsampler.fc',
'cont_net.0',
'fwc6.cont_fc.0',
'fwc6.fc.0',
'fwc6.fc.1.gate',
'fwc7.cont_fc.0',
'fwc7.fc.0',
'fwc7.fc.1.gate'
]
description=f"""
This is an unsafe dumping script for FWGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer
and will fail to export any other weights.
Furthermore, the quanitze option relies on the following explicit list of layers to be excluded:
{unquantized}.
Modify this script manually if adjustments are needed.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name')
parser.add_argument('weightfile', type=str, help='weight file path')
parser.add_argument('export_folder', type=str)
parser.add_argument('--export-filename', type=str, default='fwgan_data', help='filename for source and header file (.c and .h will be added), defaults to fwgan_data')
parser.add_argument('--struct-name', type=str, default='FWGAN', help='name for C struct, defaults to FWGAN')
parser.add_argument('--quantize', action='store_true', help='apply quantization')
if __name__ == "__main__":
args = parser.parse_args()
model = model_dict[args.model]()
print(f"loading weights from {args.weightfile}...")
saved_gen= torch.load(args.weightfile, map_location='cpu')
model.load_state_dict(saved_gen)
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
model.apply(_remove_weight_norm)
print("dumping model...")
quantize_model=args.quantize
output_folder = args.export_folder
os.makedirs(output_folder, exist_ok=True)
writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name)
for name, module in model.named_modules():
if quantize_model:
quantize=name not in unquantized
scale = None if quantize else 1/128
else:
quantize=False
scale=1/128
if isinstance(module, nn.Linear):
print(f"dumping linear layer {name}...")
wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
if isinstance(module, nn.Conv1d):
print(f"dumping conv1d layer {name}...")
wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
if isinstance(module, nn.GRU):
print(f"dumping GRU layer {name}...")
wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
writer.close()

View file

@ -0,0 +1,141 @@
import os
import time
import torch
import numpy as np
from scipy import signal as si
from scipy.io import wavfile
import argparse
from models import model_dict
parser = argparse.ArgumentParser()
parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name')
parser.add_argument('weightfile', type=str, help='weight file')
parser.add_argument('input', type=str, help='input: feature file or folder with feature files')
parser.add_argument('output', type=str, help='output: wav file name or folder name, depending on input')
########################### Signal Processing Layers ###########################
def preemphasis(x, coef= -0.85):
return si.lfilter(np.array([1.0, coef]), np.array([1.0]), x).astype('float32')
def deemphasis(x, coef= -0.85):
return si.lfilter(np.array([1.0]), np.array([1.0, coef]), x).astype('float32')
gamma = 0.92
weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
out = np.zeros_like(frame)
filt = np.flip(filt)
inp = frame[:]
for i in range(0, inp.shape[0]):
s = inp[i] - np.dot(buffer*weighting_vector, filt)
buffer[0] = s
buffer = np.roll(buffer, -1)
out[i] = s
return out
def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
#inverse perceptual weighting= H_preemph / W(z/gamma)
pw_signal = preemphasis(pw_signal)
signal = np.zeros_like(pw_signal)
buffer = np.zeros(16)
num_frames = pw_signal.shape[0] //160
assert num_frames == filters.shape[0]
for frame_idx in range(0, num_frames):
in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
buffer[:] = out_sig_frame[-16:]
return signal
def process_item(generator, feature_filename, output_filename, verbose=False):
feat = np.memmap(feature_filename, dtype='float32', mode='r')
num_feat_frames = len(feat) // 36
feat = np.reshape(feat, (num_feat_frames, 36))
bfcc = np.copy(feat[:, :18])
corr = np.copy(feat[:, 19:20]) + 0.5
bfcc_with_corr = torch.from_numpy(np.hstack((bfcc, corr))).type(torch.FloatTensor).unsqueeze(0)#.to(device)
period = torch.from_numpy((0.1 + 50 * np.copy(feat[:, 18:19]) + 100)\
.astype('int32')).type(torch.long).view(1,-1)#.to(device)
lpc_filters = np.copy(feat[:, -16:])
start_time = time.time()
x1 = generator(period, bfcc_with_corr, torch.zeros(1,320)) #this means the vocoder runs in complete synthesis mode with zero history audio frames
end_time = time.time()
total_time = end_time - start_time
x1 = x1.squeeze(1).squeeze(0).detach().cpu().numpy()
gen_seconds = len(x1)/16000
out = deemphasis(inverse_perceptual_weighting(x1, lpc_filters, weighting_vector))
if verbose:
print(f"Took {total_time:.3f}s to generate {len(x1)} samples ({gen_seconds}s) -> {gen_seconds/total_time:.2f}x real time")
out = np.clip(np.round(2**15 * out), -2**15, 2**15 -1).astype(np.int16)
wavfile.write(output_filename, 16000, out)
########################### The inference loop over folder containing lpcnet feature files #################################
if __name__ == "__main__":
args = parser.parse_args()
generator = model_dict[args.model]()
#Load the FWGAN500Hz Checkpoint
saved_gen= torch.load(args.weightfile, map_location='cpu')
generator.load_state_dict(saved_gen)
#this is just to remove the weight_norm from the model layers as it's no longer needed
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
generator.apply(_remove_weight_norm)
#enable inference mode
generator = generator.eval()
print('Successfully loaded the generator model ... start generation:')
if os.path.isdir(args.input):
os.makedirs(args.output, exist_ok=True)
for fn in os.listdir(args.input):
print(f"processing input {fn}...")
feature_filename = os.path.join(args.input, fn)
output_filename = os.path.join(args.output, os.path.splitext(fn)[0] + f"_{args.model}.wav")
process_item(generator, feature_filename, output_filename)
else:
process_item(generator, args.input, args.output)
print("Finished!")

View file

@ -0,0 +1,7 @@
from .fwgan400 import FWGAN400ContLarge
from .fwgan500 import FWGAN500Cont
model_dict = {
'fwgan400': FWGAN400ContLarge,
'fwgan500': FWGAN500Cont
}

View file

@ -0,0 +1,308 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import numpy as np
which_norm = weight_norm
#################### Definition of basic model components ####################
#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet)
class ConvLookahead(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False):
super(ConvLookahead, self).__init__()
torch.manual_seed(5)
self.padding_left = (kernel_size - 2) * dilation
self.padding_right = 1 * dilation
self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias))
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
x = F.pad(x,(self.padding_left, self.padding_right))
conv_out = self.conv(x)
return conv_out
#(modified) GLU Activation layer definition
class GLU(nn.Module):
def __init__(self, feat_size):
super(GLU, self).__init__()
torch.manual_seed(5)
self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False))
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
out = torch.tanh(x) * torch.sigmoid(self.gate(x))
return out
#GRU layer definition
class ContForwardGRU(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1):
super(ContForwardGRU, self).__init__()
torch.manual_seed(5)
self.hidden_size = hidden_size
self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.hidden_size, bias=False)),
nn.Tanh())
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\
bias=False)
self.nl = GLU(self.hidden_size)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x, x0):
self.gru.flatten_parameters()
h0 = self.cont_fc(x0).unsqueeze(0)
output, h0 = self.gru(x, h0)
return self.nl(output)
# Framewise convolution layer definition
class ContFramewiseConv(torch.nn.Module):
def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True):
super(ContFramewiseConv, self).__init__()
torch.manual_seed(5)
self.frame_kernel_size = frame_kernel_size
self.frame_len = frame_len
if (causal == True) or (self.frame_kernel_size == 2):
self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len
self.required_pad_right = 0
self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.required_pad_left, bias=False)),
nn.Tanh()
)
else:
self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len
self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len
self.fc_input_dim = self.frame_kernel_size * self.frame_len
self.fc_out_dim = out_dim
if act=='glu':
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
GLU(self.fc_out_dim)
)
if act=='tanh':
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
nn.Tanh()
)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x, x0):
if self.frame_kernel_size == 1:
return self.fc(x)
x_flat = x.reshape(x.size(0),1,-1)
pad = self.cont_fc(x0).view(x0.size(0),1,-1)
x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2)
x_flat_padded_unfolded = F.unfold(x_flat_padded,\
kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous()
out = self.fc(x_flat_padded_unfolded)
return out
# A fully-connected based upsampling layer definition
class UpsampleFC(nn.Module):
def __init__(self, in_ch, out_ch, upsample_factor):
super(UpsampleFC, self).__init__()
torch.manual_seed(5)
self.in_ch = in_ch
self.out_ch = out_ch
self.upsample_factor = upsample_factor
self.fc = nn.Linear(in_ch, out_ch * upsample_factor, bias=False)
self.nl = nn.Tanh()
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or\
isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
batch_size = x.size(0)
x = x.permute(0, 2, 1)
x = self.nl(self.fc(x))
x = x.reshape((batch_size, -1, self.out_ch))
x = x.permute(0, 2, 1)
return x
########################### The complete model definition #################################
class FWGAN400ContLarge(nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(5)
self.bfcc_with_corr_upsampler = UpsampleFC(19,80,4)
self.feat_in_conv1 = ConvLookahead(160,256,kernel_size=5)
self.feat_in_nl1 = GLU(256)
self.cont_net = nn.Sequential(which_norm(nn.Linear(321, 160, bias=False)),
nn.Tanh(),
which_norm(nn.Linear(160, 160, bias=False)),
nn.Tanh(),
which_norm(nn.Linear(160, 80, bias=False)),
nn.Tanh(),
which_norm(nn.Linear(80, 80, bias=False)),
nn.Tanh(),
which_norm(nn.Linear(80, 64, bias=False)),
nn.Tanh(),
which_norm(nn.Linear(64, 64, bias=False)),
nn.Tanh())
self.rnn = ContForwardGRU(256,256)
self.fwc1 = ContFramewiseConv(256, 256)
self.fwc2 = ContFramewiseConv(256, 128)
self.fwc3 = ContFramewiseConv(128, 128)
self.fwc4 = ContFramewiseConv(128, 64)
self.fwc5 = ContFramewiseConv(64, 64)
self.fwc6 = ContFramewiseConv(64, 40)
self.fwc7 = ContFramewiseConv(40, 40)
self.init_weights()
self.count_parameters()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def count_parameters(self):
num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n")
def create_phase_signals(self, periods):
batch_size = periods.size(0)
progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
progression = torch.repeat_interleave(progression, batch_size, 0)
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
chunks = []
for sframe in range(periods.size(1)):
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
chunk_sin = torch.sin(f * progression + phase0)
chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,40)
chunk_cos = torch.cos(f * progression + phase0)
chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,40)
chunk = torch.cat((chunk_sin, chunk_cos), dim = -1)
phase0 = phase0 + 160 * f
chunks.append(chunk)
phase_signals = torch.cat(chunks, dim=1)
return phase_signals
def gain_multiply(self, x, c0):
gain = 10**(0.5*c0/np.sqrt(18.0))
gain = torch.repeat_interleave(gain, 160, dim=-1)
gain = gain.reshape(gain.size(0),1,-1).squeeze(1)
return x * gain
def forward(self, pitch_period, bfcc_with_corr, x0):
norm_x0 = torch.norm(x0,2, dim=-1, keepdim=True)
x0 = x0 / torch.sqrt((1e-8) + norm_x0**2)
x0 = torch.cat((torch.log(norm_x0 + 1e-7), x0), dim=-1)
p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous()
envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous())
feat_in = torch.cat((p_embed , envelope), dim=1)
wav_latent1 = self.feat_in_nl1(self.feat_in_conv1(feat_in).permute(0,2,1).contiguous())
cont_latent = self.cont_net(x0)
rnn_out = self.rnn(wav_latent1, cont_latent)
fwc1_out = self.fwc1(rnn_out, cont_latent)
fwc2_out = self.fwc2(fwc1_out, cont_latent)
fwc3_out = self.fwc3(fwc2_out, cont_latent)
fwc4_out = self.fwc4(fwc3_out, cont_latent)
fwc5_out = self.fwc5(fwc4_out, cont_latent)
fwc6_out = self.fwc6(fwc5_out, cont_latent)
fwc7_out = self.fwc7(fwc6_out, cont_latent)
waveform = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1)
waveform = self.gain_multiply(waveform,bfcc_with_corr[:,:,:1])
return waveform

View file

@ -0,0 +1,260 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import numpy as np
which_norm = weight_norm
#################### Definition of basic model components ####################
#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet)
class ConvLookahead(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False):
super(ConvLookahead, self).__init__()
torch.manual_seed(5)
self.padding_left = (kernel_size - 2) * dilation
self.padding_right = 1 * dilation
self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias))
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
x = F.pad(x,(self.padding_left, self.padding_right))
conv_out = self.conv(x)
return conv_out
#(modified) GLU Activation layer definition
class GLU(nn.Module):
def __init__(self, feat_size):
super(GLU, self).__init__()
torch.manual_seed(5)
self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False))
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
out = torch.tanh(x) * torch.sigmoid(self.gate(x))
return out
#GRU layer definition
class ContForwardGRU(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1):
super(ContForwardGRU, self).__init__()
torch.manual_seed(5)
self.hidden_size = hidden_size
#This is to initialize the layer with history audio samples for continuation.
self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.hidden_size, bias=False)),
nn.Tanh())
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\
bias=False)
self.nl = GLU(self.hidden_size)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x, x0):
self.gru.flatten_parameters()
h0 = self.cont_fc(x0).unsqueeze(0)
output, h0 = self.gru(x, h0)
return self.nl(output)
# Framewise convolution layer definition
class ContFramewiseConv(torch.nn.Module):
def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True):
super(ContFramewiseConv, self).__init__()
torch.manual_seed(5)
self.frame_kernel_size = frame_kernel_size
self.frame_len = frame_len
if (causal == True) or (self.frame_kernel_size == 2):
self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len
self.required_pad_right = 0
#This is to initialize the layer with history audio samples for continuation.
self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.required_pad_left, bias=False)),
nn.Tanh()
)
else:
#This means non-causal frame-wise convolution. We don't use it at the moment
self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len
self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len
self.fc_input_dim = self.frame_kernel_size * self.frame_len
self.fc_out_dim = out_dim
if act=='glu':
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
GLU(self.fc_out_dim)
)
if act=='tanh':
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
nn.Tanh()
)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x, x0):
if self.frame_kernel_size == 1:
return self.fc(x)
x_flat = x.reshape(x.size(0),1,-1)
pad = self.cont_fc(x0).view(x0.size(0),1,-1)
x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2)
x_flat_padded_unfolded = F.unfold(x_flat_padded,\
kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous()
out = self.fc(x_flat_padded_unfolded)
return out
########################### The complete model definition #################################
class FWGAN500Cont(nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(5)
#PrecondNet:
self.bfcc_with_corr_upsampler = nn.Sequential(nn.ConvTranspose1d(19,64,kernel_size=5,stride=5,padding=0,\
bias=False),
nn.Tanh())
self.feat_in_conv = ConvLookahead(128,256,kernel_size=5)
self.feat_in_nl = GLU(256)
#GRU:
self.rnn = ContForwardGRU(256,256)
#Frame-wise convolution stack:
self.fwc1 = ContFramewiseConv(256, 256)
self.fwc2 = ContFramewiseConv(256, 128)
self.fwc3 = ContFramewiseConv(128, 128)
self.fwc4 = ContFramewiseConv(128, 64)
self.fwc5 = ContFramewiseConv(64, 64)
self.fwc6 = ContFramewiseConv(64, 32)
self.fwc7 = ContFramewiseConv(32, 32, act='tanh')
self.init_weights()
self.count_parameters()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def count_parameters(self):
num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n")
def create_phase_signals(self, periods):
batch_size = periods.size(0)
progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
progression = torch.repeat_interleave(progression, batch_size, 0)
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
chunks = []
for sframe in range(periods.size(1)):
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
chunk_sin = torch.sin(f * progression + phase0)
chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,32)
chunk_cos = torch.cos(f * progression + phase0)
chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,32)
chunk = torch.cat((chunk_sin, chunk_cos), dim = -1)
phase0 = phase0 + 160 * f
chunks.append(chunk)
phase_signals = torch.cat(chunks, dim=1)
return phase_signals
def gain_multiply(self, x, c0):
gain = 10**(0.5*c0/np.sqrt(18.0))
gain = torch.repeat_interleave(gain, 160, dim=-1)
gain = gain.reshape(gain.size(0),1,-1).squeeze(1)
return x * gain
def forward(self, pitch_period, bfcc_with_corr, x0):
#This should create a latent representation of shape [Batch_dim, 500 frames, 256 elemets per frame]
p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous()
envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous())
feat_in = torch.cat((p_embed , envelope), dim=1)
wav_latent = self.feat_in_nl(self.feat_in_conv(feat_in).permute(0,2,1).contiguous())
#Generation with continuation using history samples x0 starts from here:
rnn_out = self.rnn(wav_latent, x0)
fwc1_out = self.fwc1(rnn_out, x0)
fwc2_out = self.fwc2(fwc1_out, x0)
fwc3_out = self.fwc3(fwc2_out, x0)
fwc4_out = self.fwc4(fwc3_out, x0)
fwc5_out = self.fwc5(fwc4_out, x0)
fwc6_out = self.fwc6(fwc5_out, x0)
fwc7_out = self.fwc7(fwc6_out, x0)
waveform_unscaled = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1)
waveform = self.gain_multiply(waveform_unscaled,bfcc_with_corr[:,:,:1])
return waveform