Updated LACE and NoLACE models to version 2

This commit is contained in:
Jan Buethe 2023-12-18 12:19:55 +01:00
parent 4f311a1ad4
commit 299e38cab7
No known key found for this signature in database
GPG key ID: 9E32027A35B36314
57 changed files with 4793 additions and 109 deletions

View file

@ -28,6 +28,7 @@
"""
import torch
from torch.nn.utils import remove_weight_norm
def count_parameters(model, verbose=False):
total = 0
@ -41,7 +42,17 @@ def count_parameters(model, verbose=False):
return total
def count_nonzero_parameters(model, verbose=False):
total = 0
for name, p in model.named_parameters():
count = torch.count_nonzero(p).item()
if verbose:
print(f"{name}: {count} non-zero parameters")
total += count
return total
def retain_grads(module):
for p in module.parameters():
if p.requires_grad:
@ -62,4 +73,23 @@ def create_weights(s_real, s_gen, alpha):
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
weights.append(weight)
return weights
return weights
def _get_candidates(module: torch.nn.Module):
candidates = []
for key in module.__dict__.keys():
if hasattr(module, key + '_v'):
candidates.append(key)
return candidates
def remove_all_weight_norm(model : torch.nn.Module, verbose=False):
for name, m in model.named_modules():
candidates = _get_candidates(m)
for candidate in candidates:
try:
remove_weight_norm(m, name=candidate)
if verbose: print(f'removed weight norm on weight {name}.{candidate}')
except:
pass