added missing dead-zone to encode_rdovae.py

This commit is contained in:
jbuethe 2022-09-27 16:31:04 +00:00
parent be42c3b514
commit 01baf1a0fc

View file

@ -48,6 +48,8 @@ args = parser.parse_args()
import importlib import importlib
rdovae = importlib.import_module(args.model) rdovae = importlib.import_module(args.model)
from rdovae import apply_dead_zone
import sys import sys
import numpy as np import numpy as np
from tensorflow.keras.optimizers import Adam from tensorflow.keras.optimizers import Adam
@ -105,10 +107,15 @@ print(gru_state_dec.shape)
features.astype('float32').tofile(args.output + "-input.f32") features.astype('float32').tofile(args.output + "-input.f32")
#quant_out.astype('float32').tofile(args.output + "-enc_dec.f32") #quant_out.astype('float32').tofile(args.output + "-enc_dec.f32")
nbits=80
dead_zone = tf.math.softplus(quant_embed_dec[:, :, nbits : 2 * nbits])
symbols = apply_dead_zone([bits, dead_zone]).numpy()
np.round(bits).astype('int16').tofile(args.output + "-bits.s16") np.round(bits).astype('int16').tofile(args.output + "-bits.s16")
quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32") quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32")
gru_state_dec = gru_state_dec[:,-1,:] gru_state_dec = gru_state_dec[:,-1,:]
dec_out = decoder([bits[:,1::2,:], quant_embed_dec[:,1::2,:], gru_state_dec]) dec_out = decoder([bits[:,1::2,:], quant_embed_dec[:,1::2,:], gru_state_dec])
print(dec_out.shape)
dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32") dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32")