diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py index c766d0ab..5efa7e70 100644 --- a/dnn/torch/osce/models/silk_feature_net_pl.py +++ b/dnn/torch/osce/models/silk_feature_net_pl.py @@ -66,18 +66,17 @@ class SilkFeatureNetPL(nn.Module): self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1)) self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2)) self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4)) - gru_input_dim = num_channels + self.repeat_upsamp_dim if self.repeat_upsamp else num_channels - self.gru = norm(norm(nn.GRU(gru_input_dim, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0') + self.gru = norm(norm(nn.GRU(num_channels, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0') if softquant: self.conv2 = soft_quant(self.conv2) - if not self.repeat_upsamp: self.tconv = soft_quant(self.tconv) + self.tconv = soft_quant(self.tconv) self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0']) if sparsify: mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4])) - if not self.repeat_upsamp: mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4])) + mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4])) mark_for_sparsification( self.gru, {