From 01baf1a0fceccdc7ea762ef883f43c78e64102b5 Mon Sep 17 00:00:00 2001 From: jbuethe Date: Tue, 27 Sep 2022 16:31:04 +0000 Subject: [PATCH] added missing dead-zone to encode_rdovae.py --- dnn/training_tf2/encode_rdovae.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/dnn/training_tf2/encode_rdovae.py b/dnn/training_tf2/encode_rdovae.py index 00a44305..1d76a0ff 100644 --- a/dnn/training_tf2/encode_rdovae.py +++ b/dnn/training_tf2/encode_rdovae.py @@ -48,6 +48,8 @@ args = parser.parse_args() import importlib rdovae = importlib.import_module(args.model) +from rdovae import apply_dead_zone + import sys import numpy as np from tensorflow.keras.optimizers import Adam @@ -105,10 +107,15 @@ print(gru_state_dec.shape) features.astype('float32').tofile(args.output + "-input.f32") #quant_out.astype('float32').tofile(args.output + "-enc_dec.f32") +nbits=80 +dead_zone = tf.math.softplus(quant_embed_dec[:, :, nbits : 2 * nbits]) +symbols = apply_dead_zone([bits, dead_zone]).numpy() np.round(bits).astype('int16').tofile(args.output + "-bits.s16") quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32") gru_state_dec = gru_state_dec[:,-1,:] dec_out = decoder([bits[:,1::2,:], quant_embed_dec[:,1::2,:], gru_state_dec]) +print(dec_out.shape) + dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32")