mirror of
https://github.com/xiph/opus.git
synced 2025-05-16 16:38:30 +00:00
15 lines
531 B
Python
15 lines
531 B
Python
import torch
|
|
from torch import nn
|
|
|
|
class DualFC(nn.Module):
|
|
def __init__(self, input_dim, output_dim):
|
|
super(DualFC, self).__init__()
|
|
|
|
self.dense1 = nn.Linear(input_dim, output_dim)
|
|
self.dense2 = nn.Linear(input_dim, output_dim)
|
|
|
|
self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
|
|
self.beta = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
|
|
|
|
def forward(self, x):
|
|
return self.alpha * torch.tanh(self.dense1(x)) + self.beta * torch.tanh(self.dense2(x))
|