Using sparse GRUs in DRED decoder
Saves ~270 kB of weights in the decoder
This commit is contained in:
parent
58923f61c2
commit
b0620c0bf9
7 changed files with 105 additions and 19 deletions
|
@ -225,10 +225,15 @@ f"""
|
|||
|
||||
# decoder
|
||||
decoder_dense_layers = [
|
||||
('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False),
|
||||
('core_decoder.module.output' , 'dec_output', 'LINEAR', True),
|
||||
('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False),
|
||||
('core_decoder.module.glu1.gate' , 'dec_glu1', 'TANH', True),
|
||||
('core_decoder.module.glu2.gate' , 'dec_glu2', 'TANH', True),
|
||||
('core_decoder.module.glu3.gate' , 'dec_glu3', 'TANH', True),
|
||||
('core_decoder.module.glu4.gate' , 'dec_glu4', 'TANH', True),
|
||||
('core_decoder.module.glu5.gate' , 'dec_glu5', 'TANH', True),
|
||||
('core_decoder.module.output' , 'dec_output', 'LINEAR', True),
|
||||
('core_decoder.module.hidden_init' , 'dec_hidden_init', 'TANH', False),
|
||||
('core_decoder.module.gru_init' , 'dec_gru_init', 'TANH', True),
|
||||
('core_decoder.module.gru_init' , 'dec_gru_init','TANH', True),
|
||||
]
|
||||
|
||||
for name, export_name, _, quantize in decoder_dense_layers:
|
||||
|
@ -338,6 +343,13 @@ if __name__ == "__main__":
|
|||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
model.apply(_remove_weight_norm)
|
||||
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(f"error: missing keys in state dict")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue