mirror of
https://github.com/xiph/opus.git
synced 2025-05-29 22:57:41 +00:00
27 lines
No EOL
754 B
Python
27 lines
No EOL
754 B
Python
import numpy as np
|
|
import scipy.signal
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class FIR(nn.Module):
|
|
def __init__(self, numtaps, bands, desired, fs=2):
|
|
super().__init__()
|
|
|
|
if numtaps % 2 == 0:
|
|
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
|
|
numtaps += 1
|
|
|
|
a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
|
|
|
|
self.weight = torch.from_numpy(a.astype(np.float32))
|
|
|
|
def forward(self, x):
|
|
num_channels = x.size(1)
|
|
|
|
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
|
|
|
|
y = F.conv1d(x, weight, groups=num_channels)
|
|
|
|
return y |