mirror of
https://github.com/xiph/opus.git
synced 2025-05-17 00:48:29 +00:00
92 lines
No EOL
2.8 KiB
Python
92 lines
No EOL
2.8 KiB
Python
import torch
|
|
|
|
def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False):
|
|
""" sparsifies matrix with specified block size
|
|
|
|
Parameters:
|
|
-----------
|
|
matrix : torch.tensor
|
|
matrix to sparsify
|
|
density : int
|
|
target density
|
|
block_size : [int, int]
|
|
block size dimensions
|
|
keep_diagonal : bool
|
|
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
|
|
"""
|
|
|
|
m, n = matrix.shape
|
|
m1, n1 = block_size
|
|
|
|
if m % m1 or n % n1:
|
|
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
|
|
|
|
# extract diagonal if keep_diagonal = True
|
|
if keep_diagonal:
|
|
if m != n:
|
|
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
|
|
|
|
to_spare = torch.diag(torch.diag(matrix))
|
|
matrix = matrix - to_spare
|
|
else:
|
|
to_spare = torch.zeros_like(matrix)
|
|
|
|
# calculate energy in sub-blocks
|
|
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
|
|
x = x ** 2
|
|
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
|
|
|
|
number_of_blocks = (m * n) // (m1 * n1)
|
|
number_of_survivors = round(number_of_blocks * density)
|
|
|
|
# masking threshold
|
|
if number_of_survivors == 0:
|
|
threshold = 0
|
|
else:
|
|
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
|
|
|
|
# create mask
|
|
mask = torch.ones_like(block_energies)
|
|
mask[block_energies < threshold] = 0
|
|
mask = torch.repeat_interleave(mask, m1, dim=0)
|
|
mask = torch.repeat_interleave(mask, n1, dim=1)
|
|
|
|
# perform masking
|
|
masked_matrix = mask * matrix + to_spare
|
|
|
|
if return_mask:
|
|
return masked_matrix, mask
|
|
else:
|
|
return masked_matrix
|
|
|
|
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
|
|
input_size = gru.input_size
|
|
hidden_size = gru.hidden_size
|
|
flops = 0
|
|
|
|
input_density = (
|
|
sparsification_dict.get('W_ir', [1])[0]
|
|
+ sparsification_dict.get('W_in', [1])[0]
|
|
+ sparsification_dict.get('W_iz', [1])[0]
|
|
) / 3
|
|
|
|
recurrent_density = (
|
|
sparsification_dict.get('W_hr', [1])[0]
|
|
+ sparsification_dict.get('W_hn', [1])[0]
|
|
+ sparsification_dict.get('W_hz', [1])[0]
|
|
) / 3
|
|
|
|
# input matrix vector multiplications
|
|
if not drop_input:
|
|
flops += 2 * 3 * input_size * hidden_size * input_density
|
|
|
|
# recurrent matrix vector multiplications
|
|
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
|
|
|
|
# biases
|
|
flops += 6 * hidden_size
|
|
|
|
# activations estimated by 10 flops per activation
|
|
flops += 30 * hidden_size
|
|
|
|
return flops |