fixed decoder bug (non-quantized input)

This commit is contained in:
Jan Buethe 2022-09-29 21:13:30 +02:00
parent 589e674116
commit 97ffa94d5c

View file

@ -121,7 +121,7 @@ quant_gru_state_dec = pvq_quantize(gru_state_dec, 30)
# rate estimate
hard_distr_embed = tf.math.sigmoid(quant_embed_dec[:, :, 4 * nsymbols : ]).numpy()
rate_input = np.concatenate((symbols, hard_distr_embed, enc_lambda), axis=-1)
rate_input = np.concatenate((qsymbols, hard_distr_embed, enc_lambda), axis=-1)
rates = sq_rate_metric(None, rate_input, reduce=False).numpy()
# run decoder
@ -133,7 +133,7 @@ packet_sizes = []
for i in range(offset, num_frames):
print(f"processing frame {i - offset}...")
features = decoder.predict([symbols[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_embed_dec[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_gru_state_dec[:, i, :]])
features = decoder.predict([qsymbols[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_embed_dec[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_gru_state_dec[:, i, :]])
packets.append(features)
packet_size = 8 * int((np.sum(rates[:, i - 2 * input_length + 2 : i + 1 : 2]) + 7) / 8) + 64
packet_sizes.append(packet_size)