mirror of
https://github.com/xiph/opus.git
synced 2025-05-14 23:48:28 +00:00
222 lines
8.8 KiB
Python
222 lines
8.8 KiB
Python
"""
|
|
/* Copyright (c) 2023 Amazon
|
|
Written by Jan Buethe */
|
|
/*
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions
|
|
are met:
|
|
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*/
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
from utils.endoscopy import write_data
|
|
|
|
class LimitedAdaptiveComb1d(nn.Module):
|
|
COUNTER = 1
|
|
|
|
def __init__(self,
|
|
kernel_size,
|
|
feature_dim,
|
|
frame_size=160,
|
|
overlap_size=40,
|
|
padding=None,
|
|
max_lag=256,
|
|
name=None,
|
|
gain_limit_db=10,
|
|
global_gain_limits_db=[-6, 6],
|
|
norm_p=2,
|
|
**kwargs):
|
|
"""
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
feature_dim : int
|
|
dimension of features from which kernels, biases and gains are computed
|
|
|
|
frame_size : int, optional
|
|
frame size, defaults to 160
|
|
|
|
overlap_size : int, optional
|
|
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40
|
|
|
|
use_bias : bool, optional
|
|
if true, biases will be added to output channels. Defaults to True
|
|
|
|
padding : List[int, int], optional
|
|
left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2]
|
|
|
|
max_lag : int, optional
|
|
maximal pitch lag, defaults to 256
|
|
|
|
have_a0 : bool, optional
|
|
If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False
|
|
|
|
name: str or None, optional
|
|
specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d
|
|
|
|
"""
|
|
|
|
super(LimitedAdaptiveComb1d, self).__init__()
|
|
|
|
self.in_channels = 1
|
|
self.out_channels = 1
|
|
self.feature_dim = feature_dim
|
|
self.kernel_size = kernel_size
|
|
self.frame_size = frame_size
|
|
self.overlap_size = overlap_size
|
|
self.max_lag = max_lag
|
|
self.limit_db = gain_limit_db
|
|
self.norm_p = norm_p
|
|
|
|
if name is None:
|
|
self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER)
|
|
LimitedAdaptiveComb1d.COUNTER += 1
|
|
else:
|
|
self.name = name
|
|
|
|
# network for generating convolution weights
|
|
self.conv_kernel = nn.Linear(feature_dim, kernel_size)
|
|
|
|
|
|
# comb filter gain
|
|
self.filter_gain = nn.Linear(feature_dim, 1)
|
|
self.log_gain_limit = gain_limit_db * 0.11512925464970229
|
|
with torch.no_grad():
|
|
self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
|
|
|
|
self.global_filter_gain = nn.Linear(feature_dim, 1)
|
|
log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
|
|
self.filter_gain_a = (log_max - log_min) / 2
|
|
self.filter_gain_b = (log_max + log_min) / 2
|
|
|
|
if type(padding) == type(None):
|
|
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
|
|
else:
|
|
self.padding = padding
|
|
|
|
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
|
|
|
|
def forward(self, x, features, lags, debug=False):
|
|
""" adaptive 1d convolution
|
|
|
|
|
|
Parameters:
|
|
-----------
|
|
x : torch.tensor
|
|
input signal of shape (batch_size, in_channels, num_samples)
|
|
|
|
feathres : torch.tensor
|
|
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
|
|
|
lags: torch.LongTensor
|
|
frame-wise lags for comb-filtering
|
|
|
|
"""
|
|
|
|
batch_size = x.size(0)
|
|
num_frames = features.size(1)
|
|
num_samples = x.size(2)
|
|
frame_size = self.frame_size
|
|
overlap_size = self.overlap_size
|
|
kernel_size = self.kernel_size
|
|
win1 = torch.flip(self.overlap_win, [0])
|
|
win2 = self.overlap_win
|
|
|
|
if num_samples // self.frame_size != num_frames:
|
|
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
|
|
|
|
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
|
|
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
|
|
|
|
conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
|
|
# calculate gains
|
|
global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
|
|
|
|
if debug and batch_size == 1:
|
|
key = self.name + "_gains"
|
|
write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
|
key = self.name + "_kernels"
|
|
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
|
key = self.name + "_lags"
|
|
write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
|
key = self.name + "_global_conv_gains"
|
|
write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
|
|
|
|
|
# frame-wise convolution with overlap-add
|
|
output_frames = []
|
|
overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
|
|
x = F.pad(x, self.padding)
|
|
x = F.pad(x, [self.max_lag, self.overlap_size])
|
|
|
|
idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1)
|
|
idx = torch.repeat_interleave(idx, batch_size, 0)
|
|
idx = torch.repeat_interleave(idx, self.in_channels, 1)
|
|
|
|
|
|
for i in range(num_frames):
|
|
|
|
cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1)
|
|
xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1))
|
|
|
|
new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
|
|
|
|
offset = self.max_lag + self.padding[0]
|
|
new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
|
|
|
|
# overlapping part
|
|
output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
|
|
|
|
# non-overlapping part
|
|
output_frames.append(new_chunk[:, :, overlap_size : frame_size])
|
|
|
|
# mem for next frame
|
|
overlap_mem = new_chunk[:, :, frame_size :]
|
|
|
|
# concatenate chunks
|
|
output = torch.cat(output_frames, dim=-1)
|
|
|
|
return output
|
|
|
|
def flop_count(self, rate):
|
|
frame_rate = rate / self.frame_size
|
|
overlap = self.overlap_size
|
|
overhead = overlap / self.frame_size
|
|
|
|
count = 0
|
|
|
|
# kernel computation and filtering
|
|
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
|
|
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
|
|
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
|
|
|
|
# a0 computation
|
|
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
|
|
|
|
# windowing
|
|
count += overlap * frame_rate * 3 * self.out_channels
|
|
|
|
return count
|