added LPCNet torch implementation
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
This commit is contained in:
parent
90a171c1c2
commit
35ee397e06
38 changed files with 3200 additions and 0 deletions
158
dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py
Normal file
158
dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
import torch
|
||||
|
||||
from .common import sparsify_matrix
|
||||
|
||||
|
||||
class GRUSparsifier:
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
|
||||
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
|
||||
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
|
||||
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
|
||||
should be kept.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> gru = torch.nn.GRU(10, 20)
|
||||
>>> sparsify_dict = {
|
||||
... 'W_ir' : (0.5, [2, 2], False),
|
||||
... 'W_iz' : (0.6, [2, 2], False),
|
||||
... 'W_in' : (0.7, [2, 2], False),
|
||||
... 'W_hr' : (0.1, [4, 4], True),
|
||||
... 'W_hz' : (0.2, [4, 4], True),
|
||||
... 'W_hn' : (0.3, [4, 4], True),
|
||||
... }
|
||||
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
# just copying parameters...
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.interval = interval
|
||||
self.exponent = exponent
|
||||
self.task_list = task_list
|
||||
|
||||
# ... and setting counter to 0
|
||||
self.step_counter = 0
|
||||
|
||||
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
|
||||
|
||||
def step(self, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
# compute current interpolation factor
|
||||
self.step_counter += 1
|
||||
|
||||
if self.step_counter < self.start:
|
||||
return
|
||||
elif self.step_counter < self.stop:
|
||||
# update only every self.interval-th interval
|
||||
if self.step_counter % self.interval:
|
||||
return
|
||||
|
||||
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
|
||||
else:
|
||||
alpha = 0
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
for gru, params in self.task_list:
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
# input weights
|
||||
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
|
||||
if key in params:
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
|
||||
gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density, # density
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
|
||||
print(f"sparsification mask {key} changed for gru {gru}")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
# recurrent weights
|
||||
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
|
||||
if key in params:
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density,
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
|
||||
print(f"sparsification mask {key} changed for gru {gru}")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
gru = torch.nn.GRU(10, 20)
|
||||
sparsify_dict = {
|
||||
'W_ir' : (0.5, [2, 2], False),
|
||||
'W_iz' : (0.6, [2, 2], False),
|
||||
'W_in' : (0.7, [2, 2], False),
|
||||
'W_hr' : (0.1, [4, 4], True),
|
||||
'W_hz' : (0.2, [4, 4], True),
|
||||
'W_hn' : (0.3, [4, 4], True),
|
||||
}
|
||||
|
||||
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
Loading…
Add table
Add a link
Reference in a new issue