bugfix in SilkFeatureNetPL

This commit is contained in:
Jan Buethe 2024-01-22 15:12:52 +01:00
parent 5f8201c71e
commit ec04a94eb2
No known key found for this signature in database
GPG key ID: 9E32027A35B36314

View file

@ -66,18 +66,17 @@ class SilkFeatureNetPL(nn.Module):
self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1))
self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2))
self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
gru_input_dim = num_channels + self.repeat_upsamp_dim if self.repeat_upsamp else num_channels
self.gru = norm(norm(nn.GRU(gru_input_dim, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
self.gru = norm(norm(nn.GRU(num_channels, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
if softquant:
self.conv2 = soft_quant(self.conv2)
if not self.repeat_upsamp: self.tconv = soft_quant(self.tconv)
self.tconv = soft_quant(self.tconv)
self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
if sparsify:
mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4]))
if not self.repeat_upsamp: mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
mark_for_sparsification(
self.gru,
{