mirror of
https://github.com/xiph/opus.git
synced 2025-06-02 08:37:43 +00:00
497 lines
17 KiB
Python
497 lines
17 KiB
Python
"""
|
|
/* Copyright (c) 2023 Amazon
|
|
Written by Jan Buethe */
|
|
/*
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions
|
|
are met:
|
|
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*/
|
|
"""
|
|
|
|
from re import sub
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
|
def get_subconditioner( method,
|
|
number_of_subsamples,
|
|
pcm_embedding_size,
|
|
state_size,
|
|
pcm_levels,
|
|
number_of_signals,
|
|
**kwargs):
|
|
|
|
subconditioner_dict = {
|
|
'additive' : AdditiveSubconditioner,
|
|
'concatenative' : ConcatenativeSubconditioner,
|
|
'modulative' : ModulativeSubconditioner
|
|
}
|
|
|
|
return subconditioner_dict[method](number_of_subsamples,
|
|
pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
|
|
|
|
|
|
class Subconditioner(nn.Module):
|
|
def __init__(self):
|
|
""" upsampling by subconditioning
|
|
|
|
Upsamples a sequence of states conditioning on pcm signals and
|
|
optionally a feature vector.
|
|
"""
|
|
super(Subconditioner, self).__init__()
|
|
|
|
def forward(self, states, signals, features=None):
|
|
raise Exception("Base class should not be called")
|
|
|
|
def single_step(self, index, state, signals, features):
|
|
raise Exception("Base class should not be called")
|
|
|
|
def get_output_dim(self, index):
|
|
raise Exception("Base class should not be called")
|
|
|
|
|
|
class AdditiveSubconditioner(Subconditioner):
|
|
def __init__(self,
|
|
number_of_subsamples,
|
|
pcm_embedding_size,
|
|
state_size,
|
|
pcm_levels,
|
|
number_of_signals,
|
|
**kwargs):
|
|
""" subconditioning by addition """
|
|
|
|
super(AdditiveSubconditioner, self).__init__()
|
|
|
|
self.number_of_subsamples = number_of_subsamples
|
|
self.pcm_embedding_size = pcm_embedding_size
|
|
self.state_size = state_size
|
|
self.pcm_levels = pcm_levels
|
|
self.number_of_signals = number_of_signals
|
|
|
|
if self.pcm_embedding_size != self.state_size:
|
|
raise ValueError('For additive subconditioning state and embedding '
|
|
+ f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
|
|
|
|
self.embeddings = [None]
|
|
for i in range(1, self.number_of_subsamples):
|
|
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
|
self.add_module('pcm_embedding_' + str(i), embedding)
|
|
self.embeddings.append(embedding)
|
|
|
|
def forward(self, states, signals):
|
|
""" creates list of subconditioned states
|
|
|
|
Parameters:
|
|
-----------
|
|
states : torch.tensor
|
|
states of shape (batch, seq_length // s, state_size)
|
|
signals : torch.tensor
|
|
signals of shape (batch, seq_length, number_of_signals)
|
|
|
|
Returns:
|
|
--------
|
|
c_states : list of torch.tensor
|
|
list of s subconditioned states
|
|
"""
|
|
|
|
s = self.number_of_subsamples
|
|
|
|
c_states = [states]
|
|
new_states = states
|
|
for i in range(1, self.number_of_subsamples):
|
|
embed = self.embeddings[i](signals[:, i::s])
|
|
# reduce signal dimension
|
|
embed = torch.sum(embed, dim=2)
|
|
|
|
new_states = new_states + embed
|
|
c_states.append(new_states)
|
|
|
|
return c_states
|
|
|
|
def single_step(self, index, state, signals):
|
|
""" carry out single step for inference
|
|
|
|
Parameters:
|
|
-----------
|
|
index : int
|
|
position in subconditioning batch
|
|
|
|
state : torch.tensor
|
|
state to sub-condition
|
|
|
|
signals : torch.tensor
|
|
signals for subconditioning, all but the last dimensions
|
|
must match those of state
|
|
|
|
Returns:
|
|
c_state : torch.tensor
|
|
subconditioned state
|
|
"""
|
|
|
|
if index == 0:
|
|
c_state = state
|
|
else:
|
|
embed_signals = self.embeddings[index](signals)
|
|
c = torch.sum(embed_signals, dim=-2)
|
|
c_state = state + c
|
|
|
|
return c_state
|
|
|
|
def get_output_dim(self, index):
|
|
return self.state_size
|
|
|
|
def get_average_flops_per_step(self):
|
|
s = self.number_of_subsamples
|
|
flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
|
|
return flops
|
|
|
|
|
|
class ConcatenativeSubconditioner(Subconditioner):
|
|
def __init__(self,
|
|
number_of_subsamples,
|
|
pcm_embedding_size,
|
|
state_size,
|
|
pcm_levels,
|
|
number_of_signals,
|
|
recurrent=True,
|
|
**kwargs):
|
|
""" subconditioning by concatenation """
|
|
|
|
super(ConcatenativeSubconditioner, self).__init__()
|
|
|
|
self.number_of_subsamples = number_of_subsamples
|
|
self.pcm_embedding_size = pcm_embedding_size
|
|
self.state_size = state_size
|
|
self.pcm_levels = pcm_levels
|
|
self.number_of_signals = number_of_signals
|
|
self.recurrent = recurrent
|
|
|
|
self.embeddings = []
|
|
start_index = 0
|
|
if self.recurrent:
|
|
start_index = 1
|
|
self.embeddings.append(None)
|
|
|
|
for i in range(start_index, self.number_of_subsamples):
|
|
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
|
self.add_module('pcm_embedding_' + str(i), embedding)
|
|
self.embeddings.append(embedding)
|
|
|
|
def forward(self, states, signals):
|
|
""" creates list of subconditioned states
|
|
|
|
Parameters:
|
|
-----------
|
|
states : torch.tensor
|
|
states of shape (batch, seq_length // s, state_size)
|
|
signals : torch.tensor
|
|
signals of shape (batch, seq_length, number_of_signals)
|
|
|
|
Returns:
|
|
--------
|
|
c_states : list of torch.tensor
|
|
list of s subconditioned states
|
|
"""
|
|
s = self.number_of_subsamples
|
|
|
|
if self.recurrent:
|
|
c_states = [states]
|
|
start = 1
|
|
else:
|
|
c_states = []
|
|
start = 0
|
|
|
|
new_states = states
|
|
for i in range(start, self.number_of_subsamples):
|
|
embed = self.embeddings[i](signals[:, i::s])
|
|
# reduce signal dimension
|
|
embed = torch.flatten(embed, -2)
|
|
|
|
if self.recurrent:
|
|
new_states = torch.cat((new_states, embed), dim=-1)
|
|
else:
|
|
new_states = torch.cat((states, embed), dim=-1)
|
|
|
|
c_states.append(new_states)
|
|
|
|
return c_states
|
|
|
|
def single_step(self, index, state, signals):
|
|
""" carry out single step for inference
|
|
|
|
Parameters:
|
|
-----------
|
|
index : int
|
|
position in subconditioning batch
|
|
|
|
state : torch.tensor
|
|
state to sub-condition
|
|
|
|
signals : torch.tensor
|
|
signals for subconditioning, all but the last dimensions
|
|
must match those of state
|
|
|
|
Returns:
|
|
c_state : torch.tensor
|
|
subconditioned state
|
|
"""
|
|
|
|
if index == 0 and self.recurrent:
|
|
c_state = state
|
|
else:
|
|
embed_signals = self.embeddings[index](signals)
|
|
c = torch.flatten(embed_signals, -2)
|
|
if not self.recurrent and index > 0:
|
|
# overwrite previous conditioning vector
|
|
c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
|
|
else:
|
|
c_state = torch.cat((state, c), dim=-1)
|
|
return c_state
|
|
|
|
return c_state
|
|
|
|
def get_average_flops_per_step(self):
|
|
return 0
|
|
|
|
def get_output_dim(self, index):
|
|
if self.recurrent:
|
|
return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
|
|
else:
|
|
return self.state_size + self.pcm_embedding_size * self.number_of_signals
|
|
|
|
class ModulativeSubconditioner(Subconditioner):
|
|
def __init__(self,
|
|
number_of_subsamples,
|
|
pcm_embedding_size,
|
|
state_size,
|
|
pcm_levels,
|
|
number_of_signals,
|
|
state_recurrent=False,
|
|
**kwargs):
|
|
""" subconditioning by modulation """
|
|
|
|
super(ModulativeSubconditioner, self).__init__()
|
|
|
|
self.number_of_subsamples = number_of_subsamples
|
|
self.pcm_embedding_size = pcm_embedding_size
|
|
self.state_size = state_size
|
|
self.pcm_levels = pcm_levels
|
|
self.number_of_signals = number_of_signals
|
|
self.state_recurrent = state_recurrent
|
|
|
|
self.hidden_size = self.pcm_embedding_size * self.number_of_signals
|
|
|
|
if self.state_recurrent:
|
|
self.hidden_size += self.pcm_embedding_size
|
|
self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
|
|
|
|
self.embeddings = [None]
|
|
self.alphas = [None]
|
|
self.betas = [None]
|
|
|
|
for i in range(1, self.number_of_subsamples):
|
|
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
|
self.add_module('pcm_embedding_' + str(i), embedding)
|
|
self.embeddings.append(embedding)
|
|
|
|
self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
|
|
self.add_module('alpha_dense_' + str(i), self.alphas[-1])
|
|
|
|
self.betas.append(nn.Linear(self.hidden_size, self.state_size))
|
|
self.add_module('beta_dense_' + str(i), self.betas[-1])
|
|
|
|
|
|
|
|
def forward(self, states, signals):
|
|
""" creates list of subconditioned states
|
|
|
|
Parameters:
|
|
-----------
|
|
states : torch.tensor
|
|
states of shape (batch, seq_length // s, state_size)
|
|
signals : torch.tensor
|
|
signals of shape (batch, seq_length, number_of_signals)
|
|
|
|
Returns:
|
|
--------
|
|
c_states : list of torch.tensor
|
|
list of s subconditioned states
|
|
"""
|
|
s = self.number_of_subsamples
|
|
|
|
c_states = [states]
|
|
new_states = states
|
|
for i in range(1, self.number_of_subsamples):
|
|
embed = self.embeddings[i](signals[:, i::s])
|
|
# reduce signal dimension
|
|
embed = torch.flatten(embed, -2)
|
|
|
|
if self.state_recurrent:
|
|
comp_states = self.state_transform(new_states)
|
|
embed = torch.cat((embed, comp_states), dim=-1)
|
|
|
|
alpha = torch.tanh(self.alphas[i](embed))
|
|
beta = torch.tanh(self.betas[i](embed))
|
|
|
|
# new state obtained by modulating previous state
|
|
new_states = torch.tanh((1 + alpha) * new_states + beta)
|
|
|
|
c_states.append(new_states)
|
|
|
|
return c_states
|
|
|
|
def single_step(self, index, state, signals):
|
|
""" carry out single step for inference
|
|
|
|
Parameters:
|
|
-----------
|
|
index : int
|
|
position in subconditioning batch
|
|
|
|
state : torch.tensor
|
|
state to sub-condition
|
|
|
|
signals : torch.tensor
|
|
signals for subconditioning, all but the last dimensions
|
|
must match those of state
|
|
|
|
Returns:
|
|
c_state : torch.tensor
|
|
subconditioned state
|
|
"""
|
|
|
|
if index == 0:
|
|
c_state = state
|
|
else:
|
|
embed_signals = self.embeddings[index](signals)
|
|
c = torch.flatten(embed_signals, -2)
|
|
if self.state_recurrent:
|
|
r_state = self.state_transform(state)
|
|
c = torch.cat((c, r_state), dim=-1)
|
|
alpha = torch.tanh(self.alphas[index](c))
|
|
beta = torch.tanh(self.betas[index](c))
|
|
c_state = torch.tanh((1 + alpha) * state + beta)
|
|
return c_state
|
|
|
|
return c_state
|
|
|
|
def get_output_dim(self, index):
|
|
return self.state_size
|
|
|
|
def get_average_flops_per_step(self):
|
|
s = self.number_of_subsamples
|
|
|
|
# estimate activation by 10 flops
|
|
# c_state = torch.tanh((1 + alpha) * state + beta)
|
|
flops = 13 * self.state_size
|
|
|
|
# hidden size
|
|
hidden_size = self.number_of_signals * self.pcm_embedding_size
|
|
if self.state_recurrent:
|
|
hidden_size += self.pcm_embedding_size
|
|
|
|
# counting 2 * A * B flops for Linear(A, B)
|
|
# alpha = torch.tanh(self.alphas[index](c))
|
|
# beta = torch.tanh(self.betas[index](c))
|
|
flops += 4 * hidden_size * self.state_size + 20 * self.state_size
|
|
|
|
# r_state = self.state_transform(state)
|
|
if self.state_recurrent:
|
|
flops += 2 * self.state_size * self.pcm_embedding_size
|
|
|
|
# average over steps
|
|
flops *= (s - 1) / s
|
|
|
|
return flops
|
|
|
|
class ComparitiveSubconditioner(Subconditioner):
|
|
def __init__(self,
|
|
number_of_subsamples,
|
|
pcm_embedding_size,
|
|
state_size,
|
|
pcm_levels,
|
|
number_of_signals,
|
|
error_index=-1,
|
|
apply_gate=True,
|
|
normalize=False):
|
|
""" subconditioning by comparison """
|
|
|
|
super(ComparitiveSubconditioner, self).__init__()
|
|
|
|
self.comparison_size = self.pcm_embedding_size
|
|
self.error_position = error_index
|
|
self.apply_gate = apply_gate
|
|
self.normalize = normalize
|
|
|
|
self.state_transform = nn.Linear(self.state_size, self.comparison_size)
|
|
|
|
self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
|
|
self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
|
|
|
|
if self.apply_gate:
|
|
self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
|
|
|
|
# embeddings and state transforms
|
|
self.embeddings = [None]
|
|
self.alpha_denses = [None]
|
|
self.beta_denses = [None]
|
|
self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
|
|
self.add_module('state_transform_0', self.state_transforms[0])
|
|
|
|
for i in range(1, self.number_of_subsamples):
|
|
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
|
self.add_module('pcm_embedding_' + str(i), embedding)
|
|
self.embeddings.append(embedding)
|
|
|
|
state_transform = nn.Linear(self.state_size, self.comparison_size)
|
|
self.add_module('state_transform_' + str(i), state_transform)
|
|
self.state_transforms.append(state_transform)
|
|
|
|
self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
|
|
self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
|
|
|
|
self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
|
|
self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
|
|
|
|
def forward(self, states, signals):
|
|
s = self.number_of_subsamples
|
|
|
|
c_states = [states]
|
|
new_states = states
|
|
for i in range(1, self.number_of_subsamples):
|
|
embed = self.embeddings[i](signals[:, i::s])
|
|
# reduce signal dimension
|
|
embed = torch.flatten(embed, -2)
|
|
|
|
comp_states = self.state_transforms[i](new_states)
|
|
|
|
alpha = torch.tanh(self.alpha_dense(embed))
|
|
beta = torch.tanh(self.beta_dense(embed))
|
|
|
|
# new state obtained by modulating previous state
|
|
new_states = torch.tanh((1 + alpha) * comp_states + beta)
|
|
|
|
c_states.append(new_states)
|
|
|
|
return c_states
|