Updated LACE and NoLACE models to version 2

This commit is contained in:
Jan Buethe 2023-12-18 12:19:55 +01:00
parent 4f311a1ad4
commit 299e38cab7
No known key found for this signature in database
GPG key ID: 9E32027A35B36314
57 changed files with 4793 additions and 109 deletions

View file

@ -9,7 +9,7 @@ set -e
srcdir=`dirname $0`
test -n "$srcdir" && cd "$srcdir"
dnn/download_model.sh caca188
dnn/download_model.sh 88477f4
echo "Updating build configuration files, please wait...."

View file

@ -340,7 +340,8 @@ void adashape_process_frame(
float *x_out,
const float *x_in,
const float *features,
const LinearLayer *alpha1,
const LinearLayer *alpha1f,
const LinearLayer *alpha1t,
const LinearLayer *alpha2,
int feature_dim,
int frame_size,
@ -350,6 +351,7 @@ void adashape_process_frame(
{
float in_buffer[ADASHAPE_MAX_INPUT_DIM + ADASHAPE_MAX_FRAME_SIZE];
float out_buffer[ADASHAPE_MAX_FRAME_SIZE];
float tmp_buffer[ADASHAPE_MAX_FRAME_SIZE];
int i, k;
int tenv_size;
float mean;
@ -389,14 +391,16 @@ void adashape_process_frame(
#ifdef DEBUG_NNDSP
print_float_vector("alpha1_in", in_buffer, feature_dim + tenv_size + 1);
#endif
compute_generic_conv1d(alpha1, out_buffer, hAdaShape->conv_alpha1_state, in_buffer, feature_dim + tenv_size + 1, ACTIVATION_LINEAR, arch);
compute_generic_conv1d(alpha1f, out_buffer, hAdaShape->conv_alpha1f_state, in_buffer, feature_dim, ACTIVATION_LINEAR, arch);
compute_generic_conv1d(alpha1t, tmp_buffer, hAdaShape->conv_alpha1t_state, tenv, tenv_size + 1, ACTIVATION_LINEAR, arch);
#ifdef DEBUG_NNDSP
print_float_vector("alpha1_out", out_buffer, frame_size);
#endif
/* compute leaky ReLU by hand. ToDo: try tanh activation */
for (i = 0; i < frame_size; i ++)
{
in_buffer[i] = out_buffer[i] >= 0 ? out_buffer[i] : 0.2f * out_buffer[i];
float tmp = out_buffer[i] + tmp_buffer[i];
in_buffer[i] = tmp >= 0 ? tmp : 0.2 * tmp;
}
#ifdef DEBUG_NNDSP
print_float_vector("post_alpha1", in_buffer, frame_size);

View file

@ -71,7 +71,8 @@ typedef struct {
typedef struct {
float conv_alpha1_state[ADASHAPE_MAX_INPUT_DIM];
float conv_alpha1f_state[ADASHAPE_MAX_INPUT_DIM];
float conv_alpha1t_state[ADASHAPE_MAX_INPUT_DIM];
float conv_alpha2_state[ADASHAPE_MAX_FRAME_SIZE];
} AdaShapeState;
@ -130,7 +131,8 @@ void adashape_process_frame(
float *x_out,
const float *x_in,
const float *features,
const LinearLayer *alpha1,
const LinearLayer *alpha1f,
const LinearLayer *alpha1t,
const LinearLayer *alpha2,
int feature_dim,
int frame_size,

View file

@ -155,7 +155,7 @@ static void lace_feature_net(
&hLACE->layers.lace_fnet_tconv,
output_buffer,
input_buffer,
ACTIVATION_LINEAR,
ACTIVATION_TANH,
arch
);
@ -426,7 +426,7 @@ static void nolace_feature_net(
&hNoLACE->layers.nolace_fnet_tconv,
output_buffer,
input_buffer,
ACTIVATION_LINEAR,
ACTIVATION_TANH,
arch
);
@ -633,7 +633,8 @@ static void nolace_process_20ms_frame(
x_buffer2 + i_subframe * NOLACE_AF1_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
x_buffer2 + i_subframe * NOLACE_AF1_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
feature_buffer + i_subframe * NOLACE_COND_DIM,
&layers->nolace_tdshape1_alpha1,
&layers->nolace_tdshape1_alpha1_f,
&layers->nolace_tdshape1_alpha1_t,
&layers->nolace_tdshape1_alpha2,
NOLACE_TDSHAPE1_FEATURE_DIM,
NOLACE_TDSHAPE1_FRAME_SIZE,
@ -688,7 +689,8 @@ static void nolace_process_20ms_frame(
x_buffer1 + i_subframe * NOLACE_AF2_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
x_buffer1 + i_subframe * NOLACE_AF2_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
feature_buffer + i_subframe * NOLACE_COND_DIM,
&layers->nolace_tdshape2_alpha1,
&layers->nolace_tdshape2_alpha1_f,
&layers->nolace_tdshape2_alpha1_t,
&layers->nolace_tdshape2_alpha2,
NOLACE_TDSHAPE2_FEATURE_DIM,
NOLACE_TDSHAPE2_FRAME_SIZE,
@ -739,7 +741,8 @@ static void nolace_process_20ms_frame(
x_buffer2 + i_subframe * NOLACE_AF3_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
x_buffer2 + i_subframe * NOLACE_AF3_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
feature_buffer + i_subframe * NOLACE_COND_DIM,
&layers->nolace_tdshape3_alpha1,
&layers->nolace_tdshape3_alpha1_f,
&layers->nolace_tdshape3_alpha1_t,
&layers->nolace_tdshape3_alpha2,
NOLACE_TDSHAPE3_FEATURE_DIM,
NOLACE_TDSHAPE3_FRAME_SIZE,
@ -884,7 +887,7 @@ int osce_load_models(OSCEModel *model, const unsigned char *data, int len)
if (ret == 0) {ret = init_lace(&model->lace, list);}
#endif
#ifndef DISABLE_LACE
#ifndef DISABLE_NOLACE
if (ret == 0) {ret = init_nolace(&model->nolace, list);}
#endif
@ -898,7 +901,7 @@ int osce_load_models(OSCEModel *model, const unsigned char *data, int len)
if (ret == 0) {ret = init_lace(&model->lace, lacelayers_arrays);}
#endif
#ifndef DISABLE_LACE
#ifndef DISABLE_NOLACE
if (ret == 0) {ret = init_nolace(&model->nolace, nolacelayers_arrays);}
#endif

View file

@ -41,7 +41,7 @@
#define OSCE_PREEMPH 0.85f
#define OSCE_PITCH_HANGOVER 8
#define OSCE_PITCH_HANGOVER 0
#define OSCE_CLEAN_SPEC_START 0
#define OSCE_CLEAN_SPEC_LENGTH 64

View file

@ -296,6 +296,7 @@ static void calculate_acorr(float *acorr, float *signal, int lag)
static int pitch_postprocessing(OSCEFeatureState *psFeatures, int lag, int type)
{
int new_lag;
int modulus;
#ifdef OSCE_HANGOVER_BUGFIX
#define TESTBIT 1
@ -303,6 +304,9 @@ static int pitch_postprocessing(OSCEFeatureState *psFeatures, int lag, int type)
#define TESTBIT 0
#endif
modulus = OSCE_PITCH_HANGOVER;
if (modulus == 0) modulus ++;
/* hangover is currently disabled to reflect a bug in the python code. ToDo: re-evaluate hangover */
if (type != TYPE_VOICED && psFeatures->last_type == TYPE_VOICED && TESTBIT)
/* enter hangover */
@ -311,14 +315,14 @@ static int pitch_postprocessing(OSCEFeatureState *psFeatures, int lag, int type)
if (psFeatures->pitch_hangover_count < OSCE_PITCH_HANGOVER)
{
new_lag = psFeatures->last_lag;
psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % OSCE_PITCH_HANGOVER;
psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % modulus;
}
}
else if (type != TYPE_VOICED && psFeatures->pitch_hangover_count && TESTBIT)
/* continue hangover */
{
new_lag = psFeatures->last_lag;
psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % OSCE_PITCH_HANGOVER;
psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % modulus;
}
else if (type != TYPE_VOICED)
/* unvoiced frame after hangover */
@ -376,11 +380,7 @@ void osce_calculate_features(
/* smooth bit count */
psFeatures->numbits_smooth = 0.9f * psFeatures->numbits_smooth + 0.1f * num_bits;
numbits[0] = num_bits;
#ifdef OSCE_NUMBITS_BUGFIX
numbits[1] = psFeatures->numbits_smooth;
#else
numbits[1] = num_bits;
#endif
for (n = 0; n < num_samples; n++)
{

View file

@ -0,0 +1,2 @@
from . import quantization
from . import sparsification

View file

@ -0,0 +1 @@
from .softquant import soft_quant, remove_soft_quant

View file

@ -0,0 +1,113 @@
import torch
@torch.no_grad()
def compute_optimal_scale(weight):
with torch.no_grad():
n_out, n_in = weight.shape
assert n_in % 4 == 0
if n_out % 8:
# add padding
pad = n_out - n_out % 8
weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
scale_max = weight_max_abs / 127
scale_sum = weight_max_sum / 129
scale = torch.maximum(scale_max, scale_sum)
return scale[:n_out]
@torch.no_grad()
def q_scaled_noise(module, weight):
if isinstance(module, torch.nn.Conv1d):
w = weight.permute(0, 2, 1).flatten(1)
noise = torch.rand_like(w) - 0.5
noise[w == 0] = 0 # ignore zero entries from sparsification
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
elif isinstance(module, torch.nn.ConvTranspose1d):
i, o, k = weight.shape
w = weight.permute(2, 1, 0).reshape(k * o, i)
noise = torch.rand_like(w) - 0.5
noise[w == 0] = 0 # ignore zero entries from sparsification
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(k, o, i).permute(2, 1, 0)
elif len(weight.shape) == 2:
noise = torch.rand_like(weight) - 0.5
noise[weight == 0] = 0 # ignore zero entries from sparsification
scale = compute_optimal_scale(weight)
noise = noise * scale.unsqueeze(-1)
else:
raise ValueError('unknown quantization setting')
return noise
class SoftQuant:
name: str
def __init__(self, names: str, scale: float) -> None:
self.names = names
self.quantization_noise = None
self.scale = scale
def __call__(self, module, inputs, *args, before=True):
if not module.training: return
if before:
self.quantization_noise = dict()
for name in self.names:
weight = getattr(module, name)
if self.scale is None:
self.quantization_noise[name] = q_scaled_noise(module, weight)
else:
self.quantization_noise[name] = \
self.scale * (torch.rand_like(weight) - 0.5)
with torch.no_grad():
weight.data[:] = weight + self.quantization_noise[name]
else:
for name in self.names:
weight = getattr(module, name)
with torch.no_grad():
weight.data[:] = weight - self.quantization_noise[name]
self.quantization_noise = None
def apply(module, names=['weight'], scale=None):
fn = SoftQuant(names, scale)
for name in names:
if not hasattr(module, name):
raise ValueError("")
fn_before = lambda *x : fn(*x, before=True)
fn_after = lambda *x : fn(*x, before=False)
setattr(fn_before, 'sqm', fn)
setattr(fn_after, 'sqm', fn)
module.register_forward_pre_hook(fn_before)
module.register_forward_hook(fn_after)
module
return fn
def soft_quant(module, names=['weight'], scale=None):
fn = SoftQuant.apply(module, names, scale)
return module
def remove_soft_quant(module, names=['weight']):
for k, hook in module._forward_pre_hooks.items():
if hasattr(hook, 'sqm'):
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
del module._forward_pre_hooks[k]
for k, hook in module._forward_hooks.items():
if hasattr(hook, 'sqm'):
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
del module._forward_hooks[k]
return module

View file

@ -0,0 +1,2 @@
from .relegance import relegance_gradient_weighting, relegance_create_tconv_kernel, relegance_map_relevance_to_input_domain, relegance_resize_relevance_to_input_size
from .meta_critic import MetaCritic

View file

@ -0,0 +1,85 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
class MetaCritic():
def __init__(self, normalize=False, gamma=0.9, beta=0.0, joint_stats=False):
""" Class for assessing relevance of discriminator scores
Args:
gamma (float, optional): update rate for tracking discriminator stats. Defaults to 0.9.
beta (float, optional): Miminum confidence related threshold. Defaults to 0.0.
"""
self.normalize = normalize
self.gamma = gamma
self.beta = beta
self.joint_stats = joint_stats
self.disc_stats = dict()
def __call__(self, disc_id, real_scores, generated_scores):
""" calculates relevance from normalized scores
Args:
disc_id (any valid key): id for tracking discriminator statistics
real_scores (torch.tensor): scores for real data
generated_scores (torch.tensor): scores for generated data; expecting device to match real_scores.device
Returns:
torch.tensor: output-domain relevance
"""
if self.normalize:
real_std = torch.std(real_scores.detach()).cpu().item()
gen_std = torch.std(generated_scores.detach()).cpu().item()
std = (real_std**2 + gen_std**2) ** .5
mean = torch.mean(real_scores.detach()).cpu().item() - torch.mean(generated_scores.detach()).cpu().item()
key = 0 if self.joint_stats else disc_id
if key in self.disc_stats:
self.disc_stats[key]['std'] = self.gamma * self.disc_stats[key]['std'] + (1 - self.gamma) * std
self.disc_stats[key]['mean'] = self.gamma * self.disc_stats[key]['mean'] + (1 - self.gamma) * mean
else:
self.disc_stats[key] = {
'std': std + 1e-5,
'mean': mean
}
std = self.disc_stats[key]['std']
mean = self.disc_stats[key]['mean']
else:
mean, std = 0, 1
relevance = torch.relu((real_scores - generated_scores - mean) / std + mean - self.beta)
if False: print(f"relevance({disc_id}): {relevance.min()=} {relevance.max()=} {relevance.mean()=}")
return relevance

View file

@ -0,0 +1,449 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
import torch.nn.functional as F
def view_one_hot(index, length):
vec = length * [1]
vec[index] = -1
return vec
def create_smoothing_kernel(widths, gamma=1.5):
""" creates a truncated gaussian smoothing kernel for the given widths
Parameters:
-----------
widths: list[Int] or torch.LongTensor
specifies the shape of the smoothing kernel, entries must be > 0.
gamma: float, optional
decay factor for gaussian relative to kernel size
Returns:
--------
kernel: torch.FloatTensor
"""
widths = torch.LongTensor(widths)
num_dims = len(widths)
assert(widths.min() > 0)
centers = widths.float() / 2 - 0.5
sigmas = gamma * (centers + 1)
vals = []
vals= [((torch.arange(widths[i]) - centers[i]) / sigmas[i]) ** 2 for i in range(num_dims)]
vals = sum([vals[i].view(view_one_hot(i, num_dims)) for i in range(num_dims)])
kernel = torch.exp(- vals)
kernel = kernel / kernel.sum()
return kernel
def create_partition_kernel(widths, strides):
""" creates a partition kernel for mapping a convolutional network output back to the input domain
Given a fully convolutional network with receptive field of shape widths and the given strides, this
function construncts an intorpolation kernel whose tranlations by multiples of the given strides form
a partition of one on the input domain.
Parameter:
----------
widths: list[Int] or torch.LongTensor
shape of receptive field
strides: list[Int] or torch.LongTensor
total strides of convolutional network
Returns:
kernel: torch.FloatTensor
"""
num_dims = len(widths)
assert num_dims == len(strides) and num_dims in {1, 2, 3}
convs = {1 : F.conv1d, 2 : F.conv2d, 3 : F.conv3d}
widths = torch.LongTensor(widths)
strides = torch.LongTensor(strides)
proto_kernel = torch.ones(torch.minimum(strides, widths).tolist())
# create interpolation kernel eta
eta_widths = widths - strides + 1
if eta_widths.min() <= 0:
print("[create_partition_kernel] warning: receptive field does not cover input domain")
eta_widths = torch.maximum(eta_widths, torch.ones_like(eta_widths))
eta = create_smoothing_kernel(eta_widths).view(1, 1, *eta_widths.tolist())
padding = torch.repeat_interleave(eta_widths - 1, 2, 0).tolist()[::-1] # ordering of dimensions for padding and convolution functions is reversed in torch
padded_proto_kernel = F.pad(proto_kernel, padding)
padded_proto_kernel = padded_proto_kernel.view(1, 1, *padded_proto_kernel.shape)
kernel = convs[num_dims](padded_proto_kernel, eta)
return kernel
def receptive_field(conv_model, input_shape, output_position):
""" estimates boundaries of receptive field connected to output_position via autograd
Parameters:
-----------
conv_model: nn.Module or autograd function
function or model implementing fully convolutional model
input_shape: List[Int]
input shape ignoring batch dimension, i.e. [num_channels, dim1, dim2, ...]
output_position: List[Int]
output position for which the receptive field is determined; the function raises an exception
if output_position is out of bounds for the given input_shape.
Returns:
--------
low: List[Int]
start indices of receptive field
high: List[Int]
stop indices of receptive field
"""
x = torch.randn((1,) + tuple(input_shape), requires_grad=True)
y = conv_model(x)
# collapse channels and remove batch dimension
y = torch.sum(y, 1)[0]
# create mask
mask = torch.zeros_like(y)
index = [torch.tensor(i) for i in output_position]
try:
mask.index_put_(index, torch.tensor(1, dtype=mask.dtype))
except IndexError:
raise ValueError('output_position out of bounds')
(mask * y).sum().backward()
# sum over channels and remove batch dimension
grad = torch.sum(x.grad, dim=1)[0]
tmp = torch.nonzero(grad, as_tuple=True)
low = [t.min().item() for t in tmp]
high = [t.max().item() for t in tmp]
return low, high
def estimate_conv_parameters(model, num_channels, num_dims, width, max_stride=10):
""" attempts to estimate receptive field size, strides and left paddings for given model
Parameters:
-----------
model: nn.Module or autograd function
fully convolutional model for which parameters are estimated
num_channels: Int
number of input channels for model
num_dims: Int
number of input dimensions for model (without channel dimension)
width: Int
width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
max_stride: Int, optional
assumed maximal stride of the model for any dimension, when set too low the function may fail for
any value of width
Returns:
--------
receptive_field_size: List[Int]
receptive field size in all dimension
strides: List[Int]
stride in all dimensions
left_paddings: List[Int]
left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
Raises:
-------
ValueError, KeyError
"""
input_shape = [num_channels] + num_dims * [width]
output_position1 = num_dims * [width // (2 * max_stride)]
output_position2 = num_dims * [width // (2 * max_stride) + 1]
low1, high1 = receptive_field(model, input_shape, output_position1)
low2, high2 = receptive_field(model, input_shape, output_position2)
widths1 = [h - l + 1 for l, h in zip(low1, high1)]
widths2 = [h - l + 1 for l, h in zip(low2, high2)]
if not all([w1 - w2 == 0 for w1, w2 in zip(widths1, widths2)]) or not all([l1 != l2 for l1, l2 in zip(low1, low2)]):
raise ValueError("[estimate_strides]: widths to small to determine strides")
receptive_field_size = widths1
strides = [l2 - l1 for l1, l2 in zip(low1, low2)]
left_paddings = [s * p - l for l, s, p in zip(low1, strides, output_position1)]
return receptive_field_size, strides, left_paddings
def inspect_conv_model(model, num_channels, num_dims, max_width=10000, width_hint=None, stride_hint=None, verbose=False):
""" determines size of receptive field, strides and padding probabilistically
Parameters:
-----------
model: nn.Module or autograd function
fully convolutional model for which parameters are estimated
num_channels: Int
number of input channels for model
num_dims: Int
number of input dimensions for model (without channel dimension)
max_width: Int
maximum width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
verbose: bool, optional
if true, the function prints parameters for individual trials
Returns:
--------
receptive_field_size: List[Int]
receptive field size in all dimension
strides: List[Int]
stride in all dimensions
left_paddings: List[Int]
left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
Raises:
-------
ValueError
"""
max_stride = max_width // 2
stride = max_stride // 100
width = max_width // 100
if width_hint is not None: width = 2 * width_hint
if stride_hint is not None: stride = stride_hint
did_it = False
while width < max_width and stride < max_stride:
try:
if verbose: print(f"[inspect_conv_model] trying parameters {width=}, {stride=}")
receptive_field_size, strides, left_paddings = estimate_conv_parameters(model, num_channels, num_dims, width, stride)
did_it = True
except:
pass
if did_it: break
width *= 2
if width >= max_width and stride < max_stride:
stride *= 2
width = 2 * stride
if not did_it:
raise ValueError(f'could not determine conv parameter with given max_width={max_width}')
return receptive_field_size, strides, left_paddings
class GradWeight(torch.autograd.Function):
def __init__(self):
super().__init__()
@staticmethod
def forward(ctx, x, weight):
ctx.save_for_backward(weight)
return x.clone()
@staticmethod
def backward(ctx, grad_output):
weight, = ctx.saved_tensors
grad_input = grad_output * weight
return grad_input, None
# API
def relegance_gradient_weighting(x, weight):
"""
Args:
x (torch.tensor): input tensor
weight (torch.tensor or None): weight tensor for gradients of x; if None, no gradient weighting will be applied in backward pass
Returns:
torch.tensor: the unmodified input tensor x
Raises:
RuntimeError: if estimation of parameters fails due to exceeded compute budget
"""
if weight is None:
return x
else:
return GradWeight.apply(x, weight)
def relegance_create_tconv_kernel(model, num_channels, num_dims, width_hint=None, stride_hint=None, verbose=False):
""" creates parameters for mapping back output domain relevance to input tomain
Args:
model (nn.Module or autograd.Function): fully convolutional model
num_channels (int): number of input channels to model
num_dims (int): number of input dimensions of model (without channel and batch dimension)
width_hint(int or None): optional hint at maximal width of receptive field
stride_hint(int or None): optional hint at maximal stride
Returns:
dict: contains kernel, kernel dimensions, strides and left paddings for transposed convolution
"""
max_width = int(100000 / (10 ** num_dims))
did_it = False
try:
receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
did_it = True
except:
# try once again with larger max_width
max_width *= 10
# crash if exception is raised
try:
if not did_it: receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
except:
raise RuntimeError("could not determine parameters within given compute budget")
partition_kernel = create_partition_kernel(receptive_field_size, strides)
partition_kernel = torch.repeat_interleave(partition_kernel, num_channels, 1)
tconv_parameters = {
'kernel': partition_kernel,
'receptive_field_shape': receptive_field_size,
'stride': strides,
'left_padding': left_paddings,
'num_dims': num_dims
}
return tconv_parameters
def relegance_map_relevance_to_input_domain(od_relevance, tconv_parameters):
""" maps output-domain relevance to input-domain relevance via transpose convolution
Args:
od_relevance (torch.tensor): output-domain relevance
tconv_parameters (dict): parameter dict as created by relegance_create_tconv_kernel
Returns:
torch.tensor: input-domain relevance. The tensor is left aligned, i.e. the all-zero index of the output corresponds to the all-zero index of the discriminator input.
Otherwise, the size of the output tensor does not need to match the size of the discriminator input. Use relegance_resize_relevance_to_input_size for a
convenient way to adjust the output to the correct size.
Raises:
ValueError: if number of dimensions is not supported
"""
kernel = tconv_parameters['kernel'].to(od_relevance.device)
rf_shape = tconv_parameters['receptive_field_shape']
stride = tconv_parameters['stride']
left_padding = tconv_parameters['left_padding']
num_dims = len(kernel.shape) - 2
# repeat boundary values
od_padding = [rf_shape[i//2] // stride[i//2] + 1 for i in range(2 * num_dims)]
padded_od_relevance = F.pad(od_relevance, od_padding[::-1], mode='replicate')
od_padding = od_padding[::2]
# apply mapping and left trimming
if num_dims == 1:
id_relevance = F.conv_transpose1d(padded_od_relevance, kernel, stride=stride)
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :]
elif num_dims == 2:
id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:]
elif num_dims == 3:
id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:, left_padding[2] + stride[2] * od_padding[2] :]
else:
raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
return id_relevance
def relegance_resize_relevance_to_input_size(reference_input, relevance):
""" adjusts size of relevance tensor to reference input size
Args:
reference_input (torch.tensor): discriminator input tensor for reference
relevance (torch.tensor): input-domain relevance corresponding to input tensor reference_input
Returns:
torch.tensor: resized relevance
Raises:
ValueError: if number of dimensions is not supported
"""
resized_relevance = torch.zeros_like(reference_input)
num_dims = len(reference_input.shape) - 2
with torch.no_grad():
if num_dims == 1:
resized_relevance[:] = relevance[..., : min(reference_input.size(-1), relevance.size(-1))]
elif num_dims == 2:
resized_relevance[:] = relevance[..., : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
elif num_dims == 3:
resized_relevance[:] = relevance[..., : min(reference_input.size(-3), relevance.size(-3)), : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
else:
raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
return resized_relevance

View file

@ -0,0 +1,6 @@
from .gru_sparsifier import GRUSparsifier
from .conv1d_sparsifier import Conv1dSparsifier
from .conv_transpose1d_sparsifier import ConvTranspose1dSparsifier
from .linear_sparsifier import LinearSparsifier
from .common import sparsify_matrix, calculate_gru_flops_per_step
from .utils import mark_for_sparsification, create_sparsifier

View file

@ -0,0 +1,58 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
class BaseSparsifier:
def __init__(self, task_list, start, stop, interval, exponent=3):
# just copying parameters...
self.start = start
self.stop = stop
self.interval = interval
self.exponent = exponent
self.task_list = task_list
# ... and setting counter to 0
self.step_counter = 0
def step(self, verbose=False):
# compute current interpolation factor
self.step_counter += 1
if self.step_counter < self.start:
return
elif self.step_counter < self.stop:
# update only every self.interval-th interval
if self.step_counter % self.interval:
return
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
else:
alpha = 0
self.sparsify(alpha, verbose=verbose)

View file

@ -0,0 +1,123 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
debug=True
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
""" sparsifies matrix with specified block size
Parameters:
-----------
matrix : torch.tensor
matrix to sparsify
density : int
target density
block_size : [int, int]
block size dimensions
keep_diagonal : bool
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
"""
m, n = matrix.shape
m1, n1 = block_size
if m % m1 or n % n1:
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
# extract diagonal if keep_diagonal = True
if keep_diagonal:
if m != n:
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
to_spare = torch.diag(torch.diag(matrix))
matrix = matrix - to_spare
else:
to_spare = torch.zeros_like(matrix)
# calculate energy in sub-blocks
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
x = x ** 2
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
number_of_blocks = (m * n) // (m1 * n1)
number_of_survivors = round(number_of_blocks * density)
# masking threshold
if number_of_survivors == 0:
threshold = 0
else:
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
# create mask
mask = torch.ones_like(block_energies)
mask[block_energies < threshold] = 0
mask = torch.repeat_interleave(mask, m1, dim=0)
mask = torch.repeat_interleave(mask, n1, dim=1)
# perform masking
masked_matrix = mask * matrix + to_spare
if return_mask:
return masked_matrix, mask
else:
return masked_matrix
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
input_size = gru.input_size
hidden_size = gru.hidden_size
flops = 0
input_density = (
sparsification_dict.get('W_ir', [1])[0]
+ sparsification_dict.get('W_in', [1])[0]
+ sparsification_dict.get('W_iz', [1])[0]
) / 3
recurrent_density = (
sparsification_dict.get('W_hr', [1])[0]
+ sparsification_dict.get('W_hn', [1])[0]
+ sparsification_dict.get('W_hz', [1])[0]
) / 3
# input matrix vector multiplications
if not drop_input:
flops += 2 * 3 * input_size * hidden_size * input_density
# recurrent matrix vector multiplications
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
# biases
flops += 6 * hidden_size
# activations estimated by 10 flops per activation
flops += 30 * hidden_size
return flops

View file

@ -0,0 +1,133 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix, debug
class Conv1dSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (conv1d, params), where conv1d is an instance
of torch.nn.Conv1d and params is a tuple (density, [m, n]),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> conv = torch.nn.Conv1d(8, 16, 8)
>>> params = (0.2, [8, 4])
>>> sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
super().__init__(task_list, start, stop, interval, exponent=3)
self.last_mask = None
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
with torch.no_grad():
for conv, params in self.task_list:
# reshape weight
if hasattr(conv, 'weight_v'):
weight = conv.weight_v
else:
weight = conv.weight
i, o, k = weight.shape
w = weight.permute(0, 2, 1).flatten(1)
target_density, block_size = params
density = alpha + (1 - alpha) * target_density
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
w = w.reshape(i, k, o).permute(0, 2, 1)
weight[:] = w
if self.last_mask is not None:
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
print("weight resurrection in conv.weight")
self.last_mask = new_mask
if verbose:
print(f"conv1d_sparsier[{self.step_counter}]: {density=}")
if __name__ == "__main__":
print("Testing sparsifier")
import torch
conv = torch.nn.Conv1d(8, 16, 8)
params = (0.2, [8, 4])
sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 5)
for i in range(100):
sparsifier.step(verbose=True)
print(conv.weight)

View file

@ -0,0 +1,134 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix, debug
class ConvTranspose1dSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (conv1d, params), where conv1d is an instance
of torch.nn.Conv1d and params is a tuple (density, [m, n]),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> conv = torch.nn.ConvTranspose1d(8, 16, 8)
>>> params = (0.2, [8, 4])
>>> sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
super().__init__(task_list, start, stop, interval, exponent=3)
self.last_mask = None
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
with torch.no_grad():
for conv, params in self.task_list:
# reshape weight
if hasattr(conv, 'weight_v'):
weight = conv.weight_v
else:
weight = conv.weight
i, o, k = weight.shape
w = weight.permute(2, 1, 0).reshape(k * o, i)
target_density, block_size = params
density = alpha + (1 - alpha) * target_density
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
w = w.reshape(k, o, i).permute(2, 1, 0)
weight[:] = w
if self.last_mask is not None:
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
print("weight resurrection in conv.weight")
self.last_mask = new_mask
if verbose:
print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}")
if __name__ == "__main__":
print("Testing sparsifier")
import torch
conv = torch.nn.ConvTranspose1d(8, 16, 4, 4)
params = (0.2, [8, 4])
sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 5)
for i in range(100):
sparsifier.step(verbose=True)
print(conv.weight)

View file

@ -0,0 +1,178 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix, debug
class GRUSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
should be kept.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> gru = torch.nn.GRU(10, 20)
>>> sparsify_dict = {
... 'W_ir' : (0.5, [2, 2], False),
... 'W_iz' : (0.6, [2, 2], False),
... 'W_in' : (0.7, [2, 2], False),
... 'W_hr' : (0.1, [4, 4], True),
... 'W_hz' : (0.2, [4, 4], True),
... 'W_hn' : (0.3, [4, 4], True),
... }
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
super().__init__(task_list, start, stop, interval, exponent=3)
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
with torch.no_grad():
for gru, params in self.task_list:
hidden_size = gru.hidden_size
# input weights
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
if key in params:
if hasattr(gru, 'weight_ih_l0_v'):
weight = gru.weight_ih_l0_v
else:
weight = gru.weight_ih_l0
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
weight[i * hidden_size : (i + 1) * hidden_size, : ],
density, # density
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] * new_mask == new_mask) and debug:
print("weight resurrection in weight_ih_l0_v")
self.last_masks[key] = new_mask
# recurrent weights
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
if key in params:
if hasattr(gru, 'weight_hh_l0_v'):
weight = gru.weight_hh_l0_v
else:
weight = gru.weight_hh_l0
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
weight[i * hidden_size : (i + 1) * hidden_size, : ],
density,
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] * new_mask == new_mask) and True:
print("weight resurrection in weight_hh_l0_v")
self.last_masks[key] = new_mask
if __name__ == "__main__":
print("Testing sparsifier")
gru = torch.nn.GRU(10, 20)
sparsify_dict = {
'W_ir' : (0.5, [2, 2], False),
'W_iz' : (0.6, [2, 2], False),
'W_in' : (0.7, [2, 2], False),
'W_hr' : (0.1, [4, 4], True),
'W_hz' : (0.2, [4, 4], True),
'W_hn' : (0.3, [4, 4], True),
}
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
for i in range(100):
sparsifier.step(verbose=True)
print(gru.weight_hh_l0)

View file

@ -0,0 +1,128 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix
class LinearSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (linear, params), where linear is an instance
of torch.nn.Linear and params is a tuple (density, [m, n]),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> linear = torch.nn.Linear(8, 16)
>>> params = (0.2, [8, 4])
>>> sparsifier = LinearSparsifier([(linear, params)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
super().__init__(task_list, start, stop, interval, exponent=3)
self.last_mask = None
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
with torch.no_grad():
for linear, params in self.task_list:
if hasattr(linear, 'weight_v'):
weight = linear.weight_v
else:
weight = linear.weight
target_density, block_size = params
density = alpha + (1 - alpha) * target_density
weight[:], new_mask = sparsify_matrix(weight, density, block_size, return_mask=True)
if self.last_mask is not None:
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
print("weight resurrection in conv.weight")
self.last_mask = new_mask
if verbose:
print(f"linear_sparsifier[{self.step_counter}]: {density=}")
if __name__ == "__main__":
print("Testing sparsifier")
import torch
linear = torch.nn.Linear(8, 16)
params = (0.2, [4, 2])
sparsifier = LinearSparsifier([(linear, params)], 0, 100, 5)
for i in range(100):
sparsifier.step(verbose=True)
print(linear.weight)

View file

@ -0,0 +1,64 @@
import torch
from dnntools.sparsification import GRUSparsifier, LinearSparsifier, Conv1dSparsifier, ConvTranspose1dSparsifier
def mark_for_sparsification(module, params):
setattr(module, 'sparsify', True)
setattr(module, 'sparsification_params', params)
return module
def create_sparsifier(module, start, stop, interval):
sparsifier_list = []
for m in module.modules():
if hasattr(m, 'sparsify'):
if isinstance(m, torch.nn.GRU):
sparsifier_list.append(
GRUSparsifier([(m, m.sparsification_params)], start, stop, interval)
)
elif isinstance(m, torch.nn.Linear):
sparsifier_list.append(
LinearSparsifier([(m, m.sparsification_params)], start, stop, interval)
)
elif isinstance(m, torch.nn.Conv1d):
sparsifier_list.append(
Conv1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
)
elif isinstance(m, torch.nn.ConvTranspose1d):
sparsifier_list.append(
ConvTranspose1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
)
else:
print(f"[create_sparsifier] warning: module {m} marked for sparsification but no suitable sparsifier exists.")
def sparsify(verbose=False):
for sparsifier in sparsifier_list:
sparsifier.step(verbose)
return sparsify
def count_parameters(model, verbose=False):
total = 0
for name, p in model.named_parameters():
count = torch.ones_like(p).sum().item()
if verbose:
print(f"{name}: {count} parameters")
total += count
return total
def estimate_nonzero_parameters(module):
num_zero_parameters = 0
if hasattr(module, 'sparsify'):
params = module.sparsification_params
if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.ConvTranspose1d):
num_zero_parameters = torch.ones_like(module.weight).sum().item() * (1 - params[0])
elif isinstance(module, torch.nn.GRU):
num_zero_parameters = module.input_size * module.hidden_size * (3 - params['W_ir'][0] - params['W_iz'][0] - params['W_in'][0])
num_zero_parameters += module.hidden_size * module.hidden_size * (3 - params['W_hr'][0] - params['W_hz'][0] - params['W_hn'][0])
elif isinstance(module, torch.nn.Linear):
num_zero_parameters = module.in_features * module.out_features * params[0]
else:
raise ValueError(f'unknown sparsification method for module of type {type(module)}')

View file

@ -0,0 +1 @@
torch

View file

@ -0,0 +1,48 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
#!/usr/bin/env/python
import os
from setuptools import setup
lib_folder = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(lib_folder, 'requirements.txt'), 'r') as f:
install_requires = list(f.read().splitlines())
print(install_requires)
setup(name='dnntools',
version='1.0',
author='Jan Buethe',
author_email='jbuethe@amazon.de',
description='Non-Standard tools for deep neural network training with PyTorch',
packages=['dnntools', 'dnntools.sparsification', 'dnntools.quantization'],
install_requires=install_requires
)

View file

@ -111,7 +111,7 @@ os.makedirs(checkpoint_dir, exist_ok=True)
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir)
repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
@ -408,6 +408,10 @@ for ep in range(1, epochs + 1):
optimizer.step()
# sparsification
if hasattr(model, 'sparsifier'):
model.sparsifier()
running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
running_adv_loss += gen_loss.detach().cpu().item()
running_disc_loss += disc_loss.detach().cpu().item()

View file

@ -111,7 +111,7 @@ os.makedirs(checkpoint_dir, exist_ok=True)
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir)
repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)

View file

@ -46,6 +46,10 @@ def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler,
# update learning rate
scheduler.step()
# sparsification
if hasattr(model, 'sparsifier'):
model.sparsifier()
# update running loss
running_loss += float(loss.cpu())
@ -73,8 +77,6 @@ def evaluate(model, criterion, dataloader, device, log_interval=10):
for i, batch in enumerate(tepoch):
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)

View file

@ -43,6 +43,7 @@ from models import model_dict
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.misc import remove_all_weight_norm
from wexchange.torch import dump_torch_weights
@ -58,30 +59,30 @@ schedules = {
'nolace': [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
('feature_net.conv2', dict(quantize=True, scale=None)),
('feature_net.tconv', dict(quantize=True, scale=None)),
('feature_net.gru', dict()),
('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None)),
('tdshape1', dict()),
('tdshape2', dict()),
('tdshape3', dict()),
('tdshape1', dict(quantize=True, scale=None)),
('tdshape2', dict(quantize=True, scale=None)),
('tdshape3', dict(quantize=True, scale=None)),
('af2', dict(quantize=True, scale=None)),
('af3', dict(quantize=True, scale=None)),
('af4', dict(quantize=True, scale=None)),
('post_cf1', dict(quantize=True, scale=None)),
('post_cf2', dict(quantize=True, scale=None)),
('post_af1', dict(quantize=True, scale=None)),
('post_af2', dict(quantize=True, scale=None)),
('post_af3', dict(quantize=True, scale=None))
('post_cf1', dict(quantize=True, scale=None, sparse=True)),
('post_cf2', dict(quantize=True, scale=None, sparse=True)),
('post_af1', dict(quantize=True, scale=None, sparse=True)),
('post_af2', dict(quantize=True, scale=None, sparse=True)),
('post_af3', dict(quantize=True, scale=None, sparse=True))
],
'lace' : [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
('feature_net.conv2', dict(quantize=True, scale=None)),
('feature_net.tconv', dict(quantize=True, scale=None)),
('feature_net.gru', dict()),
('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None))
@ -140,6 +141,7 @@ if __name__ == "__main__":
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
model.load_state_dict(checkpoint['state_dict'])
remove_all_weight_norm(model, verbose=True)
# CWriter
model_name = checkpoint['setup']['model']['name']

View file

@ -41,6 +41,12 @@ from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
import sys
sys.path.append('../dnntools')
from dnntools.sparsification import create_sparsifier
class LACE(NNSBase):
""" Linear-Adaptive Coding Enhancer """
FRAME_SIZE=80
@ -60,7 +66,12 @@ class LACE(NNSBase):
numbits_embedding_dim=8,
hidden_feature_dim=64,
partial_lookahead=True,
norm_p=2):
norm_p=2,
softquant=False,
sparsify=False,
sparsification_schedule=[10000, 30000, 100],
sparsification_density=0.5,
apply_weight_norm=False):
super().__init__(skip=skip, preemph=preemph)
@ -85,18 +96,21 @@ class LACE(NNSBase):
# feature net
if partial_lookahead:
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
if sparsify:
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):

View file

@ -27,9 +27,13 @@
*/
"""
import numbers
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import numpy as np
@ -43,6 +47,11 @@ from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
import sys
sys.path.append('../dnntools')
from dnntools.quantization import soft_quant
from dnntools.sparsification import create_sparsifier, mark_for_sparsification
class NoLACE(NNSBase):
""" Non-Linear Adaptive Coding Enhancer """
FRAME_SIZE=80
@ -64,11 +73,15 @@ class NoLACE(NNSBase):
partial_lookahead=True,
norm_p=2,
avg_pool_k=4,
pool_after=False):
pool_after=False,
softquant=False,
sparsify=False,
sparsification_schedule=[100, 1000, 100],
sparsification_density=0.5,
apply_weight_norm=False):
super().__init__(skip=skip, preemph=preemph)
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
@ -81,6 +94,11 @@ class NoLACE(NNSBase):
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
if isinstance(sparsification_density, numbers.Number):
sparsification_density = 10 * [sparsification_density]
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
@ -89,36 +107,52 @@ class NoLACE(NNSBase):
# feature net
if partial_lookahead:
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# non-linear transforms
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after)
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after)
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
# combinators
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# feature transforms
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_cf1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_cf2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_af1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_af2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_af3 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
if softquant:
self.post_cf1 = soft_quant(self.post_cf1)
self.post_cf2 = soft_quant(self.post_cf2)
self.post_af1 = soft_quant(self.post_af1)
self.post_af2 = soft_quant(self.post_af2)
self.post_af3 = soft_quant(self.post_af3)
if sparsify:
mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4]))
mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4]))
mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4]))
mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4]))
mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4]))
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):
@ -141,9 +175,12 @@ class NoLACE(NNSBase):
return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops
def feature_transform(self, f, layer):
f = f.permute(0, 2, 1)
f = F.pad(f, [1, 0])
f = torch.tanh(layer(f))
f0 = f.permute(0, 2, 1)
f = F.pad(f0, [1, 0])
if self.residual_in_feature_transform:
f = torch.tanh(layer(f) + f0)
else:
f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, x, features, periods, numbits, debug=False):

View file

@ -26,36 +26,74 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import sys
sys.path.append('../dnntools')
import numbers
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from utils.complexity import _conv1d_flop_count
from dnntools.quantization.softquant import soft_quant
from dnntools.sparsification import mark_for_sparsification
class SilkFeatureNetPL(nn.Module):
""" feature net with partial lookahead """
def __init__(self,
feature_dim=47,
num_channels=256,
hidden_feature_dim=64):
hidden_feature_dim=64,
softquant=False,
sparsify=True,
sparsification_density=0.5,
apply_weight_norm=False):
super(SilkFeatureNetPL, self).__init__()
if isinstance(sparsification_density, numbers.Number):
sparsification_density = 4 * [sparsification_density]
self.feature_dim = feature_dim
self.num_channels = num_channels
self.hidden_feature_dim = hidden_feature_dim
self.conv1 = nn.Conv1d(feature_dim, self.hidden_feature_dim, 1)
self.conv2 = nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2)
self.tconv = nn.ConvTranspose1d(num_channels, num_channels, 4, 4)
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1))
self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2))
self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
gru_input_dim = num_channels + self.repeat_upsamp_dim if self.repeat_upsamp else num_channels
self.gru = norm(norm(nn.GRU(gru_input_dim, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
if softquant:
self.conv2 = soft_quant(self.conv2)
if not self.repeat_upsamp: self.tconv = soft_quant(self.tconv)
self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
if sparsify:
mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4]))
if not self.repeat_upsamp: mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
mark_for_sparsification(
self.gru,
{
'W_ir' : (sparsification_density[2], [8, 4], False),
'W_iz' : (sparsification_density[2], [8, 4], False),
'W_in' : (sparsification_density[2], [8, 4], False),
'W_hr' : (sparsification_density[3], [8, 4], True),
'W_hz' : (sparsification_density[3], [8, 4], True),
'W_hn' : (sparsification_density[3], [8, 4], True),
}
)
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
def flop_count(self, rate=200):
count = 0
for conv in self.conv1, self.conv2, self.tconv:
for conv in [self.conv1, self.conv2] if self.repeat_upsamp else [self.conv1, self.conv2, self.tconv]:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
@ -82,7 +120,7 @@ class SilkFeatureNetPL(nn.Module):
c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
# upsampling
c = self.tconv(c)
c = torch.tanh(self.tconv(c))
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)

View file

@ -0,0 +1,25 @@
#!/bin/bash
INPUT="dataset/LibriSpeech"
OUTPUT="testdata"
OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
BITRATES=( 6000 7500 ) # 9000 12000 15000 18000 24000 32000 )
mkdir -p $OUTPUT
for fn in $(find $INPUT -name "*.wav")
do
name=$(basename ${fn%*.wav})
sox $fn -r 16000 -b 16 -e signed-integer ${OUTPUT}/tmp.raw
for br in ${BITRATES[@]}
do
folder=${OUTPUT}/"${name}_${br}.se"
echo "creating ${folder}..."
mkdir -p $folder
cp ${OUTPUT}/tmp.raw ${folder}/clean.s16
(cd ${folder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
done
rm -f ${OUTPUT}/tmp.raw
done

View file

@ -0,0 +1,7 @@
#!/bin/bash
export PYTHON=/home/ubuntu/opt/miniconda3/envs/torch/bin/python
export LACE="/local/experiments/ietf_enhancement_studies/checkpoints/lace_checkpoint.pth"
export NOLACE="/local/experiments/ietf_enhancement_studies/checkpoints/nolace_checkpoint.pth"
export TESTMODEL="/local/experiments/ietf_enhancement_studies/opus/dnn/torch/osce/test_model.py"
export OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"

View file

@ -0,0 +1,113 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
from scipy.io import wavfile
from pesq import pesq
import numpy as np
from moc import compare
from moc2 import compare as compare2
#from warpq import compute_WAPRQ as warpq
from lace_loss_metric import compare as laceloss_compare
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='folder with processed items')
parser.add_argument('metric', type=str, choices=['pesq', 'moc', 'moc2', 'laceloss'], help='metric to be used for evaluation')
def get_bitrates(folder):
with open(os.path.join(folder, 'bitrates.txt')) as f:
x = f.read()
bitrates = [int(y) for y in x.rstrip('\n').split()]
return bitrates
def get_itemlist(folder):
with open(os.path.join(folder, 'items.txt')) as f:
lines = f.readlines()
items = [x.split()[0] for x in lines]
return items
def process_item(folder, item, bitrate, metric):
fs, x_clean = wavfile.read(os.path.join(folder, 'clean', f"{item}_{bitrate}_clean.wav"))
fs, x_opus = wavfile.read(os.path.join(folder, 'opus', f"{item}_{bitrate}_opus.wav"))
fs, x_lace = wavfile.read(os.path.join(folder, 'lace', f"{item}_{bitrate}_lace.wav"))
fs, x_nolace = wavfile.read(os.path.join(folder, 'nolace', f"{item}_{bitrate}_nolace.wav"))
x_clean = x_clean.astype(np.float32) / 2**15
x_opus = x_opus.astype(np.float32) / 2**15
x_lace = x_lace.astype(np.float32) / 2**15
x_nolace = x_nolace.astype(np.float32) / 2**15
if metric == 'pesq':
result = [pesq(fs, x_clean, x_opus), pesq(fs, x_clean, x_lace), pesq(fs, x_clean, x_nolace)]
elif metric =='moc':
result = [compare(x_clean, x_opus), compare(x_clean, x_lace), compare(x_clean, x_nolace)]
elif metric =='moc2':
result = [compare2(x_clean, x_opus), compare2(x_clean, x_lace), compare2(x_clean, x_nolace)]
# elif metric == 'warpq':
# result = [warpq(x_clean, x_opus), warpq(x_clean, x_lace), warpq(x_clean, x_nolace)]
elif metric == 'laceloss':
result = [laceloss_compare(x_clean, x_opus), laceloss_compare(x_clean, x_lace), laceloss_compare(x_clean, x_nolace)]
else:
raise ValueError(f'unknown metric {metric}')
return result
def process_bitrate(folder, items, bitrate, metric):
results = np.zeros((len(items), 3))
for i, item in enumerate(items):
results[i, :] = np.array(process_item(folder, item, bitrate, metric))
return results
if __name__ == "__main__":
args = parser.parse_args()
items = get_itemlist(args.folder)
bitrates = get_bitrates(args.folder)
results = dict()
for br in bitrates:
print(f"processing bitrate {br}...")
results[br] = process_bitrate(args.folder, items, br, args.metric)
np.save(os.path.join(args.folder, f'results_{args.metric}.npy'), results)
print("Done.")

View file

@ -0,0 +1,330 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
"""STFT-based Loss modules."""
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import torchaudio
def get_window(win_name, win_length, *args, **kwargs):
window_dict = {
'bartlett_window' : torch.bartlett_window,
'blackman_window' : torch.blackman_window,
'hamming_window' : torch.hamming_window,
'hann_window' : torch.hann_window,
'kaiser_window' : torch.kaiser_window
}
if not win_name in window_dict:
raise ValueError()
return window_dict[win_name](win_length, *args, **kwargs)
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
win = get_window(window, win_length).to(x.device)
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
return torch.clamp(torch.abs(x_stft), min=1e-7)
def spectral_convergence_loss(Y_true, Y_pred):
dims=list(range(1, len(Y_pred.shape)))
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
def log_magnitude_loss(Y_true, Y_pred):
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
def spectral_xcorr_loss(Y_true, Y_pred):
Y_true = Y_true.abs()
Y_pred = Y_pred.abs()
dims=list(range(1, len(Y_pred.shape)))
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
return 1 - xcorr.mean()
class MRLogMelLoss(nn.Module):
def __init__(self,
fft_sizes=[512, 256, 128, 64],
overlap=0.5,
fs=16000,
n_mels=18
):
self.fft_sizes = fft_sizes
self.overlap = overlap
self.fs = fs
self.n_mels = n_mels
super().__init__()
self.mel_specs = []
for fft_size in fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
n_mels = self.n_mels
if fft_size < 128:
n_mels //= 2
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
for i, mel_spec in enumerate(self.mel_specs):
self.add_module(f'mel_spec_{i+1}', mel_spec)
def forward(self, y_true, y_pred):
loss = torch.zeros(1, device=y_true.device)
for mel_spec in self.mel_specs:
Y_true = mel_spec(y_true)
Y_pred = mel_spec(y_pred)
loss = loss + log_magnitude_loss(Y_true, Y_pred)
loss = loss / len(self.mel_specs)
return loss
def create_weight_matrix(num_bins, bins_per_band=10):
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
r0 = bins_per_band // 2
r1 = bins_per_band - r0
for i in range(num_bins):
i0 = max(i - r0, 0)
j0 = min(i + r1, num_bins)
m[i, i0: j0] += 1
if i < r0:
m[i, :r0 - i] += 1
if i > num_bins - r1:
m[i, num_bins - r1 - i:] += 1
return m / bins_per_band
def weighted_spectral_convergence(Y_true, Y_pred, w):
# calculate sfm based weights
logY = torch.log(torch.abs(Y_true) + 1e-9)
Y = torch.abs(Y_true)
avg_logY = torch.matmul(logY.transpose(1, 2), w)
avg_Y = torch.matmul(Y.transpose(1, 2), w)
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
loss = torch.mean(
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
)
return loss
def gen_filterbank(N, Fs=16000):
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
ERB_N = 24.7 + .108*in_freq
delta = np.abs(in_freq-out_freq)/ERB_N
center = (delta<.5).astype('float32')
R = -12*center*delta**2 + (1-center)*(3-12*delta)
RE = 10.**(R/10.)
norm = np.sum(RE, axis=1)
RE = RE/norm[:, np.newaxis]
return torch.from_numpy(RE)
def smooth_log_mag(Y_true, Y_pred, filterbank):
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
loss = torch.abs(
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
)
loss = loss.mean()
return loss
class MRSTFTLoss(nn.Module):
def __init__(self,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
overlap=0.5,
window='hann_window',
fs=16000,
log_mag_weight=0,
sc_weight=0,
wsc_weight=0,
smooth_log_mag_weight=2,
sxcorr_weight=1):
super().__init__()
self.fft_sizes = fft_sizes
self.overlap = overlap
self.window = window
self.log_mag_weight = log_mag_weight
self.sc_weight = sc_weight
self.wsc_weight = wsc_weight
self.smooth_log_mag_weight = smooth_log_mag_weight
self.sxcorr_weight = sxcorr_weight
self.fs = fs
# weights for SFM weighted spectral convergence loss
self.wsc_weights = torch.nn.ParameterDict()
for fft_size in fft_sizes:
width = min(11, int(1000 * fft_size / self.fs + .5))
width += width % 2
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
create_weight_matrix(fft_size // 2 + 1, width),
requires_grad=False
)
# filterbanks for smooth log magnitude loss
self.filterbanks = torch.nn.ParameterDict()
for fft_size in fft_sizes:
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
gen_filterbank(fft_size//2),
requires_grad=False
)
def __call__(self, y_true, y_pred):
lm_loss = torch.zeros(1, device=y_true.device)
sc_loss = torch.zeros(1, device=y_true.device)
wsc_loss = torch.zeros(1, device=y_true.device)
slm_loss = torch.zeros(1, device=y_true.device)
sxcorr_loss = torch.zeros(1, device=y_true.device)
for fft_size in self.fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
win_size = fft_size
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
if self.log_mag_weight > 0:
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
if self.sc_weight > 0:
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
if self.wsc_weight > 0:
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
if self.smooth_log_mag_weight > 0:
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
if self.sxcorr_weight > 0:
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
return total_loss
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
class LaceLoss(nn.Module):
def __init__(self):
super().__init__()
self.stftloss = MRSTFTLoss(log_mag_weight=0, sc_weight=0, wsc_weight=0, smooth_log_mag_weight=2, sxcorr_weight=1)
def forward(self, x, y):
specloss = self.stftloss(x, y)
phaseloss = td_l2_norm(x, y)
total_loss = (specloss + 10 * phaseloss) / 13
return total_loss
def compare(self, x_ref, x_deg):
# trim items to same size
n = min(len(x_ref), len(x_deg))
x_ref = x_ref[:n].copy()
x_deg = x_deg[:n].copy()
# pre-emphasis
x_ref[1:] -= 0.85 * x_ref[:-1]
x_deg[1:] -= 0.85 * x_deg[:-1]
device = next(iter(self.parameters())).device
x = torch.from_numpy(x_ref).to(device)
y = torch.from_numpy(x_deg).to(device)
with torch.no_grad():
dist = 10 * self.forward(x, y)
return dist.cpu().numpy().item()
lace_loss = LaceLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lace_loss.to(device)
def compare(x, y):
return lace_loss.compare(x, y)

View file

@ -0,0 +1,116 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def plot_data(filename, data, title=None):
compare_dict = dict()
for br in data.keys():
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
compare_dict[f'LACE {br/1000:.1f} kb/s'] = data[br][:, 1]
compare_dict[f'NoLACE {br/1000:.1f} kb/s'] = data[br][:, 2]
plt.rcParams.update({
"text.usetex": True,
"font.family": "Helvetica",
"font.size": 32
})
black = '#000000'
red = '#ff5745'
blue = '#007dbc'
colors = [black, red, blue]
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
Patch(facecolor=colors[1], label='LACE'),
Patch(facecolor=colors[2], label='NoLACE')]
fig, ax = plt.subplots()
fig.set_size_inches(40, 20)
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
for i, patch in enumerate(bplot['boxes']):
patch.set_facecolor(colors[i%3])
ax.set_xticklabels(compare_dict.keys(), rotation=290)
if title is not None:
ax.set_title(title)
ax.legend(handles=legend_elements)
fig.savefig(filename, bbox_inches='tight')
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
print("Done.")

View file

@ -0,0 +1,109 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def plot_data(filename, data, title=None):
compare_dict = dict()
for br in data.keys():
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
compare_dict[f'LACE (MOC only) {br/1000:.1f} kb/s'] = data[br][:, 1]
compare_dict[f'LACE (MOC + TD) {br/1000:.1f} kb/s'] = data[br][:, 2]
plt.rcParams.update({
"text.usetex": True,
"font.family": "Helvetica",
"font.size": 32
})
colors = ['pink', 'lightblue', 'lightgreen']
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
Patch(facecolor=colors[1], label='MOC loss only'),
Patch(facecolor=colors[2], label='MOC + TD loss')]
fig, ax = plt.subplots()
fig.set_size_inches(40, 20)
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
for i, patch in enumerate(bplot['boxes']):
patch.set_facecolor(colors[i%3])
ax.set_xticklabels(compare_dict.keys(), rotation=290)
if title is not None:
ax.set_title(title)
ax.legend(handles=legend_elements)
fig.savefig(filename, bbox_inches='tight')
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
print("Done.")

View file

@ -0,0 +1,124 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def make_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1]
nolace = data[br][:, 2]
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
def make_diff_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1] - opus
nolace = data[br][:, 2] - opus
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
print("Done.")

View file

@ -0,0 +1,121 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def make_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1]
nolace = data[br][:, 2]
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
def make_diff_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1] - opus
nolace = data[br][:, 2] - opus
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
print("Done.")

View file

@ -0,0 +1,182 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numpy as np
import scipy.signal
def compute_vad_mask(x, fs, stop_db=-70):
frame_length = (fs + 49) // 50
x = x[: frame_length * (len(x) // frame_length)]
frames = x.reshape(-1, frame_length)
frame_energy = np.sum(frames ** 2, axis=1)
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
vactive = np.ones_like(frames)
vactive[frame_energy_smooth < max_threshold, :] = 0
vactive = vactive.reshape(-1)
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
filter = filter / filter.sum()
mask = np.convolve(vactive, filter, mode='same')
return x, mask
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
num_samples = frame_size + (num_frames - 1) * hop_size
if len(mask) < num_samples:
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
else:
mask = mask[:num_samples]
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
return new_mask
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
num_spectra = (len(x) - window_size - hop_size) // hop_size
window = scipy.signal.get_window(window, window_size)
N = window_size // 2
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
return psd
def frequency_mask(num_bands, up_factor, down_factor):
up_mask = np.zeros((num_bands, num_bands))
down_mask = np.zeros((num_bands, num_bands))
for i in range(num_bands):
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
return down_mask @ up_mask
def rect_fb(band_limits, num_bins=None):
num_bands = len(band_limits) - 1
if num_bins is None:
num_bins = band_limits[-1]
fb = np.zeros((num_bands, num_bins))
for i in range(num_bands):
fb[i, band_limits[i]:band_limits[i+1]] = 1
return fb
def compare(x, y, apply_vad=False):
""" Modified version of opus_compare for 16 kHz mono signals
Args:
x (np.ndarray): reference input signal scaled to [-1, 1]
y (np.ndarray): test signal scaled to [-1, 1]
Returns:
float: perceptually weighted error
"""
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
num_bands = len(band_limits) - 1
fb = rect_fb(band_limits, num_bins=81)
# trim samples to same size
num_samples = min(len(x), len(y))
x = x[:num_samples] * 2**15
y = y[:num_samples] * 2**15
psd_x = power_spectrum(x) + 100000
psd_y = power_spectrum(y) + 100000
num_frames = psd_x.shape[0]
# average band energies
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
# frequecy masking
f_mask = frequency_mask(num_bands, 0.1, 0.03)
mask_x = be_x @ f_mask.T
# temporal masking
for i in range(1, num_frames):
mask_x[i, :] += 0.5 * mask_x[i-1, :]
# apply mask
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
# 2-frame average
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
# distortion metric
re = masked_psd_y / masked_psd_x
im = np.log(re) ** 2
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
Ef = np.mean(Eb , axis=1)
if apply_vad:
_, mask = compute_vad_mask(x, 16000)
mask = convert_mask(mask, Ef.shape[0])
else:
mask = np.ones_like(Ef)
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
return float(err)
if __name__ == "__main__":
import argparse
from scipy.io import wavfile
parser = argparse.ArgumentParser()
parser.add_argument('ref', type=str, help='reference wav file')
parser.add_argument('deg', type=str, help='degraded wav file')
parser.add_argument('--apply-vad', action='store_true')
args = parser.parse_args()
fs1, x = wavfile.read(args.ref)
fs2, y = wavfile.read(args.deg)
if max(fs1, fs2) != 16000:
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
x = x.astype(np.float32) / 2**15
y = y.astype(np.float32) / 2**15
err = compare(x, y, apply_vad=args.apply_vad)
print(f"MOC: {err}")

View file

@ -0,0 +1,190 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numpy as np
import scipy.signal
def compute_vad_mask(x, fs, stop_db=-70):
frame_length = (fs + 49) // 50
x = x[: frame_length * (len(x) // frame_length)]
frames = x.reshape(-1, frame_length)
frame_energy = np.sum(frames ** 2, axis=1)
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
vactive = np.ones_like(frames)
vactive[frame_energy_smooth < max_threshold, :] = 0
vactive = vactive.reshape(-1)
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
filter = filter / filter.sum()
mask = np.convolve(vactive, filter, mode='same')
return x, mask
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
num_samples = frame_size + (num_frames - 1) * hop_size
if len(mask) < num_samples:
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
else:
mask = mask[:num_samples]
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
return new_mask
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
num_spectra = (len(x) - window_size - hop_size) // hop_size
window = scipy.signal.get_window(window, window_size)
N = window_size // 2
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
return psd
def frequency_mask(num_bands, up_factor, down_factor):
up_mask = np.zeros((num_bands, num_bands))
down_mask = np.zeros((num_bands, num_bands))
for i in range(num_bands):
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
return down_mask @ up_mask
def rect_fb(band_limits, num_bins=None):
num_bands = len(band_limits) - 1
if num_bins is None:
num_bins = band_limits[-1]
fb = np.zeros((num_bands, num_bins))
for i in range(num_bands):
fb[i, band_limits[i]:band_limits[i+1]] = 1
return fb
def _compare(x, y, apply_vad=False, factor=1):
""" Modified version of opus_compare for 16 kHz mono signals
Args:
x (np.ndarray): reference input signal scaled to [-1, 1]
y (np.ndarray): test signal scaled to [-1, 1]
Returns:
float: perceptually weighted error
"""
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
band_limits = [factor * b for b in [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]]
window_size = factor * 160
hop_size = factor * 40
num_bins = window_size // 2 + 1
num_bands = len(band_limits) - 1
fb = rect_fb(band_limits, num_bins=num_bins)
# trim samples to same size
num_samples = min(len(x), len(y))
x = x[:num_samples].copy() * 2**15
y = y[:num_samples].copy() * 2**15
psd_x = power_spectrum(x, window_size=window_size, hop_size=hop_size) + 100000
psd_y = power_spectrum(y, window_size=window_size, hop_size=hop_size) + 100000
num_frames = psd_x.shape[0]
# average band energies
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
# frequecy masking
f_mask = frequency_mask(num_bands, 0.1, 0.03)
mask_x = be_x @ f_mask.T
# temporal masking
for i in range(1, num_frames):
mask_x[i, :] += (0.5 ** factor) * mask_x[i-1, :]
# apply mask
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
# 2-frame average
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
# distortion metric
re = masked_psd_y / masked_psd_x
#im = re - np.log(re) - 1
im = np.log(re) ** 2
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
Ef = np.mean(Eb ** 1, axis=1)
if apply_vad:
_, mask = compute_vad_mask(x, 16000)
mask = convert_mask(mask, Ef.shape[0])
else:
mask = np.ones_like(Ef)
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
return float(err)
def compare(x, y, apply_vad=False):
err = np.linalg.norm([_compare(x, y, apply_vad=apply_vad, factor=1)], ord=2)
return err
if __name__ == "__main__":
import argparse
from scipy.io import wavfile
parser = argparse.ArgumentParser()
parser.add_argument('ref', type=str, help='reference wav file')
parser.add_argument('deg', type=str, help='degraded wav file')
parser.add_argument('--apply-vad', action='store_true')
args = parser.parse_args()
fs1, x = wavfile.read(args.ref)
fs2, y = wavfile.read(args.deg)
if max(fs1, fs2) != 16000:
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
x = x.astype(np.float32) / 2**15
y = y.astype(np.float32) / 2**15
err = compare(x, y, apply_vad=args.apply_vad)
print(f"MOC: {err}")

View file

@ -0,0 +1,98 @@
#!/bin/bash
if [ ! -f "$PYTHON" ]
then
echo "PYTHON variable does not link to a file. Please point it to your python executable."
exit 1
fi
if [ ! -f "$TESTMODEL" ]
then
echo "TESTMODEL variable does not link to a file. Please point it to your copy of test_model.py"
exit 1
fi
if [ ! -f "$OPUSDEMO" ]
then
echo "OPUSDEMO variable does not link to a file. Please point it to your patched version of opus_demo."
exit 1
fi
if [ ! -f "$LACE" ]
then
echo "LACE variable does not link to a file. Please point it to your copy of the LACE checkpoint."
exit 1
fi
if [ ! -f "$NOLACE" ]
then
echo "LACE variable does not link to a file. Please point it to your copy of the NOLACE checkpoint."
exit 1
fi
case $# in
2) INPUT=$1; OUTPUT=$2;;
*) echo "process_dataset.sh <input folder> <output folder>"; exit 1;;
esac
if [ -d $OUTPUT ]
then
echo "output folder $OUTPUT exists, aborting..."
exit 1
fi
mkdir -p $OUTPUT
if [ "$BITRATES" == "" ]
then
BITRATES=( 6000 7500 9000 12000 15000 18000 24000 32000 )
echo "BITRATES variable not defined. Proceeding with default bitrates ${BITRATES[@]}."
fi
echo "LACE=${LACE}" > ${OUTPUT}/info.txt
echo "NOLACE=${NOLACE}" >> ${OUTPUT}/info.txt
ITEMFILE=${OUTPUT}/items.txt
BITRATEFILE=${OUTPUT}/bitrates.txt
FPROCESSING=${OUTPUT}/processing
FCLEAN=${OUTPUT}/clean
FOPUS=${OUTPUT}/opus
FLACE=${OUTPUT}/lace
FNOLACE=${OUTPUT}/nolace
mkdir -p $FPROCESSING $FCLEAN $FOPUS $FLACE $FNOLACE
echo "${BITRATES[@]}" > $BITRATEFILE
for fn in $(find $INPUT -type f -name "*.wav")
do
UUID=$(uuid)
echo "$UUID $fn" >> $ITEMFILE
PIDS=( )
for br in ${BITRATES[@]}
do
# run opus
pfolder=${FPROCESSING}/${UUID}_${br}
mkdir -p $pfolder
sox $fn -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16
(cd ${pfolder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
# copy clean and opus
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16 $FCLEAN/${UUID}_${br}_clean.wav
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/noisy.s16 $FOPUS/${UUID}_${br}_opus.wav
# run LACE
$PYTHON $TESTMODEL $pfolder $LACE $FLACE/${UUID}_${br}_lace.wav &
PIDS+=( "$!" )
# run NoLACE
$PYTHON $TESTMODEL $pfolder $NOLACE $FNOLACE/${UUID}_${br}_nolace.wav &
PIDS+=( "$!" )
done
for pid in ${PIDS[@]}
do
wait $pid
done
done

View file

@ -0,0 +1,138 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import tempfile
import shutil
import pandas as pd
from scipy.spatial.distance import cdist
from scipy.io import wavfile
import numpy as np
from nomad_audio.nomad import Nomad
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='folder with processed items')
parser.add_argument('--full-reference', action='store_true', help='use NOMAD as full-reference metric')
parser.add_argument('--device', type=str, default=None, help='device for Nomad')
def get_bitrates(folder):
with open(os.path.join(folder, 'bitrates.txt')) as f:
x = f.read()
bitrates = [int(y) for y in x.rstrip('\n').split()]
return bitrates
def get_itemlist(folder):
with open(os.path.join(folder, 'items.txt')) as f:
lines = f.readlines()
items = [x.split()[0] for x in lines]
return items
def nomad_wrapper(ref_folder, deg_folder, full_reference=False, ref_embeddings=None, device=None):
model = Nomad(device=device)
if not full_reference:
results = model.predict(nmr=ref_folder, deg=deg_folder)[0].to_dict()['NOMAD']
return results, None
else:
if ref_embeddings is None:
print(f"Computing reference embeddings from {ref_folder}")
ref_data = pd.DataFrame(sorted(os.listdir(ref_folder)))
ref_data.columns = ['filename']
ref_data['filename'] = [os.path.join(ref_folder, x) for x in ref_data['filename']]
ref_embeddings = model.get_embeddings_csv(model.model, ref_data).set_index('filename')
print(f"Computing degraded embeddings from {deg_folder}")
deg_data = pd.DataFrame(sorted(os.listdir(deg_folder)))
deg_data.columns = ['filename']
deg_data['filename'] = [os.path.join(deg_folder, x) for x in deg_data['filename']]
deg_embeddings = model.get_embeddings_csv(model.model, deg_data).set_index('filename')
dist = np.diag(cdist(ref_embeddings, deg_embeddings)) # wasteful
test_files = [x.split('/')[-1].split('.')[0] for x in deg_embeddings.index]
results = dict(zip(test_files, dist))
return results, ref_embeddings
def nomad_process_all(folder, full_reference=False, device=None):
bitrates = get_bitrates(folder)
items = get_itemlist(folder)
with tempfile.TemporaryDirectory() as dir:
cleandir = os.path.join(dir, 'clean')
opusdir = os.path.join(dir, 'opus')
lacedir = os.path.join(dir, 'lace')
nolacedir = os.path.join(dir, 'nolace')
# prepare files
for d in [cleandir, opusdir, lacedir, nolacedir]: os.makedirs(d)
for br in bitrates:
for item in items:
for cond in ['clean', 'opus', 'lace', 'nolace']:
shutil.copyfile(os.path.join(folder, cond, f"{item}_{br}_{cond}.wav"), os.path.join(dir, cond, f"{item}_{br}.wav"))
nomad_opus, ref_embeddings = nomad_wrapper(cleandir, opusdir, full_reference=full_reference, ref_embeddings=None)
nomad_lace, ref_embeddings = nomad_wrapper(cleandir, lacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
nomad_nolace, ref_embeddings = nomad_wrapper(cleandir, nolacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
results = dict()
for br in bitrates:
results[br] = np.zeros((len(items), 3))
for i, item in enumerate(items):
key = f"{item}_{br}"
results[br][i, 0] = nomad_opus[key]
results[br][i, 1] = nomad_lace[key]
results[br][i, 2] = nomad_nolace[key]
return results
if __name__ == "__main__":
args = parser.parse_args()
items = get_itemlist(args.folder)
bitrates = get_bitrates(args.folder)
results = nomad_process_all(args.folder, full_reference=args.full_reference, device=args.device)
np.save(os.path.join(args.folder, f'results_nomad.npy'), results)
print("Done.")

View file

@ -0,0 +1,205 @@
""" module for inspecting models during inference """
import os
import yaml
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import numpy as np
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
_state = dict()
_folder = 'endoscopy'
def get_gru_gates(gru, input, state):
hidden_size = gru.hidden_size
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
# reset gate
start, stop = 0 * hidden_size, 1 * hidden_size
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# update gate
start, stop = 1 * hidden_size, 2 * hidden_size
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# new gate
start, stop = 2 * hidden_size, 3 * hidden_size
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
def init(folder='endoscopy'):
""" sets up output folder for endoscopy data """
global _folder
_folder = folder
if not os.path.exists(folder):
os.makedirs(folder)
else:
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
def write_data(key, data, fs):
""" appends data to previous data written under key """
global _state
# convert to numpy if torch.Tensor is given
if isinstance(data, torch.Tensor):
data = data.detach().numpy()
if not key in _state:
_state[key] = {
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
'fs' : fs,
'dim' : tuple(data.shape),
'dtype' : str(data.dtype)
}
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
else:
if _state[key]['fs'] != fs:
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
if _state[key]['dtype'] != str(data.dtype):
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
if _state[key]['dim'] != tuple(data.shape):
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
_state[key]['fid'].write(data.tobytes())
def close(folder='endoscopy'):
""" clean up """
for key in _state.keys():
_state[key]['fid'].close()
def read_data(folder='endoscopy'):
""" retrieves written data as numpy arrays """
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
return_dict = dict()
for key in keys:
with open(os.path.join(folder, key + '.yml'), 'r') as f:
value = yaml.load(f.read(), yaml.FullLoader)
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
data = np.frombuffer(f.read(), dtype=value['dtype'])
value['data'] = data.reshape((-1,) + value['dim'])
return_dict[key] = value
return return_dict
def get_best_reshape(shape, target_ratio=1):
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return (1,)
num_columns = int((pixel_count / target_ratio)**.5)
while (pixel_count % num_columns):
num_columns -= 1
num_rows = pixel_count // num_columns
return (num_rows, num_columns)
def get_type_and_shape(shape):
# can happen if data is one dimensional
if len(shape) == 0:
shape = (1,)
# calculate pixel count
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return 'plot', (1, )
# stay with shape if already 2-dimensional
if len(shape) == 2:
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
return 'image', shape
return 'image', get_best_reshape(shape)
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
# determine plot setup
num_keys = len(data.keys())
num_rows = int((num_keys * 3/4) ** .5)
num_cols = (num_keys + num_rows - 1) // num_rows
fig, axs = plt.subplots(num_rows, num_cols)
fig.set_size_inches(num_cols * 5, num_rows * 5)
display = dict()
fs_max = max([val['fs'] for val in data.values()])
num_samples = max([val['data'].shape[0] for val in data.values()])
keys = sorted(data.keys())
# inspect data
for i, key in enumerate(keys):
axs[i // num_cols, i % num_cols].title.set_text(key)
display[key] = dict()
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
display[key]['down_factor'] = data[key]['fs'] / fs_max
start_index = max(start_index, half_signal_window_length)
while stop_index < 0:
stop_index += num_samples
stop_index = min(stop_index, num_samples - half_signal_window_length)
# actual plotting
frames = []
for index in range(start_index, stop_index):
ims = []
for i, key in enumerate(keys):
feature_index = int(round(index * display[key]['down_factor']))
if display[key]['type'] == 'plot':
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
elif display[key]['type'] == 'image':
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
frames.append(ims)
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
if not filename.endswith('.mp4'):
filename += '.mp4'
ani.save(filename)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,25 @@
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation
def make_playback_animation(savepath, spec, duration_ms, vmin=20, vmax=90):
fig, axs = plt.subplots()
axs.set_axis_off()
fig.set_size_inches((duration_ms / 1000 * 5, 5))
frames = []
frame_duration=20
num_frames = int(duration_ms / frame_duration + .99)
spec_height, spec_width = spec.shape
for i in range(num_frames):
xpos = (i - 1) / (num_frames - 3) * (spec_width - 1)
new_frame = axs.imshow(spec, cmap='inferno', origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
if i in {0, num_frames - 1}:
frames.append([new_frame])
else:
line = axs.plot([xpos, xpos], [0, spec_height-1], color='white', alpha=0.8)[0]
frames.append([new_frame, line])
ani = matplotlib.animation.ArtistAnimation(fig, frames, blit=True, interval=frame_duration)
ani.save(savepath, dpi=720)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -27,9 +27,13 @@
*/
"""
seed=1888
import os
import argparse
import sys
import random
random.seed(seed)
import yaml
@ -40,9 +44,12 @@ except:
has_git = False
import torch
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
np.random.seed(seed)
from scipy.io import wavfile
@ -54,7 +61,7 @@ from engine.engine import train_one_epoch, evaluate
from utils.silk_features import load_inference_data
from utils.misc import count_parameters
from utils.misc import count_parameters, count_nonzero_parameters
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
@ -71,6 +78,7 @@ parser.add_argument('--no-redirect', action='store_true', help='disables re-dire
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
@ -98,7 +106,7 @@ if os.path.exists(args.output):
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
os._exit(0)
else:
os.makedirs(args.output, exist_ok=True)
@ -109,7 +117,7 @@ os.makedirs(checkpoint_dir, exist_ok=True)
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir)
repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
@ -117,6 +125,8 @@ if has_git:
if is_dirty:
print("warning: repo is dirty")
with open(os.path.join(args.output, 'repo.diff'), "w") as f:
f.write(repo.git.execute(["git", "diff"]))
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
@ -292,6 +302,6 @@ for ep in range(1, epochs + 1):
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
print()
print(f"non-zero parameters: {count_nonzero_parameters(model)}\n")
print('Done')

View file

@ -107,7 +107,7 @@ os.makedirs(checkpoint_dir, exist_ok=True)
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir)
repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)

View file

@ -32,6 +32,7 @@ from torch import nn
import torch.nn.functional as F
from utils.endoscopy import write_data
from utils.softquant import soft_quant
class LimitedAdaptiveComb1d(nn.Module):
COUNTER = 1
@ -47,6 +48,8 @@ class LimitedAdaptiveComb1d(nn.Module):
gain_limit_db=10,
global_gain_limits_db=[-6, 6],
norm_p=2,
softquant=False,
apply_weight_norm=False,
**kwargs):
"""
@ -97,17 +100,22 @@ class LimitedAdaptiveComb1d(nn.Module):
else:
self.name = name
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
# network for generating convolution weights
self.conv_kernel = nn.Linear(feature_dim, kernel_size)
self.conv_kernel = norm(nn.Linear(feature_dim, kernel_size))
if softquant:
self.conv_kernel = soft_quant(self.conv_kernel)
# comb filter gain
self.filter_gain = nn.Linear(feature_dim, 1)
self.filter_gain = norm(nn.Linear(feature_dim, 1))
self.log_gain_limit = gain_limit_db * 0.11512925464970229
with torch.no_grad():
self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
self.global_filter_gain = nn.Linear(feature_dim, 1)
self.global_filter_gain = norm(nn.Linear(feature_dim, 1))
log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
self.filter_gain_a = (log_max - log_min) / 2
self.filter_gain_b = (log_max + log_min) / 2

View file

@ -34,7 +34,7 @@ import torch.nn.functional as F
from utils.endoscopy import write_data
from utils.ada_conv import adaconv_kernel
from utils.softquant import soft_quant
class LimitedAdaptiveConv1d(nn.Module):
COUNTER = 1
@ -51,6 +51,8 @@ class LimitedAdaptiveConv1d(nn.Module):
gain_limits_db=[-6, 6],
shape_gain_db=0,
norm_p=2,
softquant=False,
apply_weight_norm=False,
**kwargs):
"""
@ -100,12 +102,16 @@ class LimitedAdaptiveConv1d(nn.Module):
else:
self.name = name
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
# network for generating convolution weights
self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size)
self.conv_kernel = norm(nn.Linear(feature_dim, in_channels * out_channels * kernel_size))
if softquant:
self.conv_kernel = soft_quant(self.conv_kernel)
self.shape_gain = min(1, 10**(shape_gain_db / 20))
self.filter_gain = nn.Linear(feature_dim, out_channels)
self.filter_gain = norm(nn.Linear(feature_dim, out_channels))
log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
self.filter_gain_a = (log_max - log_min) / 2
self.filter_gain_b = (log_max + log_min) / 2

View file

@ -3,6 +3,7 @@ from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
from utils.softquant import soft_quant
class TDShaper(nn.Module):
COUNTER = 1
@ -12,7 +13,9 @@ class TDShaper(nn.Module):
frame_size=160,
avg_pool_k=4,
innovate=False,
pool_after=False
pool_after=False,
softquant=False,
apply_weight_norm=False
):
"""
@ -45,23 +48,29 @@ class TDShaper(nn.Module):
assert frame_size % avg_pool_k == 0
self.env_dim = frame_size // avg_pool_k + 1
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
# feature transform
self.feature_alpha1 = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2)
self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))
self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2))
if softquant:
self.feature_alpha1_f = soft_quant(self.feature_alpha1_f)
if self.innovate:
self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
self.feature_alpha2b = nn.Conv1d(frame_size, frame_size, 2)
self.feature_alpha2c = nn.Conv1d(frame_size, frame_size, 2)
self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2))
self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2))
def flop_count(self, rate):
frame_rate = rate / self.frame_size
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
if self.innovate:
inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
@ -110,9 +119,10 @@ class TDShaper(nn.Module):
tenv = self.envelope_transform(x)
# feature path
f = torch.cat((features, tenv), dim=-1)
f = F.pad(f.permute(0, 2, 1), [1, 0])
alpha = F.leaky_relu(self.feature_alpha1(f), 0.2)
f = F.pad(features.permute(0, 2, 1), [1, 0])
t = F.pad(tenv.permute(0, 2, 1), [1, 0])
alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t)
alpha = F.leaky_relu(alpha, 0.2)
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
alpha = alpha.permute(0, 2, 1)

View file

@ -28,6 +28,7 @@
"""
import torch
from torch.nn.utils import remove_weight_norm
def count_parameters(model, verbose=False):
total = 0
@ -41,7 +42,17 @@ def count_parameters(model, verbose=False):
return total
def count_nonzero_parameters(model, verbose=False):
total = 0
for name, p in model.named_parameters():
count = torch.count_nonzero(p).item()
if verbose:
print(f"{name}: {count} non-zero parameters")
total += count
return total
def retain_grads(module):
for p in module.parameters():
if p.requires_grad:
@ -62,4 +73,23 @@ def create_weights(s_real, s_gen, alpha):
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
weights.append(weight)
return weights
return weights
def _get_candidates(module: torch.nn.Module):
candidates = []
for key in module.__dict__.keys():
if hasattr(module, key + '_v'):
candidates.append(key)
return candidates
def remove_all_weight_norm(model : torch.nn.Module, verbose=False):
for name, m in model.named_modules():
candidates = _get_candidates(m)
for candidate in candidates:
try:
remove_weight_norm(m, name=candidate)
if verbose: print(f'removed weight norm on weight {name}.{candidate}')
except:
pass

View file

@ -0,0 +1,110 @@
import torch
@torch.no_grad()
def compute_optimal_scale(weight):
with torch.no_grad():
n_out, n_in = weight.shape
assert n_in % 4 == 0
if n_out % 8:
# add padding
pad = n_out - n_out % 8
weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
scale_max = weight_max_abs / 127
scale_sum = weight_max_sum / 129
scale = torch.maximum(scale_max, scale_sum)
return scale[:n_out]
@torch.no_grad()
def q_scaled_noise(module, weight):
if isinstance(module, torch.nn.Conv1d):
w = weight.permute(0, 2, 1).flatten(1)
noise = torch.rand_like(w) - 0.5
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
elif isinstance(module, torch.nn.ConvTranspose1d):
i, o, k = weight.shape
w = weight.permute(2, 1, 0).reshape(k * o, i)
noise = torch.rand_like(w) - 0.5
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(k, o, i).permute(2, 1, 0)
elif len(weight.shape) == 2:
noise = torch.rand_like(weight) - 0.5
scale = compute_optimal_scale(weight)
noise = noise * scale.unsqueeze(-1)
else:
raise ValueError('unknown quantization setting')
return noise
class SoftQuant:
name: str
def __init__(self, names: str, scale: float) -> None:
self.names = names
self.quantization_noise = None
self.scale = scale
def __call__(self, module, inputs, *args, before=True):
if not module.training: return
if before:
self.quantization_noise = dict()
for name in self.names:
weight = getattr(module, name)
if self.scale is None:
self.quantization_noise[name] = q_scaled_noise(module, weight)
else:
self.quantization_noise[name] = \
self.scale * weight.abs().max() * (torch.rand_like(weight) - 0.5)
with torch.no_grad():
weight.data[:] = weight + self.quantization_noise[name]
else:
for name in self.names:
weight = getattr(module, name)
with torch.no_grad():
weight.data[:] = weight - self.quantization_noise[name]
self.quantization_noise = None
def apply(module, names=['weight'], scale=None):
fn = SoftQuant(names, scale)
for name in names:
if not hasattr(module, name):
raise ValueError("")
fn_before = lambda *x : fn(*x, before=True)
fn_after = lambda *x : fn(*x, before=False)
setattr(fn_before, 'sqm', fn)
setattr(fn_after, 'sqm', fn)
module.register_forward_pre_hook(fn_before)
module.register_forward_hook(fn_after)
module
return fn
def soft_quant(module, names=['weight'], scale=None):
fn = SoftQuant.apply(module, names, scale)
return module
def remove_soft_quant(module, names=['weight']):
for k, hook in module._forward_pre_hooks.items():
if hasattr(hook, 'sqm'):
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
del module._forward_pre_hooks[k]
for k, hook in module._forward_hooks.items():
if hasattr(hook, 'sqm'):
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
del module._forward_hooks[k]
return module

View file

@ -50,7 +50,11 @@ lace_setup = {
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'skip': 91
'skip': 91,
'softquant': True,
'sparsify': False,
'sparsification_density': 0.4,
'sparsification_schedule': [10000, 40000, 200]
}
},
'data': {
@ -63,7 +67,7 @@ lace_setup = {
'num_bands_clean_spec': 64,
'num_bands_noisy_spec': 18,
'noisy_spec_scale': 'opus',
'pitch_hangover': 8,
'pitch_hangover': 0,
},
'training': {
'batch_size': 256,
@ -106,7 +110,11 @@ nolace_setup = {
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'skip': 91
'skip': 91,
'softquant': True,
'sparsify': False,
'sparsification_density': 0.4,
'sparsification_schedule': [10000, 40000, 200]
}
},
'data': {
@ -119,7 +127,7 @@ nolace_setup = {
'num_bands_clean_spec': 64,
'num_bands_noisy_spec': 18,
'noisy_spec_scale': 'opus',
'pitch_hangover': 8,
'pitch_hangover': 0,
},
'training': {
'batch_size': 256,
@ -160,7 +168,11 @@ nolace_setup_adv = {
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'skip': 91
'skip': 91,
'softquant': True,
'sparsify': False,
'sparsification_density': 0.4,
'sparsification_schedule': [0, 0, 200]
}
},
'data': {
@ -173,7 +185,7 @@ nolace_setup_adv = {
'num_bands_clean_spec': 64,
'num_bands_noisy_spec': 18,
'noisy_spec_scale': 'opus',
'pitch_hangover': 8,
'pitch_hangover': 0,
},
'discriminator': {
'args': [],

View file

@ -282,7 +282,8 @@ def print_conv1d_layer(writer : CWriter,
bias : np.ndarray,
scale=1/128,
format : str = 'torch',
quantize=False):
quantize=False,
sparse=False):
if format == "torch":
@ -290,7 +291,7 @@ def print_conv1d_layer(writer : CWriter,
weight = np.transpose(weight, (2, 1, 0))
lin_weight = np.reshape(weight, (-1, weight.shape[-1]))
print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=False, diagonal=False, quantize=quantize)
print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=sparse, diagonal=False, quantize=quantize)
writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {weight.shape[2]}\n")
@ -369,7 +370,8 @@ def print_tconv1d_layer(writer : CWriter,
bias : np.ndarray,
stride: int,
scale=1/128,
quantize=False):
quantize=False,
sparse=False):
in_channels, out_channels, kernel_size = weight.shape
@ -377,7 +379,7 @@ def print_tconv1d_layer(writer : CWriter,
linear_weight = weight.transpose(2, 1, 0).reshape(kernel_size * out_channels, in_channels).transpose(1, 0)
linear_bias = np.repeat(bias[np.newaxis, :], kernel_size, 0).flatten()
print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize)
print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize, sparse=sparse)
writer.header.write(f"\n#define {name.upper()}_KERNEL_SIZE {kernel_size}\n")
writer.header.write(f"\n#define {name.upper()}_STRIDE {stride}\n")

View file

@ -153,7 +153,7 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', scale=1/1
np.save(where, 'weight_global_gain.npy', w_global_gain)
np.save(where, 'bias_global_gain.npy', b_global_gain)
def dump_torch_tdshaper(where, shaper, name='tdshaper'):
def dump_torch_tdshaper(where, shaper, name='tdshaper', quantize=False, scale=1/128):
if isinstance(where, CWriter):
where.header.write(f"""
@ -165,7 +165,8 @@ def dump_torch_tdshaper(where, shaper, name='tdshaper'):
"""
)
dump_torch_conv1d_weights(where, shaper.feature_alpha1, name + "_alpha1")
dump_torch_conv1d_weights(where, shaper.feature_alpha1_f, name + "_alpha1_f", quantize=quantize, scale=scale)
dump_torch_conv1d_weights(where, shaper.feature_alpha1_t, name + "_alpha1_t")
dump_torch_conv1d_weights(where, shaper.feature_alpha2, name + "_alpha2")
if shaper.innovate:
@ -274,7 +275,7 @@ def load_torch_dense_weights(where, dense):
dense.bias.set_(torch.from_numpy(b))
def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False):
def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False):
w = conv.weight.detach().cpu().numpy().copy()
if conv.bias is None:
@ -284,7 +285,7 @@ def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=Fa
if isinstance(where, CWriter):
return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize)
return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize, sparse=sparse)
else:
os.makedirs(where, exist_ok=True)
@ -304,7 +305,7 @@ def load_torch_conv1d_weights(where, conv):
conv.bias.set_(torch.from_numpy(b))
def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False):
def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False):
w = conv.weight.detach().cpu().numpy().copy()
if conv.bias is None:
@ -314,7 +315,7 @@ def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=F
if isinstance(where, CWriter):
return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize)
return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize, sparse=sparse)
else:
os.makedirs(where, exist_ok=True)