""" This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """ import torch from torch import nn import torch.nn.functional as F import numpy as np frac_fir = np.array( [ [189, -600, 617, 30567, 2996, -1375, 425, -46], [117, -159, -1070, 29704, 5784, -2143, 611, -71], [52, 221, -2392, 28276, 8798, -2865, 773, -91], [-4, 529, -3350, 26341, 11950, -3487, 896, -103], [-48, 758, -3956, 23973, 15143, -3957, 967, -107], [-80, 905, -4235, 21254, 18278, -4222, 972, -99], [-99, 972, -4222, 18278, 21254, -4235, 905, -80], [-107, 967, -3957, 15143, 23973, -3956, 758, -48], [-103, 896, -3487, 11950, 26341, -3350, 529, -4], [-91, 773, -2865, 8798, 28276, -2392, 221, 52], [-71, 611, -2143, 5784, 29704, -1070, -159, 117], [-46, 425, -1375, 2996, 30567, 617, -600, 189] ], dtype=np.float32 ) / 2**15 hq_2x_up_c_even = [x / 2**16 for x in [1746, 14986, 39083 - 65536]] hq_2x_up_c_odd = [x / 2**16 for x in [6854, 25769, 55542 - 65536]] def get_impz(coeffs, n): s = 3*[0] y = np.zeros(n) x = 1 for i in range(n): Y = x - s[0] X = Y * coeffs[0] tmp1 = s[0] + X s[0] = x + X Y = tmp1 - s[1] X = Y * coeffs[1] tmp2 = s[1] + X s[1] = tmp1 + X Y = tmp2 - s[2] X = Y * (1 + coeffs[2]) tmp3 = s[2] + X s[2] = tmp2 + X y[i] = tmp3 x = 0 return y class SilkUpsampler(nn.Module): SUPPORTED_TARGET_RATES = {24000, 48000} SUPPORTED_SOURCE_RATES = {16000} def __init__(self, fs_in=16000, fs_out=48000): super().__init__() self.fs_in = fs_in self.fs_out = fs_out if fs_in not in self.SUPPORTED_SOURCE_RATES: raise ValueError(f'SilkUpsampler currently only supports upsampling from {self.SUPPORTED_SOURCE_RATES} Hz') if fs_out not in self.SUPPORTED_TARGET_RATES: raise ValueError(f'SilkUpsampler currently only supports upsampling to {self.SUPPORTED_TARGET_RATES} Hz') # hq 2x upsampler as FIR approximation hq_2x_up_even = get_impz(hq_2x_up_c_even, 128)[::-1].copy() hq_2x_up_odd = get_impz(hq_2x_up_c_odd , 128)[::-1].copy() self.hq_2x_up_even = nn.Parameter(torch.from_numpy(hq_2x_up_even).float().view(1, 1, -1), requires_grad=False) self.hq_2x_up_odd = nn.Parameter(torch.from_numpy(hq_2x_up_odd ).float().view(1, 1, -1), requires_grad=False) self.hq_2x_up_padding = [127, 0] # interpolation filters frac_01_24 = frac_fir[0] frac_17_24 = frac_fir[8] frac_09_24 = frac_fir[4] self.frac_01_24 = nn.Parameter(torch.from_numpy(frac_01_24).view(1, 1, -1), requires_grad=False) self.frac_17_24 = nn.Parameter(torch.from_numpy(frac_17_24).view(1, 1, -1), requires_grad=False) self.frac_09_24 = nn.Parameter(torch.from_numpy(frac_09_24).view(1, 1, -1), requires_grad=False) self.stride = 1 if fs_out == 48000 else 2 def hq_2x_up(self, x): num_channels = x.size(1) weight_even = torch.repeat_interleave(self.hq_2x_up_even, num_channels, 0) weight_odd = torch.repeat_interleave(self.hq_2x_up_odd , num_channels, 0) x_pad = F.pad(x, self.hq_2x_up_padding) y_even = F.conv1d(x_pad, weight_even, groups=num_channels) y_odd = F.conv1d(x_pad, weight_odd , groups=num_channels) y = torch.cat((y_even.unsqueeze(-1), y_odd.unsqueeze(-1)), dim=-1).flatten(2) return y def interpolate_3_2(self, x): num_channels = x.size(1) weight_01_24 = torch.repeat_interleave(self.frac_01_24, num_channels, 0) weight_17_24 = torch.repeat_interleave(self.frac_17_24, num_channels, 0) weight_09_24 = torch.repeat_interleave(self.frac_09_24, num_channels, 0) x_pad = F.pad(x, [8, 0]) y_01_24 = F.conv1d(x_pad, weight_01_24, stride=2, groups=num_channels) y_17_24 = F.conv1d(x_pad, weight_17_24, stride=2, groups=num_channels) y_09_24_sh1 = F.conv1d(torch.roll(x_pad, -1, -1), weight_09_24, stride=2, groups=num_channels) y = torch.cat( (y_01_24.unsqueeze(-1), y_17_24.unsqueeze(-1), y_09_24_sh1.unsqueeze(-1)), dim=-1).flatten(2) return y[..., :-3] def forward(self, x): y_2x = self.hq_2x_up(x) y_3x = self.interpolate_3_2(y_2x) return y_3x[:, :, ::self.stride]