added LPCNet torch implementation

Signed-off-by: Jan Buethe <jbuethe@amazon.de>
This commit is contained in:
Jan Buethe 2023-09-05 12:29:38 +02:00
parent 90a171c1c2
commit 35ee397e06
No known key found for this signature in database
GPG key ID: 9E32027A35B36314
38 changed files with 3200 additions and 0 deletions

View file

@ -0,0 +1,158 @@
import torch
from .common import sparsify_matrix
class GRUSparsifier:
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
should be kept.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> gru = torch.nn.GRU(10, 20)
>>> sparsify_dict = {
... 'W_ir' : (0.5, [2, 2], False),
... 'W_iz' : (0.6, [2, 2], False),
... 'W_in' : (0.7, [2, 2], False),
... 'W_hr' : (0.1, [4, 4], True),
... 'W_hz' : (0.2, [4, 4], True),
... 'W_hn' : (0.3, [4, 4], True),
... }
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
# just copying parameters...
self.start = start
self.stop = stop
self.interval = interval
self.exponent = exponent
self.task_list = task_list
# ... and setting counter to 0
self.step_counter = 0
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
def step(self, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
# compute current interpolation factor
self.step_counter += 1
if self.step_counter < self.start:
return
elif self.step_counter < self.stop:
# update only every self.interval-th interval
if self.step_counter % self.interval:
return
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
else:
alpha = 0
with torch.no_grad():
for gru, params in self.task_list:
hidden_size = gru.hidden_size
# input weights
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
if key in params:
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
density, # density
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
print(f"sparsification mask {key} changed for gru {gru}")
self.last_masks[key] = new_mask
# recurrent weights
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
if key in params:
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
density,
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
print(f"sparsification mask {key} changed for gru {gru}")
self.last_masks[key] = new_mask
if __name__ == "__main__":
print("Testing sparsifier")
gru = torch.nn.GRU(10, 20)
sparsify_dict = {
'W_ir' : (0.5, [2, 2], False),
'W_iz' : (0.6, [2, 2], False),
'W_in' : (0.7, [2, 2], False),
'W_hr' : (0.1, [4, 4], True),
'W_hz' : (0.2, [4, 4], True),
'W_hn' : (0.3, [4, 4], True),
}
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
for i in range(100):
sparsifier.step(verbose=True)