mirror of
https://github.com/xiph/opus.git
synced 2025-05-28 14:19:13 +00:00
added some bwe-related stuff
This commit is contained in:
parent
5667867fa2
commit
0dc559f060
3 changed files with 89 additions and 0 deletions
34
dnn/torch/osce/losses/td_lowpass.py
Normal file
34
dnn/torch/osce/losses/td_lowpass.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import torch
|
||||
import scipy.signal
|
||||
|
||||
|
||||
from utils.layers.fir import FIR
|
||||
|
||||
class TDLowpass(torch.nn.Module):
|
||||
def __init__(self, numtaps, cutoff, power=2):
|
||||
super().__init__()
|
||||
|
||||
self.b = scipy.signal.firwin(numtaps, cutoff)
|
||||
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
|
||||
self.power = power
|
||||
|
||||
def forward(self, y_true, y_pred):
|
||||
|
||||
assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
|
||||
|
||||
diff = y_true - y_pred
|
||||
diff_lp = torch.nn.functional.conv1d(diff, self.weight)
|
||||
|
||||
loss = torch.mean(torch.abs(diff_lp ** self.power))
|
||||
|
||||
return loss, diff_lp
|
||||
|
||||
def get_freqz(self):
|
||||
freq, response = scipy.signal.freqz(self.b)
|
||||
|
||||
return freq, response
|
||||
|
||||
|
||||
|
||||
|
||||
|
28
dnn/torch/osce/silk_16_to_48.py
Normal file
28
dnn/torch/osce/silk_16_to_48.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import argparse
|
||||
|
||||
from scipy.io import wavfile
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.silk_upsampler import SilkUpsampler
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, help="input wave file")
|
||||
parser.add_argument("output", type=str, help="output wave file")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
fs, x = wavfile.read(args.input)
|
||||
|
||||
# being lazy for now
|
||||
assert fs == 16000 and x.dtype == np.int16
|
||||
|
||||
x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)
|
||||
|
||||
upsampler = SilkUpsampler()
|
||||
y = upsampler(x)
|
||||
|
||||
y = y.squeeze().numpy().astype(np.int16)
|
||||
|
||||
wavfile.write(args.output, 48000, y[13:])
|
27
dnn/torch/osce/utils/layers/fir.py
Normal file
27
dnn/torch/osce/utils/layers/fir.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
import numpy as np
|
||||
import scipy.signal
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FIR(nn.Module):
|
||||
def __init__(self, numtaps, bands, desired, fs=2):
|
||||
super().__init__()
|
||||
|
||||
if numtaps % 2 == 0:
|
||||
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
|
||||
numtaps += 1
|
||||
|
||||
a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
|
||||
|
||||
self.weight = torch.from_numpy(a.astype(np.float32))
|
||||
|
||||
def forward(self, x):
|
||||
num_channels = x.size(1)
|
||||
|
||||
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
|
||||
|
||||
y = F.conv1d(x, weight, groups=num_channels)
|
||||
|
||||
return y
|
Loading…
Add table
Add a link
Reference in a new issue