mirror of
https://github.com/xiph/opus.git
synced 2025-06-05 23:10:54 +00:00
Adds skip connections
This commit is contained in:
parent
fb570ed8bb
commit
d54b9fb49a
1 changed files with 20 additions and 15 deletions
|
@ -140,7 +140,7 @@ class FARGANSub(nn.Module):
|
||||||
|
|
||||||
print("has_gain:", self.has_gain)
|
print("has_gain:", self.has_gain)
|
||||||
print("passthrough_size:", self.passthrough_size)
|
print("passthrough_size:", self.passthrough_size)
|
||||||
self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size+4, self.cond_size, bias=False)
|
self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
|
||||||
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
|
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
|
||||||
self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
|
self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
|
||||||
self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
|
self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
|
||||||
|
@ -151,11 +151,11 @@ class FARGANSub(nn.Module):
|
||||||
self.gru1_glu = GLU(self.cond_size)
|
self.gru1_glu = GLU(self.cond_size)
|
||||||
self.gru2_glu = GLU(self.cond_size)
|
self.gru2_glu = GLU(self.cond_size)
|
||||||
self.gru3_glu = GLU(self.cond_size)
|
self.gru3_glu = GLU(self.cond_size)
|
||||||
self.ptaps_dense = nn.Linear(self.cond_size, 5)
|
self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
|
||||||
|
|
||||||
self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size+self.passthrough_size, bias=False)
|
self.sig_dense_out = nn.Linear(4*self.cond_size, self.subframe_size+self.passthrough_size, bias=False)
|
||||||
if self.has_gain:
|
if self.has_gain:
|
||||||
self.gain_dense_out = nn.Linear(self.cond_size, 1)
|
self.gain_dense_out = nn.Linear(4*self.cond_size, 1)
|
||||||
|
|
||||||
|
|
||||||
self.apply(init_weights)
|
self.apply(init_weights)
|
||||||
|
@ -173,30 +173,35 @@ class FARGANSub(nn.Module):
|
||||||
pred = pred/(1e-5+gain)
|
pred = pred/(1e-5+gain)
|
||||||
|
|
||||||
prev = prev/(1e-5+gain)
|
prev = prev/(1e-5+gain)
|
||||||
#prev = prev*0
|
|
||||||
dump_signal(prev, 'pitch_exc.f32')
|
dump_signal(prev, 'pitch_exc.f32')
|
||||||
dump_signal(exc_mem, 'exc_mem.f32')
|
dump_signal(exc_mem, 'exc_mem.f32')
|
||||||
|
|
||||||
passthrough = states[3]
|
passthrough = states[3]
|
||||||
tmp = torch.cat((cond, pred, prev, passthrough, phase), 1)
|
tmp = torch.cat((cond, pred[:,2:-2], prev, passthrough, phase), 1)
|
||||||
|
|
||||||
tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
|
tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
|
||||||
tmp = self.dense2_glu(torch.tanh(self.sig_dense2(tmp)))
|
dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(tmp)))
|
||||||
gru1_state = self.gru1(tmp, states[0])
|
gru1_state = self.gru1(dense2_out, states[0])
|
||||||
gru2_state = self.gru2(self.gru1_glu(gru1_state), states[1])
|
gru1_out = self.gru1_glu(gru1_state)
|
||||||
gru3_state = self.gru3(self.gru2_glu(gru2_state), states[2])
|
#gru1_out = torch.cat([gru1_out, fpitch], 1)
|
||||||
|
gru2_state = self.gru2(gru1_out, states[1])
|
||||||
|
gru2_out = self.gru2_glu(gru2_state)
|
||||||
|
#gru2_out = torch.cat([gru2_out, fpitch], 1)
|
||||||
|
gru3_state = self.gru3(gru2_out, states[2])
|
||||||
gru3_out = self.gru3_glu(gru3_state)
|
gru3_out = self.gru3_glu(gru3_state)
|
||||||
|
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1)
|
||||||
sig_out = torch.tanh(self.sig_dense_out(gru3_out))
|
sig_out = torch.tanh(self.sig_dense_out(gru3_out))
|
||||||
if self.passthrough_size != 0:
|
if self.passthrough_size != 0:
|
||||||
passthrough = sig_out[:,self.subframe_size:]
|
passthrough = sig_out[:,self.subframe_size:]
|
||||||
sig_out = sig_out[:,:self.subframe_size]
|
sig_out = sig_out[:,:self.subframe_size]
|
||||||
dump_signal(sig_out, 'exc_out.f32')
|
dump_signal(sig_out, 'exc_out.f32')
|
||||||
|
taps = self.ptaps_dense(gru3_out)
|
||||||
|
taps = .2*taps + torch.exp(taps)
|
||||||
|
taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
|
||||||
|
dump_signal(taps, 'taps.f32')
|
||||||
|
fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
|
||||||
|
|
||||||
if self.has_gain:
|
if self.has_gain:
|
||||||
taps = self.ptaps_dense(gru3_out)
|
|
||||||
taps = .2*taps + torch.exp(taps)
|
|
||||||
taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
|
|
||||||
dump_signal(taps, 'taps.f32')
|
|
||||||
fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
|
|
||||||
pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
|
pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
|
||||||
dump_signal(pitch_gain, 'pgain.f32')
|
dump_signal(pitch_gain, 'pgain.f32')
|
||||||
sig_out = (sig_out + pitch_gain*fpitch) * gain
|
sig_out = (sig_out + pitch_gain*fpitch) * gain
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue