Updated LACE and NoLACE models to version 2
This commit is contained in:
parent
4f311a1ad4
commit
299e38cab7
57 changed files with 4793 additions and 109 deletions
|
@ -28,6 +28,7 @@
|
|||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
|
||||
def count_parameters(model, verbose=False):
|
||||
total = 0
|
||||
|
@ -41,7 +42,17 @@ def count_parameters(model, verbose=False):
|
|||
|
||||
return total
|
||||
|
||||
def count_nonzero_parameters(model, verbose=False):
|
||||
total = 0
|
||||
for name, p in model.named_parameters():
|
||||
count = torch.count_nonzero(p).item()
|
||||
|
||||
if verbose:
|
||||
print(f"{name}: {count} non-zero parameters")
|
||||
|
||||
total += count
|
||||
|
||||
return total
|
||||
def retain_grads(module):
|
||||
for p in module.parameters():
|
||||
if p.requires_grad:
|
||||
|
@ -62,4 +73,23 @@ def create_weights(s_real, s_gen, alpha):
|
|||
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
|
||||
weights.append(weight)
|
||||
|
||||
return weights
|
||||
return weights
|
||||
|
||||
|
||||
def _get_candidates(module: torch.nn.Module):
|
||||
candidates = []
|
||||
for key in module.__dict__.keys():
|
||||
if hasattr(module, key + '_v'):
|
||||
candidates.append(key)
|
||||
return candidates
|
||||
|
||||
def remove_all_weight_norm(model : torch.nn.Module, verbose=False):
|
||||
for name, m in model.named_modules():
|
||||
candidates = _get_candidates(m)
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
remove_weight_norm(m, name=candidate)
|
||||
if verbose: print(f'removed weight norm on weight {name}.{candidate}')
|
||||
except:
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue