Using sparse GRUs in DRED decoder

Saves ~270 kB of weights in the decoder
This commit is contained in:
Jean-Marc Valin 2023-11-15 04:08:50 -05:00
parent 58923f61c2
commit b0620c0bf9
No known key found for this signature in database
GPG key ID: 531A52533318F00A
7 changed files with 105 additions and 19 deletions

View file

@ -225,10 +225,15 @@ f"""
# decoder
decoder_dense_layers = [
('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False),
('core_decoder.module.output' , 'dec_output', 'LINEAR', True),
('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False),
('core_decoder.module.glu1.gate' , 'dec_glu1', 'TANH', True),
('core_decoder.module.glu2.gate' , 'dec_glu2', 'TANH', True),
('core_decoder.module.glu3.gate' , 'dec_glu3', 'TANH', True),
('core_decoder.module.glu4.gate' , 'dec_glu4', 'TANH', True),
('core_decoder.module.glu5.gate' , 'dec_glu5', 'TANH', True),
('core_decoder.module.output' , 'dec_output', 'LINEAR', True),
('core_decoder.module.hidden_init' , 'dec_hidden_init', 'TANH', False),
('core_decoder.module.gru_init' , 'dec_gru_init', 'TANH', True),
('core_decoder.module.gru_init' , 'dec_gru_init','TANH', True),
]
for name, export_name, _, quantize in decoder_dense_layers:
@ -338,6 +343,13 @@ if __name__ == "__main__":
checkpoint = torch.load(args.checkpoint, map_location='cpu')
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
model.apply(_remove_weight_norm)
if len(missing_keys) > 0:
raise ValueError(f"error: missing keys in state dict")