removed trailing whitespace in fargan

Signed-off-by: Jan Buethe <jbuethe@amazon.de>
This commit is contained in:
Jan Buethe 2023-09-13 16:57:28 +02:00
parent e7beaec3fb
commit 82f48d368b
No known key found for this signature in database
GPG key ID: 9E32027A35B36314
5 changed files with 34 additions and 35 deletions

View file

@ -81,7 +81,7 @@ def gen_phase_embedding(periods, frame_size):
class GLU(nn.Module):
def __init__(self, feat_size):
super(GLU, self).__init__()
torch.manual_seed(5)
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
@ -89,16 +89,16 @@ class GLU(nn.Module):
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
out = x * torch.sigmoid(self.gate(x))
out = x * torch.sigmoid(self.gate(x))
return out
class FWConv(nn.Module):
@ -160,21 +160,21 @@ class FARGANSub(nn.Module):
self.subframe_size = subframe_size
self.nb_subframes = nb_subframes
self.cond_size = cond_size
#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size)
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.dense1_glu = GLU(self.cond_size)
self.dense2_glu = GLU(self.cond_size)
self.gru1_glu = GLU(self.cond_size)
self.gru2_glu = GLU(self.cond_size)
self.gru3_glu = GLU(self.cond_size)
self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
self.sig_dense_out = nn.Linear(4*self.cond_size, self.subframe_size, bias=False)
self.gain_dense_out = nn.Linear(4*self.cond_size, 1)
@ -184,7 +184,7 @@ class FARGANSub(nn.Module):
def forward(self, cond, prev, exc_mem, phase, period, states, gain=None):
device = exc_mem.device
#print(cond.shape, prev.shape)
dump_signal(prev, 'prev_in.f32')
idx = 256-torch.clamp(period[:,None], min=self.subframe_size+2, max=254)
@ -283,4 +283,3 @@ class FARGAN(nn.Module):
prev = out
states = [s.detach() for s in states]
return sig, states