opus/dnn/torch/osce/utils/layers/fir.py
2024-04-24 12:17:51 +02:00

27 lines
No EOL
754 B
Python

import numpy as np
import scipy.signal
import torch
from torch import nn
import torch.nn.functional as F
class FIR(nn.Module):
def __init__(self, numtaps, bands, desired, fs=2):
super().__init__()
if numtaps % 2 == 0:
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
numtaps += 1
a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
self.weight = torch.from_numpy(a.astype(np.float32))
def forward(self, x):
num_channels = x.size(1)
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
y = F.conv1d(x, weight, groups=num_channels)
return y