diff --git a/autogen.sh b/autogen.sh
index 2cac2083..0c6c7f90 100755
--- a/autogen.sh
+++ b/autogen.sh
@@ -9,7 +9,7 @@ set -e
srcdir=`dirname $0`
test -n "$srcdir" && cd "$srcdir"
-dnn/download_model.sh caca188
+dnn/download_model.sh 88477f4
echo "Updating build configuration files, please wait...."
diff --git a/dnn/nndsp.c b/dnn/nndsp.c
index bfbf5735..caa77038 100644
--- a/dnn/nndsp.c
+++ b/dnn/nndsp.c
@@ -340,7 +340,8 @@ void adashape_process_frame(
float *x_out,
const float *x_in,
const float *features,
- const LinearLayer *alpha1,
+ const LinearLayer *alpha1f,
+ const LinearLayer *alpha1t,
const LinearLayer *alpha2,
int feature_dim,
int frame_size,
@@ -350,6 +351,7 @@ void adashape_process_frame(
{
float in_buffer[ADASHAPE_MAX_INPUT_DIM + ADASHAPE_MAX_FRAME_SIZE];
float out_buffer[ADASHAPE_MAX_FRAME_SIZE];
+ float tmp_buffer[ADASHAPE_MAX_FRAME_SIZE];
int i, k;
int tenv_size;
float mean;
@@ -389,14 +391,16 @@ void adashape_process_frame(
#ifdef DEBUG_NNDSP
print_float_vector("alpha1_in", in_buffer, feature_dim + tenv_size + 1);
#endif
- compute_generic_conv1d(alpha1, out_buffer, hAdaShape->conv_alpha1_state, in_buffer, feature_dim + tenv_size + 1, ACTIVATION_LINEAR, arch);
+ compute_generic_conv1d(alpha1f, out_buffer, hAdaShape->conv_alpha1f_state, in_buffer, feature_dim, ACTIVATION_LINEAR, arch);
+ compute_generic_conv1d(alpha1t, tmp_buffer, hAdaShape->conv_alpha1t_state, tenv, tenv_size + 1, ACTIVATION_LINEAR, arch);
#ifdef DEBUG_NNDSP
print_float_vector("alpha1_out", out_buffer, frame_size);
#endif
/* compute leaky ReLU by hand. ToDo: try tanh activation */
for (i = 0; i < frame_size; i ++)
{
- in_buffer[i] = out_buffer[i] >= 0 ? out_buffer[i] : 0.2f * out_buffer[i];
+ float tmp = out_buffer[i] + tmp_buffer[i];
+ in_buffer[i] = tmp >= 0 ? tmp : 0.2 * tmp;
}
#ifdef DEBUG_NNDSP
print_float_vector("post_alpha1", in_buffer, frame_size);
diff --git a/dnn/nndsp.h b/dnn/nndsp.h
index f00094b6..6021250f 100644
--- a/dnn/nndsp.h
+++ b/dnn/nndsp.h
@@ -71,7 +71,8 @@ typedef struct {
typedef struct {
- float conv_alpha1_state[ADASHAPE_MAX_INPUT_DIM];
+ float conv_alpha1f_state[ADASHAPE_MAX_INPUT_DIM];
+ float conv_alpha1t_state[ADASHAPE_MAX_INPUT_DIM];
float conv_alpha2_state[ADASHAPE_MAX_FRAME_SIZE];
} AdaShapeState;
@@ -130,7 +131,8 @@ void adashape_process_frame(
float *x_out,
const float *x_in,
const float *features,
- const LinearLayer *alpha1,
+ const LinearLayer *alpha1f,
+ const LinearLayer *alpha1t,
const LinearLayer *alpha2,
int feature_dim,
int frame_size,
diff --git a/dnn/osce.c b/dnn/osce.c
index 2a78a6ea..aca45500 100644
--- a/dnn/osce.c
+++ b/dnn/osce.c
@@ -155,7 +155,7 @@ static void lace_feature_net(
&hLACE->layers.lace_fnet_tconv,
output_buffer,
input_buffer,
- ACTIVATION_LINEAR,
+ ACTIVATION_TANH,
arch
);
@@ -426,7 +426,7 @@ static void nolace_feature_net(
&hNoLACE->layers.nolace_fnet_tconv,
output_buffer,
input_buffer,
- ACTIVATION_LINEAR,
+ ACTIVATION_TANH,
arch
);
@@ -633,7 +633,8 @@ static void nolace_process_20ms_frame(
x_buffer2 + i_subframe * NOLACE_AF1_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
x_buffer2 + i_subframe * NOLACE_AF1_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
feature_buffer + i_subframe * NOLACE_COND_DIM,
- &layers->nolace_tdshape1_alpha1,
+ &layers->nolace_tdshape1_alpha1_f,
+ &layers->nolace_tdshape1_alpha1_t,
&layers->nolace_tdshape1_alpha2,
NOLACE_TDSHAPE1_FEATURE_DIM,
NOLACE_TDSHAPE1_FRAME_SIZE,
@@ -688,7 +689,8 @@ static void nolace_process_20ms_frame(
x_buffer1 + i_subframe * NOLACE_AF2_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
x_buffer1 + i_subframe * NOLACE_AF2_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
feature_buffer + i_subframe * NOLACE_COND_DIM,
- &layers->nolace_tdshape2_alpha1,
+ &layers->nolace_tdshape2_alpha1_f,
+ &layers->nolace_tdshape2_alpha1_t,
&layers->nolace_tdshape2_alpha2,
NOLACE_TDSHAPE2_FEATURE_DIM,
NOLACE_TDSHAPE2_FRAME_SIZE,
@@ -739,7 +741,8 @@ static void nolace_process_20ms_frame(
x_buffer2 + i_subframe * NOLACE_AF3_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
x_buffer2 + i_subframe * NOLACE_AF3_OUT_CHANNELS * NOLACE_FRAME_SIZE + NOLACE_FRAME_SIZE,
feature_buffer + i_subframe * NOLACE_COND_DIM,
- &layers->nolace_tdshape3_alpha1,
+ &layers->nolace_tdshape3_alpha1_f,
+ &layers->nolace_tdshape3_alpha1_t,
&layers->nolace_tdshape3_alpha2,
NOLACE_TDSHAPE3_FEATURE_DIM,
NOLACE_TDSHAPE3_FRAME_SIZE,
@@ -884,7 +887,7 @@ int osce_load_models(OSCEModel *model, const unsigned char *data, int len)
if (ret == 0) {ret = init_lace(&model->lace, list);}
#endif
-#ifndef DISABLE_LACE
+#ifndef DISABLE_NOLACE
if (ret == 0) {ret = init_nolace(&model->nolace, list);}
#endif
@@ -898,7 +901,7 @@ int osce_load_models(OSCEModel *model, const unsigned char *data, int len)
if (ret == 0) {ret = init_lace(&model->lace, lacelayers_arrays);}
#endif
-#ifndef DISABLE_LACE
+#ifndef DISABLE_NOLACE
if (ret == 0) {ret = init_nolace(&model->nolace, nolacelayers_arrays);}
#endif
diff --git a/dnn/osce_config.h b/dnn/osce_config.h
index de94fe2f..1543b803 100644
--- a/dnn/osce_config.h
+++ b/dnn/osce_config.h
@@ -41,7 +41,7 @@
#define OSCE_PREEMPH 0.85f
-#define OSCE_PITCH_HANGOVER 8
+#define OSCE_PITCH_HANGOVER 0
#define OSCE_CLEAN_SPEC_START 0
#define OSCE_CLEAN_SPEC_LENGTH 64
diff --git a/dnn/osce_features.c b/dnn/osce_features.c
index 0466f132..bcd48016 100644
--- a/dnn/osce_features.c
+++ b/dnn/osce_features.c
@@ -296,6 +296,7 @@ static void calculate_acorr(float *acorr, float *signal, int lag)
static int pitch_postprocessing(OSCEFeatureState *psFeatures, int lag, int type)
{
int new_lag;
+ int modulus;
#ifdef OSCE_HANGOVER_BUGFIX
#define TESTBIT 1
@@ -303,6 +304,9 @@ static int pitch_postprocessing(OSCEFeatureState *psFeatures, int lag, int type)
#define TESTBIT 0
#endif
+ modulus = OSCE_PITCH_HANGOVER;
+ if (modulus == 0) modulus ++;
+
/* hangover is currently disabled to reflect a bug in the python code. ToDo: re-evaluate hangover */
if (type != TYPE_VOICED && psFeatures->last_type == TYPE_VOICED && TESTBIT)
/* enter hangover */
@@ -311,14 +315,14 @@ static int pitch_postprocessing(OSCEFeatureState *psFeatures, int lag, int type)
if (psFeatures->pitch_hangover_count < OSCE_PITCH_HANGOVER)
{
new_lag = psFeatures->last_lag;
- psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % OSCE_PITCH_HANGOVER;
+ psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % modulus;
}
}
else if (type != TYPE_VOICED && psFeatures->pitch_hangover_count && TESTBIT)
/* continue hangover */
{
new_lag = psFeatures->last_lag;
- psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % OSCE_PITCH_HANGOVER;
+ psFeatures->pitch_hangover_count = (psFeatures->pitch_hangover_count + 1) % modulus;
}
else if (type != TYPE_VOICED)
/* unvoiced frame after hangover */
@@ -376,11 +380,7 @@ void osce_calculate_features(
/* smooth bit count */
psFeatures->numbits_smooth = 0.9f * psFeatures->numbits_smooth + 0.1f * num_bits;
numbits[0] = num_bits;
-#ifdef OSCE_NUMBITS_BUGFIX
numbits[1] = psFeatures->numbits_smooth;
-#else
- numbits[1] = num_bits;
-#endif
for (n = 0; n < num_samples; n++)
{
diff --git a/dnn/torch/dnntools/dnntools/__init__.py b/dnn/torch/dnntools/dnntools/__init__.py
new file mode 100644
index 00000000..117597ab
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/__init__.py
@@ -0,0 +1,2 @@
+from . import quantization
+from . import sparsification
\ No newline at end of file
diff --git a/dnn/torch/dnntools/dnntools/quantization/__init__.py b/dnn/torch/dnntools/dnntools/quantization/__init__.py
new file mode 100644
index 00000000..3b46a2e0
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/quantization/__init__.py
@@ -0,0 +1 @@
+from .softquant import soft_quant, remove_soft_quant
\ No newline at end of file
diff --git a/dnn/torch/dnntools/dnntools/quantization/softquant.py b/dnn/torch/dnntools/dnntools/quantization/softquant.py
new file mode 100644
index 00000000..877c6450
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/quantization/softquant.py
@@ -0,0 +1,113 @@
+import torch
+
+@torch.no_grad()
+def compute_optimal_scale(weight):
+ with torch.no_grad():
+ n_out, n_in = weight.shape
+ assert n_in % 4 == 0
+ if n_out % 8:
+ # add padding
+ pad = n_out - n_out % 8
+ weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
+
+ weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
+ weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
+ scale_max = weight_max_abs / 127
+ scale_sum = weight_max_sum / 129
+
+ scale = torch.maximum(scale_max, scale_sum)
+
+ return scale[:n_out]
+
+@torch.no_grad()
+def q_scaled_noise(module, weight):
+ if isinstance(module, torch.nn.Conv1d):
+ w = weight.permute(0, 2, 1).flatten(1)
+ noise = torch.rand_like(w) - 0.5
+ noise[w == 0] = 0 # ignore zero entries from sparsification
+ scale = compute_optimal_scale(w)
+ noise = noise * scale.unsqueeze(-1)
+ noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
+ elif isinstance(module, torch.nn.ConvTranspose1d):
+ i, o, k = weight.shape
+ w = weight.permute(2, 1, 0).reshape(k * o, i)
+ noise = torch.rand_like(w) - 0.5
+ noise[w == 0] = 0 # ignore zero entries from sparsification
+ scale = compute_optimal_scale(w)
+ noise = noise * scale.unsqueeze(-1)
+ noise = noise.reshape(k, o, i).permute(2, 1, 0)
+ elif len(weight.shape) == 2:
+ noise = torch.rand_like(weight) - 0.5
+ noise[weight == 0] = 0 # ignore zero entries from sparsification
+ scale = compute_optimal_scale(weight)
+ noise = noise * scale.unsqueeze(-1)
+ else:
+ raise ValueError('unknown quantization setting')
+
+ return noise
+
+class SoftQuant:
+ name: str
+
+ def __init__(self, names: str, scale: float) -> None:
+ self.names = names
+ self.quantization_noise = None
+ self.scale = scale
+
+ def __call__(self, module, inputs, *args, before=True):
+ if not module.training: return
+
+ if before:
+ self.quantization_noise = dict()
+ for name in self.names:
+ weight = getattr(module, name)
+ if self.scale is None:
+ self.quantization_noise[name] = q_scaled_noise(module, weight)
+ else:
+ self.quantization_noise[name] = \
+ self.scale * (torch.rand_like(weight) - 0.5)
+ with torch.no_grad():
+ weight.data[:] = weight + self.quantization_noise[name]
+ else:
+ for name in self.names:
+ weight = getattr(module, name)
+ with torch.no_grad():
+ weight.data[:] = weight - self.quantization_noise[name]
+ self.quantization_noise = None
+
+ def apply(module, names=['weight'], scale=None):
+ fn = SoftQuant(names, scale)
+
+ for name in names:
+ if not hasattr(module, name):
+ raise ValueError("")
+
+ fn_before = lambda *x : fn(*x, before=True)
+ fn_after = lambda *x : fn(*x, before=False)
+ setattr(fn_before, 'sqm', fn)
+ setattr(fn_after, 'sqm', fn)
+
+
+ module.register_forward_pre_hook(fn_before)
+ module.register_forward_hook(fn_after)
+
+ module
+
+ return fn
+
+
+def soft_quant(module, names=['weight'], scale=None):
+ fn = SoftQuant.apply(module, names, scale)
+ return module
+
+def remove_soft_quant(module, names=['weight']):
+ for k, hook in module._forward_pre_hooks.items():
+ if hasattr(hook, 'sqm'):
+ if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
+ del module._forward_pre_hooks[k]
+ for k, hook in module._forward_hooks.items():
+ if hasattr(hook, 'sqm'):
+ if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
+ del module._forward_hooks[k]
+
+ return module
\ No newline at end of file
diff --git a/dnn/torch/dnntools/dnntools/relegance/__init__.py b/dnn/torch/dnntools/dnntools/relegance/__init__.py
new file mode 100644
index 00000000..cee0143b
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/relegance/__init__.py
@@ -0,0 +1,2 @@
+from .relegance import relegance_gradient_weighting, relegance_create_tconv_kernel, relegance_map_relevance_to_input_domain, relegance_resize_relevance_to_input_size
+from .meta_critic import MetaCritic
\ No newline at end of file
diff --git a/dnn/torch/dnntools/dnntools/relegance/meta_critic.py b/dnn/torch/dnntools/dnntools/relegance/meta_critic.py
new file mode 100644
index 00000000..1af0f8ff
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/relegance/meta_critic.py
@@ -0,0 +1,85 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+class MetaCritic():
+ def __init__(self, normalize=False, gamma=0.9, beta=0.0, joint_stats=False):
+ """ Class for assessing relevance of discriminator scores
+
+ Args:
+ gamma (float, optional): update rate for tracking discriminator stats. Defaults to 0.9.
+ beta (float, optional): Miminum confidence related threshold. Defaults to 0.0.
+ """
+ self.normalize = normalize
+ self.gamma = gamma
+ self.beta = beta
+ self.joint_stats = joint_stats
+
+ self.disc_stats = dict()
+
+ def __call__(self, disc_id, real_scores, generated_scores):
+ """ calculates relevance from normalized scores
+
+ Args:
+ disc_id (any valid key): id for tracking discriminator statistics
+ real_scores (torch.tensor): scores for real data
+ generated_scores (torch.tensor): scores for generated data; expecting device to match real_scores.device
+
+ Returns:
+ torch.tensor: output-domain relevance
+ """
+
+ if self.normalize:
+ real_std = torch.std(real_scores.detach()).cpu().item()
+ gen_std = torch.std(generated_scores.detach()).cpu().item()
+ std = (real_std**2 + gen_std**2) ** .5
+ mean = torch.mean(real_scores.detach()).cpu().item() - torch.mean(generated_scores.detach()).cpu().item()
+
+ key = 0 if self.joint_stats else disc_id
+
+ if key in self.disc_stats:
+ self.disc_stats[key]['std'] = self.gamma * self.disc_stats[key]['std'] + (1 - self.gamma) * std
+ self.disc_stats[key]['mean'] = self.gamma * self.disc_stats[key]['mean'] + (1 - self.gamma) * mean
+ else:
+ self.disc_stats[key] = {
+ 'std': std + 1e-5,
+ 'mean': mean
+ }
+
+ std = self.disc_stats[key]['std']
+ mean = self.disc_stats[key]['mean']
+ else:
+ mean, std = 0, 1
+
+ relevance = torch.relu((real_scores - generated_scores - mean) / std + mean - self.beta)
+
+ if False: print(f"relevance({disc_id}): {relevance.min()=} {relevance.max()=} {relevance.mean()=}")
+
+ return relevance
\ No newline at end of file
diff --git a/dnn/torch/dnntools/dnntools/relegance/relegance.py b/dnn/torch/dnntools/dnntools/relegance/relegance.py
new file mode 100644
index 00000000..29c5be23
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/relegance/relegance.py
@@ -0,0 +1,449 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+import torch.nn.functional as F
+
+
+def view_one_hot(index, length):
+ vec = length * [1]
+ vec[index] = -1
+ return vec
+
+def create_smoothing_kernel(widths, gamma=1.5):
+ """ creates a truncated gaussian smoothing kernel for the given widths
+
+ Parameters:
+ -----------
+ widths: list[Int] or torch.LongTensor
+ specifies the shape of the smoothing kernel, entries must be > 0.
+
+ gamma: float, optional
+ decay factor for gaussian relative to kernel size
+
+ Returns:
+ --------
+ kernel: torch.FloatTensor
+ """
+
+ widths = torch.LongTensor(widths)
+ num_dims = len(widths)
+
+ assert(widths.min() > 0)
+
+ centers = widths.float() / 2 - 0.5
+ sigmas = gamma * (centers + 1)
+
+ vals = []
+
+ vals= [((torch.arange(widths[i]) - centers[i]) / sigmas[i]) ** 2 for i in range(num_dims)]
+ vals = sum([vals[i].view(view_one_hot(i, num_dims)) for i in range(num_dims)])
+
+ kernel = torch.exp(- vals)
+ kernel = kernel / kernel.sum()
+
+ return kernel
+
+
+def create_partition_kernel(widths, strides):
+ """ creates a partition kernel for mapping a convolutional network output back to the input domain
+
+ Given a fully convolutional network with receptive field of shape widths and the given strides, this
+ function construncts an intorpolation kernel whose tranlations by multiples of the given strides form
+ a partition of one on the input domain.
+
+ Parameter:
+ ----------
+ widths: list[Int] or torch.LongTensor
+ shape of receptive field
+
+ strides: list[Int] or torch.LongTensor
+ total strides of convolutional network
+
+ Returns:
+ kernel: torch.FloatTensor
+ """
+
+ num_dims = len(widths)
+ assert num_dims == len(strides) and num_dims in {1, 2, 3}
+
+ convs = {1 : F.conv1d, 2 : F.conv2d, 3 : F.conv3d}
+
+ widths = torch.LongTensor(widths)
+ strides = torch.LongTensor(strides)
+
+ proto_kernel = torch.ones(torch.minimum(strides, widths).tolist())
+
+ # create interpolation kernel eta
+ eta_widths = widths - strides + 1
+ if eta_widths.min() <= 0:
+ print("[create_partition_kernel] warning: receptive field does not cover input domain")
+ eta_widths = torch.maximum(eta_widths, torch.ones_like(eta_widths))
+
+
+ eta = create_smoothing_kernel(eta_widths).view(1, 1, *eta_widths.tolist())
+
+ padding = torch.repeat_interleave(eta_widths - 1, 2, 0).tolist()[::-1] # ordering of dimensions for padding and convolution functions is reversed in torch
+ padded_proto_kernel = F.pad(proto_kernel, padding)
+ padded_proto_kernel = padded_proto_kernel.view(1, 1, *padded_proto_kernel.shape)
+ kernel = convs[num_dims](padded_proto_kernel, eta)
+
+ return kernel
+
+
+def receptive_field(conv_model, input_shape, output_position):
+ """ estimates boundaries of receptive field connected to output_position via autograd
+
+ Parameters:
+ -----------
+ conv_model: nn.Module or autograd function
+ function or model implementing fully convolutional model
+
+ input_shape: List[Int]
+ input shape ignoring batch dimension, i.e. [num_channels, dim1, dim2, ...]
+
+ output_position: List[Int]
+ output position for which the receptive field is determined; the function raises an exception
+ if output_position is out of bounds for the given input_shape.
+
+ Returns:
+ --------
+ low: List[Int]
+ start indices of receptive field
+
+ high: List[Int]
+ stop indices of receptive field
+
+ """
+
+ x = torch.randn((1,) + tuple(input_shape), requires_grad=True)
+ y = conv_model(x)
+
+ # collapse channels and remove batch dimension
+ y = torch.sum(y, 1)[0]
+
+ # create mask
+ mask = torch.zeros_like(y)
+ index = [torch.tensor(i) for i in output_position]
+ try:
+ mask.index_put_(index, torch.tensor(1, dtype=mask.dtype))
+ except IndexError:
+ raise ValueError('output_position out of bounds')
+
+ (mask * y).sum().backward()
+
+ # sum over channels and remove batch dimension
+ grad = torch.sum(x.grad, dim=1)[0]
+ tmp = torch.nonzero(grad, as_tuple=True)
+ low = [t.min().item() for t in tmp]
+ high = [t.max().item() for t in tmp]
+
+ return low, high
+
+def estimate_conv_parameters(model, num_channels, num_dims, width, max_stride=10):
+ """ attempts to estimate receptive field size, strides and left paddings for given model
+
+
+ Parameters:
+ -----------
+ model: nn.Module or autograd function
+ fully convolutional model for which parameters are estimated
+
+ num_channels: Int
+ number of input channels for model
+
+ num_dims: Int
+ number of input dimensions for model (without channel dimension)
+
+ width: Int
+ width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
+
+ max_stride: Int, optional
+ assumed maximal stride of the model for any dimension, when set too low the function may fail for
+ any value of width
+
+ Returns:
+ --------
+ receptive_field_size: List[Int]
+ receptive field size in all dimension
+
+ strides: List[Int]
+ stride in all dimensions
+
+ left_paddings: List[Int]
+ left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
+
+ Raises:
+ -------
+ ValueError, KeyError
+
+ """
+
+ input_shape = [num_channels] + num_dims * [width]
+ output_position1 = num_dims * [width // (2 * max_stride)]
+ output_position2 = num_dims * [width // (2 * max_stride) + 1]
+
+ low1, high1 = receptive_field(model, input_shape, output_position1)
+ low2, high2 = receptive_field(model, input_shape, output_position2)
+
+ widths1 = [h - l + 1 for l, h in zip(low1, high1)]
+ widths2 = [h - l + 1 for l, h in zip(low2, high2)]
+
+ if not all([w1 - w2 == 0 for w1, w2 in zip(widths1, widths2)]) or not all([l1 != l2 for l1, l2 in zip(low1, low2)]):
+ raise ValueError("[estimate_strides]: widths to small to determine strides")
+
+ receptive_field_size = widths1
+ strides = [l2 - l1 for l1, l2 in zip(low1, low2)]
+ left_paddings = [s * p - l for l, s, p in zip(low1, strides, output_position1)]
+
+ return receptive_field_size, strides, left_paddings
+
+def inspect_conv_model(model, num_channels, num_dims, max_width=10000, width_hint=None, stride_hint=None, verbose=False):
+ """ determines size of receptive field, strides and padding probabilistically
+
+
+ Parameters:
+ -----------
+ model: nn.Module or autograd function
+ fully convolutional model for which parameters are estimated
+
+ num_channels: Int
+ number of input channels for model
+
+ num_dims: Int
+ number of input dimensions for model (without channel dimension)
+
+ max_width: Int
+ maximum width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
+
+ verbose: bool, optional
+ if true, the function prints parameters for individual trials
+
+ Returns:
+ --------
+ receptive_field_size: List[Int]
+ receptive field size in all dimension
+
+ strides: List[Int]
+ stride in all dimensions
+
+ left_paddings: List[Int]
+ left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
+
+ Raises:
+ -------
+ ValueError
+
+ """
+
+ max_stride = max_width // 2
+ stride = max_stride // 100
+ width = max_width // 100
+
+ if width_hint is not None: width = 2 * width_hint
+ if stride_hint is not None: stride = stride_hint
+
+ did_it = False
+ while width < max_width and stride < max_stride:
+ try:
+ if verbose: print(f"[inspect_conv_model] trying parameters {width=}, {stride=}")
+ receptive_field_size, strides, left_paddings = estimate_conv_parameters(model, num_channels, num_dims, width, stride)
+ did_it = True
+ except:
+ pass
+
+ if did_it: break
+
+ width *= 2
+ if width >= max_width and stride < max_stride:
+ stride *= 2
+ width = 2 * stride
+
+ if not did_it:
+ raise ValueError(f'could not determine conv parameter with given max_width={max_width}')
+
+ return receptive_field_size, strides, left_paddings
+
+
+class GradWeight(torch.autograd.Function):
+ def __init__(self):
+ super().__init__()
+
+ @staticmethod
+ def forward(ctx, x, weight):
+ ctx.save_for_backward(weight)
+ return x.clone()
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ weight, = ctx.saved_tensors
+
+ grad_input = grad_output * weight
+
+ return grad_input, None
+
+
+# API
+
+def relegance_gradient_weighting(x, weight):
+ """
+
+ Args:
+ x (torch.tensor): input tensor
+ weight (torch.tensor or None): weight tensor for gradients of x; if None, no gradient weighting will be applied in backward pass
+
+ Returns:
+ torch.tensor: the unmodified input tensor x
+
+ Raises:
+ RuntimeError: if estimation of parameters fails due to exceeded compute budget
+ """
+ if weight is None:
+ return x
+ else:
+ return GradWeight.apply(x, weight)
+
+
+
+def relegance_create_tconv_kernel(model, num_channels, num_dims, width_hint=None, stride_hint=None, verbose=False):
+ """ creates parameters for mapping back output domain relevance to input tomain
+
+ Args:
+ model (nn.Module or autograd.Function): fully convolutional model
+ num_channels (int): number of input channels to model
+ num_dims (int): number of input dimensions of model (without channel and batch dimension)
+ width_hint(int or None): optional hint at maximal width of receptive field
+ stride_hint(int or None): optional hint at maximal stride
+
+ Returns:
+ dict: contains kernel, kernel dimensions, strides and left paddings for transposed convolution
+ """
+
+ max_width = int(100000 / (10 ** num_dims))
+
+ did_it = False
+ try:
+ receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
+ did_it = True
+ except:
+ # try once again with larger max_width
+ max_width *= 10
+
+ # crash if exception is raised
+ try:
+ if not did_it: receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
+ except:
+ raise RuntimeError("could not determine parameters within given compute budget")
+
+ partition_kernel = create_partition_kernel(receptive_field_size, strides)
+ partition_kernel = torch.repeat_interleave(partition_kernel, num_channels, 1)
+
+ tconv_parameters = {
+ 'kernel': partition_kernel,
+ 'receptive_field_shape': receptive_field_size,
+ 'stride': strides,
+ 'left_padding': left_paddings,
+ 'num_dims': num_dims
+ }
+
+ return tconv_parameters
+
+
+
+def relegance_map_relevance_to_input_domain(od_relevance, tconv_parameters):
+ """ maps output-domain relevance to input-domain relevance via transpose convolution
+
+ Args:
+ od_relevance (torch.tensor): output-domain relevance
+ tconv_parameters (dict): parameter dict as created by relegance_create_tconv_kernel
+
+ Returns:
+ torch.tensor: input-domain relevance. The tensor is left aligned, i.e. the all-zero index of the output corresponds to the all-zero index of the discriminator input.
+ Otherwise, the size of the output tensor does not need to match the size of the discriminator input. Use relegance_resize_relevance_to_input_size for a
+ convenient way to adjust the output to the correct size.
+
+ Raises:
+ ValueError: if number of dimensions is not supported
+ """
+
+ kernel = tconv_parameters['kernel'].to(od_relevance.device)
+ rf_shape = tconv_parameters['receptive_field_shape']
+ stride = tconv_parameters['stride']
+ left_padding = tconv_parameters['left_padding']
+
+ num_dims = len(kernel.shape) - 2
+
+ # repeat boundary values
+ od_padding = [rf_shape[i//2] // stride[i//2] + 1 for i in range(2 * num_dims)]
+ padded_od_relevance = F.pad(od_relevance, od_padding[::-1], mode='replicate')
+ od_padding = od_padding[::2]
+
+ # apply mapping and left trimming
+ if num_dims == 1:
+ id_relevance = F.conv_transpose1d(padded_od_relevance, kernel, stride=stride)
+ id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :]
+ elif num_dims == 2:
+ id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
+ id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:]
+ elif num_dims == 3:
+ id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
+ id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:, left_padding[2] + stride[2] * od_padding[2] :]
+ else:
+ raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
+
+ return id_relevance
+
+
+def relegance_resize_relevance_to_input_size(reference_input, relevance):
+ """ adjusts size of relevance tensor to reference input size
+
+ Args:
+ reference_input (torch.tensor): discriminator input tensor for reference
+ relevance (torch.tensor): input-domain relevance corresponding to input tensor reference_input
+
+ Returns:
+ torch.tensor: resized relevance
+
+ Raises:
+ ValueError: if number of dimensions is not supported
+ """
+ resized_relevance = torch.zeros_like(reference_input)
+
+ num_dims = len(reference_input.shape) - 2
+ with torch.no_grad():
+ if num_dims == 1:
+ resized_relevance[:] = relevance[..., : min(reference_input.size(-1), relevance.size(-1))]
+ elif num_dims == 2:
+ resized_relevance[:] = relevance[..., : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
+ elif num_dims == 3:
+ resized_relevance[:] = relevance[..., : min(reference_input.size(-3), relevance.size(-3)), : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
+ else:
+ raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
+
+ return resized_relevance
\ No newline at end of file
diff --git a/dnn/torch/dnntools/dnntools/sparsification/__init__.py b/dnn/torch/dnntools/dnntools/sparsification/__init__.py
new file mode 100644
index 00000000..fcc91746
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/sparsification/__init__.py
@@ -0,0 +1,6 @@
+from .gru_sparsifier import GRUSparsifier
+from .conv1d_sparsifier import Conv1dSparsifier
+from .conv_transpose1d_sparsifier import ConvTranspose1dSparsifier
+from .linear_sparsifier import LinearSparsifier
+from .common import sparsify_matrix, calculate_gru_flops_per_step
+from .utils import mark_for_sparsification, create_sparsifier
\ No newline at end of file
diff --git a/dnn/torch/dnntools/dnntools/sparsification/base_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/base_sparsifier.py
new file mode 100644
index 00000000..dd62f40b
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/sparsification/base_sparsifier.py
@@ -0,0 +1,58 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+class BaseSparsifier:
+ def __init__(self, task_list, start, stop, interval, exponent=3):
+
+ # just copying parameters...
+ self.start = start
+ self.stop = stop
+ self.interval = interval
+ self.exponent = exponent
+ self.task_list = task_list
+
+ # ... and setting counter to 0
+ self.step_counter = 0
+
+ def step(self, verbose=False):
+ # compute current interpolation factor
+ self.step_counter += 1
+
+ if self.step_counter < self.start:
+ return
+ elif self.step_counter < self.stop:
+ # update only every self.interval-th interval
+ if self.step_counter % self.interval:
+ return
+
+ alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
+ else:
+ alpha = 0
+
+ self.sparsify(alpha, verbose=verbose)
\ No newline at end of file
diff --git a/dnn/torch/dnntools/dnntools/sparsification/common.py b/dnn/torch/dnntools/dnntools/sparsification/common.py
new file mode 100644
index 00000000..47181800
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/sparsification/common.py
@@ -0,0 +1,123 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+debug=True
+
+def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
+ """ sparsifies matrix with specified block size
+
+ Parameters:
+ -----------
+ matrix : torch.tensor
+ matrix to sparsify
+ density : int
+ target density
+ block_size : [int, int]
+ block size dimensions
+ keep_diagonal : bool
+ If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
+ """
+
+ m, n = matrix.shape
+ m1, n1 = block_size
+
+ if m % m1 or n % n1:
+ raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
+
+ # extract diagonal if keep_diagonal = True
+ if keep_diagonal:
+ if m != n:
+ raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
+
+ to_spare = torch.diag(torch.diag(matrix))
+ matrix = matrix - to_spare
+ else:
+ to_spare = torch.zeros_like(matrix)
+
+ # calculate energy in sub-blocks
+ x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
+ x = x ** 2
+ block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
+
+ number_of_blocks = (m * n) // (m1 * n1)
+ number_of_survivors = round(number_of_blocks * density)
+
+ # masking threshold
+ if number_of_survivors == 0:
+ threshold = 0
+ else:
+ threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
+
+ # create mask
+ mask = torch.ones_like(block_energies)
+ mask[block_energies < threshold] = 0
+ mask = torch.repeat_interleave(mask, m1, dim=0)
+ mask = torch.repeat_interleave(mask, n1, dim=1)
+
+ # perform masking
+ masked_matrix = mask * matrix + to_spare
+
+ if return_mask:
+ return masked_matrix, mask
+ else:
+ return masked_matrix
+
+def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
+ input_size = gru.input_size
+ hidden_size = gru.hidden_size
+ flops = 0
+
+ input_density = (
+ sparsification_dict.get('W_ir', [1])[0]
+ + sparsification_dict.get('W_in', [1])[0]
+ + sparsification_dict.get('W_iz', [1])[0]
+ ) / 3
+
+ recurrent_density = (
+ sparsification_dict.get('W_hr', [1])[0]
+ + sparsification_dict.get('W_hn', [1])[0]
+ + sparsification_dict.get('W_hz', [1])[0]
+ ) / 3
+
+ # input matrix vector multiplications
+ if not drop_input:
+ flops += 2 * 3 * input_size * hidden_size * input_density
+
+ # recurrent matrix vector multiplications
+ flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
+
+ # biases
+ flops += 6 * hidden_size
+
+ # activations estimated by 10 flops per activation
+ flops += 30 * hidden_size
+
+ return flops
diff --git a/dnn/torch/dnntools/dnntools/sparsification/conv1d_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/conv1d_sparsifier.py
new file mode 100644
index 00000000..1ac51d0d
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/sparsification/conv1d_sparsifier.py
@@ -0,0 +1,133 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+from .base_sparsifier import BaseSparsifier
+from .common import sparsify_matrix, debug
+
+
+class Conv1dSparsifier(BaseSparsifier):
+ def __init__(self, task_list, start, stop, interval, exponent=3):
+ """ Sparsifier for torch.nn.GRUs
+
+ Parameters:
+ -----------
+ task_list : list
+ task_list contains a list of tuples (conv1d, params), where conv1d is an instance
+ of torch.nn.Conv1d and params is a tuple (density, [m, n]),
+ where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
+ sparsification is applied.
+
+ start : int
+ training step after which sparsification will be started.
+
+ stop : int
+ training step after which sparsification will be completed.
+
+ interval : int
+ sparsification interval for steps between start and stop. After stop sparsification will be
+ carried out after every call to GRUSparsifier.step()
+
+ exponent : float
+ Interpolation exponent for sparsification interval. In step i sparsification will be carried out
+ with density (alpha + target_density * (1 * alpha)), where
+ alpha = ((stop - i) / (start - stop)) ** exponent
+
+ Example:
+ --------
+ >>> import torch
+ >>> conv = torch.nn.Conv1d(8, 16, 8)
+ >>> params = (0.2, [8, 4])
+ >>> sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 50)
+ >>> for i in range(100):
+ ... sparsifier.step()
+ """
+ super().__init__(task_list, start, stop, interval, exponent=3)
+
+ self.last_mask = None
+
+
+ def sparsify(self, alpha, verbose=False):
+ """ carries out sparsification step
+
+ Call this function after optimizer.step in your
+ training loop.
+
+ Parameters:
+ ----------
+ alpha : float
+ density interpolation parameter (1: dense, 0: target density)
+ verbose : bool
+ if true, densities are printed out
+
+ Returns:
+ --------
+ None
+
+ """
+
+ with torch.no_grad():
+ for conv, params in self.task_list:
+ # reshape weight
+ if hasattr(conv, 'weight_v'):
+ weight = conv.weight_v
+ else:
+ weight = conv.weight
+ i, o, k = weight.shape
+ w = weight.permute(0, 2, 1).flatten(1)
+ target_density, block_size = params
+ density = alpha + (1 - alpha) * target_density
+ w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
+ w = w.reshape(i, k, o).permute(0, 2, 1)
+ weight[:] = w
+
+ if self.last_mask is not None:
+ if not torch.all(self.last_mask * new_mask == new_mask) and debug:
+ print("weight resurrection in conv.weight")
+
+ self.last_mask = new_mask
+
+ if verbose:
+ print(f"conv1d_sparsier[{self.step_counter}]: {density=}")
+
+
+if __name__ == "__main__":
+ print("Testing sparsifier")
+
+ import torch
+ conv = torch.nn.Conv1d(8, 16, 8)
+ params = (0.2, [8, 4])
+
+ sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 5)
+
+ for i in range(100):
+ sparsifier.step(verbose=True)
+
+ print(conv.weight)
diff --git a/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py
new file mode 100644
index 00000000..6d9398f2
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py
@@ -0,0 +1,134 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+
+from .base_sparsifier import BaseSparsifier
+from .common import sparsify_matrix, debug
+
+
+class ConvTranspose1dSparsifier(BaseSparsifier):
+ def __init__(self, task_list, start, stop, interval, exponent=3):
+ """ Sparsifier for torch.nn.GRUs
+
+ Parameters:
+ -----------
+ task_list : list
+ task_list contains a list of tuples (conv1d, params), where conv1d is an instance
+ of torch.nn.Conv1d and params is a tuple (density, [m, n]),
+ where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
+ sparsification is applied.
+
+ start : int
+ training step after which sparsification will be started.
+
+ stop : int
+ training step after which sparsification will be completed.
+
+ interval : int
+ sparsification interval for steps between start and stop. After stop sparsification will be
+ carried out after every call to GRUSparsifier.step()
+
+ exponent : float
+ Interpolation exponent for sparsification interval. In step i sparsification will be carried out
+ with density (alpha + target_density * (1 * alpha)), where
+ alpha = ((stop - i) / (start - stop)) ** exponent
+
+ Example:
+ --------
+ >>> import torch
+ >>> conv = torch.nn.ConvTranspose1d(8, 16, 8)
+ >>> params = (0.2, [8, 4])
+ >>> sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 50)
+ >>> for i in range(100):
+ ... sparsifier.step()
+ """
+
+ super().__init__(task_list, start, stop, interval, exponent=3)
+
+ self.last_mask = None
+
+ def sparsify(self, alpha, verbose=False):
+ """ carries out sparsification step
+
+ Call this function after optimizer.step in your
+ training loop.
+
+ Parameters:
+ ----------
+ alpha : float
+ density interpolation parameter (1: dense, 0: target density)
+ verbose : bool
+ if true, densities are printed out
+
+ Returns:
+ --------
+ None
+
+ """
+
+ with torch.no_grad():
+ for conv, params in self.task_list:
+ # reshape weight
+ if hasattr(conv, 'weight_v'):
+ weight = conv.weight_v
+ else:
+ weight = conv.weight
+ i, o, k = weight.shape
+ w = weight.permute(2, 1, 0).reshape(k * o, i)
+ target_density, block_size = params
+ density = alpha + (1 - alpha) * target_density
+ w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
+ w = w.reshape(k, o, i).permute(2, 1, 0)
+ weight[:] = w
+
+ if self.last_mask is not None:
+ if not torch.all(self.last_mask * new_mask == new_mask) and debug:
+ print("weight resurrection in conv.weight")
+
+ self.last_mask = new_mask
+
+ if verbose:
+ print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}")
+
+
+if __name__ == "__main__":
+ print("Testing sparsifier")
+
+ import torch
+ conv = torch.nn.ConvTranspose1d(8, 16, 4, 4)
+ params = (0.2, [8, 4])
+
+ sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 5)
+
+ for i in range(100):
+ sparsifier.step(verbose=True)
+
+ print(conv.weight)
diff --git a/dnn/torch/dnntools/dnntools/sparsification/gru_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/gru_sparsifier.py
new file mode 100644
index 00000000..417b04be
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/sparsification/gru_sparsifier.py
@@ -0,0 +1,178 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+from .base_sparsifier import BaseSparsifier
+from .common import sparsify_matrix, debug
+
+
+class GRUSparsifier(BaseSparsifier):
+ def __init__(self, task_list, start, stop, interval, exponent=3):
+ """ Sparsifier for torch.nn.GRUs
+
+ Parameters:
+ -----------
+ task_list : list
+ task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
+ of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
+ 'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
+ update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
+ where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
+ sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
+ should be kept.
+
+ start : int
+ training step after which sparsification will be started.
+
+ stop : int
+ training step after which sparsification will be completed.
+
+ interval : int
+ sparsification interval for steps between start and stop. After stop sparsification will be
+ carried out after every call to GRUSparsifier.step()
+
+ exponent : float
+ Interpolation exponent for sparsification interval. In step i sparsification will be carried out
+ with density (alpha + target_density * (1 * alpha)), where
+ alpha = ((stop - i) / (start - stop)) ** exponent
+
+ Example:
+ --------
+ >>> import torch
+ >>> gru = torch.nn.GRU(10, 20)
+ >>> sparsify_dict = {
+ ... 'W_ir' : (0.5, [2, 2], False),
+ ... 'W_iz' : (0.6, [2, 2], False),
+ ... 'W_in' : (0.7, [2, 2], False),
+ ... 'W_hr' : (0.1, [4, 4], True),
+ ... 'W_hz' : (0.2, [4, 4], True),
+ ... 'W_hn' : (0.3, [4, 4], True),
+ ... }
+ >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
+ >>> for i in range(100):
+ ... sparsifier.step()
+ """
+ super().__init__(task_list, start, stop, interval, exponent=3)
+
+ self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
+
+ def sparsify(self, alpha, verbose=False):
+ """ carries out sparsification step
+
+ Call this function after optimizer.step in your
+ training loop.
+
+ Parameters:
+ ----------
+ alpha : float
+ density interpolation parameter (1: dense, 0: target density)
+ verbose : bool
+ if true, densities are printed out
+
+ Returns:
+ --------
+ None
+
+ """
+
+ with torch.no_grad():
+ for gru, params in self.task_list:
+ hidden_size = gru.hidden_size
+
+ # input weights
+ for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
+ if key in params:
+ if hasattr(gru, 'weight_ih_l0_v'):
+ weight = gru.weight_ih_l0_v
+ else:
+ weight = gru.weight_ih_l0
+ density = alpha + (1 - alpha) * params[key][0]
+ if verbose:
+ print(f"[{self.step_counter}]: {key} density: {density}")
+
+ weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
+ weight[i * hidden_size : (i + 1) * hidden_size, : ],
+ density, # density
+ params[key][1], # block_size
+ params[key][2], # keep_diagonal (might want to set this to False)
+ return_mask=True
+ )
+
+ if type(self.last_masks[key]) != type(None):
+ if not torch.all(self.last_masks[key] * new_mask == new_mask) and debug:
+ print("weight resurrection in weight_ih_l0_v")
+
+ self.last_masks[key] = new_mask
+
+ # recurrent weights
+ for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
+ if key in params:
+ if hasattr(gru, 'weight_hh_l0_v'):
+ weight = gru.weight_hh_l0_v
+ else:
+ weight = gru.weight_hh_l0
+ density = alpha + (1 - alpha) * params[key][0]
+ if verbose:
+ print(f"[{self.step_counter}]: {key} density: {density}")
+ weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
+ weight[i * hidden_size : (i + 1) * hidden_size, : ],
+ density,
+ params[key][1], # block_size
+ params[key][2], # keep_diagonal (might want to set this to False)
+ return_mask=True
+ )
+
+ if type(self.last_masks[key]) != type(None):
+ if not torch.all(self.last_masks[key] * new_mask == new_mask) and True:
+ print("weight resurrection in weight_hh_l0_v")
+
+ self.last_masks[key] = new_mask
+
+
+
+if __name__ == "__main__":
+ print("Testing sparsifier")
+
+ gru = torch.nn.GRU(10, 20)
+ sparsify_dict = {
+ 'W_ir' : (0.5, [2, 2], False),
+ 'W_iz' : (0.6, [2, 2], False),
+ 'W_in' : (0.7, [2, 2], False),
+ 'W_hr' : (0.1, [4, 4], True),
+ 'W_hz' : (0.2, [4, 4], True),
+ 'W_hn' : (0.3, [4, 4], True),
+ }
+
+ sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
+
+ for i in range(100):
+ sparsifier.step(verbose=True)
+
+ print(gru.weight_hh_l0)
diff --git a/dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py
new file mode 100644
index 00000000..59251ddd
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py
@@ -0,0 +1,128 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+from .base_sparsifier import BaseSparsifier
+from .common import sparsify_matrix
+
+
+class LinearSparsifier(BaseSparsifier):
+ def __init__(self, task_list, start, stop, interval, exponent=3):
+ """ Sparsifier for torch.nn.GRUs
+
+ Parameters:
+ -----------
+ task_list : list
+ task_list contains a list of tuples (linear, params), where linear is an instance
+ of torch.nn.Linear and params is a tuple (density, [m, n]),
+ where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
+ sparsification is applied.
+
+ start : int
+ training step after which sparsification will be started.
+
+ stop : int
+ training step after which sparsification will be completed.
+
+ interval : int
+ sparsification interval for steps between start and stop. After stop sparsification will be
+ carried out after every call to GRUSparsifier.step()
+
+ exponent : float
+ Interpolation exponent for sparsification interval. In step i sparsification will be carried out
+ with density (alpha + target_density * (1 * alpha)), where
+ alpha = ((stop - i) / (start - stop)) ** exponent
+
+ Example:
+ --------
+ >>> import torch
+ >>> linear = torch.nn.Linear(8, 16)
+ >>> params = (0.2, [8, 4])
+ >>> sparsifier = LinearSparsifier([(linear, params)], 0, 100, 50)
+ >>> for i in range(100):
+ ... sparsifier.step()
+ """
+
+ super().__init__(task_list, start, stop, interval, exponent=3)
+
+ self.last_mask = None
+
+ def sparsify(self, alpha, verbose=False):
+ """ carries out sparsification step
+
+ Call this function after optimizer.step in your
+ training loop.
+
+ Parameters:
+ ----------
+ alpha : float
+ density interpolation parameter (1: dense, 0: target density)
+ verbose : bool
+ if true, densities are printed out
+
+ Returns:
+ --------
+ None
+
+ """
+
+ with torch.no_grad():
+ for linear, params in self.task_list:
+ if hasattr(linear, 'weight_v'):
+ weight = linear.weight_v
+ else:
+ weight = linear.weight
+ target_density, block_size = params
+ density = alpha + (1 - alpha) * target_density
+ weight[:], new_mask = sparsify_matrix(weight, density, block_size, return_mask=True)
+
+ if self.last_mask is not None:
+ if not torch.all(self.last_mask * new_mask == new_mask) and debug:
+ print("weight resurrection in conv.weight")
+
+ self.last_mask = new_mask
+
+ if verbose:
+ print(f"linear_sparsifier[{self.step_counter}]: {density=}")
+
+
+if __name__ == "__main__":
+ print("Testing sparsifier")
+
+ import torch
+ linear = torch.nn.Linear(8, 16)
+ params = (0.2, [4, 2])
+
+ sparsifier = LinearSparsifier([(linear, params)], 0, 100, 5)
+
+ for i in range(100):
+ sparsifier.step(verbose=True)
+
+ print(linear.weight)
diff --git a/dnn/torch/dnntools/dnntools/sparsification/utils.py b/dnn/torch/dnntools/dnntools/sparsification/utils.py
new file mode 100644
index 00000000..42f22353
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/sparsification/utils.py
@@ -0,0 +1,64 @@
+import torch
+
+from dnntools.sparsification import GRUSparsifier, LinearSparsifier, Conv1dSparsifier, ConvTranspose1dSparsifier
+
+def mark_for_sparsification(module, params):
+ setattr(module, 'sparsify', True)
+ setattr(module, 'sparsification_params', params)
+ return module
+
+def create_sparsifier(module, start, stop, interval):
+ sparsifier_list = []
+ for m in module.modules():
+ if hasattr(m, 'sparsify'):
+ if isinstance(m, torch.nn.GRU):
+ sparsifier_list.append(
+ GRUSparsifier([(m, m.sparsification_params)], start, stop, interval)
+ )
+ elif isinstance(m, torch.nn.Linear):
+ sparsifier_list.append(
+ LinearSparsifier([(m, m.sparsification_params)], start, stop, interval)
+ )
+ elif isinstance(m, torch.nn.Conv1d):
+ sparsifier_list.append(
+ Conv1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
+ )
+ elif isinstance(m, torch.nn.ConvTranspose1d):
+ sparsifier_list.append(
+ ConvTranspose1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
+ )
+ else:
+ print(f"[create_sparsifier] warning: module {m} marked for sparsification but no suitable sparsifier exists.")
+
+ def sparsify(verbose=False):
+ for sparsifier in sparsifier_list:
+ sparsifier.step(verbose)
+
+ return sparsify
+
+
+def count_parameters(model, verbose=False):
+ total = 0
+ for name, p in model.named_parameters():
+ count = torch.ones_like(p).sum().item()
+
+ if verbose:
+ print(f"{name}: {count} parameters")
+
+ total += count
+
+ return total
+
+def estimate_nonzero_parameters(module):
+ num_zero_parameters = 0
+ if hasattr(module, 'sparsify'):
+ params = module.sparsification_params
+ if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.ConvTranspose1d):
+ num_zero_parameters = torch.ones_like(module.weight).sum().item() * (1 - params[0])
+ elif isinstance(module, torch.nn.GRU):
+ num_zero_parameters = module.input_size * module.hidden_size * (3 - params['W_ir'][0] - params['W_iz'][0] - params['W_in'][0])
+ num_zero_parameters += module.hidden_size * module.hidden_size * (3 - params['W_hr'][0] - params['W_hz'][0] - params['W_hn'][0])
+ elif isinstance(module, torch.nn.Linear):
+ num_zero_parameters = module.in_features * module.out_features * params[0]
+ else:
+ raise ValueError(f'unknown sparsification method for module of type {type(module)}')
diff --git a/dnn/torch/dnntools/requirements.txt b/dnn/torch/dnntools/requirements.txt
new file mode 100644
index 00000000..08ed5eeb
--- /dev/null
+++ b/dnn/torch/dnntools/requirements.txt
@@ -0,0 +1 @@
+torch
\ No newline at end of file
diff --git a/dnn/torch/dnntools/setup.py b/dnn/torch/dnntools/setup.py
new file mode 100644
index 00000000..bc4ef3f1
--- /dev/null
+++ b/dnn/torch/dnntools/setup.py
@@ -0,0 +1,48 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+#!/usr/bin/env/python
+import os
+from setuptools import setup
+
+lib_folder = os.path.dirname(os.path.realpath(__file__))
+
+with open(os.path.join(lib_folder, 'requirements.txt'), 'r') as f:
+ install_requires = list(f.read().splitlines())
+
+print(install_requires)
+
+setup(name='dnntools',
+ version='1.0',
+ author='Jan Buethe',
+ author_email='jbuethe@amazon.de',
+ description='Non-Standard tools for deep neural network training with PyTorch',
+ packages=['dnntools', 'dnntools.sparsification', 'dnntools.quantization'],
+ install_requires=install_requires
+ )
diff --git a/dnn/torch/osce/adv_train_model.py b/dnn/torch/osce/adv_train_model.py
index 9cd32000..7db859e4 100644
--- a/dnn/torch/osce/adv_train_model.py
+++ b/dnn/torch/osce/adv_train_model.py
@@ -111,7 +111,7 @@ os.makedirs(checkpoint_dir, exist_ok=True)
if has_git:
working_dir = os.path.split(__file__)[0]
try:
- repo = git.Repo(working_dir)
+ repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
@@ -408,6 +408,10 @@ for ep in range(1, epochs + 1):
optimizer.step()
+ # sparsification
+ if hasattr(model, 'sparsifier'):
+ model.sparsifier()
+
running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
running_adv_loss += gen_loss.detach().cpu().item()
running_disc_loss += disc_loss.detach().cpu().item()
diff --git a/dnn/torch/osce/adv_train_vocoder.py b/dnn/torch/osce/adv_train_vocoder.py
index 754a1529..73e3c9b0 100644
--- a/dnn/torch/osce/adv_train_vocoder.py
+++ b/dnn/torch/osce/adv_train_vocoder.py
@@ -111,7 +111,7 @@ os.makedirs(checkpoint_dir, exist_ok=True)
if has_git:
working_dir = os.path.split(__file__)[0]
try:
- repo = git.Repo(working_dir)
+ repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
diff --git a/dnn/torch/osce/engine/engine.py b/dnn/torch/osce/engine/engine.py
index 7688e9b4..0762c898 100644
--- a/dnn/torch/osce/engine/engine.py
+++ b/dnn/torch/osce/engine/engine.py
@@ -46,6 +46,10 @@ def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler,
# update learning rate
scheduler.step()
+ # sparsification
+ if hasattr(model, 'sparsifier'):
+ model.sparsifier()
+
# update running loss
running_loss += float(loss.cpu())
@@ -73,8 +77,6 @@ def evaluate(model, criterion, dataloader, device, log_interval=10):
for i, batch in enumerate(tepoch):
-
-
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py
index f94431d3..0bec9604 100644
--- a/dnn/torch/osce/export_model_weights.py
+++ b/dnn/torch/osce/export_model_weights.py
@@ -43,6 +43,7 @@ from models import model_dict
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
+from utils.misc import remove_all_weight_norm
from wexchange.torch import dump_torch_weights
@@ -58,30 +59,30 @@ schedules = {
'nolace': [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
- ('feature_net.conv2', dict(quantize=True, scale=None)),
- ('feature_net.tconv', dict(quantize=True, scale=None)),
- ('feature_net.gru', dict()),
+ ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
+ ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
+ ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None)),
- ('tdshape1', dict()),
- ('tdshape2', dict()),
- ('tdshape3', dict()),
+ ('tdshape1', dict(quantize=True, scale=None)),
+ ('tdshape2', dict(quantize=True, scale=None)),
+ ('tdshape3', dict(quantize=True, scale=None)),
('af2', dict(quantize=True, scale=None)),
('af3', dict(quantize=True, scale=None)),
('af4', dict(quantize=True, scale=None)),
- ('post_cf1', dict(quantize=True, scale=None)),
- ('post_cf2', dict(quantize=True, scale=None)),
- ('post_af1', dict(quantize=True, scale=None)),
- ('post_af2', dict(quantize=True, scale=None)),
- ('post_af3', dict(quantize=True, scale=None))
+ ('post_cf1', dict(quantize=True, scale=None, sparse=True)),
+ ('post_cf2', dict(quantize=True, scale=None, sparse=True)),
+ ('post_af1', dict(quantize=True, scale=None, sparse=True)),
+ ('post_af2', dict(quantize=True, scale=None, sparse=True)),
+ ('post_af3', dict(quantize=True, scale=None, sparse=True))
],
'lace' : [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
- ('feature_net.conv2', dict(quantize=True, scale=None)),
- ('feature_net.tconv', dict(quantize=True, scale=None)),
- ('feature_net.gru', dict()),
+ ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
+ ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
+ ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None))
@@ -140,6 +141,7 @@ if __name__ == "__main__":
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
model.load_state_dict(checkpoint['state_dict'])
+ remove_all_weight_norm(model, verbose=True)
# CWriter
model_name = checkpoint['setup']['model']['name']
diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py
index 58293de4..51d65c3e 100644
--- a/dnn/torch/osce/models/lace.py
+++ b/dnn/torch/osce/models/lace.py
@@ -41,6 +41,12 @@ from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
+import sys
+sys.path.append('../dnntools')
+
+from dnntools.sparsification import create_sparsifier
+
+
class LACE(NNSBase):
""" Linear-Adaptive Coding Enhancer """
FRAME_SIZE=80
@@ -60,7 +66,12 @@ class LACE(NNSBase):
numbits_embedding_dim=8,
hidden_feature_dim=64,
partial_lookahead=True,
- norm_p=2):
+ norm_p=2,
+ softquant=False,
+ sparsify=False,
+ sparsification_schedule=[10000, 30000, 100],
+ sparsification_density=0.5,
+ apply_weight_norm=False):
super().__init__(skip=skip, preemph=preemph)
@@ -85,18 +96,21 @@ class LACE(NNSBase):
# feature net
if partial_lookahead:
- self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
+ self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
- self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
- self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
+ self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
+ self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# spectral shaping
- self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+ self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
+
+ if sparsify:
+ self.sparsifier = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py
index 0e0fb1b3..801857a4 100644
--- a/dnn/torch/osce/models/no_lace.py
+++ b/dnn/torch/osce/models/no_lace.py
@@ -27,9 +27,13 @@
*/
"""
+import numbers
+
import torch
from torch import nn
import torch.nn.functional as F
+from torch.nn.utils import weight_norm
+
import numpy as np
@@ -43,6 +47,11 @@ from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
+import sys
+sys.path.append('../dnntools')
+from dnntools.quantization import soft_quant
+from dnntools.sparsification import create_sparsifier, mark_for_sparsification
+
class NoLACE(NNSBase):
""" Non-Linear Adaptive Coding Enhancer """
FRAME_SIZE=80
@@ -64,11 +73,15 @@ class NoLACE(NNSBase):
partial_lookahead=True,
norm_p=2,
avg_pool_k=4,
- pool_after=False):
+ pool_after=False,
+ softquant=False,
+ sparsify=False,
+ sparsification_schedule=[100, 1000, 100],
+ sparsification_density=0.5,
+ apply_weight_norm=False):
super().__init__(skip=skip, preemph=preemph)
-
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
@@ -81,6 +94,11 @@ class NoLACE(NNSBase):
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
+ if isinstance(sparsification_density, numbers.Number):
+ sparsification_density = 10 * [sparsification_density]
+
+ norm = weight_norm if apply_weight_norm else lambda x, name=None: x
+
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
@@ -89,36 +107,52 @@ class NoLACE(NNSBase):
# feature net
if partial_lookahead:
- self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
+ self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
- self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
- self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
+ self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
+ self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# spectral shaping
- self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+ self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# non-linear transforms
- self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after)
- self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after)
- self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after)
+ self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
+ self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
+ self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
# combinators
- self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
- self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
- self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+ self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
+ self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
+ self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# feature transforms
- self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2)
- self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2)
- self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2)
- self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2)
- self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2)
+ self.post_cf1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
+ self.post_cf2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
+ self.post_af1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
+ self.post_af2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
+ self.post_af3 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
+ if softquant:
+ self.post_cf1 = soft_quant(self.post_cf1)
+ self.post_cf2 = soft_quant(self.post_cf2)
+ self.post_af1 = soft_quant(self.post_af1)
+ self.post_af2 = soft_quant(self.post_af2)
+ self.post_af3 = soft_quant(self.post_af3)
+
+
+ if sparsify:
+ mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4]))
+ mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4]))
+ mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4]))
+ mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4]))
+ mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4]))
+
+ self.sparsifier = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):
@@ -141,9 +175,12 @@ class NoLACE(NNSBase):
return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops
def feature_transform(self, f, layer):
- f = f.permute(0, 2, 1)
- f = F.pad(f, [1, 0])
- f = torch.tanh(layer(f))
+ f0 = f.permute(0, 2, 1)
+ f = F.pad(f0, [1, 0])
+ if self.residual_in_feature_transform:
+ f = torch.tanh(layer(f) + f0)
+ else:
+ f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, x, features, periods, numbits, debug=False):
diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py
index ae37951c..c766d0ab 100644
--- a/dnn/torch/osce/models/silk_feature_net_pl.py
+++ b/dnn/torch/osce/models/silk_feature_net_pl.py
@@ -26,36 +26,74 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
+import sys
+sys.path.append('../dnntools')
+import numbers
import torch
from torch import nn
import torch.nn.functional as F
+from torch.nn.utils import weight_norm
from utils.complexity import _conv1d_flop_count
+from dnntools.quantization.softquant import soft_quant
+from dnntools.sparsification import mark_for_sparsification
+
class SilkFeatureNetPL(nn.Module):
""" feature net with partial lookahead """
def __init__(self,
feature_dim=47,
num_channels=256,
- hidden_feature_dim=64):
+ hidden_feature_dim=64,
+ softquant=False,
+ sparsify=True,
+ sparsification_density=0.5,
+ apply_weight_norm=False):
super(SilkFeatureNetPL, self).__init__()
+ if isinstance(sparsification_density, numbers.Number):
+ sparsification_density = 4 * [sparsification_density]
+
self.feature_dim = feature_dim
self.num_channels = num_channels
self.hidden_feature_dim = hidden_feature_dim
- self.conv1 = nn.Conv1d(feature_dim, self.hidden_feature_dim, 1)
- self.conv2 = nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2)
- self.tconv = nn.ConvTranspose1d(num_channels, num_channels, 4, 4)
+ norm = weight_norm if apply_weight_norm else lambda x, name=None: x
+
+ self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1))
+ self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2))
+ self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
+ gru_input_dim = num_channels + self.repeat_upsamp_dim if self.repeat_upsamp else num_channels
+ self.gru = norm(norm(nn.GRU(gru_input_dim, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
+
+ if softquant:
+ self.conv2 = soft_quant(self.conv2)
+ if not self.repeat_upsamp: self.tconv = soft_quant(self.tconv)
+ self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
+
+
+ if sparsify:
+ mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4]))
+ if not self.repeat_upsamp: mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
+ mark_for_sparsification(
+ self.gru,
+ {
+ 'W_ir' : (sparsification_density[2], [8, 4], False),
+ 'W_iz' : (sparsification_density[2], [8, 4], False),
+ 'W_in' : (sparsification_density[2], [8, 4], False),
+ 'W_hr' : (sparsification_density[3], [8, 4], True),
+ 'W_hz' : (sparsification_density[3], [8, 4], True),
+ 'W_hn' : (sparsification_density[3], [8, 4], True),
+ }
+ )
- self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
def flop_count(self, rate=200):
count = 0
- for conv in self.conv1, self.conv2, self.tconv:
+ for conv in [self.conv1, self.conv2] if self.repeat_upsamp else [self.conv1, self.conv2, self.tconv]:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
@@ -82,7 +120,7 @@ class SilkFeatureNetPL(nn.Module):
c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
# upsampling
- c = self.tconv(c)
+ c = torch.tanh(self.tconv(c))
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)
diff --git a/dnn/torch/osce/stndrd/evaluation/create_input_data.sh b/dnn/torch/osce/stndrd/evaluation/create_input_data.sh
new file mode 100644
index 00000000..54bacb88
--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/create_input_data.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+
+INPUT="dataset/LibriSpeech"
+OUTPUT="testdata"
+OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
+BITRATES=( 6000 7500 ) # 9000 12000 15000 18000 24000 32000 )
+
+
+mkdir -p $OUTPUT
+
+for fn in $(find $INPUT -name "*.wav")
+do
+ name=$(basename ${fn%*.wav})
+ sox $fn -r 16000 -b 16 -e signed-integer ${OUTPUT}/tmp.raw
+ for br in ${BITRATES[@]}
+ do
+ folder=${OUTPUT}/"${name}_${br}.se"
+ echo "creating ${folder}..."
+ mkdir -p $folder
+ cp ${OUTPUT}/tmp.raw ${folder}/clean.s16
+ (cd ${folder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
+ done
+ rm -f ${OUTPUT}/tmp.raw
+done
diff --git a/dnn/torch/osce/stndrd/evaluation/env.rc b/dnn/torch/osce/stndrd/evaluation/env.rc
new file mode 100644
index 00000000..f1266b6f
--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/env.rc
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+export PYTHON=/home/ubuntu/opt/miniconda3/envs/torch/bin/python
+export LACE="/local/experiments/ietf_enhancement_studies/checkpoints/lace_checkpoint.pth"
+export NOLACE="/local/experiments/ietf_enhancement_studies/checkpoints/nolace_checkpoint.pth"
+export TESTMODEL="/local/experiments/ietf_enhancement_studies/opus/dnn/torch/osce/test_model.py"
+export OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
\ No newline at end of file
diff --git a/dnn/torch/osce/stndrd/evaluation/evaluate.py b/dnn/torch/osce/stndrd/evaluation/evaluate.py
new file mode 100644
index 00000000..54700dbe
--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/evaluate.py
@@ -0,0 +1,113 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+
+
+from scipy.io import wavfile
+from pesq import pesq
+import numpy as np
+from moc import compare
+from moc2 import compare as compare2
+#from warpq import compute_WAPRQ as warpq
+from lace_loss_metric import compare as laceloss_compare
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('folder', type=str, help='folder with processed items')
+parser.add_argument('metric', type=str, choices=['pesq', 'moc', 'moc2', 'laceloss'], help='metric to be used for evaluation')
+
+
+def get_bitrates(folder):
+ with open(os.path.join(folder, 'bitrates.txt')) as f:
+ x = f.read()
+
+ bitrates = [int(y) for y in x.rstrip('\n').split()]
+
+ return bitrates
+
+def get_itemlist(folder):
+ with open(os.path.join(folder, 'items.txt')) as f:
+ lines = f.readlines()
+
+ items = [x.split()[0] for x in lines]
+
+ return items
+
+
+def process_item(folder, item, bitrate, metric):
+ fs, x_clean = wavfile.read(os.path.join(folder, 'clean', f"{item}_{bitrate}_clean.wav"))
+ fs, x_opus = wavfile.read(os.path.join(folder, 'opus', f"{item}_{bitrate}_opus.wav"))
+ fs, x_lace = wavfile.read(os.path.join(folder, 'lace', f"{item}_{bitrate}_lace.wav"))
+ fs, x_nolace = wavfile.read(os.path.join(folder, 'nolace', f"{item}_{bitrate}_nolace.wav"))
+
+ x_clean = x_clean.astype(np.float32) / 2**15
+ x_opus = x_opus.astype(np.float32) / 2**15
+ x_lace = x_lace.astype(np.float32) / 2**15
+ x_nolace = x_nolace.astype(np.float32) / 2**15
+
+ if metric == 'pesq':
+ result = [pesq(fs, x_clean, x_opus), pesq(fs, x_clean, x_lace), pesq(fs, x_clean, x_nolace)]
+ elif metric =='moc':
+ result = [compare(x_clean, x_opus), compare(x_clean, x_lace), compare(x_clean, x_nolace)]
+ elif metric =='moc2':
+ result = [compare2(x_clean, x_opus), compare2(x_clean, x_lace), compare2(x_clean, x_nolace)]
+ # elif metric == 'warpq':
+ # result = [warpq(x_clean, x_opus), warpq(x_clean, x_lace), warpq(x_clean, x_nolace)]
+ elif metric == 'laceloss':
+ result = [laceloss_compare(x_clean, x_opus), laceloss_compare(x_clean, x_lace), laceloss_compare(x_clean, x_nolace)]
+ else:
+ raise ValueError(f'unknown metric {metric}')
+
+ return result
+
+def process_bitrate(folder, items, bitrate, metric):
+ results = np.zeros((len(items), 3))
+
+ for i, item in enumerate(items):
+ results[i, :] = np.array(process_item(folder, item, bitrate, metric))
+
+ return results
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+
+ items = get_itemlist(args.folder)
+ bitrates = get_bitrates(args.folder)
+
+ results = dict()
+ for br in bitrates:
+ print(f"processing bitrate {br}...")
+ results[br] = process_bitrate(args.folder, items, br, args.metric)
+
+ np.save(os.path.join(args.folder, f'results_{args.metric}.npy'), results)
+
+ print("Done.")
diff --git a/dnn/torch/osce/stndrd/evaluation/lace_loss_metric.py b/dnn/torch/osce/stndrd/evaluation/lace_loss_metric.py
new file mode 100644
index 00000000..b0790585
--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/lace_loss_metric.py
@@ -0,0 +1,330 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+"""STFT-based Loss modules."""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+import numpy as np
+import torchaudio
+
+
+def get_window(win_name, win_length, *args, **kwargs):
+ window_dict = {
+ 'bartlett_window' : torch.bartlett_window,
+ 'blackman_window' : torch.blackman_window,
+ 'hamming_window' : torch.hamming_window,
+ 'hann_window' : torch.hann_window,
+ 'kaiser_window' : torch.kaiser_window
+ }
+
+ if not win_name in window_dict:
+ raise ValueError()
+
+ return window_dict[win_name](win_length, *args, **kwargs)
+
+
+def stft(x, fft_size, hop_size, win_length, window):
+ """Perform STFT and convert to magnitude spectrogram.
+ Args:
+ x (Tensor): Input signal tensor (B, T).
+ fft_size (int): FFT size.
+ hop_size (int): Hop size.
+ win_length (int): Window length.
+ window (str): Window function type.
+ Returns:
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+ """
+
+ win = get_window(window, win_length).to(x.device)
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
+
+
+ return torch.clamp(torch.abs(x_stft), min=1e-7)
+
+def spectral_convergence_loss(Y_true, Y_pred):
+ dims=list(range(1, len(Y_pred.shape)))
+ return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
+
+
+def log_magnitude_loss(Y_true, Y_pred):
+ Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
+ Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
+
+ return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
+
+def spectral_xcorr_loss(Y_true, Y_pred):
+ Y_true = Y_true.abs()
+ Y_pred = Y_pred.abs()
+ dims=list(range(1, len(Y_pred.shape)))
+ xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
+
+ return 1 - xcorr.mean()
+
+
+
+class MRLogMelLoss(nn.Module):
+ def __init__(self,
+ fft_sizes=[512, 256, 128, 64],
+ overlap=0.5,
+ fs=16000,
+ n_mels=18
+ ):
+
+ self.fft_sizes = fft_sizes
+ self.overlap = overlap
+ self.fs = fs
+ self.n_mels = n_mels
+
+ super().__init__()
+
+ self.mel_specs = []
+ for fft_size in fft_sizes:
+ hop_size = int(round(fft_size * (1 - self.overlap)))
+
+ n_mels = self.n_mels
+ if fft_size < 128:
+ n_mels //= 2
+
+ self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
+
+ for i, mel_spec in enumerate(self.mel_specs):
+ self.add_module(f'mel_spec_{i+1}', mel_spec)
+
+ def forward(self, y_true, y_pred):
+
+ loss = torch.zeros(1, device=y_true.device)
+
+ for mel_spec in self.mel_specs:
+ Y_true = mel_spec(y_true)
+ Y_pred = mel_spec(y_pred)
+ loss = loss + log_magnitude_loss(Y_true, Y_pred)
+
+ loss = loss / len(self.mel_specs)
+
+ return loss
+
+def create_weight_matrix(num_bins, bins_per_band=10):
+ m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
+
+ r0 = bins_per_band // 2
+ r1 = bins_per_band - r0
+
+ for i in range(num_bins):
+ i0 = max(i - r0, 0)
+ j0 = min(i + r1, num_bins)
+
+ m[i, i0: j0] += 1
+
+ if i < r0:
+ m[i, :r0 - i] += 1
+
+ if i > num_bins - r1:
+ m[i, num_bins - r1 - i:] += 1
+
+ return m / bins_per_band
+
+def weighted_spectral_convergence(Y_true, Y_pred, w):
+
+ # calculate sfm based weights
+ logY = torch.log(torch.abs(Y_true) + 1e-9)
+ Y = torch.abs(Y_true)
+
+ avg_logY = torch.matmul(logY.transpose(1, 2), w)
+ avg_Y = torch.matmul(Y.transpose(1, 2), w)
+
+ sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
+
+ weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
+
+ loss = torch.mean(
+ torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
+ / (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
+ )
+
+ return loss
+
+def gen_filterbank(N, Fs=16000):
+ in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
+ out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
+ #ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
+ ERB_N = 24.7 + .108*in_freq
+ delta = np.abs(in_freq-out_freq)/ERB_N
+ center = (delta<.5).astype('float32')
+ R = -12*center*delta**2 + (1-center)*(3-12*delta)
+ RE = 10.**(R/10.)
+ norm = np.sum(RE, axis=1)
+ RE = RE/norm[:, np.newaxis]
+ return torch.from_numpy(RE)
+
+def smooth_log_mag(Y_true, Y_pred, filterbank):
+ Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
+ Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
+
+ loss = torch.abs(
+ torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
+ )
+
+ loss = loss.mean()
+
+ return loss
+
+class MRSTFTLoss(nn.Module):
+ def __init__(self,
+ fft_sizes=[2048, 1024, 512, 256, 128, 64],
+ overlap=0.5,
+ window='hann_window',
+ fs=16000,
+ log_mag_weight=0,
+ sc_weight=0,
+ wsc_weight=0,
+ smooth_log_mag_weight=2,
+ sxcorr_weight=1):
+ super().__init__()
+
+ self.fft_sizes = fft_sizes
+ self.overlap = overlap
+ self.window = window
+ self.log_mag_weight = log_mag_weight
+ self.sc_weight = sc_weight
+ self.wsc_weight = wsc_weight
+ self.smooth_log_mag_weight = smooth_log_mag_weight
+ self.sxcorr_weight = sxcorr_weight
+ self.fs = fs
+
+ # weights for SFM weighted spectral convergence loss
+ self.wsc_weights = torch.nn.ParameterDict()
+ for fft_size in fft_sizes:
+ width = min(11, int(1000 * fft_size / self.fs + .5))
+ width += width % 2
+ self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
+ create_weight_matrix(fft_size // 2 + 1, width),
+ requires_grad=False
+ )
+
+ # filterbanks for smooth log magnitude loss
+ self.filterbanks = torch.nn.ParameterDict()
+ for fft_size in fft_sizes:
+ self.filterbanks[str(fft_size)] = torch.nn.Parameter(
+ gen_filterbank(fft_size//2),
+ requires_grad=False
+ )
+
+
+ def __call__(self, y_true, y_pred):
+
+
+ lm_loss = torch.zeros(1, device=y_true.device)
+ sc_loss = torch.zeros(1, device=y_true.device)
+ wsc_loss = torch.zeros(1, device=y_true.device)
+ slm_loss = torch.zeros(1, device=y_true.device)
+ sxcorr_loss = torch.zeros(1, device=y_true.device)
+
+ for fft_size in self.fft_sizes:
+ hop_size = int(round(fft_size * (1 - self.overlap)))
+ win_size = fft_size
+
+ Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
+ Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
+
+ if self.log_mag_weight > 0:
+ lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
+
+ if self.sc_weight > 0:
+ sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
+
+ if self.wsc_weight > 0:
+ wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
+
+ if self.smooth_log_mag_weight > 0:
+ slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
+
+ if self.sxcorr_weight > 0:
+ sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
+
+
+ total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
+ + self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
+ + self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
+
+ return total_loss
+
+
+def td_l2_norm(y_true, y_pred):
+ dims = list(range(1, len(y_true.shape)))
+
+ loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
+
+ return loss.mean()
+
+
+class LaceLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+
+ self.stftloss = MRSTFTLoss(log_mag_weight=0, sc_weight=0, wsc_weight=0, smooth_log_mag_weight=2, sxcorr_weight=1)
+
+
+ def forward(self, x, y):
+ specloss = self.stftloss(x, y)
+ phaseloss = td_l2_norm(x, y)
+ total_loss = (specloss + 10 * phaseloss) / 13
+
+ return total_loss
+
+ def compare(self, x_ref, x_deg):
+ # trim items to same size
+ n = min(len(x_ref), len(x_deg))
+ x_ref = x_ref[:n].copy()
+ x_deg = x_deg[:n].copy()
+
+ # pre-emphasis
+ x_ref[1:] -= 0.85 * x_ref[:-1]
+ x_deg[1:] -= 0.85 * x_deg[:-1]
+
+ device = next(iter(self.parameters())).device
+
+ x = torch.from_numpy(x_ref).to(device)
+ y = torch.from_numpy(x_deg).to(device)
+
+ with torch.no_grad():
+ dist = 10 * self.forward(x, y)
+
+ return dist.cpu().numpy().item()
+
+
+lace_loss = LaceLoss()
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+lace_loss.to(device)
+
+def compare(x, y):
+
+ return lace_loss.compare(x, y)
diff --git a/dnn/torch/osce/stndrd/evaluation/make_boxplots.py b/dnn/torch/osce/stndrd/evaluation/make_boxplots.py
new file mode 100644
index 00000000..f7ea778a
--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/make_boxplots.py
@@ -0,0 +1,116 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+
+import numpy as np
+import matplotlib.pyplot as plt
+from prettytable import PrettyTable
+from matplotlib.patches import Patch
+
+parser = argparse.ArgumentParser()
+parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
+parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
+parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
+
+def load_data(folder):
+ data = dict()
+
+ if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
+ data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
+ data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
+ data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
+ data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
+ data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
+ data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
+
+ return data
+
+def plot_data(filename, data, title=None):
+ compare_dict = dict()
+ for br in data.keys():
+ compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
+ compare_dict[f'LACE {br/1000:.1f} kb/s'] = data[br][:, 1]
+ compare_dict[f'NoLACE {br/1000:.1f} kb/s'] = data[br][:, 2]
+
+ plt.rcParams.update({
+ "text.usetex": True,
+ "font.family": "Helvetica",
+ "font.size": 32
+ })
+
+ black = '#000000'
+ red = '#ff5745'
+ blue = '#007dbc'
+ colors = [black, red, blue]
+ legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
+ Patch(facecolor=colors[1], label='LACE'),
+ Patch(facecolor=colors[2], label='NoLACE')]
+
+ fig, ax = plt.subplots()
+ fig.set_size_inches(40, 20)
+ bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
+
+ for i, patch in enumerate(bplot['boxes']):
+ patch.set_facecolor(colors[i%3])
+
+ ax.set_xticklabels(compare_dict.keys(), rotation=290)
+
+ if title is not None:
+ ax.set_title(title)
+
+ ax.legend(handles=legend_elements)
+
+ fig.savefig(filename, bbox_inches='tight')
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ data = load_data(args.folder)
+
+
+ metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
+ folder = args.folder if args.output is None else args.output
+ os.makedirs(folder, exist_ok=True)
+
+ for metric in metrics:
+ print(f"Plotting data for {metric} metric...")
+ plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
+
+ print("Done.")
\ No newline at end of file
diff --git a/dnn/torch/osce/stndrd/evaluation/make_boxplots_moctest.py b/dnn/torch/osce/stndrd/evaluation/make_boxplots_moctest.py
new file mode 100644
index 00000000..ca65aba9
--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/make_boxplots_moctest.py
@@ -0,0 +1,109 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+
+import numpy as np
+import matplotlib.pyplot as plt
+from prettytable import PrettyTable
+from matplotlib.patches import Patch
+
+parser = argparse.ArgumentParser()
+parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
+parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
+parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
+
+def load_data(folder):
+ data = dict()
+
+ if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
+ data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
+ data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
+ data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
+ data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
+ data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
+
+ return data
+
+def plot_data(filename, data, title=None):
+ compare_dict = dict()
+ for br in data.keys():
+ compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
+ compare_dict[f'LACE (MOC only) {br/1000:.1f} kb/s'] = data[br][:, 1]
+ compare_dict[f'LACE (MOC + TD) {br/1000:.1f} kb/s'] = data[br][:, 2]
+
+ plt.rcParams.update({
+ "text.usetex": True,
+ "font.family": "Helvetica",
+ "font.size": 32
+ })
+ colors = ['pink', 'lightblue', 'lightgreen']
+ legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
+ Patch(facecolor=colors[1], label='MOC loss only'),
+ Patch(facecolor=colors[2], label='MOC + TD loss')]
+
+ fig, ax = plt.subplots()
+ fig.set_size_inches(40, 20)
+ bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
+
+ for i, patch in enumerate(bplot['boxes']):
+ patch.set_facecolor(colors[i%3])
+
+ ax.set_xticklabels(compare_dict.keys(), rotation=290)
+
+ if title is not None:
+ ax.set_title(title)
+
+ ax.legend(handles=legend_elements)
+
+ fig.savefig(filename, bbox_inches='tight')
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ data = load_data(args.folder)
+
+
+ metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
+ folder = args.folder if args.output is None else args.output
+ os.makedirs(folder, exist_ok=True)
+
+ for metric in metrics:
+ print(f"Plotting data for {metric} metric...")
+ plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
+
+ print("Done.")
\ No newline at end of file
diff --git a/dnn/torch/osce/stndrd/evaluation/make_tables.py b/dnn/torch/osce/stndrd/evaluation/make_tables.py
new file mode 100644
index 00000000..56080127
--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/make_tables.py
@@ -0,0 +1,124 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+
+import numpy as np
+import matplotlib.pyplot as plt
+from prettytable import PrettyTable
+from matplotlib.patches import Patch
+
+parser = argparse.ArgumentParser()
+parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
+parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
+parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
+
+def load_data(folder):
+ data = dict()
+
+ if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
+ data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
+ data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
+ data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
+ data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
+ data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
+ data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
+
+ return data
+
+def make_table(filename, data, title=None):
+
+ # mean values
+ tbl = PrettyTable()
+ tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
+ for br in data.keys():
+ opus = data[br][:, 0]
+ lace = data[br][:, 1]
+ nolace = data[br][:, 2]
+ tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
+
+ with open(filename + ".txt", "w") as f:
+ f.write(str(tbl))
+
+ with open(filename + ".html", "w") as f:
+ f.write(tbl.get_html_string())
+
+ with open(filename + ".csv", "w") as f:
+ f.write(tbl.get_csv_string())
+
+ print(tbl)
+
+
+def make_diff_table(filename, data, title=None):
+
+ # mean values
+ tbl = PrettyTable()
+ tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
+ for br in data.keys():
+ opus = data[br][:, 0]
+ lace = data[br][:, 1] - opus
+ nolace = data[br][:, 2] - opus
+ tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
+
+ with open(filename + ".txt", "w") as f:
+ f.write(str(tbl))
+
+ with open(filename + ".html", "w") as f:
+ f.write(tbl.get_html_string())
+
+ with open(filename + ".csv", "w") as f:
+ f.write(tbl.get_csv_string())
+
+ print(tbl)
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ data = load_data(args.folder)
+
+ metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
+ folder = args.folder if args.output is None else args.output
+ os.makedirs(folder, exist_ok=True)
+
+ for metric in metrics:
+ print(f"Plotting data for {metric} metric...")
+ make_table(os.path.join(folder, f"table_{metric}"), data[metric])
+ make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
+
+ print("Done.")
\ No newline at end of file
diff --git a/dnn/torch/osce/stndrd/evaluation/make_tables_moctest.py b/dnn/torch/osce/stndrd/evaluation/make_tables_moctest.py
new file mode 100644
index 00000000..37718068
--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/make_tables_moctest.py
@@ -0,0 +1,121 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+
+import numpy as np
+import matplotlib.pyplot as plt
+from prettytable import PrettyTable
+from matplotlib.patches import Patch
+
+parser = argparse.ArgumentParser()
+parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
+parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
+parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
+
+def load_data(folder):
+ data = dict()
+
+ if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
+ data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
+ data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
+ data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
+ data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
+
+ if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
+ data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
+
+ return data
+
+def make_table(filename, data, title=None):
+
+ # mean values
+ tbl = PrettyTable()
+ tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
+ for br in data.keys():
+ opus = data[br][:, 0]
+ lace = data[br][:, 1]
+ nolace = data[br][:, 2]
+ tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
+
+ with open(filename + ".txt", "w") as f:
+ f.write(str(tbl))
+
+ with open(filename + ".html", "w") as f:
+ f.write(tbl.get_html_string())
+
+ with open(filename + ".csv", "w") as f:
+ f.write(tbl.get_csv_string())
+
+ print(tbl)
+
+
+def make_diff_table(filename, data, title=None):
+
+ # mean values
+ tbl = PrettyTable()
+ tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
+ for br in data.keys():
+ opus = data[br][:, 0]
+ lace = data[br][:, 1] - opus
+ nolace = data[br][:, 2] - opus
+ tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
+
+ with open(filename + ".txt", "w") as f:
+ f.write(str(tbl))
+
+ with open(filename + ".html", "w") as f:
+ f.write(tbl.get_html_string())
+
+ with open(filename + ".csv", "w") as f:
+ f.write(tbl.get_csv_string())
+
+ print(tbl)
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ data = load_data(args.folder)
+
+ metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
+ folder = args.folder if args.output is None else args.output
+ os.makedirs(folder, exist_ok=True)
+
+ for metric in metrics:
+ print(f"Plotting data for {metric} metric...")
+ make_table(os.path.join(folder, f"table_{metric}"), data[metric])
+ make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
+
+ print("Done.")
\ No newline at end of file
diff --git a/dnn/torch/osce/stndrd/evaluation/moc.py b/dnn/torch/osce/stndrd/evaluation/moc.py
new file mode 100644
index 00000000..bf004de9
--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/moc.py
@@ -0,0 +1,182 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import numpy as np
+import scipy.signal
+
+def compute_vad_mask(x, fs, stop_db=-70):
+
+ frame_length = (fs + 49) // 50
+ x = x[: frame_length * (len(x) // frame_length)]
+
+ frames = x.reshape(-1, frame_length)
+ frame_energy = np.sum(frames ** 2, axis=1)
+ frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
+
+ max_threshold = frame_energy.max() * 10 ** (stop_db/20)
+ vactive = np.ones_like(frames)
+ vactive[frame_energy_smooth < max_threshold, :] = 0
+ vactive = vactive.reshape(-1)
+
+ filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
+ filter = filter / filter.sum()
+
+ mask = np.convolve(vactive, filter, mode='same')
+
+ return x, mask
+
+def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
+ num_samples = frame_size + (num_frames - 1) * hop_size
+ if len(mask) < num_samples:
+ mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
+ else:
+ mask = mask[:num_samples]
+
+ new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
+
+ return new_mask
+
+def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
+ num_spectra = (len(x) - window_size - hop_size) // hop_size
+ window = scipy.signal.get_window(window, window_size)
+ N = window_size // 2
+
+ frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
+ psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
+
+ return psd
+
+
+def frequency_mask(num_bands, up_factor, down_factor):
+
+ up_mask = np.zeros((num_bands, num_bands))
+ down_mask = np.zeros((num_bands, num_bands))
+
+ for i in range(num_bands):
+ up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
+ down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
+
+ return down_mask @ up_mask
+
+
+def rect_fb(band_limits, num_bins=None):
+ num_bands = len(band_limits) - 1
+ if num_bins is None:
+ num_bins = band_limits[-1]
+
+ fb = np.zeros((num_bands, num_bins))
+ for i in range(num_bands):
+ fb[i, band_limits[i]:band_limits[i+1]] = 1
+
+ return fb
+
+
+def compare(x, y, apply_vad=False):
+ """ Modified version of opus_compare for 16 kHz mono signals
+
+ Args:
+ x (np.ndarray): reference input signal scaled to [-1, 1]
+ y (np.ndarray): test signal scaled to [-1, 1]
+
+ Returns:
+ float: perceptually weighted error
+ """
+ # filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
+ band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
+ num_bands = len(band_limits) - 1
+ fb = rect_fb(band_limits, num_bins=81)
+
+ # trim samples to same size
+ num_samples = min(len(x), len(y))
+ x = x[:num_samples] * 2**15
+ y = y[:num_samples] * 2**15
+
+ psd_x = power_spectrum(x) + 100000
+ psd_y = power_spectrum(y) + 100000
+
+ num_frames = psd_x.shape[0]
+
+ # average band energies
+ be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
+
+ # frequecy masking
+ f_mask = frequency_mask(num_bands, 0.1, 0.03)
+ mask_x = be_x @ f_mask.T
+
+ # temporal masking
+ for i in range(1, num_frames):
+ mask_x[i, :] += 0.5 * mask_x[i-1, :]
+
+ # apply mask
+ masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
+ masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
+
+ # 2-frame average
+ masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
+ masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
+
+ # distortion metric
+ re = masked_psd_y / masked_psd_x
+ im = np.log(re) ** 2
+ Eb = ((im @ fb.T) / np.sum(fb, axis=1))
+ Ef = np.mean(Eb , axis=1)
+
+ if apply_vad:
+ _, mask = compute_vad_mask(x, 16000)
+ mask = convert_mask(mask, Ef.shape[0])
+ else:
+ mask = np.ones_like(Ef)
+
+ err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
+
+ return float(err)
+
+if __name__ == "__main__":
+ import argparse
+ from scipy.io import wavfile
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('ref', type=str, help='reference wav file')
+ parser.add_argument('deg', type=str, help='degraded wav file')
+ parser.add_argument('--apply-vad', action='store_true')
+ args = parser.parse_args()
+
+
+ fs1, x = wavfile.read(args.ref)
+ fs2, y = wavfile.read(args.deg)
+
+ if max(fs1, fs2) != 16000:
+ raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
+
+ x = x.astype(np.float32) / 2**15
+ y = y.astype(np.float32) / 2**15
+
+ err = compare(x, y, apply_vad=args.apply_vad)
+
+ print(f"MOC: {err}")
diff --git a/dnn/torch/osce/stndrd/evaluation/moc2.py b/dnn/torch/osce/stndrd/evaluation/moc2.py
new file mode 100644
index 00000000..7e155f01
--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/moc2.py
@@ -0,0 +1,190 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import numpy as np
+import scipy.signal
+
+def compute_vad_mask(x, fs, stop_db=-70):
+
+ frame_length = (fs + 49) // 50
+ x = x[: frame_length * (len(x) // frame_length)]
+
+ frames = x.reshape(-1, frame_length)
+ frame_energy = np.sum(frames ** 2, axis=1)
+ frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
+
+ max_threshold = frame_energy.max() * 10 ** (stop_db/20)
+ vactive = np.ones_like(frames)
+ vactive[frame_energy_smooth < max_threshold, :] = 0
+ vactive = vactive.reshape(-1)
+
+ filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
+ filter = filter / filter.sum()
+
+ mask = np.convolve(vactive, filter, mode='same')
+
+ return x, mask
+
+def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
+ num_samples = frame_size + (num_frames - 1) * hop_size
+ if len(mask) < num_samples:
+ mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
+ else:
+ mask = mask[:num_samples]
+
+ new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
+
+ return new_mask
+
+def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
+ num_spectra = (len(x) - window_size - hop_size) // hop_size
+ window = scipy.signal.get_window(window, window_size)
+ N = window_size // 2
+
+ frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
+ psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
+
+ return psd
+
+
+def frequency_mask(num_bands, up_factor, down_factor):
+
+ up_mask = np.zeros((num_bands, num_bands))
+ down_mask = np.zeros((num_bands, num_bands))
+
+ for i in range(num_bands):
+ up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
+ down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
+
+ return down_mask @ up_mask
+
+
+def rect_fb(band_limits, num_bins=None):
+ num_bands = len(band_limits) - 1
+ if num_bins is None:
+ num_bins = band_limits[-1]
+
+ fb = np.zeros((num_bands, num_bins))
+ for i in range(num_bands):
+ fb[i, band_limits[i]:band_limits[i+1]] = 1
+
+ return fb
+
+
+def _compare(x, y, apply_vad=False, factor=1):
+ """ Modified version of opus_compare for 16 kHz mono signals
+
+ Args:
+ x (np.ndarray): reference input signal scaled to [-1, 1]
+ y (np.ndarray): test signal scaled to [-1, 1]
+
+ Returns:
+ float: perceptually weighted error
+ """
+ # filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
+ band_limits = [factor * b for b in [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]]
+ window_size = factor * 160
+ hop_size = factor * 40
+ num_bins = window_size // 2 + 1
+ num_bands = len(band_limits) - 1
+ fb = rect_fb(band_limits, num_bins=num_bins)
+
+ # trim samples to same size
+ num_samples = min(len(x), len(y))
+ x = x[:num_samples].copy() * 2**15
+ y = y[:num_samples].copy() * 2**15
+
+ psd_x = power_spectrum(x, window_size=window_size, hop_size=hop_size) + 100000
+ psd_y = power_spectrum(y, window_size=window_size, hop_size=hop_size) + 100000
+
+ num_frames = psd_x.shape[0]
+
+ # average band energies
+ be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
+
+ # frequecy masking
+ f_mask = frequency_mask(num_bands, 0.1, 0.03)
+ mask_x = be_x @ f_mask.T
+
+ # temporal masking
+ for i in range(1, num_frames):
+ mask_x[i, :] += (0.5 ** factor) * mask_x[i-1, :]
+
+ # apply mask
+ masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
+ masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
+
+ # 2-frame average
+ masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
+ masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
+
+ # distortion metric
+ re = masked_psd_y / masked_psd_x
+ #im = re - np.log(re) - 1
+ im = np.log(re) ** 2
+ Eb = ((im @ fb.T) / np.sum(fb, axis=1))
+ Ef = np.mean(Eb ** 1, axis=1)
+
+ if apply_vad:
+ _, mask = compute_vad_mask(x, 16000)
+ mask = convert_mask(mask, Ef.shape[0])
+ else:
+ mask = np.ones_like(Ef)
+
+ err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
+
+ return float(err)
+
+def compare(x, y, apply_vad=False):
+ err = np.linalg.norm([_compare(x, y, apply_vad=apply_vad, factor=1)], ord=2)
+ return err
+
+if __name__ == "__main__":
+ import argparse
+ from scipy.io import wavfile
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('ref', type=str, help='reference wav file')
+ parser.add_argument('deg', type=str, help='degraded wav file')
+ parser.add_argument('--apply-vad', action='store_true')
+ args = parser.parse_args()
+
+
+ fs1, x = wavfile.read(args.ref)
+ fs2, y = wavfile.read(args.deg)
+
+ if max(fs1, fs2) != 16000:
+ raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
+
+ x = x.astype(np.float32) / 2**15
+ y = y.astype(np.float32) / 2**15
+
+ err = compare(x, y, apply_vad=args.apply_vad)
+
+ print(f"MOC: {err}")
diff --git a/dnn/torch/osce/stndrd/evaluation/process_dataset.sh b/dnn/torch/osce/stndrd/evaluation/process_dataset.sh
new file mode 100755
index 00000000..a490da93
--- /dev/null
+++ b/dnn/torch/osce/stndrd/evaluation/process_dataset.sh
@@ -0,0 +1,98 @@
+#!/bin/bash
+
+if [ ! -f "$PYTHON" ]
+then
+ echo "PYTHON variable does not link to a file. Please point it to your python executable."
+ exit 1
+fi
+
+if [ ! -f "$TESTMODEL" ]
+then
+ echo "TESTMODEL variable does not link to a file. Please point it to your copy of test_model.py"
+ exit 1
+fi
+
+if [ ! -f "$OPUSDEMO" ]
+then
+ echo "OPUSDEMO variable does not link to a file. Please point it to your patched version of opus_demo."
+ exit 1
+fi
+
+if [ ! -f "$LACE" ]
+then
+ echo "LACE variable does not link to a file. Please point it to your copy of the LACE checkpoint."
+ exit 1
+fi
+
+if [ ! -f "$NOLACE" ]
+then
+ echo "LACE variable does not link to a file. Please point it to your copy of the NOLACE checkpoint."
+ exit 1
+fi
+
+case $# in
+ 2) INPUT=$1; OUTPUT=$2;;
+ *) echo "process_dataset.sh