Updated LACE and NoLACE models to version 2
This commit is contained in:
parent
4f311a1ad4
commit
299e38cab7
57 changed files with 4793 additions and 109 deletions
|
@ -9,7 +9,7 @@ set -e
|
|||
srcdir=`dirname $0`
|
||||
test -n "$srcdir" && cd "$srcdir"
|
||||
|
||||
dnn/download_model.sh caca188
|
||||
dnn/download_model.sh 88477f4
|
||||
|
||||
echo "Updating build configuration files, please wait...."
|
||||
|
||||
|
|
10
dnn/nndsp.c
10
dnn/nndsp.c
|
@ -340,7 +340,8 @@ void adashape_process_frame(
|
|||
float *x_out,
|
||||
const float *x_in,
|
||||
const float *features,
|
||||
const LinearLayer *alpha1,
|
||||
const LinearLayer *alpha1f,
|
||||
const LinearLayer *alpha1t,
|
||||
const LinearLayer *alpha2,
|
||||
int feature_dim,
|
||||
int frame_size,
|
||||
|
@ -350,6 +351,7 @@ void adashape_process_frame(
|
|||
{
|
||||
float in_buffer[ADASHAPE_MAX_INPUT_DIM + ADASHAPE_MAX_FRAME_SIZE];
|
||||
float out_buffer[ADASHAPE_MAX_FRAME_SIZE];
|
||||
float tmp_buffer[ADASHAPE_MAX_FRAME_SIZE];
|
||||
int i, k;
|
||||
int tenv_size;
|
||||
float mean;
|
||||
|
@ -389,14 +391,16 @@ void adashape_process_frame(
|
|||
#ifdef DEBUG_NNDSP
|
||||
print_float_vector("alpha1_in", in_buffer, feature_dim + tenv_size + 1);
|
||||
#endif
|
||||
compute_generic_conv1d(alpha1, out_buffer, hAdaShape->conv_alpha1_state, in_buffer, feature_dim + tenv_size + 1, ACTIVATION_LINEAR, arch);
|
||||
compute_generic_conv1d(alpha1f, out_buffer, hAdaShape->conv_alpha1f_state, in_buffer, feature_dim, ACTIVATION_LINEAR, arch);
|
||||
compute_generic_conv1d(alpha1t, tmp_buffer, hAdaShape->conv_alpha1t_state, tenv, tenv_size + 1, ACTIVATION_LINEAR, arch);
|
||||
#ifdef DEBUG_NNDSP
|
||||
print_float_vector("alpha1_out", out_buffer, frame_size);
|
||||
#endif
|
||||
/* compute leaky ReLU by hand. ToDo: try tanh activation */
|
||||
for (i = 0; i < frame_size; i ++)
|
||||
{
|
||||
in_buffer[i] = out_buffer[i] >= 0 ? out_buffer[i] : 0.2f * out_buffer[i];
|
||||
float tmp = out_buffer[i] + tmp_buffer[i];
|
||||
in_buffer[i] = tmp >= 0 ? tmp : 0.2 * tmp;
|
||||
}
|
||||
#ifdef DEBUG_NNDSP
|
||||
print_float_vector("post_alpha1", in_buffer, frame_size);
|
||||
|
|
|
@ -71,7 +71,8 @@ typedef struct {
|
|||
|
||||
|
||||
typedef struct {
|
||||
float conv_alpha1_state[ADASHAPE_MAX_INPUT_DIM];
|
||||
float conv_alpha1f_state[ADASHAPE_MAX_INPUT_DIM];
|
||||
float conv_alpha1t_state[ADASHAPE_MAX_INPUT_DIM];
|
||||
float conv_alpha2_state[ADASHAPE_MAX_FRAME_SIZE];
|
||||
} AdaShapeState;
|
||||
|
||||
|
@ -130,7 +131,8 @@ void adashape_process_frame(
|
|||
float *x_out,
|
||||
const float *x_in,
|
||||
const float *features,
|
||||
const LinearLayer *alpha1,
|
||||
const LinearLayer *alpha1f,
|
||||
const LinearLayer *alpha1t,
|
||||
const LinearLayer *alpha2,
|
||||
int feature_dim,
|
||||
int frame_size,
|
||||
|
|
17
dnn/osce.c
17
dnn/osce.c
|
@ -155,7 +155,7 @@ static void lace_feature_net(
|
|||
&hLACE->layers.lace_fnet_tconv,
|
||||
output_buffer,
|
||||
input_buffer,
|
||||
ACTIVATION_LINEAR,
|
||||
ACTIVATION_TANH,
|
||||
arch
|
||||
);
|
||||
|
||||
|
@ -426,7 +426,7 @@ static void nolace_feature_net(
|
|||
&hNoLACE->layers.nolace_fnet_tconv,
|
||||
output_buffer,
|
||||
input_buffer,
|
||||
ACTIVATION_LINEAR,
|
||||
ACTIVATION_TANH,
|
||||
arch
|
||||
);
|
||||
|
||||
|
@ -633,7 +633,8 @@ static void nolace_process_20ms_frame(
|
|||
x_buffer2 + i_subframe * NOLACE_AF1_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
|
||||
x_buffer2 + i_subframe * NOLACE_AF1_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
|
||||
feature_buffer + i_subframe * NOLACE_COND_DIM,
|
||||
&layers->nolace_tdshape1_alpha1,
|
||||
&layers->nolace_tdshape1_alpha1_f,
|
||||
&layers->nolace_tdshape1_alpha1_t,
|
||||
&layers->nolace_tdshape1_alpha2,
|
||||
NOLACE_TDSHAPE1_FEATURE_DIM,
|
||||
NOLACE_TDSHAPE1_FRAME_SIZE,
|
||||
|
@ -688,7 +689,8 @@ static void nolace_process_20ms_frame(
|
|||
x_buffer1 + i_subframe * NOLACE_AF2_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
|
||||
x_buffer1 + i_subframe * NOLACE_AF2_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
|
||||
feature_buffer + i_subframe * NOLACE_COND_DIM,
|
||||
&layers->nolace_tdshape2_alpha1,
|
||||
&layers->nolace_tdshape2_alpha1_f,
|
||||
&layers->nolace_tdshape2_alpha1_t,
|
||||
&layers->nolace_tdshape2_alpha2,
|
||||
NOLACE_TDSHAPE2_FEATURE_DIM,
|
||||
NOLACE_TDSHAPE2_FRAME_SIZE,
|
||||
|
@ -739,7 +741,8 @@ static void nolace_process_20ms_frame(
|
|||
x_buffer2 + i_subframe * NOLACE_AF3_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
|
||||
x_buffer2 + i_subframe * NOLACE_AF3_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
|
||||
feature_buffer + i_subframe * NOLACE_COND_DIM,
|
||||
&layers->nolace_tdshape3_alpha1,
|
||||
&layers->nolace_tdshape3_alpha1_f,
|
||||
&layers->nolace_tdshape3_alpha1_t,
|
||||
&layers->nolace_tdshape3_alpha2,
|
||||
NOLACE_TDSHAPE3_FEATURE_DIM,
|
||||
NOLACE_TDSHAPE3_FRAME_SIZE,
|
||||
|
@ -884,7 +887,7 @@ int osce_load_models(OSCEModel *model, const unsigned char *data, int len)
|
|||
if (ret == 0) {ret = init_lace(&model->lace, list);}
|
||||
#endif
|
||||
|
||||
#ifndef DISABLE_LACE
|
||||
#ifndef DISABLE_NOLACE
|
||||
if (ret == 0) {ret = init_nolace(&model->nolace, list);}
|
||||
#endif
|
||||
|
||||
|
@ -898,7 +901,7 @@ int osce_load_models(OSCEModel *model, const unsigned char *data, int len)
|
|||
if (ret == 0) {ret = init_lace(&model->lace, lacelayers_arrays);}
|
||||
#endif
|
||||
|
||||
#ifndef DISABLE_LACE
|
||||
#ifndef DISABLE_NOLACE
|
||||
if (ret == 0) {ret = init_nolace(&model->nolace, nolacelayers_arrays);}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@
|
|||
|
||||
#define OSCE_PREEMPH 0.85f
|
||||
|
||||
#define OSCE_PITCH_HANGOVER 8
|
||||
#define OSCE_PITCH_HANGOVER 0
|
||||
|
||||
#define OSCE_CLEAN_SPEC_START 0
|
||||
#define OSCE_CLEAN_SPEC_LENGTH 64
|
||||
|
|
|
@ -296,6 +296,7 @@ static void calculate_acorr(float *acorr, float *signal, int lag)
|
|||
static int pitch_postprocessing(OSCEFeatureState *psFeatures, int lag, int type)
|
||||
{
|
||||
int new_lag;
|
||||
int modulus;
|
||||
|
||||
#ifdef OSCE_HANGOVER_BUGFIX
|
||||
#define TESTBIT 1
|
||||
|
@ -303,6 +304,9 @@ static int pitch_postprocessing(OSCEFeatureState *psFeatures, int lag, int type)
|
|||
#define TESTBIT 0
|
||||
#endif
|
||||
|
||||
modulus = OSCE_PITCH_HANGOVER;
|
||||
if (modulus == 0) modulus ++;
|
||||
|
||||
/* hangover is currently disabled to reflect a bug in the python code. ToDo: re-evaluate hangover */
|
||||
if (type != TYPE_VOICED && psFeatures->last_type == TYPE_VOICED && TESTBIT)
|
||||
/* enter hangover */
|
||||
|
@ -311,14 +315,14 @@ static int pitch_postprocessing(OSCEFeatureState *psFeatures, int lag, int type)
|
|||
if (psFeatures->pitch_hangover_count < OSCE_PITCH_HANGOVER)
|
||||
{
|
||||
new_lag = psFeatures->last_lag;
|
||||
psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % OSCE_PITCH_HANGOVER;
|
||||
psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % modulus;
|
||||
}
|
||||
}
|
||||
else if (type != TYPE_VOICED && psFeatures->pitch_hangover_count && TESTBIT)
|
||||
/* continue hangover */
|
||||
{
|
||||
new_lag = psFeatures->last_lag;
|
||||
psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % OSCE_PITCH_HANGOVER;
|
||||
psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % modulus;
|
||||
}
|
||||
else if (type != TYPE_VOICED)
|
||||
/* unvoiced frame after hangover */
|
||||
|
@ -376,11 +380,7 @@ void osce_calculate_features(
|
|||
/* smooth bit count */
|
||||
psFeatures->numbits_smooth = 0.9f * psFeatures->numbits_smooth + 0.1f * num_bits;
|
||||
numbits[0] = num_bits;
|
||||
#ifdef OSCE_NUMBITS_BUGFIX
|
||||
numbits[1] = psFeatures->numbits_smooth;
|
||||
#else
|
||||
numbits[1] = num_bits;
|
||||
#endif
|
||||
|
||||
for (n = 0; n < num_samples; n++)
|
||||
{
|
||||
|
|
2
dnn/torch/dnntools/dnntools/__init__.py
Normal file
2
dnn/torch/dnntools/dnntools/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from . import quantization
|
||||
from . import sparsification
|
1
dnn/torch/dnntools/dnntools/quantization/__init__.py
Normal file
1
dnn/torch/dnntools/dnntools/quantization/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .softquant import soft_quant, remove_soft_quant
|
113
dnn/torch/dnntools/dnntools/quantization/softquant.py
Normal file
113
dnn/torch/dnntools/dnntools/quantization/softquant.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
import torch
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_optimal_scale(weight):
|
||||
with torch.no_grad():
|
||||
n_out, n_in = weight.shape
|
||||
assert n_in % 4 == 0
|
||||
if n_out % 8:
|
||||
# add padding
|
||||
pad = n_out - n_out % 8
|
||||
weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
|
||||
|
||||
weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
|
||||
weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
|
||||
scale_max = weight_max_abs / 127
|
||||
scale_sum = weight_max_sum / 129
|
||||
|
||||
scale = torch.maximum(scale_max, scale_sum)
|
||||
|
||||
return scale[:n_out]
|
||||
|
||||
@torch.no_grad()
|
||||
def q_scaled_noise(module, weight):
|
||||
if isinstance(module, torch.nn.Conv1d):
|
||||
w = weight.permute(0, 2, 1).flatten(1)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
noise[w == 0] = 0 # ignore zero entries from sparsification
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
|
||||
elif isinstance(module, torch.nn.ConvTranspose1d):
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(2, 1, 0).reshape(k * o, i)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
noise[w == 0] = 0 # ignore zero entries from sparsification
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(k, o, i).permute(2, 1, 0)
|
||||
elif len(weight.shape) == 2:
|
||||
noise = torch.rand_like(weight) - 0.5
|
||||
noise[weight == 0] = 0 # ignore zero entries from sparsification
|
||||
scale = compute_optimal_scale(weight)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError('unknown quantization setting')
|
||||
|
||||
return noise
|
||||
|
||||
class SoftQuant:
|
||||
name: str
|
||||
|
||||
def __init__(self, names: str, scale: float) -> None:
|
||||
self.names = names
|
||||
self.quantization_noise = None
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, module, inputs, *args, before=True):
|
||||
if not module.training: return
|
||||
|
||||
if before:
|
||||
self.quantization_noise = dict()
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
if self.scale is None:
|
||||
self.quantization_noise[name] = q_scaled_noise(module, weight)
|
||||
else:
|
||||
self.quantization_noise[name] = \
|
||||
self.scale * (torch.rand_like(weight) - 0.5)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight + self.quantization_noise[name]
|
||||
else:
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight - self.quantization_noise[name]
|
||||
self.quantization_noise = None
|
||||
|
||||
def apply(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant(names, scale)
|
||||
|
||||
for name in names:
|
||||
if not hasattr(module, name):
|
||||
raise ValueError("")
|
||||
|
||||
fn_before = lambda *x : fn(*x, before=True)
|
||||
fn_after = lambda *x : fn(*x, before=False)
|
||||
setattr(fn_before, 'sqm', fn)
|
||||
setattr(fn_after, 'sqm', fn)
|
||||
|
||||
|
||||
module.register_forward_pre_hook(fn_before)
|
||||
module.register_forward_hook(fn_after)
|
||||
|
||||
module
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def soft_quant(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant.apply(module, names, scale)
|
||||
return module
|
||||
|
||||
def remove_soft_quant(module, names=['weight']):
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_pre_hooks[k]
|
||||
for k, hook in module._forward_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_hooks[k]
|
||||
|
||||
return module
|
2
dnn/torch/dnntools/dnntools/relegance/__init__.py
Normal file
2
dnn/torch/dnntools/dnntools/relegance/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .relegance import relegance_gradient_weighting, relegance_create_tconv_kernel, relegance_map_relevance_to_input_domain, relegance_resize_relevance_to_input_size
|
||||
from .meta_critic import MetaCritic
|
85
dnn/torch/dnntools/dnntools/relegance/meta_critic.py
Normal file
85
dnn/torch/dnntools/dnntools/relegance/meta_critic.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
class MetaCritic():
|
||||
def __init__(self, normalize=False, gamma=0.9, beta=0.0, joint_stats=False):
|
||||
""" Class for assessing relevance of discriminator scores
|
||||
|
||||
Args:
|
||||
gamma (float, optional): update rate for tracking discriminator stats. Defaults to 0.9.
|
||||
beta (float, optional): Miminum confidence related threshold. Defaults to 0.0.
|
||||
"""
|
||||
self.normalize = normalize
|
||||
self.gamma = gamma
|
||||
self.beta = beta
|
||||
self.joint_stats = joint_stats
|
||||
|
||||
self.disc_stats = dict()
|
||||
|
||||
def __call__(self, disc_id, real_scores, generated_scores):
|
||||
""" calculates relevance from normalized scores
|
||||
|
||||
Args:
|
||||
disc_id (any valid key): id for tracking discriminator statistics
|
||||
real_scores (torch.tensor): scores for real data
|
||||
generated_scores (torch.tensor): scores for generated data; expecting device to match real_scores.device
|
||||
|
||||
Returns:
|
||||
torch.tensor: output-domain relevance
|
||||
"""
|
||||
|
||||
if self.normalize:
|
||||
real_std = torch.std(real_scores.detach()).cpu().item()
|
||||
gen_std = torch.std(generated_scores.detach()).cpu().item()
|
||||
std = (real_std**2 + gen_std**2) ** .5
|
||||
mean = torch.mean(real_scores.detach()).cpu().item() - torch.mean(generated_scores.detach()).cpu().item()
|
||||
|
||||
key = 0 if self.joint_stats else disc_id
|
||||
|
||||
if key in self.disc_stats:
|
||||
self.disc_stats[key]['std'] = self.gamma * self.disc_stats[key]['std'] + (1 - self.gamma) * std
|
||||
self.disc_stats[key]['mean'] = self.gamma * self.disc_stats[key]['mean'] + (1 - self.gamma) * mean
|
||||
else:
|
||||
self.disc_stats[key] = {
|
||||
'std': std + 1e-5,
|
||||
'mean': mean
|
||||
}
|
||||
|
||||
std = self.disc_stats[key]['std']
|
||||
mean = self.disc_stats[key]['mean']
|
||||
else:
|
||||
mean, std = 0, 1
|
||||
|
||||
relevance = torch.relu((real_scores - generated_scores - mean) / std + mean - self.beta)
|
||||
|
||||
if False: print(f"relevance({disc_id}): {relevance.min()=} {relevance.max()=} {relevance.mean()=}")
|
||||
|
||||
return relevance
|
449
dnn/torch/dnntools/dnntools/relegance/relegance.py
Normal file
449
dnn/torch/dnntools/dnntools/relegance/relegance.py
Normal file
|
@ -0,0 +1,449 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def view_one_hot(index, length):
|
||||
vec = length * [1]
|
||||
vec[index] = -1
|
||||
return vec
|
||||
|
||||
def create_smoothing_kernel(widths, gamma=1.5):
|
||||
""" creates a truncated gaussian smoothing kernel for the given widths
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
widths: list[Int] or torch.LongTensor
|
||||
specifies the shape of the smoothing kernel, entries must be > 0.
|
||||
|
||||
gamma: float, optional
|
||||
decay factor for gaussian relative to kernel size
|
||||
|
||||
Returns:
|
||||
--------
|
||||
kernel: torch.FloatTensor
|
||||
"""
|
||||
|
||||
widths = torch.LongTensor(widths)
|
||||
num_dims = len(widths)
|
||||
|
||||
assert(widths.min() > 0)
|
||||
|
||||
centers = widths.float() / 2 - 0.5
|
||||
sigmas = gamma * (centers + 1)
|
||||
|
||||
vals = []
|
||||
|
||||
vals= [((torch.arange(widths[i]) - centers[i]) / sigmas[i]) ** 2 for i in range(num_dims)]
|
||||
vals = sum([vals[i].view(view_one_hot(i, num_dims)) for i in range(num_dims)])
|
||||
|
||||
kernel = torch.exp(- vals)
|
||||
kernel = kernel / kernel.sum()
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def create_partition_kernel(widths, strides):
|
||||
""" creates a partition kernel for mapping a convolutional network output back to the input domain
|
||||
|
||||
Given a fully convolutional network with receptive field of shape widths and the given strides, this
|
||||
function construncts an intorpolation kernel whose tranlations by multiples of the given strides form
|
||||
a partition of one on the input domain.
|
||||
|
||||
Parameter:
|
||||
----------
|
||||
widths: list[Int] or torch.LongTensor
|
||||
shape of receptive field
|
||||
|
||||
strides: list[Int] or torch.LongTensor
|
||||
total strides of convolutional network
|
||||
|
||||
Returns:
|
||||
kernel: torch.FloatTensor
|
||||
"""
|
||||
|
||||
num_dims = len(widths)
|
||||
assert num_dims == len(strides) and num_dims in {1, 2, 3}
|
||||
|
||||
convs = {1 : F.conv1d, 2 : F.conv2d, 3 : F.conv3d}
|
||||
|
||||
widths = torch.LongTensor(widths)
|
||||
strides = torch.LongTensor(strides)
|
||||
|
||||
proto_kernel = torch.ones(torch.minimum(strides, widths).tolist())
|
||||
|
||||
# create interpolation kernel eta
|
||||
eta_widths = widths - strides + 1
|
||||
if eta_widths.min() <= 0:
|
||||
print("[create_partition_kernel] warning: receptive field does not cover input domain")
|
||||
eta_widths = torch.maximum(eta_widths, torch.ones_like(eta_widths))
|
||||
|
||||
|
||||
eta = create_smoothing_kernel(eta_widths).view(1, 1, *eta_widths.tolist())
|
||||
|
||||
padding = torch.repeat_interleave(eta_widths - 1, 2, 0).tolist()[::-1] # ordering of dimensions for padding and convolution functions is reversed in torch
|
||||
padded_proto_kernel = F.pad(proto_kernel, padding)
|
||||
padded_proto_kernel = padded_proto_kernel.view(1, 1, *padded_proto_kernel.shape)
|
||||
kernel = convs[num_dims](padded_proto_kernel, eta)
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def receptive_field(conv_model, input_shape, output_position):
|
||||
""" estimates boundaries of receptive field connected to output_position via autograd
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
conv_model: nn.Module or autograd function
|
||||
function or model implementing fully convolutional model
|
||||
|
||||
input_shape: List[Int]
|
||||
input shape ignoring batch dimension, i.e. [num_channels, dim1, dim2, ...]
|
||||
|
||||
output_position: List[Int]
|
||||
output position for which the receptive field is determined; the function raises an exception
|
||||
if output_position is out of bounds for the given input_shape.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
low: List[Int]
|
||||
start indices of receptive field
|
||||
|
||||
high: List[Int]
|
||||
stop indices of receptive field
|
||||
|
||||
"""
|
||||
|
||||
x = torch.randn((1,) + tuple(input_shape), requires_grad=True)
|
||||
y = conv_model(x)
|
||||
|
||||
# collapse channels and remove batch dimension
|
||||
y = torch.sum(y, 1)[0]
|
||||
|
||||
# create mask
|
||||
mask = torch.zeros_like(y)
|
||||
index = [torch.tensor(i) for i in output_position]
|
||||
try:
|
||||
mask.index_put_(index, torch.tensor(1, dtype=mask.dtype))
|
||||
except IndexError:
|
||||
raise ValueError('output_position out of bounds')
|
||||
|
||||
(mask * y).sum().backward()
|
||||
|
||||
# sum over channels and remove batch dimension
|
||||
grad = torch.sum(x.grad, dim=1)[0]
|
||||
tmp = torch.nonzero(grad, as_tuple=True)
|
||||
low = [t.min().item() for t in tmp]
|
||||
high = [t.max().item() for t in tmp]
|
||||
|
||||
return low, high
|
||||
|
||||
def estimate_conv_parameters(model, num_channels, num_dims, width, max_stride=10):
|
||||
""" attempts to estimate receptive field size, strides and left paddings for given model
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
model: nn.Module or autograd function
|
||||
fully convolutional model for which parameters are estimated
|
||||
|
||||
num_channels: Int
|
||||
number of input channels for model
|
||||
|
||||
num_dims: Int
|
||||
number of input dimensions for model (without channel dimension)
|
||||
|
||||
width: Int
|
||||
width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
|
||||
|
||||
max_stride: Int, optional
|
||||
assumed maximal stride of the model for any dimension, when set too low the function may fail for
|
||||
any value of width
|
||||
|
||||
Returns:
|
||||
--------
|
||||
receptive_field_size: List[Int]
|
||||
receptive field size in all dimension
|
||||
|
||||
strides: List[Int]
|
||||
stride in all dimensions
|
||||
|
||||
left_paddings: List[Int]
|
||||
left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
|
||||
|
||||
Raises:
|
||||
-------
|
||||
ValueError, KeyError
|
||||
|
||||
"""
|
||||
|
||||
input_shape = [num_channels] + num_dims * [width]
|
||||
output_position1 = num_dims * [width // (2 * max_stride)]
|
||||
output_position2 = num_dims * [width // (2 * max_stride) + 1]
|
||||
|
||||
low1, high1 = receptive_field(model, input_shape, output_position1)
|
||||
low2, high2 = receptive_field(model, input_shape, output_position2)
|
||||
|
||||
widths1 = [h - l + 1 for l, h in zip(low1, high1)]
|
||||
widths2 = [h - l + 1 for l, h in zip(low2, high2)]
|
||||
|
||||
if not all([w1 - w2 == 0 for w1, w2 in zip(widths1, widths2)]) or not all([l1 != l2 for l1, l2 in zip(low1, low2)]):
|
||||
raise ValueError("[estimate_strides]: widths to small to determine strides")
|
||||
|
||||
receptive_field_size = widths1
|
||||
strides = [l2 - l1 for l1, l2 in zip(low1, low2)]
|
||||
left_paddings = [s * p - l for l, s, p in zip(low1, strides, output_position1)]
|
||||
|
||||
return receptive_field_size, strides, left_paddings
|
||||
|
||||
def inspect_conv_model(model, num_channels, num_dims, max_width=10000, width_hint=None, stride_hint=None, verbose=False):
|
||||
""" determines size of receptive field, strides and padding probabilistically
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
model: nn.Module or autograd function
|
||||
fully convolutional model for which parameters are estimated
|
||||
|
||||
num_channels: Int
|
||||
number of input channels for model
|
||||
|
||||
num_dims: Int
|
||||
number of input dimensions for model (without channel dimension)
|
||||
|
||||
max_width: Int
|
||||
maximum width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
|
||||
|
||||
verbose: bool, optional
|
||||
if true, the function prints parameters for individual trials
|
||||
|
||||
Returns:
|
||||
--------
|
||||
receptive_field_size: List[Int]
|
||||
receptive field size in all dimension
|
||||
|
||||
strides: List[Int]
|
||||
stride in all dimensions
|
||||
|
||||
left_paddings: List[Int]
|
||||
left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
|
||||
|
||||
Raises:
|
||||
-------
|
||||
ValueError
|
||||
|
||||
"""
|
||||
|
||||
max_stride = max_width // 2
|
||||
stride = max_stride // 100
|
||||
width = max_width // 100
|
||||
|
||||
if width_hint is not None: width = 2 * width_hint
|
||||
if stride_hint is not None: stride = stride_hint
|
||||
|
||||
did_it = False
|
||||
while width < max_width and stride < max_stride:
|
||||
try:
|
||||
if verbose: print(f"[inspect_conv_model] trying parameters {width=}, {stride=}")
|
||||
receptive_field_size, strides, left_paddings = estimate_conv_parameters(model, num_channels, num_dims, width, stride)
|
||||
did_it = True
|
||||
except:
|
||||
pass
|
||||
|
||||
if did_it: break
|
||||
|
||||
width *= 2
|
||||
if width >= max_width and stride < max_stride:
|
||||
stride *= 2
|
||||
width = 2 * stride
|
||||
|
||||
if not did_it:
|
||||
raise ValueError(f'could not determine conv parameter with given max_width={max_width}')
|
||||
|
||||
return receptive_field_size, strides, left_paddings
|
||||
|
||||
|
||||
class GradWeight(torch.autograd.Function):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight):
|
||||
ctx.save_for_backward(weight)
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
weight, = ctx.saved_tensors
|
||||
|
||||
grad_input = grad_output * weight
|
||||
|
||||
return grad_input, None
|
||||
|
||||
|
||||
# API
|
||||
|
||||
def relegance_gradient_weighting(x, weight):
|
||||
"""
|
||||
|
||||
Args:
|
||||
x (torch.tensor): input tensor
|
||||
weight (torch.tensor or None): weight tensor for gradients of x; if None, no gradient weighting will be applied in backward pass
|
||||
|
||||
Returns:
|
||||
torch.tensor: the unmodified input tensor x
|
||||
|
||||
Raises:
|
||||
RuntimeError: if estimation of parameters fails due to exceeded compute budget
|
||||
"""
|
||||
if weight is None:
|
||||
return x
|
||||
else:
|
||||
return GradWeight.apply(x, weight)
|
||||
|
||||
|
||||
|
||||
def relegance_create_tconv_kernel(model, num_channels, num_dims, width_hint=None, stride_hint=None, verbose=False):
|
||||
""" creates parameters for mapping back output domain relevance to input tomain
|
||||
|
||||
Args:
|
||||
model (nn.Module or autograd.Function): fully convolutional model
|
||||
num_channels (int): number of input channels to model
|
||||
num_dims (int): number of input dimensions of model (without channel and batch dimension)
|
||||
width_hint(int or None): optional hint at maximal width of receptive field
|
||||
stride_hint(int or None): optional hint at maximal stride
|
||||
|
||||
Returns:
|
||||
dict: contains kernel, kernel dimensions, strides and left paddings for transposed convolution
|
||||
"""
|
||||
|
||||
max_width = int(100000 / (10 ** num_dims))
|
||||
|
||||
did_it = False
|
||||
try:
|
||||
receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
|
||||
did_it = True
|
||||
except:
|
||||
# try once again with larger max_width
|
||||
max_width *= 10
|
||||
|
||||
# crash if exception is raised
|
||||
try:
|
||||
if not did_it: receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
|
||||
except:
|
||||
raise RuntimeError("could not determine parameters within given compute budget")
|
||||
|
||||
partition_kernel = create_partition_kernel(receptive_field_size, strides)
|
||||
partition_kernel = torch.repeat_interleave(partition_kernel, num_channels, 1)
|
||||
|
||||
tconv_parameters = {
|
||||
'kernel': partition_kernel,
|
||||
'receptive_field_shape': receptive_field_size,
|
||||
'stride': strides,
|
||||
'left_padding': left_paddings,
|
||||
'num_dims': num_dims
|
||||
}
|
||||
|
||||
return tconv_parameters
|
||||
|
||||
|
||||
|
||||
def relegance_map_relevance_to_input_domain(od_relevance, tconv_parameters):
|
||||
""" maps output-domain relevance to input-domain relevance via transpose convolution
|
||||
|
||||
Args:
|
||||
od_relevance (torch.tensor): output-domain relevance
|
||||
tconv_parameters (dict): parameter dict as created by relegance_create_tconv_kernel
|
||||
|
||||
Returns:
|
||||
torch.tensor: input-domain relevance. The tensor is left aligned, i.e. the all-zero index of the output corresponds to the all-zero index of the discriminator input.
|
||||
Otherwise, the size of the output tensor does not need to match the size of the discriminator input. Use relegance_resize_relevance_to_input_size for a
|
||||
convenient way to adjust the output to the correct size.
|
||||
|
||||
Raises:
|
||||
ValueError: if number of dimensions is not supported
|
||||
"""
|
||||
|
||||
kernel = tconv_parameters['kernel'].to(od_relevance.device)
|
||||
rf_shape = tconv_parameters['receptive_field_shape']
|
||||
stride = tconv_parameters['stride']
|
||||
left_padding = tconv_parameters['left_padding']
|
||||
|
||||
num_dims = len(kernel.shape) - 2
|
||||
|
||||
# repeat boundary values
|
||||
od_padding = [rf_shape[i//2] // stride[i//2] + 1 for i in range(2 * num_dims)]
|
||||
padded_od_relevance = F.pad(od_relevance, od_padding[::-1], mode='replicate')
|
||||
od_padding = od_padding[::2]
|
||||
|
||||
# apply mapping and left trimming
|
||||
if num_dims == 1:
|
||||
id_relevance = F.conv_transpose1d(padded_od_relevance, kernel, stride=stride)
|
||||
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :]
|
||||
elif num_dims == 2:
|
||||
id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
|
||||
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:]
|
||||
elif num_dims == 3:
|
||||
id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
|
||||
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:, left_padding[2] + stride[2] * od_padding[2] :]
|
||||
else:
|
||||
raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
|
||||
|
||||
return id_relevance
|
||||
|
||||
|
||||
def relegance_resize_relevance_to_input_size(reference_input, relevance):
|
||||
""" adjusts size of relevance tensor to reference input size
|
||||
|
||||
Args:
|
||||
reference_input (torch.tensor): discriminator input tensor for reference
|
||||
relevance (torch.tensor): input-domain relevance corresponding to input tensor reference_input
|
||||
|
||||
Returns:
|
||||
torch.tensor: resized relevance
|
||||
|
||||
Raises:
|
||||
ValueError: if number of dimensions is not supported
|
||||
"""
|
||||
resized_relevance = torch.zeros_like(reference_input)
|
||||
|
||||
num_dims = len(reference_input.shape) - 2
|
||||
with torch.no_grad():
|
||||
if num_dims == 1:
|
||||
resized_relevance[:] = relevance[..., : min(reference_input.size(-1), relevance.size(-1))]
|
||||
elif num_dims == 2:
|
||||
resized_relevance[:] = relevance[..., : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
|
||||
elif num_dims == 3:
|
||||
resized_relevance[:] = relevance[..., : min(reference_input.size(-3), relevance.size(-3)), : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
|
||||
else:
|
||||
raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
|
||||
|
||||
return resized_relevance
|
6
dnn/torch/dnntools/dnntools/sparsification/__init__.py
Normal file
6
dnn/torch/dnntools/dnntools/sparsification/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from .gru_sparsifier import GRUSparsifier
|
||||
from .conv1d_sparsifier import Conv1dSparsifier
|
||||
from .conv_transpose1d_sparsifier import ConvTranspose1dSparsifier
|
||||
from .linear_sparsifier import LinearSparsifier
|
||||
from .common import sparsify_matrix, calculate_gru_flops_per_step
|
||||
from .utils import mark_for_sparsification, create_sparsifier
|
|
@ -0,0 +1,58 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
class BaseSparsifier:
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
|
||||
# just copying parameters...
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.interval = interval
|
||||
self.exponent = exponent
|
||||
self.task_list = task_list
|
||||
|
||||
# ... and setting counter to 0
|
||||
self.step_counter = 0
|
||||
|
||||
def step(self, verbose=False):
|
||||
# compute current interpolation factor
|
||||
self.step_counter += 1
|
||||
|
||||
if self.step_counter < self.start:
|
||||
return
|
||||
elif self.step_counter < self.stop:
|
||||
# update only every self.interval-th interval
|
||||
if self.step_counter % self.interval:
|
||||
return
|
||||
|
||||
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
|
||||
else:
|
||||
alpha = 0
|
||||
|
||||
self.sparsify(alpha, verbose=verbose)
|
123
dnn/torch/dnntools/dnntools/sparsification/common.py
Normal file
123
dnn/torch/dnntools/dnntools/sparsification/common.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
debug=True
|
||||
|
||||
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
|
||||
""" sparsifies matrix with specified block size
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
matrix : torch.tensor
|
||||
matrix to sparsify
|
||||
density : int
|
||||
target density
|
||||
block_size : [int, int]
|
||||
block size dimensions
|
||||
keep_diagonal : bool
|
||||
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
|
||||
"""
|
||||
|
||||
m, n = matrix.shape
|
||||
m1, n1 = block_size
|
||||
|
||||
if m % m1 or n % n1:
|
||||
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
|
||||
|
||||
# extract diagonal if keep_diagonal = True
|
||||
if keep_diagonal:
|
||||
if m != n:
|
||||
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
|
||||
|
||||
to_spare = torch.diag(torch.diag(matrix))
|
||||
matrix = matrix - to_spare
|
||||
else:
|
||||
to_spare = torch.zeros_like(matrix)
|
||||
|
||||
# calculate energy in sub-blocks
|
||||
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
|
||||
x = x ** 2
|
||||
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
|
||||
|
||||
number_of_blocks = (m * n) // (m1 * n1)
|
||||
number_of_survivors = round(number_of_blocks * density)
|
||||
|
||||
# masking threshold
|
||||
if number_of_survivors == 0:
|
||||
threshold = 0
|
||||
else:
|
||||
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
|
||||
|
||||
# create mask
|
||||
mask = torch.ones_like(block_energies)
|
||||
mask[block_energies < threshold] = 0
|
||||
mask = torch.repeat_interleave(mask, m1, dim=0)
|
||||
mask = torch.repeat_interleave(mask, n1, dim=1)
|
||||
|
||||
# perform masking
|
||||
masked_matrix = mask * matrix + to_spare
|
||||
|
||||
if return_mask:
|
||||
return masked_matrix, mask
|
||||
else:
|
||||
return masked_matrix
|
||||
|
||||
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
|
||||
input_size = gru.input_size
|
||||
hidden_size = gru.hidden_size
|
||||
flops = 0
|
||||
|
||||
input_density = (
|
||||
sparsification_dict.get('W_ir', [1])[0]
|
||||
+ sparsification_dict.get('W_in', [1])[0]
|
||||
+ sparsification_dict.get('W_iz', [1])[0]
|
||||
) / 3
|
||||
|
||||
recurrent_density = (
|
||||
sparsification_dict.get('W_hr', [1])[0]
|
||||
+ sparsification_dict.get('W_hn', [1])[0]
|
||||
+ sparsification_dict.get('W_hz', [1])[0]
|
||||
) / 3
|
||||
|
||||
# input matrix vector multiplications
|
||||
if not drop_input:
|
||||
flops += 2 * 3 * input_size * hidden_size * input_density
|
||||
|
||||
# recurrent matrix vector multiplications
|
||||
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
|
||||
|
||||
# biases
|
||||
flops += 6 * hidden_size
|
||||
|
||||
# activations estimated by 10 flops per activation
|
||||
flops += 30 * hidden_size
|
||||
|
||||
return flops
|
133
dnn/torch/dnntools/dnntools/sparsification/conv1d_sparsifier.py
Normal file
133
dnn/torch/dnntools/dnntools/sparsification/conv1d_sparsifier.py
Normal file
|
@ -0,0 +1,133 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix, debug
|
||||
|
||||
|
||||
class Conv1dSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (conv1d, params), where conv1d is an instance
|
||||
of torch.nn.Conv1d and params is a tuple (density, [m, n]),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> conv = torch.nn.Conv1d(8, 16, 8)
|
||||
>>> params = (0.2, [8, 4])
|
||||
>>> sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_mask = None
|
||||
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for conv, params in self.task_list:
|
||||
# reshape weight
|
||||
if hasattr(conv, 'weight_v'):
|
||||
weight = conv.weight_v
|
||||
else:
|
||||
weight = conv.weight
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(0, 2, 1).flatten(1)
|
||||
target_density, block_size = params
|
||||
density = alpha + (1 - alpha) * target_density
|
||||
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
|
||||
w = w.reshape(i, k, o).permute(0, 2, 1)
|
||||
weight[:] = w
|
||||
|
||||
if self.last_mask is not None:
|
||||
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in conv.weight")
|
||||
|
||||
self.last_mask = new_mask
|
||||
|
||||
if verbose:
|
||||
print(f"conv1d_sparsier[{self.step_counter}]: {density=}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
import torch
|
||||
conv = torch.nn.Conv1d(8, 16, 8)
|
||||
params = (0.2, [8, 4])
|
||||
|
||||
sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 5)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(conv.weight)
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix, debug
|
||||
|
||||
|
||||
class ConvTranspose1dSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (conv1d, params), where conv1d is an instance
|
||||
of torch.nn.Conv1d and params is a tuple (density, [m, n]),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> conv = torch.nn.ConvTranspose1d(8, 16, 8)
|
||||
>>> params = (0.2, [8, 4])
|
||||
>>> sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_mask = None
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for conv, params in self.task_list:
|
||||
# reshape weight
|
||||
if hasattr(conv, 'weight_v'):
|
||||
weight = conv.weight_v
|
||||
else:
|
||||
weight = conv.weight
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(2, 1, 0).reshape(k * o, i)
|
||||
target_density, block_size = params
|
||||
density = alpha + (1 - alpha) * target_density
|
||||
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
|
||||
w = w.reshape(k, o, i).permute(2, 1, 0)
|
||||
weight[:] = w
|
||||
|
||||
if self.last_mask is not None:
|
||||
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in conv.weight")
|
||||
|
||||
self.last_mask = new_mask
|
||||
|
||||
if verbose:
|
||||
print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
import torch
|
||||
conv = torch.nn.ConvTranspose1d(8, 16, 4, 4)
|
||||
params = (0.2, [8, 4])
|
||||
|
||||
sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 5)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(conv.weight)
|
178
dnn/torch/dnntools/dnntools/sparsification/gru_sparsifier.py
Normal file
178
dnn/torch/dnntools/dnntools/sparsification/gru_sparsifier.py
Normal file
|
@ -0,0 +1,178 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix, debug
|
||||
|
||||
|
||||
class GRUSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
|
||||
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
|
||||
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
|
||||
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
|
||||
should be kept.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> gru = torch.nn.GRU(10, 20)
|
||||
>>> sparsify_dict = {
|
||||
... 'W_ir' : (0.5, [2, 2], False),
|
||||
... 'W_iz' : (0.6, [2, 2], False),
|
||||
... 'W_in' : (0.7, [2, 2], False),
|
||||
... 'W_hr' : (0.1, [4, 4], True),
|
||||
... 'W_hz' : (0.2, [4, 4], True),
|
||||
... 'W_hn' : (0.3, [4, 4], True),
|
||||
... }
|
||||
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for gru, params in self.task_list:
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
# input weights
|
||||
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
|
||||
if key in params:
|
||||
if hasattr(gru, 'weight_ih_l0_v'):
|
||||
weight = gru.weight_ih_l0_v
|
||||
else:
|
||||
weight = gru.weight_ih_l0
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
|
||||
weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
weight[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density, # density
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in weight_ih_l0_v")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
# recurrent weights
|
||||
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
|
||||
if key in params:
|
||||
if hasattr(gru, 'weight_hh_l0_v'):
|
||||
weight = gru.weight_hh_l0_v
|
||||
else:
|
||||
weight = gru.weight_hh_l0
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
weight[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density,
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] * new_mask == new_mask) and True:
|
||||
print("weight resurrection in weight_hh_l0_v")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
gru = torch.nn.GRU(10, 20)
|
||||
sparsify_dict = {
|
||||
'W_ir' : (0.5, [2, 2], False),
|
||||
'W_iz' : (0.6, [2, 2], False),
|
||||
'W_in' : (0.7, [2, 2], False),
|
||||
'W_hr' : (0.1, [4, 4], True),
|
||||
'W_hz' : (0.2, [4, 4], True),
|
||||
'W_hn' : (0.3, [4, 4], True),
|
||||
}
|
||||
|
||||
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(gru.weight_hh_l0)
|
128
dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py
Normal file
128
dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix
|
||||
|
||||
|
||||
class LinearSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (linear, params), where linear is an instance
|
||||
of torch.nn.Linear and params is a tuple (density, [m, n]),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> linear = torch.nn.Linear(8, 16)
|
||||
>>> params = (0.2, [8, 4])
|
||||
>>> sparsifier = LinearSparsifier([(linear, params)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_mask = None
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for linear, params in self.task_list:
|
||||
if hasattr(linear, 'weight_v'):
|
||||
weight = linear.weight_v
|
||||
else:
|
||||
weight = linear.weight
|
||||
target_density, block_size = params
|
||||
density = alpha + (1 - alpha) * target_density
|
||||
weight[:], new_mask = sparsify_matrix(weight, density, block_size, return_mask=True)
|
||||
|
||||
if self.last_mask is not None:
|
||||
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in conv.weight")
|
||||
|
||||
self.last_mask = new_mask
|
||||
|
||||
if verbose:
|
||||
print(f"linear_sparsifier[{self.step_counter}]: {density=}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
import torch
|
||||
linear = torch.nn.Linear(8, 16)
|
||||
params = (0.2, [4, 2])
|
||||
|
||||
sparsifier = LinearSparsifier([(linear, params)], 0, 100, 5)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(linear.weight)
|
64
dnn/torch/dnntools/dnntools/sparsification/utils.py
Normal file
64
dnn/torch/dnntools/dnntools/sparsification/utils.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import torch
|
||||
|
||||
from dnntools.sparsification import GRUSparsifier, LinearSparsifier, Conv1dSparsifier, ConvTranspose1dSparsifier
|
||||
|
||||
def mark_for_sparsification(module, params):
|
||||
setattr(module, 'sparsify', True)
|
||||
setattr(module, 'sparsification_params', params)
|
||||
return module
|
||||
|
||||
def create_sparsifier(module, start, stop, interval):
|
||||
sparsifier_list = []
|
||||
for m in module.modules():
|
||||
if hasattr(m, 'sparsify'):
|
||||
if isinstance(m, torch.nn.GRU):
|
||||
sparsifier_list.append(
|
||||
GRUSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
elif isinstance(m, torch.nn.Linear):
|
||||
sparsifier_list.append(
|
||||
LinearSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
elif isinstance(m, torch.nn.Conv1d):
|
||||
sparsifier_list.append(
|
||||
Conv1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
elif isinstance(m, torch.nn.ConvTranspose1d):
|
||||
sparsifier_list.append(
|
||||
ConvTranspose1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
else:
|
||||
print(f"[create_sparsifier] warning: module {m} marked for sparsification but no suitable sparsifier exists.")
|
||||
|
||||
def sparsify(verbose=False):
|
||||
for sparsifier in sparsifier_list:
|
||||
sparsifier.step(verbose)
|
||||
|
||||
return sparsify
|
||||
|
||||
|
||||
def count_parameters(model, verbose=False):
|
||||
total = 0
|
||||
for name, p in model.named_parameters():
|
||||
count = torch.ones_like(p).sum().item()
|
||||
|
||||
if verbose:
|
||||
print(f"{name}: {count} parameters")
|
||||
|
||||
total += count
|
||||
|
||||
return total
|
||||
|
||||
def estimate_nonzero_parameters(module):
|
||||
num_zero_parameters = 0
|
||||
if hasattr(module, 'sparsify'):
|
||||
params = module.sparsification_params
|
||||
if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.ConvTranspose1d):
|
||||
num_zero_parameters = torch.ones_like(module.weight).sum().item() * (1 - params[0])
|
||||
elif isinstance(module, torch.nn.GRU):
|
||||
num_zero_parameters = module.input_size * module.hidden_size * (3 - params['W_ir'][0] - params['W_iz'][0] - params['W_in'][0])
|
||||
num_zero_parameters += module.hidden_size * module.hidden_size * (3 - params['W_hr'][0] - params['W_hz'][0] - params['W_hn'][0])
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
num_zero_parameters = module.in_features * module.out_features * params[0]
|
||||
else:
|
||||
raise ValueError(f'unknown sparsification method for module of type {type(module)}')
|
1
dnn/torch/dnntools/requirements.txt
Normal file
1
dnn/torch/dnntools/requirements.txt
Normal file
|
@ -0,0 +1 @@
|
|||
torch
|
48
dnn/torch/dnntools/setup.py
Normal file
48
dnn/torch/dnntools/setup.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
#!/usr/bin/env/python
|
||||
import os
|
||||
from setuptools import setup
|
||||
|
||||
lib_folder = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
with open(os.path.join(lib_folder, 'requirements.txt'), 'r') as f:
|
||||
install_requires = list(f.read().splitlines())
|
||||
|
||||
print(install_requires)
|
||||
|
||||
setup(name='dnntools',
|
||||
version='1.0',
|
||||
author='Jan Buethe',
|
||||
author_email='jbuethe@amazon.de',
|
||||
description='Non-Standard tools for deep neural network training with PyTorch',
|
||||
packages=['dnntools', 'dnntools.sparsification', 'dnntools.quantization'],
|
||||
install_requires=install_requires
|
||||
)
|
|
@ -111,7 +111,7 @@ os.makedirs(checkpoint_dir, exist_ok=True)
|
|||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir)
|
||||
repo = git.Repo(working_dir, search_parent_directories=True)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
|
@ -408,6 +408,10 @@ for ep in range(1, epochs + 1):
|
|||
|
||||
optimizer.step()
|
||||
|
||||
# sparsification
|
||||
if hasattr(model, 'sparsifier'):
|
||||
model.sparsifier()
|
||||
|
||||
running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
|
||||
running_adv_loss += gen_loss.detach().cpu().item()
|
||||
running_disc_loss += disc_loss.detach().cpu().item()
|
||||
|
|
|
@ -111,7 +111,7 @@ os.makedirs(checkpoint_dir, exist_ok=True)
|
|||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir)
|
||||
repo = git.Repo(working_dir, search_parent_directories=True)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
|
|
|
@ -46,6 +46,10 @@ def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler,
|
|||
# update learning rate
|
||||
scheduler.step()
|
||||
|
||||
# sparsification
|
||||
if hasattr(model, 'sparsifier'):
|
||||
model.sparsifier()
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
|
@ -73,8 +77,6 @@ def evaluate(model, criterion, dataloader, device, log_interval=10):
|
|||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
|
|
@ -43,6 +43,7 @@ from models import model_dict
|
|||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
from utils.layers.td_shaper import TDShaper
|
||||
from utils.misc import remove_all_weight_norm
|
||||
from wexchange.torch import dump_torch_weights
|
||||
|
||||
|
||||
|
@ -58,30 +59,30 @@ schedules = {
|
|||
'nolace': [
|
||||
('pitch_embedding', dict()),
|
||||
('feature_net.conv1', dict()),
|
||||
('feature_net.conv2', dict(quantize=True, scale=None)),
|
||||
('feature_net.tconv', dict(quantize=True, scale=None)),
|
||||
('feature_net.gru', dict()),
|
||||
('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
|
||||
('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
|
||||
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
|
||||
('cf1', dict(quantize=True, scale=None)),
|
||||
('cf2', dict(quantize=True, scale=None)),
|
||||
('af1', dict(quantize=True, scale=None)),
|
||||
('tdshape1', dict()),
|
||||
('tdshape2', dict()),
|
||||
('tdshape3', dict()),
|
||||
('tdshape1', dict(quantize=True, scale=None)),
|
||||
('tdshape2', dict(quantize=True, scale=None)),
|
||||
('tdshape3', dict(quantize=True, scale=None)),
|
||||
('af2', dict(quantize=True, scale=None)),
|
||||
('af3', dict(quantize=True, scale=None)),
|
||||
('af4', dict(quantize=True, scale=None)),
|
||||
('post_cf1', dict(quantize=True, scale=None)),
|
||||
('post_cf2', dict(quantize=True, scale=None)),
|
||||
('post_af1', dict(quantize=True, scale=None)),
|
||||
('post_af2', dict(quantize=True, scale=None)),
|
||||
('post_af3', dict(quantize=True, scale=None))
|
||||
('post_cf1', dict(quantize=True, scale=None, sparse=True)),
|
||||
('post_cf2', dict(quantize=True, scale=None, sparse=True)),
|
||||
('post_af1', dict(quantize=True, scale=None, sparse=True)),
|
||||
('post_af2', dict(quantize=True, scale=None, sparse=True)),
|
||||
('post_af3', dict(quantize=True, scale=None, sparse=True))
|
||||
],
|
||||
'lace' : [
|
||||
('pitch_embedding', dict()),
|
||||
('feature_net.conv1', dict()),
|
||||
('feature_net.conv2', dict(quantize=True, scale=None)),
|
||||
('feature_net.tconv', dict(quantize=True, scale=None)),
|
||||
('feature_net.gru', dict()),
|
||||
('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
|
||||
('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
|
||||
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
|
||||
('cf1', dict(quantize=True, scale=None)),
|
||||
('cf2', dict(quantize=True, scale=None)),
|
||||
('af1', dict(quantize=True, scale=None))
|
||||
|
@ -140,6 +141,7 @@ if __name__ == "__main__":
|
|||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
remove_all_weight_norm(model, verbose=True)
|
||||
|
||||
# CWriter
|
||||
model_name = checkpoint['setup']['model']['name']
|
||||
|
|
|
@ -41,6 +41,12 @@ from models.silk_feature_net_pl import SilkFeatureNetPL
|
|||
from models.silk_feature_net import SilkFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
import sys
|
||||
sys.path.append('../dnntools')
|
||||
|
||||
from dnntools.sparsification import create_sparsifier
|
||||
|
||||
|
||||
class LACE(NNSBase):
|
||||
""" Linear-Adaptive Coding Enhancer """
|
||||
FRAME_SIZE=80
|
||||
|
@ -60,7 +66,12 @@ class LACE(NNSBase):
|
|||
numbits_embedding_dim=8,
|
||||
hidden_feature_dim=64,
|
||||
partial_lookahead=True,
|
||||
norm_p=2):
|
||||
norm_p=2,
|
||||
softquant=False,
|
||||
sparsify=False,
|
||||
sparsification_schedule=[10000, 30000, 100],
|
||||
sparsification_density=0.5,
|
||||
apply_weight_norm=False):
|
||||
|
||||
super().__init__(skip=skip, preemph=preemph)
|
||||
|
||||
|
@ -85,18 +96,21 @@ class LACE(NNSBase):
|
|||
|
||||
# feature net
|
||||
if partial_lookahead:
|
||||
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
|
||||
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
|
||||
else:
|
||||
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
|
||||
|
||||
# comb filters
|
||||
left_pad = self.kernel_size // 2
|
||||
right_pad = self.kernel_size - 1 - left_pad
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# spectral shaping
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
if sparsify:
|
||||
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
|
|
|
@ -27,9 +27,13 @@
|
|||
*/
|
||||
"""
|
||||
|
||||
import numbers
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -43,6 +47,11 @@ from models.silk_feature_net_pl import SilkFeatureNetPL
|
|||
from models.silk_feature_net import SilkFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
import sys
|
||||
sys.path.append('../dnntools')
|
||||
from dnntools.quantization import soft_quant
|
||||
from dnntools.sparsification import create_sparsifier, mark_for_sparsification
|
||||
|
||||
class NoLACE(NNSBase):
|
||||
""" Non-Linear Adaptive Coding Enhancer """
|
||||
FRAME_SIZE=80
|
||||
|
@ -64,11 +73,15 @@ class NoLACE(NNSBase):
|
|||
partial_lookahead=True,
|
||||
norm_p=2,
|
||||
avg_pool_k=4,
|
||||
pool_after=False):
|
||||
pool_after=False,
|
||||
softquant=False,
|
||||
sparsify=False,
|
||||
sparsification_schedule=[100, 1000, 100],
|
||||
sparsification_density=0.5,
|
||||
apply_weight_norm=False):
|
||||
|
||||
super().__init__(skip=skip, preemph=preemph)
|
||||
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
|
@ -81,6 +94,11 @@ class NoLACE(NNSBase):
|
|||
self.hidden_feature_dim = hidden_feature_dim
|
||||
self.partial_lookahead = partial_lookahead
|
||||
|
||||
if isinstance(sparsification_density, numbers.Number):
|
||||
sparsification_density = 10 * [sparsification_density]
|
||||
|
||||
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
|
@ -89,36 +107,52 @@ class NoLACE(NNSBase):
|
|||
|
||||
# feature net
|
||||
if partial_lookahead:
|
||||
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
|
||||
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
|
||||
else:
|
||||
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
|
||||
|
||||
# comb filters
|
||||
left_pad = self.kernel_size // 2
|
||||
right_pad = self.kernel_size - 1 - left_pad
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# spectral shaping
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# non-linear transforms
|
||||
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after)
|
||||
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after)
|
||||
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after)
|
||||
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# combinators
|
||||
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# feature transforms
|
||||
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_cf1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_cf2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_af1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_af2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_af3 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
|
||||
if softquant:
|
||||
self.post_cf1 = soft_quant(self.post_cf1)
|
||||
self.post_cf2 = soft_quant(self.post_cf2)
|
||||
self.post_af1 = soft_quant(self.post_af1)
|
||||
self.post_af2 = soft_quant(self.post_af2)
|
||||
self.post_af3 = soft_quant(self.post_af3)
|
||||
|
||||
|
||||
if sparsify:
|
||||
mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4]))
|
||||
mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4]))
|
||||
mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4]))
|
||||
mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4]))
|
||||
mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4]))
|
||||
|
||||
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
|
@ -141,9 +175,12 @@ class NoLACE(NNSBase):
|
|||
return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops
|
||||
|
||||
def feature_transform(self, f, layer):
|
||||
f = f.permute(0, 2, 1)
|
||||
f = F.pad(f, [1, 0])
|
||||
f = torch.tanh(layer(f))
|
||||
f0 = f.permute(0, 2, 1)
|
||||
f = F.pad(f0, [1, 0])
|
||||
if self.residual_in_feature_transform:
|
||||
f = torch.tanh(layer(f) + f0)
|
||||
else:
|
||||
f = torch.tanh(layer(f))
|
||||
return f.permute(0, 2, 1)
|
||||
|
||||
def forward(self, x, features, periods, numbits, debug=False):
|
||||
|
|
|
@ -26,36 +26,74 @@
|
|||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
import sys
|
||||
sys.path.append('../dnntools')
|
||||
import numbers
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
from dnntools.quantization.softquant import soft_quant
|
||||
from dnntools.sparsification import mark_for_sparsification
|
||||
|
||||
class SilkFeatureNetPL(nn.Module):
|
||||
""" feature net with partial lookahead """
|
||||
def __init__(self,
|
||||
feature_dim=47,
|
||||
num_channels=256,
|
||||
hidden_feature_dim=64):
|
||||
hidden_feature_dim=64,
|
||||
softquant=False,
|
||||
sparsify=True,
|
||||
sparsification_density=0.5,
|
||||
apply_weight_norm=False):
|
||||
|
||||
super(SilkFeatureNetPL, self).__init__()
|
||||
|
||||
if isinstance(sparsification_density, numbers.Number):
|
||||
sparsification_density = 4 * [sparsification_density]
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.num_channels = num_channels
|
||||
self.hidden_feature_dim = hidden_feature_dim
|
||||
|
||||
self.conv1 = nn.Conv1d(feature_dim, self.hidden_feature_dim, 1)
|
||||
self.conv2 = nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2)
|
||||
self.tconv = nn.ConvTranspose1d(num_channels, num_channels, 4, 4)
|
||||
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1))
|
||||
self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2))
|
||||
self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
|
||||
gru_input_dim = num_channels + self.repeat_upsamp_dim if self.repeat_upsamp else num_channels
|
||||
self.gru = norm(norm(nn.GRU(gru_input_dim, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
|
||||
|
||||
if softquant:
|
||||
self.conv2 = soft_quant(self.conv2)
|
||||
if not self.repeat_upsamp: self.tconv = soft_quant(self.tconv)
|
||||
self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
|
||||
|
||||
if sparsify:
|
||||
mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4]))
|
||||
if not self.repeat_upsamp: mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
|
||||
mark_for_sparsification(
|
||||
self.gru,
|
||||
{
|
||||
'W_ir' : (sparsification_density[2], [8, 4], False),
|
||||
'W_iz' : (sparsification_density[2], [8, 4], False),
|
||||
'W_in' : (sparsification_density[2], [8, 4], False),
|
||||
'W_hr' : (sparsification_density[3], [8, 4], True),
|
||||
'W_hz' : (sparsification_density[3], [8, 4], True),
|
||||
'W_hn' : (sparsification_density[3], [8, 4], True),
|
||||
}
|
||||
)
|
||||
|
||||
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
|
||||
|
||||
def flop_count(self, rate=200):
|
||||
count = 0
|
||||
for conv in self.conv1, self.conv2, self.tconv:
|
||||
for conv in [self.conv1, self.conv2] if self.repeat_upsamp else [self.conv1, self.conv2, self.tconv]:
|
||||
count += _conv1d_flop_count(conv, rate)
|
||||
|
||||
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
|
||||
|
@ -82,7 +120,7 @@ class SilkFeatureNetPL(nn.Module):
|
|||
c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
|
||||
|
||||
# upsampling
|
||||
c = self.tconv(c)
|
||||
c = torch.tanh(self.tconv(c))
|
||||
c = c.permute(0, 2, 1)
|
||||
|
||||
c, _ = self.gru(c, state)
|
||||
|
|
25
dnn/torch/osce/stndrd/evaluation/create_input_data.sh
Normal file
25
dnn/torch/osce/stndrd/evaluation/create_input_data.sh
Normal file
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash
|
||||
|
||||
|
||||
INPUT="dataset/LibriSpeech"
|
||||
OUTPUT="testdata"
|
||||
OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
|
||||
BITRATES=( 6000 7500 ) # 9000 12000 15000 18000 24000 32000 )
|
||||
|
||||
|
||||
mkdir -p $OUTPUT
|
||||
|
||||
for fn in $(find $INPUT -name "*.wav")
|
||||
do
|
||||
name=$(basename ${fn%*.wav})
|
||||
sox $fn -r 16000 -b 16 -e signed-integer ${OUTPUT}/tmp.raw
|
||||
for br in ${BITRATES[@]}
|
||||
do
|
||||
folder=${OUTPUT}/"${name}_${br}.se"
|
||||
echo "creating ${folder}..."
|
||||
mkdir -p $folder
|
||||
cp ${OUTPUT}/tmp.raw ${folder}/clean.s16
|
||||
(cd ${folder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
|
||||
done
|
||||
rm -f ${OUTPUT}/tmp.raw
|
||||
done
|
7
dnn/torch/osce/stndrd/evaluation/env.rc
Normal file
7
dnn/torch/osce/stndrd/evaluation/env.rc
Normal file
|
@ -0,0 +1,7 @@
|
|||
#!/bin/bash
|
||||
|
||||
export PYTHON=/home/ubuntu/opt/miniconda3/envs/torch/bin/python
|
||||
export LACE="/local/experiments/ietf_enhancement_studies/checkpoints/lace_checkpoint.pth"
|
||||
export NOLACE="/local/experiments/ietf_enhancement_studies/checkpoints/nolace_checkpoint.pth"
|
||||
export TESTMODEL="/local/experiments/ietf_enhancement_studies/opus/dnn/torch/osce/test_model.py"
|
||||
export OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
|
113
dnn/torch/osce/stndrd/evaluation/evaluate.py
Normal file
113
dnn/torch/osce/stndrd/evaluation/evaluate.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
from scipy.io import wavfile
|
||||
from pesq import pesq
|
||||
import numpy as np
|
||||
from moc import compare
|
||||
from moc2 import compare as compare2
|
||||
#from warpq import compute_WAPRQ as warpq
|
||||
from lace_loss_metric import compare as laceloss_compare
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='folder with processed items')
|
||||
parser.add_argument('metric', type=str, choices=['pesq', 'moc', 'moc2', 'laceloss'], help='metric to be used for evaluation')
|
||||
|
||||
|
||||
def get_bitrates(folder):
|
||||
with open(os.path.join(folder, 'bitrates.txt')) as f:
|
||||
x = f.read()
|
||||
|
||||
bitrates = [int(y) for y in x.rstrip('\n').split()]
|
||||
|
||||
return bitrates
|
||||
|
||||
def get_itemlist(folder):
|
||||
with open(os.path.join(folder, 'items.txt')) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
items = [x.split()[0] for x in lines]
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def process_item(folder, item, bitrate, metric):
|
||||
fs, x_clean = wavfile.read(os.path.join(folder, 'clean', f"{item}_{bitrate}_clean.wav"))
|
||||
fs, x_opus = wavfile.read(os.path.join(folder, 'opus', f"{item}_{bitrate}_opus.wav"))
|
||||
fs, x_lace = wavfile.read(os.path.join(folder, 'lace', f"{item}_{bitrate}_lace.wav"))
|
||||
fs, x_nolace = wavfile.read(os.path.join(folder, 'nolace', f"{item}_{bitrate}_nolace.wav"))
|
||||
|
||||
x_clean = x_clean.astype(np.float32) / 2**15
|
||||
x_opus = x_opus.astype(np.float32) / 2**15
|
||||
x_lace = x_lace.astype(np.float32) / 2**15
|
||||
x_nolace = x_nolace.astype(np.float32) / 2**15
|
||||
|
||||
if metric == 'pesq':
|
||||
result = [pesq(fs, x_clean, x_opus), pesq(fs, x_clean, x_lace), pesq(fs, x_clean, x_nolace)]
|
||||
elif metric =='moc':
|
||||
result = [compare(x_clean, x_opus), compare(x_clean, x_lace), compare(x_clean, x_nolace)]
|
||||
elif metric =='moc2':
|
||||
result = [compare2(x_clean, x_opus), compare2(x_clean, x_lace), compare2(x_clean, x_nolace)]
|
||||
# elif metric == 'warpq':
|
||||
# result = [warpq(x_clean, x_opus), warpq(x_clean, x_lace), warpq(x_clean, x_nolace)]
|
||||
elif metric == 'laceloss':
|
||||
result = [laceloss_compare(x_clean, x_opus), laceloss_compare(x_clean, x_lace), laceloss_compare(x_clean, x_nolace)]
|
||||
else:
|
||||
raise ValueError(f'unknown metric {metric}')
|
||||
|
||||
return result
|
||||
|
||||
def process_bitrate(folder, items, bitrate, metric):
|
||||
results = np.zeros((len(items), 3))
|
||||
|
||||
for i, item in enumerate(items):
|
||||
results[i, :] = np.array(process_item(folder, item, bitrate, metric))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
items = get_itemlist(args.folder)
|
||||
bitrates = get_bitrates(args.folder)
|
||||
|
||||
results = dict()
|
||||
for br in bitrates:
|
||||
print(f"processing bitrate {br}...")
|
||||
results[br] = process_bitrate(args.folder, items, br, args.metric)
|
||||
|
||||
np.save(os.path.join(args.folder, f'results_{args.metric}.npy'), results)
|
||||
|
||||
print("Done.")
|
330
dnn/torch/osce/stndrd/evaluation/lace_loss_metric.py
Normal file
330
dnn/torch/osce/stndrd/evaluation/lace_loss_metric.py
Normal file
|
@ -0,0 +1,330 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
"""STFT-based Loss modules."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
|
||||
|
||||
def get_window(win_name, win_length, *args, **kwargs):
|
||||
window_dict = {
|
||||
'bartlett_window' : torch.bartlett_window,
|
||||
'blackman_window' : torch.blackman_window,
|
||||
'hamming_window' : torch.hamming_window,
|
||||
'hann_window' : torch.hann_window,
|
||||
'kaiser_window' : torch.kaiser_window
|
||||
}
|
||||
|
||||
if not win_name in window_dict:
|
||||
raise ValueError()
|
||||
|
||||
return window_dict[win_name](win_length, *args, **kwargs)
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
"""
|
||||
|
||||
win = get_window(window, win_length).to(x.device)
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
|
||||
|
||||
|
||||
return torch.clamp(torch.abs(x_stft), min=1e-7)
|
||||
|
||||
def spectral_convergence_loss(Y_true, Y_pred):
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
|
||||
|
||||
|
||||
def log_magnitude_loss(Y_true, Y_pred):
|
||||
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
|
||||
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
|
||||
|
||||
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
|
||||
|
||||
def spectral_xcorr_loss(Y_true, Y_pred):
|
||||
Y_true = Y_true.abs()
|
||||
Y_pred = Y_pred.abs()
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return 1 - xcorr.mean()
|
||||
|
||||
|
||||
|
||||
class MRLogMelLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
fs=16000,
|
||||
n_mels=18
|
||||
):
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.mel_specs = []
|
||||
for fft_size in fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
|
||||
n_mels = self.n_mels
|
||||
if fft_size < 128:
|
||||
n_mels //= 2
|
||||
|
||||
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
|
||||
|
||||
for i, mel_spec in enumerate(self.mel_specs):
|
||||
self.add_module(f'mel_spec_{i+1}', mel_spec)
|
||||
|
||||
def forward(self, y_true, y_pred):
|
||||
|
||||
loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for mel_spec in self.mel_specs:
|
||||
Y_true = mel_spec(y_true)
|
||||
Y_pred = mel_spec(y_pred)
|
||||
loss = loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
loss = loss / len(self.mel_specs)
|
||||
|
||||
return loss
|
||||
|
||||
def create_weight_matrix(num_bins, bins_per_band=10):
|
||||
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
|
||||
|
||||
r0 = bins_per_band // 2
|
||||
r1 = bins_per_band - r0
|
||||
|
||||
for i in range(num_bins):
|
||||
i0 = max(i - r0, 0)
|
||||
j0 = min(i + r1, num_bins)
|
||||
|
||||
m[i, i0: j0] += 1
|
||||
|
||||
if i < r0:
|
||||
m[i, :r0 - i] += 1
|
||||
|
||||
if i > num_bins - r1:
|
||||
m[i, num_bins - r1 - i:] += 1
|
||||
|
||||
return m / bins_per_band
|
||||
|
||||
def weighted_spectral_convergence(Y_true, Y_pred, w):
|
||||
|
||||
# calculate sfm based weights
|
||||
logY = torch.log(torch.abs(Y_true) + 1e-9)
|
||||
Y = torch.abs(Y_true)
|
||||
|
||||
avg_logY = torch.matmul(logY.transpose(1, 2), w)
|
||||
avg_Y = torch.matmul(Y.transpose(1, 2), w)
|
||||
|
||||
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
|
||||
|
||||
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
|
||||
|
||||
loss = torch.mean(
|
||||
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
|
||||
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def gen_filterbank(N, Fs=16000):
|
||||
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
|
||||
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
|
||||
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
|
||||
ERB_N = 24.7 + .108*in_freq
|
||||
delta = np.abs(in_freq-out_freq)/ERB_N
|
||||
center = (delta<.5).astype('float32')
|
||||
R = -12*center*delta**2 + (1-center)*(3-12*delta)
|
||||
RE = 10.**(R/10.)
|
||||
norm = np.sum(RE, axis=1)
|
||||
RE = RE/norm[:, np.newaxis]
|
||||
return torch.from_numpy(RE)
|
||||
|
||||
def smooth_log_mag(Y_true, Y_pred, filterbank):
|
||||
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
|
||||
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
|
||||
|
||||
loss = torch.abs(
|
||||
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
|
||||
)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
class MRSTFTLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[2048, 1024, 512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
window='hann_window',
|
||||
fs=16000,
|
||||
log_mag_weight=0,
|
||||
sc_weight=0,
|
||||
wsc_weight=0,
|
||||
smooth_log_mag_weight=2,
|
||||
sxcorr_weight=1):
|
||||
super().__init__()
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.window = window
|
||||
self.log_mag_weight = log_mag_weight
|
||||
self.sc_weight = sc_weight
|
||||
self.wsc_weight = wsc_weight
|
||||
self.smooth_log_mag_weight = smooth_log_mag_weight
|
||||
self.sxcorr_weight = sxcorr_weight
|
||||
self.fs = fs
|
||||
|
||||
# weights for SFM weighted spectral convergence loss
|
||||
self.wsc_weights = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
width = min(11, int(1000 * fft_size / self.fs + .5))
|
||||
width += width % 2
|
||||
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
|
||||
create_weight_matrix(fft_size // 2 + 1, width),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
# filterbanks for smooth log magnitude loss
|
||||
self.filterbanks = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
|
||||
gen_filterbank(fft_size//2),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def __call__(self, y_true, y_pred):
|
||||
|
||||
|
||||
lm_loss = torch.zeros(1, device=y_true.device)
|
||||
sc_loss = torch.zeros(1, device=y_true.device)
|
||||
wsc_loss = torch.zeros(1, device=y_true.device)
|
||||
slm_loss = torch.zeros(1, device=y_true.device)
|
||||
sxcorr_loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for fft_size in self.fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
win_size = fft_size
|
||||
|
||||
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
|
||||
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
|
||||
|
||||
if self.log_mag_weight > 0:
|
||||
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
if self.sc_weight > 0:
|
||||
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
|
||||
|
||||
if self.wsc_weight > 0:
|
||||
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
|
||||
|
||||
if self.smooth_log_mag_weight > 0:
|
||||
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
|
||||
|
||||
if self.sxcorr_weight > 0:
|
||||
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
|
||||
|
||||
|
||||
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
|
||||
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
|
||||
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def td_l2_norm(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class LaceLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.stftloss = MRSTFTLoss(log_mag_weight=0, sc_weight=0, wsc_weight=0, smooth_log_mag_weight=2, sxcorr_weight=1)
|
||||
|
||||
|
||||
def forward(self, x, y):
|
||||
specloss = self.stftloss(x, y)
|
||||
phaseloss = td_l2_norm(x, y)
|
||||
total_loss = (specloss + 10 * phaseloss) / 13
|
||||
|
||||
return total_loss
|
||||
|
||||
def compare(self, x_ref, x_deg):
|
||||
# trim items to same size
|
||||
n = min(len(x_ref), len(x_deg))
|
||||
x_ref = x_ref[:n].copy()
|
||||
x_deg = x_deg[:n].copy()
|
||||
|
||||
# pre-emphasis
|
||||
x_ref[1:] -= 0.85 * x_ref[:-1]
|
||||
x_deg[1:] -= 0.85 * x_deg[:-1]
|
||||
|
||||
device = next(iter(self.parameters())).device
|
||||
|
||||
x = torch.from_numpy(x_ref).to(device)
|
||||
y = torch.from_numpy(x_deg).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
dist = 10 * self.forward(x, y)
|
||||
|
||||
return dist.cpu().numpy().item()
|
||||
|
||||
|
||||
lace_loss = LaceLoss()
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
lace_loss.to(device)
|
||||
|
||||
def compare(x, y):
|
||||
|
||||
return lace_loss.compare(x, y)
|
116
dnn/torch/osce/stndrd/evaluation/make_boxplots.py
Normal file
116
dnn/torch/osce/stndrd/evaluation/make_boxplots.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
|
||||
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def plot_data(filename, data, title=None):
|
||||
compare_dict = dict()
|
||||
for br in data.keys():
|
||||
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
|
||||
compare_dict[f'LACE {br/1000:.1f} kb/s'] = data[br][:, 1]
|
||||
compare_dict[f'NoLACE {br/1000:.1f} kb/s'] = data[br][:, 2]
|
||||
|
||||
plt.rcParams.update({
|
||||
"text.usetex": True,
|
||||
"font.family": "Helvetica",
|
||||
"font.size": 32
|
||||
})
|
||||
|
||||
black = '#000000'
|
||||
red = '#ff5745'
|
||||
blue = '#007dbc'
|
||||
colors = [black, red, blue]
|
||||
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
|
||||
Patch(facecolor=colors[1], label='LACE'),
|
||||
Patch(facecolor=colors[2], label='NoLACE')]
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(40, 20)
|
||||
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
|
||||
|
||||
for i, patch in enumerate(bplot['boxes']):
|
||||
patch.set_facecolor(colors[i%3])
|
||||
|
||||
ax.set_xticklabels(compare_dict.keys(), rotation=290)
|
||||
|
||||
if title is not None:
|
||||
ax.set_title(title)
|
||||
|
||||
ax.legend(handles=legend_elements)
|
||||
|
||||
fig.savefig(filename, bbox_inches='tight')
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
|
||||
|
||||
print("Done.")
|
109
dnn/torch/osce/stndrd/evaluation/make_boxplots_moctest.py
Normal file
109
dnn/torch/osce/stndrd/evaluation/make_boxplots_moctest.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def plot_data(filename, data, title=None):
|
||||
compare_dict = dict()
|
||||
for br in data.keys():
|
||||
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
|
||||
compare_dict[f'LACE (MOC only) {br/1000:.1f} kb/s'] = data[br][:, 1]
|
||||
compare_dict[f'LACE (MOC + TD) {br/1000:.1f} kb/s'] = data[br][:, 2]
|
||||
|
||||
plt.rcParams.update({
|
||||
"text.usetex": True,
|
||||
"font.family": "Helvetica",
|
||||
"font.size": 32
|
||||
})
|
||||
colors = ['pink', 'lightblue', 'lightgreen']
|
||||
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
|
||||
Patch(facecolor=colors[1], label='MOC loss only'),
|
||||
Patch(facecolor=colors[2], label='MOC + TD loss')]
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(40, 20)
|
||||
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
|
||||
|
||||
for i, patch in enumerate(bplot['boxes']):
|
||||
patch.set_facecolor(colors[i%3])
|
||||
|
||||
ax.set_xticklabels(compare_dict.keys(), rotation=290)
|
||||
|
||||
if title is not None:
|
||||
ax.set_title(title)
|
||||
|
||||
ax.legend(handles=legend_elements)
|
||||
|
||||
fig.savefig(filename, bbox_inches='tight')
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
|
||||
|
||||
print("Done.")
|
124
dnn/torch/osce/stndrd/evaluation/make_tables.py
Normal file
124
dnn/torch/osce/stndrd/evaluation/make_tables.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
|
||||
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def make_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1]
|
||||
nolace = data[br][:, 2]
|
||||
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
|
||||
def make_diff_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1] - opus
|
||||
nolace = data[br][:, 2] - opus
|
||||
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
|
||||
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
|
||||
|
||||
print("Done.")
|
121
dnn/torch/osce/stndrd/evaluation/make_tables_moctest.py
Normal file
121
dnn/torch/osce/stndrd/evaluation/make_tables_moctest.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def make_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1]
|
||||
nolace = data[br][:, 2]
|
||||
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
|
||||
def make_diff_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1] - opus
|
||||
nolace = data[br][:, 2] - opus
|
||||
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
|
||||
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
|
||||
|
||||
print("Done.")
|
182
dnn/torch/osce/stndrd/evaluation/moc.py
Normal file
182
dnn/torch/osce/stndrd/evaluation/moc.py
Normal file
|
@ -0,0 +1,182 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
|
||||
def compute_vad_mask(x, fs, stop_db=-70):
|
||||
|
||||
frame_length = (fs + 49) // 50
|
||||
x = x[: frame_length * (len(x) // frame_length)]
|
||||
|
||||
frames = x.reshape(-1, frame_length)
|
||||
frame_energy = np.sum(frames ** 2, axis=1)
|
||||
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
|
||||
|
||||
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
|
||||
vactive = np.ones_like(frames)
|
||||
vactive[frame_energy_smooth < max_threshold, :] = 0
|
||||
vactive = vactive.reshape(-1)
|
||||
|
||||
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
|
||||
filter = filter / filter.sum()
|
||||
|
||||
mask = np.convolve(vactive, filter, mode='same')
|
||||
|
||||
return x, mask
|
||||
|
||||
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
|
||||
num_samples = frame_size + (num_frames - 1) * hop_size
|
||||
if len(mask) < num_samples:
|
||||
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
|
||||
else:
|
||||
mask = mask[:num_samples]
|
||||
|
||||
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
|
||||
|
||||
return new_mask
|
||||
|
||||
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
|
||||
num_spectra = (len(x) - window_size - hop_size) // hop_size
|
||||
window = scipy.signal.get_window(window, window_size)
|
||||
N = window_size // 2
|
||||
|
||||
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
|
||||
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
|
||||
|
||||
return psd
|
||||
|
||||
|
||||
def frequency_mask(num_bands, up_factor, down_factor):
|
||||
|
||||
up_mask = np.zeros((num_bands, num_bands))
|
||||
down_mask = np.zeros((num_bands, num_bands))
|
||||
|
||||
for i in range(num_bands):
|
||||
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
|
||||
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
|
||||
|
||||
return down_mask @ up_mask
|
||||
|
||||
|
||||
def rect_fb(band_limits, num_bins=None):
|
||||
num_bands = len(band_limits) - 1
|
||||
if num_bins is None:
|
||||
num_bins = band_limits[-1]
|
||||
|
||||
fb = np.zeros((num_bands, num_bins))
|
||||
for i in range(num_bands):
|
||||
fb[i, band_limits[i]:band_limits[i+1]] = 1
|
||||
|
||||
return fb
|
||||
|
||||
|
||||
def compare(x, y, apply_vad=False):
|
||||
""" Modified version of opus_compare for 16 kHz mono signals
|
||||
|
||||
Args:
|
||||
x (np.ndarray): reference input signal scaled to [-1, 1]
|
||||
y (np.ndarray): test signal scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
float: perceptually weighted error
|
||||
"""
|
||||
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
|
||||
band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
|
||||
num_bands = len(band_limits) - 1
|
||||
fb = rect_fb(band_limits, num_bins=81)
|
||||
|
||||
# trim samples to same size
|
||||
num_samples = min(len(x), len(y))
|
||||
x = x[:num_samples] * 2**15
|
||||
y = y[:num_samples] * 2**15
|
||||
|
||||
psd_x = power_spectrum(x) + 100000
|
||||
psd_y = power_spectrum(y) + 100000
|
||||
|
||||
num_frames = psd_x.shape[0]
|
||||
|
||||
# average band energies
|
||||
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
|
||||
|
||||
# frequecy masking
|
||||
f_mask = frequency_mask(num_bands, 0.1, 0.03)
|
||||
mask_x = be_x @ f_mask.T
|
||||
|
||||
# temporal masking
|
||||
for i in range(1, num_frames):
|
||||
mask_x[i, :] += 0.5 * mask_x[i-1, :]
|
||||
|
||||
# apply mask
|
||||
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
|
||||
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
|
||||
|
||||
# 2-frame average
|
||||
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
|
||||
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
|
||||
|
||||
# distortion metric
|
||||
re = masked_psd_y / masked_psd_x
|
||||
im = np.log(re) ** 2
|
||||
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
|
||||
Ef = np.mean(Eb , axis=1)
|
||||
|
||||
if apply_vad:
|
||||
_, mask = compute_vad_mask(x, 16000)
|
||||
mask = convert_mask(mask, Ef.shape[0])
|
||||
else:
|
||||
mask = np.ones_like(Ef)
|
||||
|
||||
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
|
||||
|
||||
return float(err)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from scipy.io import wavfile
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('ref', type=str, help='reference wav file')
|
||||
parser.add_argument('deg', type=str, help='degraded wav file')
|
||||
parser.add_argument('--apply-vad', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
fs1, x = wavfile.read(args.ref)
|
||||
fs2, y = wavfile.read(args.deg)
|
||||
|
||||
if max(fs1, fs2) != 16000:
|
||||
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
|
||||
|
||||
x = x.astype(np.float32) / 2**15
|
||||
y = y.astype(np.float32) / 2**15
|
||||
|
||||
err = compare(x, y, apply_vad=args.apply_vad)
|
||||
|
||||
print(f"MOC: {err}")
|
190
dnn/torch/osce/stndrd/evaluation/moc2.py
Normal file
190
dnn/torch/osce/stndrd/evaluation/moc2.py
Normal file
|
@ -0,0 +1,190 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
|
||||
def compute_vad_mask(x, fs, stop_db=-70):
|
||||
|
||||
frame_length = (fs + 49) // 50
|
||||
x = x[: frame_length * (len(x) // frame_length)]
|
||||
|
||||
frames = x.reshape(-1, frame_length)
|
||||
frame_energy = np.sum(frames ** 2, axis=1)
|
||||
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
|
||||
|
||||
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
|
||||
vactive = np.ones_like(frames)
|
||||
vactive[frame_energy_smooth < max_threshold, :] = 0
|
||||
vactive = vactive.reshape(-1)
|
||||
|
||||
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
|
||||
filter = filter / filter.sum()
|
||||
|
||||
mask = np.convolve(vactive, filter, mode='same')
|
||||
|
||||
return x, mask
|
||||
|
||||
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
|
||||
num_samples = frame_size + (num_frames - 1) * hop_size
|
||||
if len(mask) < num_samples:
|
||||
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
|
||||
else:
|
||||
mask = mask[:num_samples]
|
||||
|
||||
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
|
||||
|
||||
return new_mask
|
||||
|
||||
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
|
||||
num_spectra = (len(x) - window_size - hop_size) // hop_size
|
||||
window = scipy.signal.get_window(window, window_size)
|
||||
N = window_size // 2
|
||||
|
||||
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
|
||||
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
|
||||
|
||||
return psd
|
||||
|
||||
|
||||
def frequency_mask(num_bands, up_factor, down_factor):
|
||||
|
||||
up_mask = np.zeros((num_bands, num_bands))
|
||||
down_mask = np.zeros((num_bands, num_bands))
|
||||
|
||||
for i in range(num_bands):
|
||||
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
|
||||
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
|
||||
|
||||
return down_mask @ up_mask
|
||||
|
||||
|
||||
def rect_fb(band_limits, num_bins=None):
|
||||
num_bands = len(band_limits) - 1
|
||||
if num_bins is None:
|
||||
num_bins = band_limits[-1]
|
||||
|
||||
fb = np.zeros((num_bands, num_bins))
|
||||
for i in range(num_bands):
|
||||
fb[i, band_limits[i]:band_limits[i+1]] = 1
|
||||
|
||||
return fb
|
||||
|
||||
|
||||
def _compare(x, y, apply_vad=False, factor=1):
|
||||
""" Modified version of opus_compare for 16 kHz mono signals
|
||||
|
||||
Args:
|
||||
x (np.ndarray): reference input signal scaled to [-1, 1]
|
||||
y (np.ndarray): test signal scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
float: perceptually weighted error
|
||||
"""
|
||||
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
|
||||
band_limits = [factor * b for b in [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]]
|
||||
window_size = factor * 160
|
||||
hop_size = factor * 40
|
||||
num_bins = window_size // 2 + 1
|
||||
num_bands = len(band_limits) - 1
|
||||
fb = rect_fb(band_limits, num_bins=num_bins)
|
||||
|
||||
# trim samples to same size
|
||||
num_samples = min(len(x), len(y))
|
||||
x = x[:num_samples].copy() * 2**15
|
||||
y = y[:num_samples].copy() * 2**15
|
||||
|
||||
psd_x = power_spectrum(x, window_size=window_size, hop_size=hop_size) + 100000
|
||||
psd_y = power_spectrum(y, window_size=window_size, hop_size=hop_size) + 100000
|
||||
|
||||
num_frames = psd_x.shape[0]
|
||||
|
||||
# average band energies
|
||||
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
|
||||
|
||||
# frequecy masking
|
||||
f_mask = frequency_mask(num_bands, 0.1, 0.03)
|
||||
mask_x = be_x @ f_mask.T
|
||||
|
||||
# temporal masking
|
||||
for i in range(1, num_frames):
|
||||
mask_x[i, :] += (0.5 ** factor) * mask_x[i-1, :]
|
||||
|
||||
# apply mask
|
||||
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
|
||||
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
|
||||
|
||||
# 2-frame average
|
||||
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
|
||||
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
|
||||
|
||||
# distortion metric
|
||||
re = masked_psd_y / masked_psd_x
|
||||
#im = re - np.log(re) - 1
|
||||
im = np.log(re) ** 2
|
||||
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
|
||||
Ef = np.mean(Eb ** 1, axis=1)
|
||||
|
||||
if apply_vad:
|
||||
_, mask = compute_vad_mask(x, 16000)
|
||||
mask = convert_mask(mask, Ef.shape[0])
|
||||
else:
|
||||
mask = np.ones_like(Ef)
|
||||
|
||||
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
|
||||
|
||||
return float(err)
|
||||
|
||||
def compare(x, y, apply_vad=False):
|
||||
err = np.linalg.norm([_compare(x, y, apply_vad=apply_vad, factor=1)], ord=2)
|
||||
return err
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from scipy.io import wavfile
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('ref', type=str, help='reference wav file')
|
||||
parser.add_argument('deg', type=str, help='degraded wav file')
|
||||
parser.add_argument('--apply-vad', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
fs1, x = wavfile.read(args.ref)
|
||||
fs2, y = wavfile.read(args.deg)
|
||||
|
||||
if max(fs1, fs2) != 16000:
|
||||
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
|
||||
|
||||
x = x.astype(np.float32) / 2**15
|
||||
y = y.astype(np.float32) / 2**15
|
||||
|
||||
err = compare(x, y, apply_vad=args.apply_vad)
|
||||
|
||||
print(f"MOC: {err}")
|
98
dnn/torch/osce/stndrd/evaluation/process_dataset.sh
Executable file
98
dnn/torch/osce/stndrd/evaluation/process_dataset.sh
Executable file
|
@ -0,0 +1,98 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [ ! -f "$PYTHON" ]
|
||||
then
|
||||
echo "PYTHON variable does not link to a file. Please point it to your python executable."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$TESTMODEL" ]
|
||||
then
|
||||
echo "TESTMODEL variable does not link to a file. Please point it to your copy of test_model.py"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$OPUSDEMO" ]
|
||||
then
|
||||
echo "OPUSDEMO variable does not link to a file. Please point it to your patched version of opus_demo."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$LACE" ]
|
||||
then
|
||||
echo "LACE variable does not link to a file. Please point it to your copy of the LACE checkpoint."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$NOLACE" ]
|
||||
then
|
||||
echo "LACE variable does not link to a file. Please point it to your copy of the NOLACE checkpoint."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
case $# in
|
||||
2) INPUT=$1; OUTPUT=$2;;
|
||||
*) echo "process_dataset.sh <input folder> <output folder>"; exit 1;;
|
||||
esac
|
||||
|
||||
if [ -d $OUTPUT ]
|
||||
then
|
||||
echo "output folder $OUTPUT exists, aborting..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p $OUTPUT
|
||||
|
||||
if [ "$BITRATES" == "" ]
|
||||
then
|
||||
BITRATES=( 6000 7500 9000 12000 15000 18000 24000 32000 )
|
||||
echo "BITRATES variable not defined. Proceeding with default bitrates ${BITRATES[@]}."
|
||||
fi
|
||||
|
||||
|
||||
echo "LACE=${LACE}" > ${OUTPUT}/info.txt
|
||||
echo "NOLACE=${NOLACE}" >> ${OUTPUT}/info.txt
|
||||
|
||||
ITEMFILE=${OUTPUT}/items.txt
|
||||
BITRATEFILE=${OUTPUT}/bitrates.txt
|
||||
|
||||
FPROCESSING=${OUTPUT}/processing
|
||||
FCLEAN=${OUTPUT}/clean
|
||||
FOPUS=${OUTPUT}/opus
|
||||
FLACE=${OUTPUT}/lace
|
||||
FNOLACE=${OUTPUT}/nolace
|
||||
|
||||
mkdir -p $FPROCESSING $FCLEAN $FOPUS $FLACE $FNOLACE
|
||||
|
||||
echo "${BITRATES[@]}" > $BITRATEFILE
|
||||
|
||||
for fn in $(find $INPUT -type f -name "*.wav")
|
||||
do
|
||||
UUID=$(uuid)
|
||||
echo "$UUID $fn" >> $ITEMFILE
|
||||
PIDS=( )
|
||||
for br in ${BITRATES[@]}
|
||||
do
|
||||
# run opus
|
||||
pfolder=${FPROCESSING}/${UUID}_${br}
|
||||
mkdir -p $pfolder
|
||||
sox $fn -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16
|
||||
(cd ${pfolder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
|
||||
|
||||
# copy clean and opus
|
||||
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16 $FCLEAN/${UUID}_${br}_clean.wav
|
||||
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/noisy.s16 $FOPUS/${UUID}_${br}_opus.wav
|
||||
|
||||
# run LACE
|
||||
$PYTHON $TESTMODEL $pfolder $LACE $FLACE/${UUID}_${br}_lace.wav &
|
||||
PIDS+=( "$!" )
|
||||
|
||||
# run NoLACE
|
||||
$PYTHON $TESTMODEL $pfolder $NOLACE $FNOLACE/${UUID}_${br}_nolace.wav &
|
||||
PIDS+=( "$!" )
|
||||
done
|
||||
for pid in ${PIDS[@]}
|
||||
do
|
||||
wait $pid
|
||||
done
|
||||
done
|
138
dnn/torch/osce/stndrd/evaluation/run_nomad.py
Normal file
138
dnn/torch/osce/stndrd/evaluation/run_nomad.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
import pandas as pd
|
||||
from scipy.spatial.distance import cdist
|
||||
from scipy.io import wavfile
|
||||
import numpy as np
|
||||
|
||||
from nomad_audio.nomad import Nomad
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='folder with processed items')
|
||||
parser.add_argument('--full-reference', action='store_true', help='use NOMAD as full-reference metric')
|
||||
parser.add_argument('--device', type=str, default=None, help='device for Nomad')
|
||||
|
||||
|
||||
def get_bitrates(folder):
|
||||
with open(os.path.join(folder, 'bitrates.txt')) as f:
|
||||
x = f.read()
|
||||
|
||||
bitrates = [int(y) for y in x.rstrip('\n').split()]
|
||||
|
||||
return bitrates
|
||||
|
||||
def get_itemlist(folder):
|
||||
with open(os.path.join(folder, 'items.txt')) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
items = [x.split()[0] for x in lines]
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def nomad_wrapper(ref_folder, deg_folder, full_reference=False, ref_embeddings=None, device=None):
|
||||
model = Nomad(device=device)
|
||||
if not full_reference:
|
||||
results = model.predict(nmr=ref_folder, deg=deg_folder)[0].to_dict()['NOMAD']
|
||||
return results, None
|
||||
else:
|
||||
if ref_embeddings is None:
|
||||
print(f"Computing reference embeddings from {ref_folder}")
|
||||
ref_data = pd.DataFrame(sorted(os.listdir(ref_folder)))
|
||||
ref_data.columns = ['filename']
|
||||
ref_data['filename'] = [os.path.join(ref_folder, x) for x in ref_data['filename']]
|
||||
ref_embeddings = model.get_embeddings_csv(model.model, ref_data).set_index('filename')
|
||||
|
||||
print(f"Computing degraded embeddings from {deg_folder}")
|
||||
deg_data = pd.DataFrame(sorted(os.listdir(deg_folder)))
|
||||
deg_data.columns = ['filename']
|
||||
deg_data['filename'] = [os.path.join(deg_folder, x) for x in deg_data['filename']]
|
||||
deg_embeddings = model.get_embeddings_csv(model.model, deg_data).set_index('filename')
|
||||
|
||||
dist = np.diag(cdist(ref_embeddings, deg_embeddings)) # wasteful
|
||||
test_files = [x.split('/')[-1].split('.')[0] for x in deg_embeddings.index]
|
||||
|
||||
results = dict(zip(test_files, dist))
|
||||
|
||||
return results, ref_embeddings
|
||||
|
||||
|
||||
|
||||
|
||||
def nomad_process_all(folder, full_reference=False, device=None):
|
||||
bitrates = get_bitrates(folder)
|
||||
items = get_itemlist(folder)
|
||||
with tempfile.TemporaryDirectory() as dir:
|
||||
cleandir = os.path.join(dir, 'clean')
|
||||
opusdir = os.path.join(dir, 'opus')
|
||||
lacedir = os.path.join(dir, 'lace')
|
||||
nolacedir = os.path.join(dir, 'nolace')
|
||||
|
||||
# prepare files
|
||||
for d in [cleandir, opusdir, lacedir, nolacedir]: os.makedirs(d)
|
||||
for br in bitrates:
|
||||
for item in items:
|
||||
for cond in ['clean', 'opus', 'lace', 'nolace']:
|
||||
shutil.copyfile(os.path.join(folder, cond, f"{item}_{br}_{cond}.wav"), os.path.join(dir, cond, f"{item}_{br}.wav"))
|
||||
|
||||
nomad_opus, ref_embeddings = nomad_wrapper(cleandir, opusdir, full_reference=full_reference, ref_embeddings=None)
|
||||
nomad_lace, ref_embeddings = nomad_wrapper(cleandir, lacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
|
||||
nomad_nolace, ref_embeddings = nomad_wrapper(cleandir, nolacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
|
||||
|
||||
results = dict()
|
||||
for br in bitrates:
|
||||
results[br] = np.zeros((len(items), 3))
|
||||
for i, item in enumerate(items):
|
||||
key = f"{item}_{br}"
|
||||
results[br][i, 0] = nomad_opus[key]
|
||||
results[br][i, 1] = nomad_lace[key]
|
||||
results[br][i, 2] = nomad_nolace[key]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
items = get_itemlist(args.folder)
|
||||
bitrates = get_bitrates(args.folder)
|
||||
|
||||
results = nomad_process_all(args.folder, full_reference=args.full_reference, device=args.device)
|
||||
|
||||
np.save(os.path.join(args.folder, f'results_nomad.npy'), results)
|
||||
|
||||
print("Done.")
|
205
dnn/torch/osce/stndrd/presentation/endoscopy.py
Normal file
205
dnn/torch/osce/stndrd/presentation/endoscopy.py
Normal file
|
@ -0,0 +1,205 @@
|
|||
""" module for inspecting models during inference """
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation as animation
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
|
||||
_state = dict()
|
||||
_folder = 'endoscopy'
|
||||
|
||||
def get_gru_gates(gru, input, state):
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
|
||||
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
|
||||
|
||||
# reset gate
|
||||
start, stop = 0 * hidden_size, 1 * hidden_size
|
||||
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# update gate
|
||||
start, stop = 1 * hidden_size, 2 * hidden_size
|
||||
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# new gate
|
||||
start, stop = 2 * hidden_size, 3 * hidden_size
|
||||
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
|
||||
|
||||
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
|
||||
|
||||
|
||||
def init(folder='endoscopy'):
|
||||
""" sets up output folder for endoscopy data """
|
||||
|
||||
global _folder
|
||||
_folder = folder
|
||||
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
else:
|
||||
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
|
||||
|
||||
def write_data(key, data, fs):
|
||||
""" appends data to previous data written under key """
|
||||
|
||||
global _state
|
||||
|
||||
# convert to numpy if torch.Tensor is given
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.detach().numpy()
|
||||
|
||||
if not key in _state:
|
||||
_state[key] = {
|
||||
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
|
||||
'fs' : fs,
|
||||
'dim' : tuple(data.shape),
|
||||
'dtype' : str(data.dtype)
|
||||
}
|
||||
|
||||
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
|
||||
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
|
||||
else:
|
||||
if _state[key]['fs'] != fs:
|
||||
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
|
||||
if _state[key]['dtype'] != str(data.dtype):
|
||||
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
|
||||
if _state[key]['dim'] != tuple(data.shape):
|
||||
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
|
||||
|
||||
_state[key]['fid'].write(data.tobytes())
|
||||
|
||||
def close(folder='endoscopy'):
|
||||
""" clean up """
|
||||
for key in _state.keys():
|
||||
_state[key]['fid'].close()
|
||||
|
||||
|
||||
def read_data(folder='endoscopy'):
|
||||
""" retrieves written data as numpy arrays """
|
||||
|
||||
|
||||
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
|
||||
|
||||
return_dict = dict()
|
||||
|
||||
for key in keys:
|
||||
with open(os.path.join(folder, key + '.yml'), 'r') as f:
|
||||
value = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
|
||||
data = np.frombuffer(f.read(), dtype=value['dtype'])
|
||||
|
||||
value['data'] = data.reshape((-1,) + value['dim'])
|
||||
|
||||
return_dict[key] = value
|
||||
|
||||
return return_dict
|
||||
|
||||
def get_best_reshape(shape, target_ratio=1):
|
||||
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
|
||||
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return (1,)
|
||||
|
||||
num_columns = int((pixel_count / target_ratio)**.5)
|
||||
|
||||
while (pixel_count % num_columns):
|
||||
num_columns -= 1
|
||||
|
||||
num_rows = pixel_count // num_columns
|
||||
|
||||
return (num_rows, num_columns)
|
||||
|
||||
def get_type_and_shape(shape):
|
||||
|
||||
# can happen if data is one dimensional
|
||||
if len(shape) == 0:
|
||||
shape = (1,)
|
||||
|
||||
# calculate pixel count
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return 'plot', (1, )
|
||||
|
||||
# stay with shape if already 2-dimensional
|
||||
if len(shape) == 2:
|
||||
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
|
||||
return 'image', shape
|
||||
|
||||
return 'image', get_best_reshape(shape)
|
||||
|
||||
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
|
||||
|
||||
# determine plot setup
|
||||
num_keys = len(data.keys())
|
||||
|
||||
num_rows = int((num_keys * 3/4) ** .5)
|
||||
|
||||
num_cols = (num_keys + num_rows - 1) // num_rows
|
||||
|
||||
fig, axs = plt.subplots(num_rows, num_cols)
|
||||
fig.set_size_inches(num_cols * 5, num_rows * 5)
|
||||
|
||||
display = dict()
|
||||
|
||||
fs_max = max([val['fs'] for val in data.values()])
|
||||
|
||||
num_samples = max([val['data'].shape[0] for val in data.values()])
|
||||
|
||||
keys = sorted(data.keys())
|
||||
|
||||
# inspect data
|
||||
for i, key in enumerate(keys):
|
||||
axs[i // num_cols, i % num_cols].title.set_text(key)
|
||||
|
||||
display[key] = dict()
|
||||
|
||||
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
|
||||
display[key]['down_factor'] = data[key]['fs'] / fs_max
|
||||
|
||||
start_index = max(start_index, half_signal_window_length)
|
||||
while stop_index < 0:
|
||||
stop_index += num_samples
|
||||
|
||||
stop_index = min(stop_index, num_samples - half_signal_window_length)
|
||||
|
||||
# actual plotting
|
||||
frames = []
|
||||
for index in range(start_index, stop_index):
|
||||
ims = []
|
||||
for i, key in enumerate(keys):
|
||||
feature_index = int(round(index * display[key]['down_factor']))
|
||||
|
||||
if display[key]['type'] == 'plot':
|
||||
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
|
||||
|
||||
elif display[key]['type'] == 'image':
|
||||
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
|
||||
|
||||
frames.append(ims)
|
||||
|
||||
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
|
||||
|
||||
if not filename.endswith('.mp4'):
|
||||
filename += '.mp4'
|
||||
|
||||
ani.save(filename)
|
313
dnn/torch/osce/stndrd/presentation/lace_demo.ipynb
Normal file
313
dnn/torch/osce/stndrd/presentation/lace_demo.ipynb
Normal file
File diff suppressed because one or more lines are too long
320
dnn/torch/osce/stndrd/presentation/linear_prediction.ipynb
Normal file
320
dnn/torch/osce/stndrd/presentation/linear_prediction.ipynb
Normal file
File diff suppressed because one or more lines are too long
25
dnn/torch/osce/stndrd/presentation/playback.py
Normal file
25
dnn/torch/osce/stndrd/presentation/playback.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation
|
||||
|
||||
def make_playback_animation(savepath, spec, duration_ms, vmin=20, vmax=90):
|
||||
fig, axs = plt.subplots()
|
||||
axs.set_axis_off()
|
||||
fig.set_size_inches((duration_ms / 1000 * 5, 5))
|
||||
frames = []
|
||||
frame_duration=20
|
||||
num_frames = int(duration_ms / frame_duration + .99)
|
||||
|
||||
spec_height, spec_width = spec.shape
|
||||
for i in range(num_frames):
|
||||
xpos = (i - 1) / (num_frames - 3) * (spec_width - 1)
|
||||
new_frame = axs.imshow(spec, cmap='inferno', origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
|
||||
if i in {0, num_frames - 1}:
|
||||
frames.append([new_frame])
|
||||
else:
|
||||
line = axs.plot([xpos, xpos], [0, spec_height-1], color='white', alpha=0.8)[0]
|
||||
frames.append([new_frame, line])
|
||||
|
||||
|
||||
ani = matplotlib.animation.ArtistAnimation(fig, frames, blit=True, interval=frame_duration)
|
||||
ani.save(savepath, dpi=720)
|
275
dnn/torch/osce/stndrd/presentation/postfilter.ipynb
Normal file
275
dnn/torch/osce/stndrd/presentation/postfilter.ipynb
Normal file
File diff suppressed because one or more lines are too long
173
dnn/torch/osce/stndrd/presentation/spectrogram.ipynb
Normal file
173
dnn/torch/osce/stndrd/presentation/spectrogram.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -27,9 +27,13 @@
|
|||
*/
|
||||
"""
|
||||
|
||||
seed=1888
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
import random
|
||||
random.seed(seed)
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -40,9 +44,12 @@ except:
|
|||
has_git = False
|
||||
|
||||
import torch
|
||||
torch.manual_seed(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
import numpy as np
|
||||
np.random.seed(seed)
|
||||
|
||||
from scipy.io import wavfile
|
||||
|
||||
|
@ -54,7 +61,7 @@ from engine.engine import train_one_epoch, evaluate
|
|||
|
||||
|
||||
from utils.silk_features import load_inference_data
|
||||
from utils.misc import count_parameters
|
||||
from utils.misc import count_parameters, count_nonzero_parameters
|
||||
|
||||
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
|
||||
|
||||
|
@ -71,6 +78,7 @@ parser.add_argument('--no-redirect', action='store_true', help='disables re-dire
|
|||
args = parser.parse_args()
|
||||
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
|
@ -98,7 +106,7 @@ if os.path.exists(args.output):
|
|||
reply = input('continue? (y/n): ')
|
||||
|
||||
if reply == 'n':
|
||||
os._exit()
|
||||
os._exit(0)
|
||||
else:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
|
@ -109,7 +117,7 @@ os.makedirs(checkpoint_dir, exist_ok=True)
|
|||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir)
|
||||
repo = git.Repo(working_dir, search_parent_directories=True)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
|
@ -117,6 +125,8 @@ if has_git:
|
|||
|
||||
if is_dirty:
|
||||
print("warning: repo is dirty")
|
||||
with open(os.path.join(args.output, 'repo.diff'), "w") as f:
|
||||
f.write(repo.git.execute(["git", "diff"]))
|
||||
|
||||
setup['repo']['hash'] = hash
|
||||
setup['repo']['urls'] = urls
|
||||
|
@ -292,6 +302,6 @@ for ep in range(1, epochs + 1):
|
|||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||
|
||||
|
||||
print()
|
||||
print(f"non-zero parameters: {count_nonzero_parameters(model)}\n")
|
||||
|
||||
print('Done')
|
||||
|
|
|
@ -107,7 +107,7 @@ os.makedirs(checkpoint_dir, exist_ok=True)
|
|||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir)
|
||||
repo = git.Repo(working_dir, search_parent_directories=True)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
|
|
|
@ -32,6 +32,7 @@ from torch import nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from utils.endoscopy import write_data
|
||||
from utils.softquant import soft_quant
|
||||
|
||||
class LimitedAdaptiveComb1d(nn.Module):
|
||||
COUNTER = 1
|
||||
|
@ -47,6 +48,8 @@ class LimitedAdaptiveComb1d(nn.Module):
|
|||
gain_limit_db=10,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
norm_p=2,
|
||||
softquant=False,
|
||||
apply_weight_norm=False,
|
||||
**kwargs):
|
||||
"""
|
||||
|
||||
|
@ -97,17 +100,22 @@ class LimitedAdaptiveComb1d(nn.Module):
|
|||
else:
|
||||
self.name = name
|
||||
|
||||
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# network for generating convolution weights
|
||||
self.conv_kernel = nn.Linear(feature_dim, kernel_size)
|
||||
self.conv_kernel = norm(nn.Linear(feature_dim, kernel_size))
|
||||
|
||||
if softquant:
|
||||
self.conv_kernel = soft_quant(self.conv_kernel)
|
||||
|
||||
|
||||
# comb filter gain
|
||||
self.filter_gain = nn.Linear(feature_dim, 1)
|
||||
self.filter_gain = norm(nn.Linear(feature_dim, 1))
|
||||
self.log_gain_limit = gain_limit_db * 0.11512925464970229
|
||||
with torch.no_grad():
|
||||
self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
|
||||
|
||||
self.global_filter_gain = nn.Linear(feature_dim, 1)
|
||||
self.global_filter_gain = norm(nn.Linear(feature_dim, 1))
|
||||
log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
|
||||
self.filter_gain_a = (log_max - log_min) / 2
|
||||
self.filter_gain_b = (log_max + log_min) / 2
|
||||
|
|
|
@ -34,7 +34,7 @@ import torch.nn.functional as F
|
|||
from utils.endoscopy import write_data
|
||||
|
||||
from utils.ada_conv import adaconv_kernel
|
||||
|
||||
from utils.softquant import soft_quant
|
||||
|
||||
class LimitedAdaptiveConv1d(nn.Module):
|
||||
COUNTER = 1
|
||||
|
@ -51,6 +51,8 @@ class LimitedAdaptiveConv1d(nn.Module):
|
|||
gain_limits_db=[-6, 6],
|
||||
shape_gain_db=0,
|
||||
norm_p=2,
|
||||
softquant=False,
|
||||
apply_weight_norm=False,
|
||||
**kwargs):
|
||||
"""
|
||||
|
||||
|
@ -100,12 +102,16 @@ class LimitedAdaptiveConv1d(nn.Module):
|
|||
else:
|
||||
self.name = name
|
||||
|
||||
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# network for generating convolution weights
|
||||
self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size)
|
||||
self.conv_kernel = norm(nn.Linear(feature_dim, in_channels * out_channels * kernel_size))
|
||||
if softquant:
|
||||
self.conv_kernel = soft_quant(self.conv_kernel)
|
||||
|
||||
self.shape_gain = min(1, 10**(shape_gain_db / 20))
|
||||
|
||||
self.filter_gain = nn.Linear(feature_dim, out_channels)
|
||||
self.filter_gain = norm(nn.Linear(feature_dim, out_channels))
|
||||
log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
|
||||
self.filter_gain_a = (log_max - log_min) / 2
|
||||
self.filter_gain_b = (log_max + log_min) / 2
|
||||
|
|
|
@ -3,6 +3,7 @@ from torch import nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
from utils.softquant import soft_quant
|
||||
|
||||
class TDShaper(nn.Module):
|
||||
COUNTER = 1
|
||||
|
@ -12,7 +13,9 @@ class TDShaper(nn.Module):
|
|||
frame_size=160,
|
||||
avg_pool_k=4,
|
||||
innovate=False,
|
||||
pool_after=False
|
||||
pool_after=False,
|
||||
softquant=False,
|
||||
apply_weight_norm=False
|
||||
):
|
||||
"""
|
||||
|
||||
|
@ -45,23 +48,29 @@ class TDShaper(nn.Module):
|
|||
assert frame_size % avg_pool_k == 0
|
||||
self.env_dim = frame_size // avg_pool_k + 1
|
||||
|
||||
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# feature transform
|
||||
self.feature_alpha1 = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
|
||||
self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2)
|
||||
self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
|
||||
self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))
|
||||
self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2))
|
||||
|
||||
if softquant:
|
||||
self.feature_alpha1_f = soft_quant(self.feature_alpha1_f)
|
||||
|
||||
if self.innovate:
|
||||
self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
|
||||
self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
|
||||
self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
|
||||
self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
|
||||
|
||||
self.feature_alpha2b = nn.Conv1d(frame_size, frame_size, 2)
|
||||
self.feature_alpha2c = nn.Conv1d(frame_size, frame_size, 2)
|
||||
self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2))
|
||||
self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2))
|
||||
|
||||
|
||||
def flop_count(self, rate):
|
||||
|
||||
frame_rate = rate / self.frame_size
|
||||
|
||||
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
|
||||
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
|
||||
|
||||
if self.innovate:
|
||||
inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
|
||||
|
@ -110,9 +119,10 @@ class TDShaper(nn.Module):
|
|||
tenv = self.envelope_transform(x)
|
||||
|
||||
# feature path
|
||||
f = torch.cat((features, tenv), dim=-1)
|
||||
f = F.pad(f.permute(0, 2, 1), [1, 0])
|
||||
alpha = F.leaky_relu(self.feature_alpha1(f), 0.2)
|
||||
f = F.pad(features.permute(0, 2, 1), [1, 0])
|
||||
t = F.pad(tenv.permute(0, 2, 1), [1, 0])
|
||||
alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t)
|
||||
alpha = F.leaky_relu(alpha, 0.2)
|
||||
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
|
||||
alpha = alpha.permute(0, 2, 1)
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
|
||||
def count_parameters(model, verbose=False):
|
||||
total = 0
|
||||
|
@ -41,7 +42,17 @@ def count_parameters(model, verbose=False):
|
|||
|
||||
return total
|
||||
|
||||
def count_nonzero_parameters(model, verbose=False):
|
||||
total = 0
|
||||
for name, p in model.named_parameters():
|
||||
count = torch.count_nonzero(p).item()
|
||||
|
||||
if verbose:
|
||||
print(f"{name}: {count} non-zero parameters")
|
||||
|
||||
total += count
|
||||
|
||||
return total
|
||||
def retain_grads(module):
|
||||
for p in module.parameters():
|
||||
if p.requires_grad:
|
||||
|
@ -62,4 +73,23 @@ def create_weights(s_real, s_gen, alpha):
|
|||
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
|
||||
weights.append(weight)
|
||||
|
||||
return weights
|
||||
return weights
|
||||
|
||||
|
||||
def _get_candidates(module: torch.nn.Module):
|
||||
candidates = []
|
||||
for key in module.__dict__.keys():
|
||||
if hasattr(module, key + '_v'):
|
||||
candidates.append(key)
|
||||
return candidates
|
||||
|
||||
def remove_all_weight_norm(model : torch.nn.Module, verbose=False):
|
||||
for name, m in model.named_modules():
|
||||
candidates = _get_candidates(m)
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
remove_weight_norm(m, name=candidate)
|
||||
if verbose: print(f'removed weight norm on weight {name}.{candidate}')
|
||||
except:
|
||||
pass
|
||||
|
|
110
dnn/torch/osce/utils/softquant.py
Normal file
110
dnn/torch/osce/utils/softquant.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
import torch
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_optimal_scale(weight):
|
||||
with torch.no_grad():
|
||||
n_out, n_in = weight.shape
|
||||
assert n_in % 4 == 0
|
||||
if n_out % 8:
|
||||
# add padding
|
||||
pad = n_out - n_out % 8
|
||||
weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
|
||||
|
||||
weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
|
||||
weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
|
||||
scale_max = weight_max_abs / 127
|
||||
scale_sum = weight_max_sum / 129
|
||||
|
||||
scale = torch.maximum(scale_max, scale_sum)
|
||||
|
||||
return scale[:n_out]
|
||||
|
||||
@torch.no_grad()
|
||||
def q_scaled_noise(module, weight):
|
||||
if isinstance(module, torch.nn.Conv1d):
|
||||
w = weight.permute(0, 2, 1).flatten(1)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
|
||||
elif isinstance(module, torch.nn.ConvTranspose1d):
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(2, 1, 0).reshape(k * o, i)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(k, o, i).permute(2, 1, 0)
|
||||
elif len(weight.shape) == 2:
|
||||
noise = torch.rand_like(weight) - 0.5
|
||||
scale = compute_optimal_scale(weight)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError('unknown quantization setting')
|
||||
|
||||
return noise
|
||||
|
||||
class SoftQuant:
|
||||
name: str
|
||||
|
||||
def __init__(self, names: str, scale: float) -> None:
|
||||
self.names = names
|
||||
self.quantization_noise = None
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, module, inputs, *args, before=True):
|
||||
if not module.training: return
|
||||
|
||||
if before:
|
||||
self.quantization_noise = dict()
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
if self.scale is None:
|
||||
self.quantization_noise[name] = q_scaled_noise(module, weight)
|
||||
else:
|
||||
self.quantization_noise[name] = \
|
||||
self.scale * weight.abs().max() * (torch.rand_like(weight) - 0.5)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight + self.quantization_noise[name]
|
||||
else:
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight - self.quantization_noise[name]
|
||||
self.quantization_noise = None
|
||||
|
||||
def apply(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant(names, scale)
|
||||
|
||||
for name in names:
|
||||
if not hasattr(module, name):
|
||||
raise ValueError("")
|
||||
|
||||
fn_before = lambda *x : fn(*x, before=True)
|
||||
fn_after = lambda *x : fn(*x, before=False)
|
||||
setattr(fn_before, 'sqm', fn)
|
||||
setattr(fn_after, 'sqm', fn)
|
||||
|
||||
|
||||
module.register_forward_pre_hook(fn_before)
|
||||
module.register_forward_hook(fn_after)
|
||||
|
||||
module
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def soft_quant(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant.apply(module, names, scale)
|
||||
return module
|
||||
|
||||
def remove_soft_quant(module, names=['weight']):
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_pre_hooks[k]
|
||||
for k, hook in module._forward_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_hooks[k]
|
||||
|
||||
return module
|
|
@ -50,7 +50,11 @@ lace_setup = {
|
|||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'skip': 91
|
||||
'skip': 91,
|
||||
'softquant': True,
|
||||
'sparsify': False,
|
||||
'sparsification_density': 0.4,
|
||||
'sparsification_schedule': [10000, 40000, 200]
|
||||
}
|
||||
},
|
||||
'data': {
|
||||
|
@ -63,7 +67,7 @@ lace_setup = {
|
|||
'num_bands_clean_spec': 64,
|
||||
'num_bands_noisy_spec': 18,
|
||||
'noisy_spec_scale': 'opus',
|
||||
'pitch_hangover': 8,
|
||||
'pitch_hangover': 0,
|
||||
},
|
||||
'training': {
|
||||
'batch_size': 256,
|
||||
|
@ -106,7 +110,11 @@ nolace_setup = {
|
|||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'skip': 91
|
||||
'skip': 91,
|
||||
'softquant': True,
|
||||
'sparsify': False,
|
||||
'sparsification_density': 0.4,
|
||||
'sparsification_schedule': [10000, 40000, 200]
|
||||
}
|
||||
},
|
||||
'data': {
|
||||
|
@ -119,7 +127,7 @@ nolace_setup = {
|
|||
'num_bands_clean_spec': 64,
|
||||
'num_bands_noisy_spec': 18,
|
||||
'noisy_spec_scale': 'opus',
|
||||
'pitch_hangover': 8,
|
||||
'pitch_hangover': 0,
|
||||
},
|
||||
'training': {
|
||||
'batch_size': 256,
|
||||
|
@ -160,7 +168,11 @@ nolace_setup_adv = {
|
|||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'skip': 91
|
||||
'skip': 91,
|
||||
'softquant': True,
|
||||
'sparsify': False,
|
||||
'sparsification_density': 0.4,
|
||||
'sparsification_schedule': [0, 0, 200]
|
||||
}
|
||||
},
|
||||
'data': {
|
||||
|
@ -173,7 +185,7 @@ nolace_setup_adv = {
|
|||
'num_bands_clean_spec': 64,
|
||||
'num_bands_noisy_spec': 18,
|
||||
'noisy_spec_scale': 'opus',
|
||||
'pitch_hangover': 8,
|
||||
'pitch_hangover': 0,
|
||||
},
|
||||
'discriminator': {
|
||||
'args': [],
|
||||
|
|
|
@ -282,7 +282,8 @@ def print_conv1d_layer(writer : CWriter,
|
|||
bias : np.ndarray,
|
||||
scale=1/128,
|
||||
format : str = 'torch',
|
||||
quantize=False):
|
||||
quantize=False,
|
||||
sparse=False):
|
||||
|
||||
|
||||
if format == "torch":
|
||||
|
@ -290,7 +291,7 @@ def print_conv1d_layer(writer : CWriter,
|
|||
weight = np.transpose(weight, (2, 1, 0))
|
||||
|
||||
lin_weight = np.reshape(weight, (-1, weight.shape[-1]))
|
||||
print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=False, diagonal=False, quantize=quantize)
|
||||
print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=sparse, diagonal=False, quantize=quantize)
|
||||
|
||||
|
||||
writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {weight.shape[2]}\n")
|
||||
|
@ -369,7 +370,8 @@ def print_tconv1d_layer(writer : CWriter,
|
|||
bias : np.ndarray,
|
||||
stride: int,
|
||||
scale=1/128,
|
||||
quantize=False):
|
||||
quantize=False,
|
||||
sparse=False):
|
||||
|
||||
in_channels, out_channels, kernel_size = weight.shape
|
||||
|
||||
|
@ -377,7 +379,7 @@ def print_tconv1d_layer(writer : CWriter,
|
|||
linear_weight = weight.transpose(2, 1, 0).reshape(kernel_size * out_channels, in_channels).transpose(1, 0)
|
||||
linear_bias = np.repeat(bias[np.newaxis, :], kernel_size, 0).flatten()
|
||||
|
||||
print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize)
|
||||
print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize, sparse=sparse)
|
||||
|
||||
writer.header.write(f"\n#define {name.upper()}_KERNEL_SIZE {kernel_size}\n")
|
||||
writer.header.write(f"\n#define {name.upper()}_STRIDE {stride}\n")
|
||||
|
|
|
@ -153,7 +153,7 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', scale=1/1
|
|||
np.save(where, 'weight_global_gain.npy', w_global_gain)
|
||||
np.save(where, 'bias_global_gain.npy', b_global_gain)
|
||||
|
||||
def dump_torch_tdshaper(where, shaper, name='tdshaper'):
|
||||
def dump_torch_tdshaper(where, shaper, name='tdshaper', quantize=False, scale=1/128):
|
||||
|
||||
if isinstance(where, CWriter):
|
||||
where.header.write(f"""
|
||||
|
@ -165,7 +165,8 @@ def dump_torch_tdshaper(where, shaper, name='tdshaper'):
|
|||
"""
|
||||
)
|
||||
|
||||
dump_torch_conv1d_weights(where, shaper.feature_alpha1, name + "_alpha1")
|
||||
dump_torch_conv1d_weights(where, shaper.feature_alpha1_f, name + "_alpha1_f", quantize=quantize, scale=scale)
|
||||
dump_torch_conv1d_weights(where, shaper.feature_alpha1_t, name + "_alpha1_t")
|
||||
dump_torch_conv1d_weights(where, shaper.feature_alpha2, name + "_alpha2")
|
||||
|
||||
if shaper.innovate:
|
||||
|
@ -274,7 +275,7 @@ def load_torch_dense_weights(where, dense):
|
|||
dense.bias.set_(torch.from_numpy(b))
|
||||
|
||||
|
||||
def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False):
|
||||
def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False):
|
||||
|
||||
w = conv.weight.detach().cpu().numpy().copy()
|
||||
if conv.bias is None:
|
||||
|
@ -284,7 +285,7 @@ def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=Fa
|
|||
|
||||
if isinstance(where, CWriter):
|
||||
|
||||
return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize)
|
||||
return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize, sparse=sparse)
|
||||
else:
|
||||
os.makedirs(where, exist_ok=True)
|
||||
|
||||
|
@ -304,7 +305,7 @@ def load_torch_conv1d_weights(where, conv):
|
|||
conv.bias.set_(torch.from_numpy(b))
|
||||
|
||||
|
||||
def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False):
|
||||
def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False):
|
||||
|
||||
w = conv.weight.detach().cpu().numpy().copy()
|
||||
if conv.bias is None:
|
||||
|
@ -314,7 +315,7 @@ def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=F
|
|||
|
||||
if isinstance(where, CWriter):
|
||||
|
||||
return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize)
|
||||
return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize, sparse=sparse)
|
||||
else:
|
||||
os.makedirs(where, exist_ok=True)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue